Skip to content

Commit dc8578a

Browse files
committed
Merge pull request scikit-learn#4828 from untom/maxabs_scaler
[MRG] add MaxAbsScaler
2 parents 87aaabc + ab734b7 commit dc8578a

File tree

5 files changed

+265
-52
lines changed

5 files changed

+265
-52
lines changed

doc/modules/preprocessing.rst

Lines changed: 60 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,10 @@ Scaling features to a range
102102
---------------------------
103103

104104
An alternative standardization is scaling features to
105-
lie between a given minimum and maximum value, often between zero and one.
106-
This can be achieved using :class:`MinMaxScaler`.
105+
lie between a given minimum and maximum value, often between zero and one,
106+
or so that the maximum absolute value of each feature is scaled to unit size.
107+
This can be achieved using :class:`MinMaxScaler` or :class:`MaxAbsScaler`,
108+
respectively.
107109

108110
The motivation to use this scaling include robustness to very small
109111
standard deviations of features and preserving zero entries in sparse data.
@@ -146,6 +148,62 @@ full formula is::
146148

147149
X_scaled = X_std / (max - min) + min
148150

151+
:class:`MaxAbsScaler` works in a very similar fashion, but scales in a way
152+
that the training data lies within the range ``[-1, 1]`` by dividing through
153+
the largest maximum value in each feature. It is meant for data
154+
that is already centered at zero or sparse data.
155+
156+
Here is how to use the toy data from the previous example with this scaler::
157+
158+
>>> X_train = np.array([[ 1., -1., 2.],
159+
... [ 2., 0., 0.],
160+
... [ 0., 1., -1.]])
161+
...
162+
>>> max_abs_scaler = preprocessing.MaxAbsScaler()
163+
>>> X_train_maxabs = max_abs_scaler.fit_transform(X_train)
164+
>>> X_train_maxabs # doctest +NORMALIZE_WHITESPACE^
165+
array([[ 0.5, -1. , 1. ],
166+
[ 1. , 0. , 0. ],
167+
[ 0. , 1. , -0.5]])
168+
>>> X_test = np.array([[ -3., -1., 4.]])
169+
>>> X_test_maxabs = max_abs_scaler.transform(X_test)
170+
>>> X_test_maxabs # doctest: +NORMALIZE_WHITESPACE
171+
array([[-1.5, -1. , 2. ]])
172+
>>> max_abs_scaler.scale_ # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
173+
array([ 2., 1., 2.])
174+
175+
176+
As with :func:`scale`, the module further provides a
177+
convenience function :func:`maxabs_scale` if you don't want to
178+
create an object.
179+
180+
181+
Scaling sparse data
182+
-------------------
183+
Centering sparse data would destroy the sparseness structure in the data, and
184+
thus rarely is a sensible thing to do. However, it can make sense to scale
185+
sparse inputs, especially if features are on different scales.
186+
187+
:class:`MaxAbsScaler` and :func:`maxabs_scale` were specifically designed
188+
for scaling sparse data, and are the recommended way to go about this.
189+
However, :func:`scale` and :class:`StandardScaler` can accept ``scipy.sparse``
190+
matrices as input, as long as ``with_centering=False`` is explicitly passed
191+
to the constructor. Otherwise a ``ValueError`` will be raised as
192+
silently centering would break the sparsity and would often crash the
193+
execution by allocating excessive amounts of memory unintentionally.
194+
:class:`RobustScaler` cannot be fited to sparse inputs, but you can use
195+
the ``transform`` method on sparse inputs.
196+
197+
Note that the scalers accept both Compressed Sparse Rows and Compressed
198+
Sparse Columns format (see ``scipy.sparse.csr_matrix`` and
199+
``scipy.sparse.csc_matrix``). Any other sparse input will be **converted to
200+
the Compressed Sparse Rows representation**. To avoid unnecessary memory
201+
copies, it is recommended to choose the CSR or CSC representation upstream.
202+
203+
Finally, if the centered data is expected to be small enough, explicitly
204+
converting the input to an array using the ``toarray`` method of sparse matrices
205+
is another option.
206+
149207

150208
Scaling data with outliers
151209
--------------------------
@@ -173,23 +231,6 @@ data.
173231
or :class:`sklearn.decomposition.RandomizedPCA` with ``whiten=True``
174232
to further remove the linear correlation across features.
175233

176-
.. topic:: Sparse input
177-
178-
:func:`scale` and :class:`StandardScaler` accept ``scipy.sparse`` matrices
179-
as input **only when with_mean=False is explicitly passed to the
180-
constructor**. Otherwise a ``ValueError`` will be raised as
181-
silently centering would break the sparsity and would often crash the
182-
execution by allocating excessive amounts of memory unintentionally.
183-
184-
If the centered data is expected to be small enough, explicitly convert
185-
the input to an array using the ``toarray`` method of sparse matrices
186-
instead.
187-
188-
For sparse input the data is **converted to the Compressed Sparse Rows
189-
representation** (see ``scipy.sparse.csr_matrix``).
190-
To avoid unnecessary memory copies, it is recommended to choose the CSR
191-
representation upstream.
192-
193234
.. topic:: Scaling target variables in regression
194235

195236
:func:`scale` and :class:`StandardScaler` work out-of-the-box with 1d arrays.

doc/whats_new.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,11 @@ New features
2121
alternative to :class:`preprocessing.StandardScaler` for feature-wise
2222
centering and range normalization that is robust to outliers. By `Thomas Unterthiner`_.
2323

24+
- The new class :class:`preprocessing.MaxAbsScaler` provides an
25+
alternative to :class:`preprocessing.MinMaxScaler` for feature-wise
26+
range normalization when the data is already centered or sparse.
27+
By `Thomas Unterthiner`_.
28+
2429
Enhancements
2530
............
2631

sklearn/preprocessing/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .data import Binarizer
77
from .data import KernelCenterer
88
from .data import MinMaxScaler
9+
from .data import MaxAbsScaler
910
from .data import Normalizer
1011
from .data import RobustScaler
1112
from .data import StandardScaler
@@ -14,6 +15,7 @@
1415
from .data import normalize
1516
from .data import scale
1617
from .data import robust_scale
18+
from .data import maxabs_scale
1719
from .data import OneHotEncoder
1820

1921
from .data import PolynomialFeatures
@@ -33,6 +35,7 @@
3335
'LabelEncoder',
3436
'MultiLabelBinarizer',
3537
'MinMaxScaler',
38+
'MaxAbsScaler',
3639
'Normalizer',
3740
'OneHotEncoder',
3841
'RobustScaler',
@@ -43,5 +46,6 @@
4346
'normalize',
4447
'scale',
4548
'robust_scale',
49+
'maxabs_scale',
4650
'label_binarize',
4751
]

sklearn/preprocessing/data.py

Lines changed: 138 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
'Binarizer',
3333
'KernelCenterer',
3434
'MinMaxScaler',
35+
'MaxAbsScaler',
3536
'Normalizer',
3637
'OneHotEncoder',
3738
'RobustScaler',
@@ -41,6 +42,7 @@
4142
'normalize',
4243
'scale',
4344
'robust_scale',
45+
'maxabs_scale',
4446
]
4547

4648

@@ -59,16 +61,28 @@ def _mean_and_std(X, axis=0, with_mean=True, with_std=True):
5961

6062
if with_std:
6163
std_ = Xr.std(axis=0)
62-
if isinstance(std_, np.ndarray):
63-
std_[std_ == 0.] = 1.0
64-
elif std_ == 0.:
65-
std_ = 1.
64+
std_ = _handle_zeros_in_scale(std_)
6665
else:
6766
std_ = None
6867

6968
return mean_, std_
7069

7170

71+
def _handle_zeros_in_scale(scale):
72+
''' Makes sure that whenever scale is zero, we handle it correctly.
73+
74+
This happens in most scalers when we have constant features.'''
75+
76+
# if we are fitting on 1D arrays, scale might be a scalar
77+
if np.isscalar(scale):
78+
if scale == 0:
79+
scale = 1.
80+
elif isinstance(scale, np.ndarray):
81+
scale[scale == 0.0] = 1.0
82+
scale[~np.isfinite(scale)] = 1.0
83+
return scale
84+
85+
7286
def scale(X, axis=0, with_mean=True, with_std=True, copy=True):
7387
"""Standardize a dataset along any axis
7488
@@ -134,7 +148,7 @@ def scale(X, axis=0, with_mean=True, with_std=True, copy=True):
134148
if copy:
135149
X = X.copy()
136150
_, var = mean_variance_axis(X, axis=0)
137-
var[var == 0.0] = 1.0
151+
var = _handle_zeros_in_scale(var)
138152
inplace_column_scale(X, 1 / np.sqrt(var))
139153
else:
140154
X = np.asarray(X)
@@ -237,11 +251,7 @@ def fit(self, X, y=None):
237251
" than maximum. Got %s." % str(feature_range))
238252
data_min = np.min(X, axis=0)
239253
data_range = np.max(X, axis=0) - data_min
240-
# Do not scale constant features
241-
if isinstance(data_range, np.ndarray):
242-
data_range[data_range == 0.0] = 1.0
243-
elif data_range == 0.:
244-
data_range = 1.
254+
data_range = _handle_zeros_in_scale(data_range)
245255
self.scale_ = (feature_range[1] - feature_range[0]) / data_range
246256
self.min_ = feature_range[0] - data_min * self.scale_
247257
self.data_range = data_range
@@ -366,7 +376,7 @@ def fit(self, X, y=None):
366376
if self.with_std:
367377
var = mean_variance_axis(X, axis=0)[1]
368378
self.std_ = np.sqrt(var)
369-
self.std_[var == 0.0] = 1.0
379+
self.std_ = _handle_zeros_in_scale(self.std_)
370380
else:
371381
self.std_ = None
372382
return self
@@ -437,6 +447,119 @@ def inverse_transform(self, X, copy=None):
437447
return X
438448

439449

450+
class MaxAbsScaler(BaseEstimator, TransformerMixin):
451+
"""Scale each feature by its maximum absolute value.
452+
453+
This estimator scales and translates each feature individually such
454+
that the maximal absolute value of each feature in the
455+
training set will be 1.0. It does not shift/center the data, and
456+
thus does not destroy any sparsity.
457+
458+
This scaler can also be applied to sparse CSR or CSC matrices.
459+
460+
Parameters
461+
----------
462+
copy : boolean, optional, default is True
463+
Set to False to perform inplace scaling and avoid a copy (if the input
464+
is already a numpy array).
465+
466+
Attributes
467+
----------
468+
scale_ : ndarray, shape (n_features,)
469+
Per feature relative scaling of the data.
470+
"""
471+
472+
def __init__(self, copy=True):
473+
self.copy = copy
474+
475+
def fit(self, X, y=None):
476+
"""Compute the minimum and maximum to be used for later scaling.
477+
478+
Parameters
479+
----------
480+
X : array-like, shape [n_samples, n_features]
481+
The data used to compute the per-feature minimum and maximum
482+
used for later scaling along the features axis.
483+
"""
484+
X = check_array(X, accept_sparse=('csr', 'csc'), copy=self.copy,
485+
ensure_2d=False, estimator=self, dtype=FLOAT_DTYPES)
486+
if sparse.issparse(X):
487+
mins, maxs = min_max_axis(X, axis=0)
488+
scales = np.maximum(np.abs(mins), np.abs(maxs))
489+
else:
490+
scales = np.abs(X).max(axis=0)
491+
scales = np.array(scales)
492+
scales = scales.reshape(-1)
493+
self.scale_ = _handle_zeros_in_scale(scales)
494+
return self
495+
496+
def transform(self, X, y=None):
497+
"""Scale the data
498+
499+
Parameters
500+
----------
501+
X : array-like or CSR matrix.
502+
The data that should be scaled.
503+
"""
504+
check_is_fitted(self, 'scale_')
505+
X = check_array(X, accept_sparse=('csr', 'csc'), copy=self.copy,
506+
ensure_2d=False, estimator=self, dtype=FLOAT_DTYPES)
507+
if sparse.issparse(X):
508+
if X.shape[0] == 1:
509+
inplace_row_scale(X, 1.0 / self.scale_)
510+
else:
511+
inplace_column_scale(X, 1.0 / self.scale_)
512+
else:
513+
X /= self.scale_
514+
return X
515+
516+
def inverse_transform(self, X):
517+
"""Scale back the data to the original representation
518+
519+
Parameters
520+
----------
521+
X : array-like or CSR matrix.
522+
The data that should be transformed back.
523+
"""
524+
check_is_fitted(self, 'scale_')
525+
X = check_array(X, accept_sparse=('csr', 'csc'), copy=self.copy,
526+
ensure_2d=False, estimator=self, dtype=FLOAT_DTYPES)
527+
if sparse.issparse(X):
528+
if X.shape[0] == 1:
529+
inplace_row_scale(X, self.scale_)
530+
else:
531+
inplace_column_scale(X, self.scale_)
532+
else:
533+
X *= self.scale_
534+
return X
535+
536+
537+
def maxabs_scale(X, axis=0, copy=True):
538+
"""Scale each feature to the [-1, 1] range without breaking the sparsity.
539+
540+
This estimator scales each feature individually such
541+
that the maximal absolute value of each feature in the
542+
training set will be 1.0.
543+
544+
This scaler can also be applied to sparse CSR or CSC matrices.
545+
546+
Parameters
547+
----------
548+
axis : int (0 by default)
549+
axis used to scale along. If 0, independently scale each feature,
550+
otherwise (if 1) scale each sample.
551+
552+
copy : boolean, optional, default is True
553+
Set to False to perform inplace scaling and avoid a copy (if the input
554+
is already a numpy array).
555+
"""
556+
s = MaxAbsScaler(copy=copy)
557+
if axis == 0:
558+
return s.fit_transform(X)
559+
else:
560+
return s.fit_transform(X.T).T
561+
562+
440563
class RobustScaler(BaseEstimator, TransformerMixin):
441564
"""Scale features using statistics that are robust to outliers.
442565
@@ -507,28 +630,15 @@ def __init__(self, with_centering=True, with_scaling=True, copy=True):
507630

508631
def _check_array(self, X, copy):
509632
"""Makes sure centering is not enabled for sparse matrices."""
510-
X = check_array(X, accept_sparse=('csr', 'csc'), dtype=np.float,
511-
copy=copy, ensure_2d=False)
633+
X = check_array(X, accept_sparse=('csr', 'csc'), copy=self.copy,
634+
ensure_2d=False, estimator=self, dtype=FLOAT_DTYPES)
512635
if sparse.issparse(X):
513636
if self.with_centering:
514637
raise ValueError(
515638
"Cannot center sparse matrices: use `with_centering=False`"
516639
" instead. See docstring for motivation and alternatives.")
517640
return X
518641

519-
def _handle_zeros_in_scale(self, scale):
520-
''' Makes sure that whenever scale is zero, we handle it correctly.
521-
522-
This happens in most scalers when we have constant features.'''
523-
# if we are fitting on 1D arrays, scale might be a scalar
524-
if np.isscalar(scale):
525-
if scale == 0:
526-
scale = 1.
527-
elif isinstance(scale, np.ndarray):
528-
scale[scale == 0.0] = 1.0
529-
scale[~np.isfinite(scale)] = 1.0
530-
return scale
531-
532642
def fit(self, X, y=None):
533643
"""Compute the median and quantiles to be used for scaling.
534644
@@ -548,12 +658,7 @@ def fit(self, X, y=None):
548658
if self.with_scaling:
549659
q = np.percentile(X, (25, 75), axis=0)
550660
self.scale_ = (q[1] - q[0])
551-
if np.isscalar(self.scale_):
552-
if self.scale_ == 0:
553-
self.scale_ = 1.
554-
else:
555-
self.scale_[self.scale_ == 0.0] = 1.0
556-
self.scale_[~np.isfinite(self.scale_)] = 1.0
661+
self.scale_ = _handle_zeros_in_scale(self.scale_)
557662
return self
558663

559664
def transform(self, X, y=None):
@@ -860,7 +965,7 @@ def normalize(X, norm='l2', axis=1, copy=True):
860965
norms = row_norms(X)
861966
elif norm == 'max':
862967
norms = np.max(X, axis=1)
863-
norms[norms == 0.0] = 1.0
968+
norms = _handle_zeros_in_scale(norms)
864969
X /= norms[:, np.newaxis]
865970

866971
if axis == 0:

0 commit comments

Comments
 (0)