Skip to content

Commit 0015659

Browse files
committed
Merge pull request scikit-learn#5081 from amueller/transformers_consistent_n_samples
[MRG + 2] Add common test that transformers don't change n_samples.
2 parents 39bad0a + 32f2e8e commit 0015659

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

sklearn/utils/estimator_checks.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,7 @@ def _check_transformer(name, Transformer, X, y):
387387
for x_pred in X_pred:
388388
assert_equal(x_pred.shape[0], n_samples)
389389
else:
390+
# check for consistent n_samples
390391
assert_equal(X_pred.shape[0], n_samples)
391392

392393
if hasattr(transformer, 'transform'):
@@ -415,6 +416,8 @@ def _check_transformer(name, Transformer, X, y):
415416
X_pred, X_pred3, 2,
416417
"consecutive fit_transform outcomes not consistent in %s"
417418
% Transformer)
419+
assert_equal(len(X_pred2), n_samples)
420+
assert_equal(len(X_pred3), n_samples)
418421

419422
# raises error on malformed input for transform
420423
if hasattr(X, 'T'):

0 commit comments

Comments
 (0)