Skip to content

Commit 132ad99

Browse files
jcusick13jnothman
authored andcommitted
ENH Pass original dataset to Stacking final estimator (scikit-learn#15138)
1 parent 5799643 commit 132ad99

File tree

3 files changed

+107
-16
lines changed

3 files changed

+107
-16
lines changed

doc/whats_new/v0.22.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,11 @@ Changelog
285285
by the max of the samples with non-null weights only.
286286
:pr:`14294` by :user:`Guillaume Lemaitre <glemaitre>`.
287287

288+
- |Enhancement| Adds ``passthrough`` to :class: `ensemble.StackingClassifier`
289+
and :class: `ensemble.StackingRegressor` allowing for the original dataset
290+
to be used in the final estimator.
291+
:pr:`15138` by :user:`Jon Cusick <jcusick13>`.
292+
288293
:mod:`sklearn.feature_extraction`
289294
.................................
290295

sklearn/ensemble/_stacking.py

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import numpy as np
1010
from joblib import Parallel, delayed
11+
import scipy.sparse as sparse
1112

1213
from ..base import clone
1314
from ..base import ClassifierMixin, RegressorMixin, TransformerMixin
@@ -37,22 +38,30 @@ class _BaseStacking(TransformerMixin, _BaseHeterogeneousEnsemble,
3738

3839
@abstractmethod
3940
def __init__(self, estimators, final_estimator=None, cv=None,
40-
stack_method='auto', n_jobs=None, verbose=0):
41+
stack_method='auto', n_jobs=None, verbose=0,
42+
passthrough=False):
4143
super().__init__(estimators=estimators)
4244
self.final_estimator = final_estimator
4345
self.cv = cv
4446
self.stack_method = stack_method
4547
self.n_jobs = n_jobs
4648
self.verbose = verbose
49+
self.passthrough = passthrough
4750

4851
def _clone_final_estimator(self, default):
4952
if self.final_estimator is not None:
5053
self.final_estimator_ = clone(self.final_estimator)
5154
else:
5255
self.final_estimator_ = clone(default)
5356

54-
def _concatenate_predictions(self, predictions):
55-
"""Concatenate the predictions of each first layer learner.
57+
def _concatenate_predictions(self, X, predictions):
58+
"""Concatenate the predictions of each first layer learner and
59+
possibly the input dataset `X`.
60+
61+
If `X` is sparse and `self.passthrough` is False, the output of
62+
`transform` will be dense (the predictions). If `X` is sparse
63+
and `self.passthrough` is True, the output of `transform` will
64+
be sparse.
5665
5766
This helper is in charge of ensuring the preditions are 2D arrays and
5867
it will drop one of the probability column when using probabilities
@@ -72,7 +81,12 @@ def _concatenate_predictions(self, predictions):
7281
X_meta.append(preds[:, 1:])
7382
else:
7483
X_meta.append(preds)
75-
return np.concatenate(X_meta, axis=1)
84+
if self.passthrough:
85+
X_meta.append(X)
86+
if sparse.issparse(X):
87+
return sparse.hstack(X_meta, format=X.format)
88+
89+
return np.hstack(X_meta)
7690

7791
@staticmethod
7892
def _method_name(name, estimator, method):
@@ -165,7 +179,7 @@ def fit(self, X, y, sample_weight=None):
165179
if est != 'drop'
166180
]
167181

168-
X_meta = self._concatenate_predictions(predictions)
182+
X_meta = self._concatenate_predictions(X, predictions)
169183
if sample_weight is not None:
170184
try:
171185
self.final_estimator_.fit(
@@ -192,7 +206,7 @@ def _transform(self, X):
192206
for est, meth in zip(self.estimators_, self.stack_method_)
193207
if est != 'drop'
194208
]
195-
return self._concatenate_predictions(predictions)
209+
return self._concatenate_predictions(X, predictions)
196210

197211
@if_delegate_has_method(delegate='final_estimator_')
198212
def predict(self, X, **predict_params):
@@ -288,6 +302,12 @@ class StackingClassifier(ClassifierMixin, _BaseStacking):
288302
`None` means 1 unless in a `joblib.parallel_backend` context. -1 means
289303
using all processors. See Glossary for more details.
290304
305+
passthrough : bool, default=False
306+
When False, only the predictions of estimators will be used as
307+
training data for `final_estimator`. When True, the
308+
`final_estimator` is trained on the predictions as well as the
309+
original training data.
310+
291311
Attributes
292312
----------
293313
estimators_ : list of estimators
@@ -344,13 +364,15 @@ class StackingClassifier(ClassifierMixin, _BaseStacking):
344364
345365
"""
346366
def __init__(self, estimators, final_estimator=None, cv=None,
347-
stack_method='auto', n_jobs=None, verbose=0):
367+
stack_method='auto', n_jobs=None, passthrough=False,
368+
verbose=0):
348369
super().__init__(
349370
estimators=estimators,
350371
final_estimator=final_estimator,
351372
cv=cv,
352373
stack_method=stack_method,
353374
n_jobs=n_jobs,
375+
passthrough=passthrough,
354376
verbose=verbose
355377
)
356378

@@ -525,6 +547,12 @@ class StackingRegressor(RegressorMixin, _BaseStacking):
525547
`None` means 1 unless in a `joblib.parallel_backend` context. -1 means
526548
using all processors. See Glossary for more details.
527549
550+
passthrough : bool, default=False
551+
When False, only the predictions of estimators will be used as
552+
training data for `final_estimator`. When True, the
553+
`final_estimator` is trained on the predictions as well as the
554+
original training data.
555+
528556
Attributes
529557
----------
530558
estimators_ : list of estimator
@@ -569,13 +597,14 @@ class StackingRegressor(RegressorMixin, _BaseStacking):
569597
570598
"""
571599
def __init__(self, estimators, final_estimator=None, cv=None, n_jobs=None,
572-
verbose=0):
600+
passthrough=False, verbose=0):
573601
super().__init__(
574602
estimators=estimators,
575603
final_estimator=final_estimator,
576604
cv=cv,
577605
stack_method="predict",
578606
n_jobs=n_jobs,
607+
passthrough=passthrough,
579608
verbose=verbose
580609
)
581610

sklearn/ensemble/tests/test_stacking.py

Lines changed: 65 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import pytest
77
import numpy as np
8+
import scipy.sparse as sparse
89

910
from sklearn.base import BaseEstimator
1011
from sklearn.base import ClassifierMixin
@@ -38,6 +39,7 @@
3839
from sklearn.model_selection import KFold
3940

4041
from sklearn.utils._testing import assert_allclose
42+
from sklearn.utils._testing import assert_allclose_dense_sparse
4143
from sklearn.utils._testing import ignore_warnings
4244
from sklearn.utils.estimator_checks import check_estimator
4345
from sklearn.utils.estimator_checks import check_no_attributes_set_in_init
@@ -52,23 +54,28 @@
5254
@pytest.mark.parametrize(
5355
"final_estimator", [None, RandomForestClassifier(random_state=42)]
5456
)
55-
def test_stacking_classifier_iris(cv, final_estimator):
57+
@pytest.mark.parametrize("passthrough", [False, True])
58+
def test_stacking_classifier_iris(cv, final_estimator, passthrough):
5659
# prescale the data to avoid convergence warning without using a pipeline
5760
# for later assert
5861
X_train, X_test, y_train, y_test = train_test_split(
5962
scale(X_iris), y_iris, stratify=y_iris, random_state=42
6063
)
6164
estimators = [('lr', LogisticRegression()), ('svc', LinearSVC())]
6265
clf = StackingClassifier(
63-
estimators=estimators, final_estimator=final_estimator, cv=cv
66+
estimators=estimators, final_estimator=final_estimator, cv=cv,
67+
passthrough=passthrough
6468
)
6569
clf.fit(X_train, y_train)
6670
clf.predict(X_test)
6771
clf.predict_proba(X_test)
6872
assert clf.score(X_test, y_test) > 0.8
6973

7074
X_trans = clf.transform(X_test)
71-
assert X_trans.shape[1] == 6
75+
expected_column_count = 10 if passthrough else 6
76+
assert X_trans.shape[1] == expected_column_count
77+
if passthrough:
78+
assert_allclose(X_test, X_trans[:, -4:])
7279

7380
clf.set_params(lr='drop')
7481
clf.fit(X_train, y_train)
@@ -79,7 +86,10 @@ def test_stacking_classifier_iris(cv, final_estimator):
7986
clf.decision_function(X_test)
8087

8188
X_trans = clf.transform(X_test)
82-
assert X_trans.shape[1] == 3
89+
expected_column_count_drop = 7 if passthrough else 3
90+
assert X_trans.shape[1] == expected_column_count_drop
91+
if passthrough:
92+
assert_allclose(X_test, X_trans[:, -4:])
8393

8494

8595
def test_stacking_classifier_drop_column_binary_classification():
@@ -161,15 +171,18 @@ def test_stacking_regressor_drop_estimator():
161171
(RandomForestRegressor(random_state=42), {}),
162172
(DummyRegressor(), {'return_std': True})]
163173
)
164-
def test_stacking_regressor_diabetes(cv, final_estimator, predict_params):
174+
@pytest.mark.parametrize("passthrough", [False, True])
175+
def test_stacking_regressor_diabetes(cv, final_estimator, predict_params,
176+
passthrough):
165177
# prescale the data to avoid convergence warning without using a pipeline
166178
# for later assert
167179
X_train, X_test, y_train, _ = train_test_split(
168180
scale(X_diabetes), y_diabetes, random_state=42
169181
)
170182
estimators = [('lr', LinearRegression()), ('svr', LinearSVR())]
171183
reg = StackingRegressor(
172-
estimators=estimators, final_estimator=final_estimator, cv=cv
184+
estimators=estimators, final_estimator=final_estimator, cv=cv,
185+
passthrough=passthrough
173186
)
174187
reg.fit(X_train, y_train)
175188
result = reg.predict(X_test, **predict_params)
@@ -178,14 +191,58 @@ def test_stacking_regressor_diabetes(cv, final_estimator, predict_params):
178191
assert len(result) == expected_result_length
179192

180193
X_trans = reg.transform(X_test)
181-
assert X_trans.shape[1] == 2
194+
expected_column_count = 12 if passthrough else 2
195+
assert X_trans.shape[1] == expected_column_count
196+
if passthrough:
197+
assert_allclose(X_test, X_trans[:, -10:])
182198

183199
reg.set_params(lr='drop')
184200
reg.fit(X_train, y_train)
185201
reg.predict(X_test)
186202

187203
X_trans = reg.transform(X_test)
188-
assert X_trans.shape[1] == 1
204+
expected_column_count_drop = 11 if passthrough else 1
205+
assert X_trans.shape[1] == expected_column_count_drop
206+
if passthrough:
207+
assert_allclose(X_test, X_trans[:, -10:])
208+
209+
210+
@pytest.mark.parametrize('fmt', ['csc', 'csr', 'coo'])
211+
def test_stacking_regressor_sparse_passthrough(fmt):
212+
# Check passthrough behavior on a sparse X matrix
213+
X_train, X_test, y_train, _ = train_test_split(
214+
sparse.coo_matrix(scale(X_diabetes)).asformat(fmt),
215+
y_diabetes, random_state=42
216+
)
217+
estimators = [('lr', LinearRegression()), ('svr', LinearSVR())]
218+
rf = RandomForestRegressor(n_estimators=10, random_state=42)
219+
clf = StackingRegressor(
220+
estimators=estimators, final_estimator=rf, cv=5, passthrough=True
221+
)
222+
clf.fit(X_train, y_train)
223+
X_trans = clf.transform(X_test)
224+
assert_allclose_dense_sparse(X_test, X_trans[:, -10:])
225+
assert sparse.issparse(X_trans)
226+
assert X_test.format == X_trans.format
227+
228+
229+
@pytest.mark.parametrize('fmt', ['csc', 'csr', 'coo'])
230+
def test_stacking_classifier_sparse_passthrough(fmt):
231+
# Check passthrough behavior on a sparse X matrix
232+
X_train, X_test, y_train, _ = train_test_split(
233+
sparse.coo_matrix(scale(X_iris)).asformat(fmt),
234+
y_iris, random_state=42
235+
)
236+
estimators = [('lr', LogisticRegression()), ('svc', LinearSVC())]
237+
rf = RandomForestClassifier(n_estimators=10, random_state=42)
238+
clf = StackingClassifier(
239+
estimators=estimators, final_estimator=rf, cv=5, passthrough=True
240+
)
241+
clf.fit(X_train, y_train)
242+
X_trans = clf.transform(X_test)
243+
assert_allclose_dense_sparse(X_test, X_trans[:, -4:])
244+
assert sparse.issparse(X_trans)
245+
assert X_test.format == X_trans.format
189246

190247

191248
def test_stacking_classifier_drop_binary_prob():

0 commit comments

Comments
 (0)