4848
4949BOSTON = None
5050CROSS_DECOMPOSITION = ['PLSCanonical' , 'PLSRegression' , 'CCA' , 'PLSSVD' ]
51+ MULTI_OUTPUT = ['CCA' , 'DecisionTreeRegressor' , 'ElasticNet' ,
52+ 'ExtraTreeRegressor' , 'ExtraTreesRegressor' , 'GaussianProcess' ,
53+ 'KNeighborsRegressor' , 'KernelRidge' , 'Lars' , 'Lasso' ,
54+ 'LassoLars' , 'LinearRegression' , 'MultiTaskElasticNet' ,
55+ 'MultiTaskElasticNetCV' , 'MultiTaskLasso' , 'MultiTaskLassoCV' ,
56+ 'OrthogonalMatchingPursuit' , 'PLSCanonical' , 'PLSRegression' ,
57+ 'RANSACRegressor' , 'RadiusNeighborsRegressor' ,
58+ 'RandomForestRegressor' , 'Ridge' , 'RidgeCV' ]
5159
5260
5361def _yield_non_meta_checks (name , Estimator ):
@@ -100,8 +108,7 @@ def _yield_classifier_checks(name, Classifier):
100108 # We don't raise a warning in these classifiers, as
101109 # the column y interface is used by the forests.
102110
103- # test if classifiers can cope with y.shape = (n_samples, 1)
104- yield check_classifiers_input_shapes
111+ yield check_supervised_y_2d
105112 # test if NotFittedError is raised
106113 yield check_estimators_unfitted
107114 if 'class_weight' in Classifier ().get_params ().keys ():
@@ -116,6 +123,7 @@ def _yield_regressor_checks(name, Regressor):
116123 yield check_regressor_data_not_an_array
117124 yield check_estimators_partial_fit_n_features
118125 yield check_regressors_no_decision_function
126+ yield check_supervised_y_2d
119127 if name != 'CCA' :
120128 # check that the regressor handles int input
121129 yield check_regressors_int
@@ -831,31 +839,36 @@ def check_estimators_unfitted(name, Estimator):
831839 est .predict_log_proba , X )
832840
833841
834- def check_classifiers_input_shapes (name , Classifier ):
835- iris = load_iris ()
836- X , y = iris .data , iris .target
837- X , y = shuffle (X , y , random_state = 1 )
838- X = StandardScaler ().fit_transform (X )
842+ def check_supervised_y_2d (name , Estimator ):
843+ if "MultiTask" in name :
844+ # These only work on 2d, so this test makes no sense
845+ return
846+ rnd = np .random .RandomState (0 )
847+ X = rnd .uniform (size = (10 , 3 ))
848+ y = np .arange (10 ) % 3
839849 # catch deprecation warnings
840850 with warnings .catch_warnings (record = True ):
841- classifier = Classifier ()
842- set_fast_parameters (classifier )
843- set_random_state (classifier )
851+ estimator = Estimator ()
852+ set_fast_parameters (estimator )
853+ set_random_state (estimator )
844854 # fit
845- classifier .fit (X , y )
846- y_pred = classifier .predict (X )
855+ estimator .fit (X , y )
856+ y_pred = estimator .predict (X )
847857
848- set_random_state (classifier )
858+ set_random_state (estimator )
849859 # Check that when a 2D y is given, a DataConversionWarning is
850860 # raised
851861 with warnings .catch_warnings (record = True ) as w :
852862 warnings .simplefilter ("always" , DataConversionWarning )
853863 warnings .simplefilter ("ignore" , RuntimeWarning )
854- classifier .fit (X , y [:, np .newaxis ])
864+ estimator .fit (X , y [:, np .newaxis ])
865+ y_pred_2d = estimator .predict (X )
855866 msg = "expected 1 DataConversionWarning, got: %s" % (
856867 ", " .join ([str (w_x ) for w_x in w ]))
857- assert_equal (len (w ), 1 , msg )
858- assert_array_equal (y_pred , classifier .predict (X ))
868+ if name not in MULTI_OUTPUT :
869+ # check that we warned if we don't support multi-output
870+ assert_equal (len (w ), 1 , msg )
871+ assert_array_almost_equal (y_pred .ravel (), y_pred_2d .ravel ())
859872
860873
861874def check_classifiers_classes (name , Classifier ):
0 commit comments