Skip to content

Commit 86ab060

Browse files
committed
Added test for fit_transform(X)==fit(X).transform(X)
1 parent c481ff8 commit 86ab060

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

sklearn/tests/test_common.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,21 +214,29 @@ def test_transformers():
214214
print()
215215
succeeded = False
216216
continue
217-
217+
218218
if hasattr(trans, 'transform'):
219219
if Trans in (_PLS, PLSCanonical, PLSRegression, CCA, PLSSVD):
220220
X_pred2 = trans.transform(X, y_)
221+
X_pred3 = trans.fit_transform(X, y=y_)
221222
else:
222223
X_pred2 = trans.transform(X)
224+
X_pred3 = trans.fit_transform(X,y=y_)
223225
if isinstance(X_pred, tuple) and isinstance(X_pred2, tuple):
224-
for x_pred, x_pred2 in zip(X_pred, X_pred2):
226+
for x_pred, x_pred2, x_pred3 in zip(X_pred, X_pred2, X_pred3):
225227
assert_array_almost_equal(
226228
x_pred, x_pred2, 2,
227229
"fit_transform not correct in %s" % Trans)
230+
assert_array_almost_equal(
231+
x_pred3, x_pred2, 2,
232+
"fit_transform not correct in %s" % Trans)
228233
else:
229234
assert_array_almost_equal(
230235
X_pred, X_pred2, 2,
231236
"fit_transform not correct in %s" % Trans)
237+
assert_array_almost_equal(
238+
X_pred3, X_pred2, 2,
239+
"fit_transform not correct in %s" % Trans)
232240

233241
# raises error on malformed input for transform
234242
assert_raises(ValueError, trans.transform, X.T)

0 commit comments

Comments
 (0)