@@ -126,6 +126,9 @@ def scale(X, axis=0, with_mean=True, with_std=True, copy=True):
126126
127127 To avoid memory copy the caller should pass a CSC matrix.
128128
129+ NaNs are treated as missing values: disregarded to compute the statistics,
130+ and maintained during the data transformation.
131+
129132 For a comparison of the different scalers, transformers, and normalizers,
130133 see :ref:`examples/preprocessing/plot_all_scaling.py
131134 <sphx_glr_auto_examples_preprocessing_plot_all_scaling.py>`.
@@ -138,7 +141,7 @@ def scale(X, axis=0, with_mean=True, with_std=True, copy=True):
138141 """ # noqa
139142 X = check_array (X , accept_sparse = 'csc' , copy = copy , ensure_2d = False ,
140143 warn_on_dtype = True , estimator = 'the scale function' ,
141- dtype = FLOAT_DTYPES )
144+ dtype = FLOAT_DTYPES , force_all_finite = 'allow-nan' )
142145 if sparse .issparse (X ):
143146 if with_mean :
144147 raise ValueError (
@@ -154,15 +157,15 @@ def scale(X, axis=0, with_mean=True, with_std=True, copy=True):
154157 else :
155158 X = np .asarray (X )
156159 if with_mean :
157- mean_ = np .mean (X , axis )
160+ mean_ = np .nanmean (X , axis )
158161 if with_std :
159- scale_ = np .std (X , axis )
162+ scale_ = np .nanstd (X , axis )
160163 # Xr is a view on the original array that enables easy use of
161164 # broadcasting on the axis in which we are interested in
162165 Xr = np .rollaxis (X , axis )
163166 if with_mean :
164167 Xr -= mean_
165- mean_1 = Xr . mean ( axis = 0 )
168+ mean_1 = np . nanmean ( Xr , axis = 0 )
166169 # Verify that mean_1 is 'close to zero'. If X contains very
167170 # large values, mean_1 can also be very large, due to a lack of
168171 # precision of mean_. In this case, a pre-scaling of the
@@ -179,7 +182,7 @@ def scale(X, axis=0, with_mean=True, with_std=True, copy=True):
179182 scale_ = _handle_zeros_in_scale (scale_ , copy = False )
180183 Xr /= scale_
181184 if with_mean :
182- mean_2 = Xr . mean ( axis = 0 )
185+ mean_2 = np . nanmean ( Xr , axis = 0 )
183186 # If mean_2 is not 'close to zero', it comes from the fact that
184187 # scale_ is very small so that mean_2 = mean_1/scale_ > 0, even
185188 # if mean_1 was close to zero. The problem is thus essentially
@@ -520,27 +523,31 @@ class StandardScaler(BaseEstimator, TransformerMixin):
520523
521524 Attributes
522525 ----------
523- scale_ : ndarray, shape (n_features,)
524- Per feature relative scaling of the data.
526+ scale_ : ndarray or None, shape (n_features,)
527+ Per feature relative scaling of the data. Equal to ``None`` when
528+ ``with_std=False``.
525529
526530 .. versionadded:: 0.17
527531 *scale_*
528532
529- mean_ : array of floats with shape [ n_features]
533+ mean_ : ndarray or None, shape ( n_features,)
530534 The mean value for each feature in the training set.
535+ Equal to ``None`` when ``with_mean=False``.
531536
532- var_ : array of floats with shape [ n_features]
537+ var_ : ndarray or None, shape ( n_features,)
533538 The variance for each feature in the training set. Used to compute
534- `scale_`
539+ `scale_`. Equal to ``None`` when ``with_std=False``.
535540
536- n_samples_seen_ : int
537- The number of samples processed by the estimator. Will be reset on
538- new calls to fit, but increments across ``partial_fit`` calls.
541+ n_samples_seen_ : int or array, shape (n_features,)
542+ The number of samples processed by the estimator for each feature.
543+ If there are not missing samples, the ``n_samples_seen`` will be an
544+ integer, otherwise it will be an array.
545+ Will be reset on new calls to fit, but increments across
546+ ``partial_fit`` calls.
539547
540548 Examples
541549 --------
542550 >>> from sklearn.preprocessing import StandardScaler
543- >>>
544551 >>> data = [[0, 0], [0, 0], [1, 1], [1, 1]]
545552 >>> scaler = StandardScaler()
546553 >>> print(scaler.fit(data))
@@ -564,6 +571,9 @@ class StandardScaler(BaseEstimator, TransformerMixin):
564571
565572 Notes
566573 -----
574+ NaNs are treated as missing values: disregarded in fit, and maintained in
575+ transform.
576+
567577 For a comparison of the different scalers, transformers, and normalizers,
568578 see :ref:`examples/preprocessing/plot_all_scaling.py
569579 <sphx_glr_auto_examples_preprocessing_plot_all_scaling.py>`.
@@ -626,22 +636,41 @@ def partial_fit(self, X, y=None):
626636 Ignored
627637 """
628638 X = check_array (X , accept_sparse = ('csr' , 'csc' ), copy = self .copy ,
629- warn_on_dtype = True , estimator = self , dtype = FLOAT_DTYPES )
639+ warn_on_dtype = True , estimator = self , dtype = FLOAT_DTYPES ,
640+ force_all_finite = 'allow-nan' )
630641
631642 # Even in the case of `with_mean=False`, we update the mean anyway
632643 # This is needed for the incremental computation of the var
633644 # See incr_mean_variance_axis and _incremental_mean_variance_axis
634645
646+ # if n_samples_seen_ is an integer (i.e. no missing values), we need to
647+ # transform it to a NumPy array of shape (n_features,) required by
648+ # incr_mean_variance_axis and _incremental_variance_axis
649+ if (hasattr (self , 'n_samples_seen_' ) and
650+ isinstance (self .n_samples_seen_ , (int , np .integer ))):
651+ self .n_samples_seen_ = np .repeat (self .n_samples_seen_ ,
652+ X .shape [1 ]).astype (np .int64 )
653+
635654 if sparse .issparse (X ):
636655 if self .with_mean :
637656 raise ValueError (
638657 "Cannot center sparse matrices: pass `with_mean=False` "
639658 "instead. See docstring for motivation and alternatives." )
659+
660+ sparse_constructor = (sparse .csr_matrix
661+ if X .format == 'csr' else sparse .csc_matrix )
662+ counts_nan = sparse_constructor (
663+ (np .isnan (X .data ), X .indices , X .indptr ),
664+ shape = X .shape ).sum (axis = 0 ).A .ravel ()
665+
666+ if not hasattr (self , 'n_samples_seen_' ):
667+ self .n_samples_seen_ = (X .shape [0 ] -
668+ counts_nan ).astype (np .int64 )
669+
640670 if self .with_std :
641671 # First pass
642- if not hasattr (self , 'n_samples_seen_ ' ):
672+ if not hasattr (self , 'scale_ ' ):
643673 self .mean_ , self .var_ = mean_variance_axis (X , axis = 0 )
644- self .n_samples_seen_ = X .shape [0 ]
645674 # Next passes
646675 else :
647676 self .mean_ , self .var_ , self .n_samples_seen_ = \
@@ -652,15 +681,15 @@ def partial_fit(self, X, y=None):
652681 else :
653682 self .mean_ = None
654683 self .var_ = None
655- if not hasattr (self , 'n_samples_seen_' ):
656- self .n_samples_seen_ = X .shape [0 ]
657- else :
658- self .n_samples_seen_ += X .shape [0 ]
684+ if hasattr (self , 'scale_' ):
685+ self .n_samples_seen_ += X .shape [0 ] - counts_nan
659686 else :
660- # First pass
661687 if not hasattr (self , 'n_samples_seen_' ):
688+ self .n_samples_seen_ = np .zeros (X .shape [1 ], dtype = np .int64 )
689+
690+ # First pass
691+ if not hasattr (self , 'scale_' ):
662692 self .mean_ = .0
663- self .n_samples_seen_ = 0
664693 if self .with_std :
665694 self .var_ = .0
666695 else :
@@ -669,12 +698,18 @@ def partial_fit(self, X, y=None):
669698 if not self .with_mean and not self .with_std :
670699 self .mean_ = None
671700 self .var_ = None
672- self .n_samples_seen_ += X .shape [0 ]
701+ self .n_samples_seen_ += X .shape [0 ] - np . isnan ( X ). sum ( axis = 0 )
673702 else :
674703 self .mean_ , self .var_ , self .n_samples_seen_ = \
675704 _incremental_mean_and_var (X , self .mean_ , self .var_ ,
676705 self .n_samples_seen_ )
677706
707+ # for backward-compatibility, reduce n_samples_seen_ to an integer
708+ # if the number of samples is the same for each feature (i.e. no
709+ # missing values)
710+ if np .ptp (self .n_samples_seen_ ) == 0 :
711+ self .n_samples_seen_ = self .n_samples_seen_ [0 ]
712+
678713 if self .with_std :
679714 self .scale_ = _handle_zeros_in_scale (np .sqrt (self .var_ ))
680715 else :
@@ -704,7 +739,8 @@ def transform(self, X, y='deprecated', copy=None):
704739
705740 copy = copy if copy is not None else self .copy
706741 X = check_array (X , accept_sparse = 'csr' , copy = copy , warn_on_dtype = True ,
707- estimator = self , dtype = FLOAT_DTYPES )
742+ estimator = self , dtype = FLOAT_DTYPES ,
743+ force_all_finite = 'allow-nan' )
708744
709745 if sparse .issparse (X ):
710746 if self .with_mean :
0 commit comments