Skip to content

Commit ab92433

Browse files
committed
Merge pull request scikit-learn#1688 from hrishikeshio/fit_transform
Test that fit_transform(X) does the same as fit(X).transform(X); fixes scikit-learn#1687
2 parents a7c6b06 + 3cc8158 commit ab92433

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

sklearn/tests/test_common.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -218,17 +218,25 @@ def test_transformers():
218218
if hasattr(trans, 'transform'):
219219
if Trans in (_PLS, PLSCanonical, PLSRegression, CCA, PLSSVD):
220220
X_pred2 = trans.transform(X, y_)
221+
X_pred3 = trans.fit_transform(X, y=y_)
221222
else:
222223
X_pred2 = trans.transform(X)
224+
X_pred3 = trans.fit_transform(X, y=y_)
223225
if isinstance(X_pred, tuple) and isinstance(X_pred2, tuple):
224-
for x_pred, x_pred2 in zip(X_pred, X_pred2):
226+
for x_pred, x_pred2, x_pred3 in zip(X_pred, X_pred2, X_pred3):
225227
assert_array_almost_equal(
226228
x_pred, x_pred2, 2,
227229
"fit_transform not correct in %s" % Trans)
230+
assert_array_almost_equal(
231+
x_pred3, x_pred2, 2,
232+
"fit_transform not correct in %s" % Trans)
228233
else:
229234
assert_array_almost_equal(
230235
X_pred, X_pred2, 2,
231236
"fit_transform not correct in %s" % Trans)
237+
assert_array_almost_equal(
238+
X_pred3, X_pred2, 2,
239+
"fit_transform not correct in %s" % Trans)
232240

233241
# raises error on malformed input for transform
234242
assert_raises(ValueError, trans.transform, X.T)
@@ -532,7 +540,7 @@ def test_classifiers_classes():
532540
y = 2 * y + 1
533541
classes = np.unique(y)
534542
# TODO: make work with next line :)
535-
#y = y.astype(np.str)
543+
# y = y.astype(np.str)
536544
for name, Clf in classifiers:
537545
if name in dont_test:
538546
continue
@@ -647,7 +655,7 @@ def test_configure():
647655
with warnings.catch_warnings():
648656
# The configuration spits out warnings when not finding
649657
# Blas/Atlas development headers
650-
warnings.simplefilter('ignore', UserWarning)
658+
warnings.simplefilter('ignore', UserWarning)
651659
execfile('setup.py', dict(__name__='__main__'))
652660
finally:
653661
sys.argv = old_argv

0 commit comments

Comments
 (0)