Skip to content

Commit 2ba21ec

Browse files
FedericoVamueller
authored andcommitted
Added test_classifiers_pickle to tests.
1 parent 60374ee commit 2ba21ec

File tree

1 file changed

+50
-3
lines changed

1 file changed

+50
-3
lines changed

sklearn/tests/test_common.py

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -573,6 +573,53 @@ def test_classifiers_classes():
573573
(clf, classes, clf.classes_))
574574

575575

576+
def test_classifiers_pickle():
577+
# test if classifiers do something sensible on training set
578+
# also test all shapes / shape errors
579+
classifiers = all_estimators(type_filter='classifier')
580+
X_m, y_m = make_blobs(random_state=0)
581+
X_m, y_m = shuffle(X_m, y_m, random_state=7)
582+
X_m = StandardScaler().fit_transform(X_m)
583+
# generate binary problem from multi-class one
584+
y_b = y_m[y_m != 2]
585+
X_b = X_m[y_m != 2]
586+
succeeded = True
587+
for (X, y) in [(X_m, y_m), (X_b, y_b)]:
588+
# do it once with binary, once with multiclass
589+
classes = np.unique(y)
590+
n_classes = len(classes)
591+
n_samples, n_features = X.shape
592+
for name, Clf in classifiers:
593+
if Clf in dont_test:
594+
continue
595+
if Clf in [MultinomialNB, BernoulliNB]:
596+
# TODO also test these!
597+
continue
598+
# catch deprecation warnings
599+
with warnings.catch_warnings(record=True):
600+
clf = Clf()
601+
# raises error on malformed input for fit
602+
assert_raises(ValueError, clf.fit, X, y[:-1])
603+
604+
# fit
605+
clf.fit(X, y)
606+
y_pred = clf.predict(X)
607+
pickled_clf = StringIO.StringIO()
608+
pickle.dump(clf, pickled_clf)
609+
pickled_clf.pos = 0
610+
unpickled_clf = pickle.load(pickled_clf)
611+
pickled_y_pred = unpickled_clf.predict(X)
612+
613+
try:
614+
assert_array_almost_equal(pickled_y_pred, y_pred)
615+
except Exception, exc:
616+
succeeded = False
617+
print ("Esimator %s doesn't predict the same value "
618+
"after pickling" % name)
619+
raise exc
620+
assert_true(succeeded)
621+
622+
576623
def test_regressors_int():
577624
# test if regressors can cope with integer labels (by converting them to
578625
# float)
@@ -678,16 +725,16 @@ def test_regressor_pickle():
678725
else:
679726
y_ = y
680727
reg.fit(X, y_)
681-
pred = reg.predict(X)
728+
y_pred = reg.predict(X)
682729
# store old predictions
683730
pickled_reg = StringIO.StringIO()
684731
pickle.dump(reg, pickled_reg)
685732
pickled_reg.pos = 0
686733
unpickled_reg = pickle.load(pickled_reg)
687-
new_pred = unpickled_reg.predict(X)
734+
pickled_y_pred = unpickled_reg.predict(X)
688735

689736
try:
690-
assert_array_almost_equal(new_pred, pred)
737+
assert_array_almost_equal(pickled_y_pred, y_pred)
691738
except Exception, exc:
692739
succeeded = False
693740
print ("Esimator %s doesn't predict the same value "

0 commit comments

Comments
 (0)