Skip to content

Commit 4e6346e

Browse files
committed
FIX raise error properly when n_features differ in fit and apply
1 parent 932726f commit 4e6346e

File tree

3 files changed

+17
-3
lines changed

3 files changed

+17
-3
lines changed

sklearn/ensemble/forest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def apply(self, X):
160160
X = check_array(X, dtype=DTYPE, accept_sparse="csr")
161161
results = Parallel(n_jobs=self.n_jobs, verbose=self.verbose,
162162
backend="threading")(
163-
delayed(_parallel_helper)(tree.tree_, 'apply', X)
163+
delayed(_parallel_helper)(tree, 'apply', X, check_input=False)
164164
for tree in self.estimators_)
165165

166166
return np.array(results).T

sklearn/tree/tests/test_tree.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -541,10 +541,12 @@ def test_error():
541541
est = TreeEstimator()
542542
est.fit(np.dot(X, Xt), y)
543543
assert_raises(ValueError, est.predict, X)
544+
assert_raises(ValueError, est.apply, X)
544545

545546
clf = TreeEstimator()
546547
clf.fit(X, y)
547548
assert_raises(ValueError, clf.predict, Xt)
549+
assert_raises(ValueError, clf.apply, Xt)
548550

549551
# apply before fitting
550552
est = TreeEstimator()

sklearn/tree/tree.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,7 @@ def predict(self, X, check_input=True):
375375
else:
376376
return proba[:, :, 0]
377377

378-
def apply(self, X):
378+
def apply(self, X, check_input=True):
379379
"""
380380
Returns the index of the leaf that each sample is predicted as.
381381
@@ -386,6 +386,10 @@ def apply(self, X):
386386
``dtype=np.float32`` and if a sparse matrix is provided
387387
to a sparse ``csr_matrix``.
388388
389+
check_input : boolean, (default=True)
390+
Allow to bypass several input checking.
391+
Don't use this parameter unless you know what you do.
392+
389393
Returns
390394
-------
391395
X_leaves : array_like, shape = [n_samples,]
@@ -398,7 +402,15 @@ def apply(self, X):
398402
raise NotFittedError("Estimator not fitted, "
399403
"call `fit` before `apply`.")
400404

401-
X = check_array(X, dtype=DTYPE, accept_sparse="csr")
405+
if check_input:
406+
X = check_array(X, dtype=DTYPE, accept_sparse="csr")
407+
408+
n_features = X.shape[1]
409+
if self.n_features_ != n_features:
410+
raise ValueError("Number of features of the model must "
411+
" match the input. Model n_features is %s and "
412+
" input n_features is %s "
413+
% (self.n_features_, n_features))
402414

403415
return self.tree_.apply(X)
404416

0 commit comments

Comments
 (0)