Skip to content

Commit 22f718d

Browse files
committed
ENH Added RobustScaler
1 parent d9ecf46 commit 22f718d

File tree

3 files changed

+357
-0
lines changed

3 files changed

+357
-0
lines changed

doc/modules/preprocessing.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,16 @@ full formula is::
145145

146146
X_scaled = X_std / (max - min) + min
147147

148+
149+
Scaling data with outliers
150+
--------------------------
151+
If your data contains many outliers, scaling using the mean and variance
152+
of the data does sometimes not work very well. In these cases, you can use
153+
:func:`robust_scale` and :class:`RobustScaler` as drop-in replacements
154+
instead, which use more robust estimates for the center and range of your
155+
data.
156+
157+
148158
.. topic:: References:
149159

150160
Further discussion on the importance of centering and scaling data is

sklearn/preprocessing/data.py

Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,259 @@ def inverse_transform(self, X, copy=None):
397397
return X
398398

399399

400+
class RobustScaler(BaseEstimator, TransformerMixin):
401+
"""Standardize features by removing the median and scaling to IQR.
402+
403+
Centering and scaling happen independently on each feature (or each
404+
sample, depending on the `axis` argument) by computing the relevant
405+
statistics on the samples in the training set. Median and interquartile
406+
range are then stored to be used on later data using the `transform`
407+
method.
408+
409+
Standardization of a dataset is a common requirement for many
410+
machine learning estimators. Typically this is done by removing the mean
411+
and scaling to unit variance. However, outliers can often influence the
412+
sample mean / variance in a negative way. In such cases, the median and
413+
the interquartile range often give better results.
414+
415+
Parameters
416+
----------
417+
interquartile_scale: float or string in ["normal" (default), ],
418+
The interquartile range is divided by this factor. If
419+
`interquartile_scale` is "normal", the data is scaled so it
420+
approximately reaches unit variance. This converge assumes Gaussian
421+
input data and will need a large number of samples.
422+
423+
with_centering : boolean, True by default
424+
If True, center the data before scaling.
425+
This does not work (and will raise an exception) when attempted on
426+
sparse matrices, because centering them entails building a dense
427+
matrix which in common use cases is likely to be too large to fit in
428+
memory.
429+
430+
with_scaling : boolean, True by default
431+
If True, scale the data to interquartile range.
432+
433+
copy : boolean, optional, default is True
434+
If False, try to avoid a copy and do inplace scaling instead.
435+
This is not guaranteed to always work inplace; e.g. if the data is
436+
not a NumPy array or scipy.sparse CSR matrix, a copy may still be
437+
returned.
438+
439+
Attributes
440+
----------
441+
`center_` : array of floats
442+
The median value for each feature in the training set, unless axis=1,
443+
in which case it contains the median value for each sample
444+
445+
`scale_` : array of floats
446+
The (scaled) interquartile range for each feature in the training set,
447+
unless axis=1, in which case it contains the median value for each
448+
sample.
449+
450+
See also
451+
--------
452+
:class:`sklearn.preprocessing.StandardScaler` to perform centering
453+
and scaling using mean and variance.
454+
455+
:class:`sklearn.decomposition.RandomizedPCA` with `whiten=True`
456+
to further remove the linear correlation across features.
457+
"""
458+
459+
def __init__(self, interquartile_scale="normal", with_centering=True,
460+
with_scaling=True, copy=True):
461+
self.interquartile_scale = interquartile_scale
462+
self.with_centering = with_centering
463+
self.with_scaling = with_scaling
464+
self.copy = copy
465+
466+
def _check_array(self, X, copy):
467+
"""Makes sure centering is not enabled for sparse matrices."""
468+
X = check_array(X, accept_sparse=('csr', 'csc'),
469+
copy=copy, ensure_2d=False)
470+
if warn_if_not_float(X, estimator=self):
471+
X = X.astype(np.float)
472+
if sparse.issparse(X):
473+
if self.with_centering:
474+
raise ValueError(
475+
"Cannot center sparse matrices: use `with_centering=False`"
476+
" instead. See docstring for motivation and alternatives.")
477+
return X
478+
479+
def _handle_zeros_in_scale(self, scale):
480+
''' Makes sure that whenever scale is zero, we handle it correctly.
481+
482+
This happens in most scalers when we have constant features.'''
483+
# if we are fitting on 1D arrays, scale might be a scalar
484+
if np.isscalar(scale):
485+
if scale == 0:
486+
scale = 1.
487+
elif isinstance(scale, np.ndarray):
488+
scale[scale == 0.0] = 1.0
489+
scale[-np.isfinite(scale)] = 1.0
490+
return scale
491+
492+
def fit(self, X, y=None, copy=None):
493+
"""Compute the mean and std to be used for later scaling.
494+
495+
Parameters
496+
----------
497+
X : array-like or CSR matrix with shape [n_samples, n_features]
498+
The data used to compute the mean and standard deviation
499+
used for later scaling along the features axis.
500+
"""
501+
if sparse.issparse(X):
502+
raise TypeError("RobustScaler cannot be fitted on sparse inputs")
503+
504+
if not np.isreal(self.interquartile_scale):
505+
if self.interquartile_scale != "normal":
506+
raise ValueError("Unknown interquartile_scale value.")
507+
else:
508+
iqr_scale = 1.34898
509+
else:
510+
iqr_scale = self.interquartile_scale
511+
512+
if copy is None:
513+
copy = self.copy
514+
515+
self.center_ = None
516+
self.scale_ = None
517+
X = self._check_array(X, copy)
518+
if self.with_centering:
519+
self.center_ = np.median(X, axis=0)
520+
521+
if self.with_scaling:
522+
q = np.percentile(X, (25, 75), axis=0)
523+
self.scale_ = (q[1] - q[0]) / iqr_scale
524+
if np.isscalar(self.scale_):
525+
if self.scale_ == 0:
526+
self.scale_ = 1.
527+
else:
528+
self.scale_[self.scale_ == 0.0] = 1.0
529+
self.scale_[-np.isfinite(self.scale_)] = 1.0
530+
return self
531+
532+
def transform(self, X, y=None, copy=None):
533+
"""Perform standardization by centering and scaling
534+
535+
Parameters
536+
----------
537+
X : array-like or CSR matrix.
538+
The data used to scale along the specified axis.
539+
"""
540+
if copy is None:
541+
copy = self.copy
542+
X = self._check_array(X, copy)
543+
if sparse.issparse(X):
544+
if self.with_scaling:
545+
if X.shape[0] == 1:
546+
inplace_row_scale(X, 1.0 / self.scale_)
547+
elif self.axis == 0:
548+
inplace_column_scale(X, 1.0 / self.scale_)
549+
else:
550+
if copy:
551+
X = X.copy()
552+
if self.with_centering:
553+
X -= self.center_
554+
if self.with_scaling:
555+
X /= self.scale_
556+
return X
557+
558+
def inverse_transform(self, X, copy=None):
559+
"""Scale back the data to the original representation
560+
561+
Parameters
562+
----------
563+
X : array-like or CSR matrix.
564+
The data used to scale along the specified axis.
565+
"""
566+
if self.with_centering:
567+
check_is_fitted(self, 'center_')
568+
if self.with_scaling:
569+
check_is_fitted(self, 'scale_')
570+
if copy is None:
571+
copy = self.copy
572+
X = self._check_array(X, copy)
573+
if sparse.issparse(X):
574+
if self.with_scaling:
575+
if X.shape[0] == 1:
576+
inplace_row_scale(X, self.scale_)
577+
else:
578+
inplace_column_scale(X, self.scale_)
579+
else:
580+
if copy:
581+
X = X.copy()
582+
583+
if self.with_scaling:
584+
X *= self.scale_
585+
if self.with_centering:
586+
X += self.center_
587+
return X
588+
589+
590+
def robust_scale(X, interquartile_scale="normal", axis=0, with_centering=True,
591+
with_scaling=True, copy=True):
592+
"""Standardize a dataset along any axis
593+
594+
Center to the median and component wise scale
595+
according to the interquartile range.
596+
597+
Parameters
598+
----------
599+
X : array-like or CSR matrix.
600+
The data to center and scale.
601+
602+
interquartile_scale: float or string in ["normal" (default), ],
603+
The interquartile range is divided by this factor. If
604+
`interquartile_scale` is "normal", the data is scaled so it
605+
approximately reaches unit variance. This converge assumes Gaussian
606+
input data and will need a large number of samples.
607+
608+
axis : int (0 by default)
609+
axis used to compute the medians and IQR along. If 0,
610+
independently scale each feature, otherwise (if 1) scale
611+
each sample.
612+
613+
with_centering : boolean, True by default
614+
If True, center the data before scaling.
615+
616+
with_scaling : boolean, True by default
617+
If True, scale the data to unit variance (or equivalently,
618+
unit standard deviation).
619+
620+
copy : boolean, optional, default is True
621+
set to False to perform inplace row normalization and avoid a
622+
copy (if the input is already a numpy array or a scipy.sparse
623+
CSR matrix and if axis is 1).
624+
625+
Notes
626+
-----
627+
This implementation will refuse to center scipy.sparse matrices
628+
since it would make them non-sparse and would potentially crash the
629+
program with memory exhaustion problems.
630+
631+
Instead the caller is expected to either set explicitly
632+
`with_centering=False` (in that case, only variance scaling will be
633+
performed on the features of the CSR matrix) or to call `X.toarray()`
634+
if he/she expects the materialized dense array to fit in memory.
635+
636+
To avoid memory copy the caller should pass a CSR matrix.
637+
638+
See also
639+
--------
640+
:class:`sklearn.preprocessing.RobustScaler` to perform centering and
641+
scaling using the ``Transformer`` API (e.g. as part of a preprocessing
642+
:class:`sklearn.pipeline.Pipeline`)
643+
"""
644+
s = RobustScaler(interquartile_scale=interquartile_scale,
645+
with_centering=with_centering, with_scaling=with_scaling,
646+
copy=copy)
647+
if axis == 0:
648+
return s.fit_transform(X)
649+
else:
650+
return s.fit_transform(X.T).T
651+
652+
400653
class PolynomialFeatures(BaseEstimator, TransformerMixin):
401654
"""Generate polynomial and interaction features.
402655

sklearn/preprocessing/tests/test_data.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
from sklearn.preprocessing.data import StandardScaler
2626
from sklearn.preprocessing.data import scale
2727
from sklearn.preprocessing.data import MinMaxScaler
28+
from sklearn.preprocessing.data import RobustScaler
29+
from sklearn.preprocessing.data import robust_scale
2830
from sklearn.preprocessing.data import add_dummy_feature
2931
from sklearn.preprocessing.data import PolynomialFeatures
3032

@@ -455,6 +457,96 @@ def test_scale_function_without_centering():
455457
assert_array_almost_equal(X_csr_scaled_std, X_scaled.std(axis=0))
456458

457459

460+
def test_robust_scaler_2d_arrays():
461+
"""Test robust scaling of 2d array along first axis"""
462+
rng = np.random.RandomState(0)
463+
X = rng.randn(4, 5)
464+
X[:, 0] = 0.0 # first feature is always of zero
465+
466+
scaler = RobustScaler()
467+
X_scaled = scaler.fit(X).transform(X, copy=True)
468+
469+
assert_array_almost_equal(np.median(X_scaled, axis=0), 5 * [0.0])
470+
assert_array_almost_equal(X_scaled.std(axis=0)[0], 0)
471+
472+
473+
def test_robust_scaler_iris():
474+
X = iris.data
475+
scaler = RobustScaler(interquartile_scale=1.0)
476+
X_trans = scaler.fit_transform(X)
477+
assert_array_almost_equal(np.median(X_trans, axis=0), 0)
478+
X_trans_inv = scaler.inverse_transform(X_trans)
479+
assert_array_almost_equal(X, X_trans_inv)
480+
q = np.percentile(X_trans, q=(25, 75), axis=0)
481+
iqr = q[1] - q[0]
482+
assert_array_almost_equal(iqr, 1)
483+
484+
485+
def test_robust_scale_axis1():
486+
X = iris.data
487+
X_trans = robust_scale(X, interquartile_scale=1.0, axis=1)
488+
assert_array_almost_equal(np.median(X_trans, axis=1), 0)
489+
q = np.percentile(X_trans, q=(25, 75), axis=1)
490+
iqr = q[1] - q[0]
491+
assert_array_almost_equal(iqr, 1)
492+
493+
494+
def test_robust_scaler_iqr_scale():
495+
"""Does iqr scaling achieve approximately std= 1 on Gaussian data?"""
496+
rng = np.random.RandomState(42)
497+
X = rng.randn(10000, 4) # need lots of samples
498+
scaler = RobustScaler()
499+
X_trans = scaler.fit_transform(X)
500+
assert_array_almost_equal(X_trans.std(axis=0), 1, decimal=2)
501+
502+
503+
def test_robust_scale_iqr_errors():
504+
"""Check that invalid arguments yield ValueError"""
505+
rng = np.random.RandomState(42)
506+
X = rng.randn(4, 5)
507+
assert_raises(ValueError, RobustScaler(interquartile_scale="foo").fit, X)
508+
# TODO: for some reason assert_raise doesn't test this correctly
509+
did_raise = False
510+
try:
511+
robust_scale(X, interquartile_scale="foo")
512+
except ValueError:
513+
did_raise = True
514+
assert(did_raise)
515+
516+
517+
def test_robust_scaler_zero_variance_features():
518+
"""Check min max scaler on toy data with zero variance features"""
519+
X = [[0., 1., +0.5],
520+
[0., 1., -0.1],
521+
[0., 1., +1.1]]
522+
523+
scaler = RobustScaler(interquartile_scale=1.0)
524+
X_trans = scaler.fit_transform(X)
525+
526+
# NOTE: for such a small sample size, what we expect in the third column
527+
# depends HEAVILY on the method used to calculate quantiles. The values
528+
# here were calculated to fit the quantiles produces by np.percentile
529+
# using numpy 1.9 Calculating quantiles with
530+
# scipy.stats.mstats.scoreatquantile or scipy.stats.mstats.mquantiles
531+
# would yield very different results!
532+
X_expected = [[0., 0., +0.0],
533+
[0., 0., -1.0],
534+
[0., 0., +1.0]]
535+
assert_array_almost_equal(X_trans, X_expected)
536+
X_trans_inv = scaler.inverse_transform(X_trans)
537+
assert_array_almost_equal(X, X_trans_inv)
538+
539+
# make sure new data gets transformed correctly
540+
X_new = [[+0., 2., 0.5],
541+
[-1., 1., 0.0],
542+
[+0., 1., 1.5]]
543+
X_trans_new = scaler.transform(X_new)
544+
X_expected_new = [[+0., 1., +0.],
545+
[-1., 0., -0.83333],
546+
[+0., 0., +1.66667]]
547+
assert_array_almost_equal(X_trans_new, X_expected_new, decimal=3)
548+
549+
458550
def test_warning_scaling_integers():
459551
"""Check warning when scaling integer data"""
460552
X = np.array([[1, 2, 0],
@@ -466,6 +558,8 @@ def test_warning_scaling_integers():
466558
assert_warns_message(UserWarning, w, scale, X)
467559
assert_warns_message(UserWarning, w, StandardScaler().fit, X)
468560
assert_warns_message(UserWarning, w, MinMaxScaler().fit, X)
561+
assert_warns_message(UserWarning, w, robust_scale, X)
562+
assert_warns_message(UserWarning, w, RobustScaler().fit, X)
469563

470564

471565
def test_normalizer_l1():

0 commit comments

Comments
 (0)