Skip to content

Commit 94318a8

Browse files
committed
ENH in transformer pickle test, only test transformers that provide a 'transform' method. and only test that.
1 parent 1d958b2 commit 94318a8

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

sklearn/tests/test_common.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,8 @@ def test_transformers_pickle():
396396
# catch deprecation warnings
397397
with warnings.catch_warnings(record=True):
398398
trans = Trans()
399+
if not hasattr(trans, 'transform'):
400+
continue
399401
set_random_state(trans)
400402
if hasattr(trans, 'compute_importances'):
401403
trans.compute_importances = True
@@ -413,7 +415,6 @@ def test_transformers_pickle():
413415
trans.n_components = 1
414416

415417
# fit
416-
417418
if Trans in (_PLS, PLSCanonical, PLSRegression, CCA, PLSSVD):
418419
random_state = np.random.RandomState(seed=12345)
419420
y_ = np.vstack([y, 2 * y + random_state.randint(2, size=len(y))])
@@ -422,10 +423,10 @@ def test_transformers_pickle():
422423
y_ = y
423424

424425
trans.fit(X, y_)
425-
X_pred = trans.fit_transform(X, y=y_)
426+
X_pred = trans.fit(X, y_).transform(X)
426427
pickled_trans = pickle.dumps(trans)
427428
unpickled_trans = pickle.loads(pickled_trans)
428-
pickled_X_pred = unpickled_trans.fit_transform(X, y=y_)
429+
pickled_X_pred = unpickled_trans.transform(X)
429430

430431
try:
431432
assert_array_almost_equal(pickled_X_pred, X_pred)

0 commit comments

Comments
 (0)