@@ -573,6 +573,53 @@ def test_classifiers_classes():
573573 (clf , classes , clf .classes_ ))
574574
575575
576+ def test_classifiers_pickle ():
577+ # test if classifiers do something sensible on training set
578+ # also test all shapes / shape errors
579+ classifiers = all_estimators (type_filter = 'classifier' )
580+ X_m , y_m = make_blobs (random_state = 0 )
581+ X_m , y_m = shuffle (X_m , y_m , random_state = 7 )
582+ X_m = StandardScaler ().fit_transform (X_m )
583+ # generate binary problem from multi-class one
584+ y_b = y_m [y_m != 2 ]
585+ X_b = X_m [y_m != 2 ]
586+ succeeded = True
587+ for (X , y ) in [(X_m , y_m ), (X_b , y_b )]:
588+ # do it once with binary, once with multiclass
589+ classes = np .unique (y )
590+ n_classes = len (classes )
591+ n_samples , n_features = X .shape
592+ for name , Clf in classifiers :
593+ if Clf in dont_test :
594+ continue
595+ if Clf in [MultinomialNB , BernoulliNB ]:
596+ # TODO also test these!
597+ continue
598+ # catch deprecation warnings
599+ with warnings .catch_warnings (record = True ):
600+ clf = Clf ()
601+ # raises error on malformed input for fit
602+ assert_raises (ValueError , clf .fit , X , y [:- 1 ])
603+
604+ # fit
605+ clf .fit (X , y )
606+ y_pred = clf .predict (X )
607+ pickled_clf = StringIO .StringIO ()
608+ pickle .dump (clf , pickled_clf )
609+ pickled_clf .pos = 0
610+ unpickled_clf = pickle .load (pickled_clf )
611+ pickled_y_pred = unpickled_clf .predict (X )
612+
613+ try :
614+ assert_array_almost_equal (pickled_y_pred , y_pred )
615+ except Exception , exc :
616+ succeeded = False
617+ print ("Esimator %s doesn't predict the same value "
618+ "after pickling" % name )
619+ raise exc
620+ assert_true (succeeded )
621+
622+
576623def test_regressors_int ():
577624 # test if regressors can cope with integer labels (by converting them to
578625 # float)
@@ -678,16 +725,16 @@ def test_regressor_pickle():
678725 else :
679726 y_ = y
680727 reg .fit (X , y_ )
681- pred = reg .predict (X )
728+ y_pred = reg .predict (X )
682729 # store old predictions
683730 pickled_reg = StringIO .StringIO ()
684731 pickle .dump (reg , pickled_reg )
685732 pickled_reg .pos = 0
686733 unpickled_reg = pickle .load (pickled_reg )
687- new_pred = unpickled_reg .predict (X )
734+ pickled_y_pred = unpickled_reg .predict (X )
688735
689736 try :
690- assert_array_almost_equal (new_pred , pred )
737+ assert_array_almost_equal (pickled_y_pred , y_pred )
691738 except Exception , exc :
692739 succeeded = False
693740 print ("Esimator %s doesn't predict the same value "
0 commit comments