@@ -27,7 +27,7 @@ class _BaseEncoder(TransformerMixin, BaseEstimator):
2727
2828 """
2929
30- def _check_X (self , X ):
30+ def _check_X (self , X , force_all_finite = True ):
3131 """
3232 Perform custom check_array:
3333 - convert list of strings to object dtype
@@ -41,17 +41,19 @@ def _check_X(self, X):
4141 """
4242 if not (hasattr (X , 'iloc' ) and getattr (X , 'ndim' , 0 ) == 2 ):
4343 # if not a dataframe, do normal check_array validation
44- X_temp = check_array (X , dtype = None )
44+ X_temp = check_array (X , dtype = None ,
45+ force_all_finite = force_all_finite )
4546 if (not hasattr (X , 'dtype' )
4647 and np .issubdtype (X_temp .dtype , np .str_ )):
47- X = check_array (X , dtype = object )
48+ X = check_array (X , dtype = object ,
49+ force_all_finite = force_all_finite )
4850 else :
4951 X = X_temp
5052 needs_validation = False
5153 else :
5254 # pandas dataframe, do validation later column by column, in order
5355 # to keep the dtype information to be used in the encoder.
54- needs_validation = True
56+ needs_validation = force_all_finite
5557
5658 n_samples , n_features = X .shape
5759 X_columns = []
@@ -71,8 +73,9 @@ def _get_feature(self, X, feature_idx):
7173 # numpy arrays, sparse arrays
7274 return X [:, feature_idx ]
7375
74- def _fit (self , X , handle_unknown = 'error' ):
75- X_list , n_samples , n_features = self ._check_X (X )
76+ def _fit (self , X , handle_unknown = 'error' , force_all_finite = True ):
77+ X_list , n_samples , n_features = self ._check_X (
78+ X , force_all_finite = force_all_finite )
7679
7780 if self .categories != 'auto' :
7881 if len (self .categories ) != n_features :
@@ -88,9 +91,16 @@ def _fit(self, X, handle_unknown='error'):
8891 else :
8992 cats = np .array (self .categories [i ], dtype = Xi .dtype )
9093 if Xi .dtype != object :
91- if not np .all (np .sort (cats ) == cats ):
92- raise ValueError ("Unsorted categories are not "
93- "supported for numerical categories" )
94+ sorted_cats = np .sort (cats )
95+ error_msg = ("Unsorted categories are not "
96+ "supported for numerical categories" )
97+ # if there are nans, nan should be the last element
98+ stop_idx = - 1 if np .isnan (sorted_cats [- 1 ]) else None
99+ if (np .any (sorted_cats [:stop_idx ] != cats [:stop_idx ]) or
100+ (np .isnan (sorted_cats [- 1 ]) and
101+ not np .isnan (sorted_cats [- 1 ]))):
102+ raise ValueError (error_msg )
103+
94104 if handle_unknown == 'error' :
95105 diff = _check_unknown (Xi , cats )
96106 if diff :
@@ -99,8 +109,9 @@ def _fit(self, X, handle_unknown='error'):
99109 raise ValueError (msg )
100110 self .categories_ .append (cats )
101111
102- def _transform (self , X , handle_unknown = 'error' ):
103- X_list , n_samples , n_features = self ._check_X (X )
112+ def _transform (self , X , handle_unknown = 'error' , force_all_finite = True ):
113+ X_list , n_samples , n_features = self ._check_X (
114+ X , force_all_finite = force_all_finite )
104115
105116 X_int = np .zeros ((n_samples , n_features ), dtype = int )
106117 X_mask = np .ones ((n_samples , n_features ), dtype = bool )
@@ -355,8 +366,26 @@ def _compute_drop_idx(self):
355366 "of features ({}), got {}" )
356367 raise ValueError (msg .format (len (self .categories_ ),
357368 len (self .drop )))
358- missing_drops = [(i , val ) for i , val in enumerate (self .drop )
359- if val not in self .categories_ [i ]]
369+ missing_drops = []
370+ drop_indices = []
371+ for col_idx , (val , cat_list ) in enumerate (zip (self .drop ,
372+ self .categories_ )):
373+ if not is_scalar_nan (val ):
374+ drop_idx = np .where (cat_list == val )[0 ]
375+ if drop_idx .size : # found drop idx
376+ drop_indices .append (drop_idx [0 ])
377+ else :
378+ missing_drops .append ((col_idx , val ))
379+ continue
380+
381+ # val is nan, find nan in categories manually
382+ for cat_idx , cat in enumerate (cat_list ):
383+ if is_scalar_nan (cat ):
384+ drop_indices .append (cat_idx )
385+ break
386+ else : # loop did not break thus drop is missing
387+ missing_drops .append ((col_idx , val ))
388+
360389 if any (missing_drops ):
361390 msg = ("The following categories were supposed to be "
362391 "dropped, but were not found in the training "
@@ -365,10 +394,7 @@ def _compute_drop_idx(self):
365394 ["Category: {}, Feature: {}" .format (c , v )
366395 for c , v in missing_drops ])))
367396 raise ValueError (msg )
368- return np .array ([np .where (cat_list == val )[0 ][0 ]
369- for (val , cat_list ) in
370- zip (self .drop , self .categories_ )],
371- dtype = object )
397+ return np .array (drop_indices , dtype = object )
372398
373399 def fit (self , X , y = None ):
374400 """
@@ -388,7 +414,8 @@ def fit(self, X, y=None):
388414 self
389415 """
390416 self ._validate_keywords ()
391- self ._fit (X , handle_unknown = self .handle_unknown )
417+ self ._fit (X , handle_unknown = self .handle_unknown ,
418+ force_all_finite = 'allow-nan' )
392419 self .drop_idx_ = self ._compute_drop_idx ()
393420 return self
394421
@@ -431,7 +458,8 @@ def transform(self, X):
431458 """
432459 check_is_fitted (self )
433460 # validation of X happens in _check_X called by _transform
434- X_int , X_mask = self ._transform (X , handle_unknown = self .handle_unknown )
461+ X_int , X_mask = self ._transform (X , handle_unknown = self .handle_unknown ,
462+ force_all_finite = 'allow-nan' )
435463
436464 n_samples , n_features = X_int .shape
437465
0 commit comments