@@ -269,34 +269,38 @@ def check_estimator_sparse_data(name, Estimator):
269269 rng = np .random .RandomState (0 )
270270 X = rng .rand (40 , 10 )
271271 X [X < .8 ] = 0
272- X = sparse .csr_matrix (X )
272+ X_csr = sparse .csr_matrix (X )
273273 y = (4 * rng .rand (40 )).astype (np .int )
274- # catch deprecation warnings
275- with warnings .catch_warnings ():
276- if name in ['Scaler' , 'StandardScaler' ]:
277- estimator = Estimator (with_mean = False )
278- else :
279- estimator = Estimator ()
280- set_fast_parameters (estimator )
281- # fit and predict
282- try :
283- estimator .fit (X , y )
284- if hasattr (estimator , "predict" ):
285- estimator .predict (X )
286- if hasattr (estimator , 'predict_proba' ):
287- estimator .predict_proba (X )
288- except TypeError as e :
289- if 'sparse' not in repr (e ):
274+ for sparse_format in ['csr' , 'csc' , 'dok' , 'lil' , 'coo' , 'dia' , 'bsr' ]:
275+ X = X_csr .asformat (sparse_format )
276+ # catch deprecation warnings
277+ with warnings .catch_warnings ():
278+ if name in ['Scaler' , 'StandardScaler' ]:
279+ estimator = Estimator (with_mean = False )
280+ else :
281+ estimator = Estimator ()
282+ set_fast_parameters (estimator )
283+ # fit and predict
284+ try :
285+ estimator .fit (X , y )
286+ if hasattr (estimator , "predict" ):
287+ pred = estimator .predict (X )
288+ assert_equal (pred .shape , (X .shape [0 ],))
289+ if hasattr (estimator , 'predict_proba' ):
290+ probs = estimator .predict_proba (X )
291+ assert_equal (probs .shape , (X .shape [0 ], 4 ))
292+ except TypeError as e :
293+ if 'sparse' not in repr (e ):
294+ print ("Estimator %s doesn't seem to fail gracefully on "
295+ "sparse data: error message state explicitly that "
296+ "sparse input is not supported if this is not the case."
297+ % name )
298+ raise
299+ except Exception :
290300 print ("Estimator %s doesn't seem to fail gracefully on "
291- "sparse data: error message state explicitly that "
292- "sparse input is not supported if this is not the case."
293- % name )
301+ "sparse data: it should raise a TypeError if sparse input "
302+ "is explicitly not supported." % name )
294303 raise
295- except Exception :
296- print ("Estimator %s doesn't seem to fail gracefully on "
297- "sparse data: it should raise a TypeError if sparse input "
298- "is explicitly not supported." % name )
299- raise
300304
301305
302306def check_dtype_object (name , Estimator ):
0 commit comments