Skip to content

Commit f9104f6

Browse files
committed
warm_start semantics now fit exactly n_estimators rather than self.estimators_ + self.n_estimators
1 parent 6c033bd commit f9104f6

File tree

2 files changed

+54
-20
lines changed

2 files changed

+54
-20
lines changed

sklearn/ensemble/gradient_boosting.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -681,7 +681,9 @@ def _clear_state(self):
681681
def _resize_state(self):
682682
"""Add additional ``n_estimators`` entries to all attributes. """
683683
# self.n_estimators is the number of additional est to fit
684-
total_n_estimators = self.n_estimators + self.estimators_.shape[0]
684+
total_n_estimators = self.n_estimators
685+
if total_n_estimators < self.estimators_.shape[0]:
686+
raise ValueError('resize with smaller n_estimators than len(estimators_)')
685687

686688
self.estimators_.resize((total_n_estimators, self.loss_.K))
687689
self.train_score_.resize(total_n_estimators)
@@ -755,6 +757,10 @@ def fit(self, X, y, monitor=None):
755757
begin_at_stage = 0
756758
else:
757759
# add more estimators to fitted model
760+
# invariant: warm_start = True
761+
if self.n_estimators < self.estimators_.shape[0]:
762+
raise ValueError('n_estimators must be larger or equal to estimators_.shape[0]' +
763+
'when warm_start==True')
758764
begin_at_stage = self.estimators_.shape[0]
759765
y_pred = self.decision_function(X)
760766
self._resize_state()
@@ -801,7 +807,7 @@ def _fit_stages(self, X, y, y_pred, random_state, begin_at_stage=0,
801807
verbose_reporter.init(self, begin_at_stage)
802808

803809
# perform boosting iterations
804-
for i in range(begin_at_stage, begin_at_stage + self.n_estimators):
810+
for i in range(begin_at_stage, self.n_estimators):
805811

806812
# subsampling
807813
if do_oob:

sklearn/ensemble/tests/test_gradient_boosting.py

Lines changed: 46 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,25 +5,23 @@
55
import numpy as np
66
import warnings
77

8-
from sklearn.utils.testing import assert_equal
9-
from sklearn.utils.testing import assert_array_equal
8+
from sklearn import datasets
9+
from sklearn.base import clone
10+
from sklearn.ensemble import GradientBoostingClassifier
11+
from sklearn.ensemble import GradientBoostingRegressor
12+
from sklearn.ensemble.gradient_boosting import ZeroEstimator
13+
from sklearn.metrics import mean_squared_error
14+
from sklearn.utils import check_random_state, tosequence
15+
from sklearn.utils.testing import assert_almost_equal
1016
from sklearn.utils.testing import assert_array_almost_equal
17+
from sklearn.utils.testing import assert_array_equal
18+
from sklearn.utils.testing import assert_equal
19+
from sklearn.utils.testing import assert_greater
1120
from sklearn.utils.testing import assert_raises
1221
from sklearn.utils.testing import assert_true
13-
from sklearn.utils.testing import assert_almost_equal
14-
from sklearn.utils.testing import assert_greater
1522
from sklearn.utils.testing import assert_warns
16-
17-
18-
from sklearn.metrics import mean_squared_error
19-
from sklearn.utils import check_random_state, tosequence
2023
from sklearn.utils.validation import DataConversionWarning
2124

22-
from sklearn.ensemble import GradientBoostingClassifier
23-
from sklearn.ensemble import GradientBoostingRegressor
24-
from sklearn.ensemble.gradient_boosting import ZeroEstimator
25-
26-
from sklearn import datasets
2725

2826
# toy sample
2927
X = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]]
@@ -637,6 +635,7 @@ def test_warm_start():
637635

638636
est_ws = cls(n_estimators=100, max_depth=1, warm_start=True)
639637
est_ws.fit(X, y)
638+
est_ws.set_params(n_estimators=200)
640639
est_ws.fit(X, y)
641640

642641
assert_array_almost_equal(est_ws.predict(X), est.predict(X))
@@ -651,7 +650,7 @@ def test_warm_start_n_estimators():
651650

652651
est_ws = cls(n_estimators=100, max_depth=1, warm_start=True)
653652
est_ws.fit(X, y)
654-
est_ws.set_params(n_estimators=200)
653+
est_ws.set_params(n_estimators=300)
655654
est_ws.fit(X, y)
656655

657656
assert_array_almost_equal(est_ws.predict(X), est.predict(X))
@@ -663,12 +662,13 @@ def test_warm_start_max_depth():
663662
for cls in [GradientBoostingRegressor, GradientBoostingClassifier]:
664663
est = cls(n_estimators=100, max_depth=1, warm_start=True)
665664
est.fit(X, y)
666-
est.set_params(n_estimators=10, max_depth=2)
665+
est.set_params(n_estimators=110, max_depth=2)
667666
est.fit(X, y)
668667

669668
# last 10 trees have different depth
670669
assert est.estimators_[0, 0].max_depth == 1
671-
assert est.estimators_[-1, 0].max_depth == 2
670+
for i in range(1, 11):
671+
assert est.estimators_[-i, 0].max_depth == 2
672672

673673

674674
def test_warm_start_clear():
@@ -696,16 +696,43 @@ def test_warm_start_zero_n_estimators():
696696
assert_raises(ValueError, est.fit, X, y)
697697

698698

699+
def test_warm_start_smaller_n_estimators():
700+
"""Test if warm start with smaller n_estimators raises error """
701+
X, y = datasets.make_hastie_10_2(n_samples=100, random_state=1)
702+
for cls in [GradientBoostingRegressor, GradientBoostingClassifier]:
703+
est = cls(n_estimators=100, max_depth=1, warm_start=True)
704+
est.fit(X, y)
705+
est.set_params(n_estimators=99)
706+
assert_raises(ValueError, est.fit, X, y)
707+
708+
709+
def test_warm_start_equal_n_estimators():
710+
"""Test if warm start with equal n_estimators does nothing """
711+
X, y = datasets.make_hastie_10_2(n_samples=100, random_state=1)
712+
for cls in [GradientBoostingRegressor, GradientBoostingClassifier]:
713+
est = cls(n_estimators=100, max_depth=1)
714+
est.fit(X, y)
715+
716+
est2 = clone(est)
717+
est2.set_params(n_estimators=est.n_estimators, warm_start=True)
718+
est2.fit(X, y)
719+
720+
assert_array_almost_equal(est2.predict(X), est.predict(X))
721+
722+
699723
def test_warm_start_oob_switch():
700724
"""Test if oob can be turned on during warm start. """
701725
X, y = datasets.make_hastie_10_2(n_samples=100, random_state=1)
702726
for cls in [GradientBoostingRegressor, GradientBoostingClassifier]:
703727
est = cls(n_estimators=100, max_depth=1, warm_start=True)
704728
est.fit(X, y)
705-
est.set_params(n_estimators=10, subsample=0.5)
729+
est.set_params(n_estimators=110, subsample=0.5)
706730
est.fit(X, y)
707731

708-
assert_array_equal(est.oob_improvement_[:10], np.zeros(10))
732+
assert_array_equal(est.oob_improvement_[:100], np.zeros(100))
733+
# the last 10 are not zeros
734+
assert_array_equal(est.oob_improvement_[-10:] == 0.0,
735+
np.zeros(10, dtype=np.bool))
709736

710737

711738
def test_warm_start_oob():
@@ -719,6 +746,7 @@ def test_warm_start_oob():
719746
est_ws = cls(n_estimators=100, max_depth=1, subsample=0.5,
720747
random_state=1, warm_start=True)
721748
est_ws.fit(X, y)
749+
est_ws.set_params(n_estimators=200)
722750
est_ws.fit(X, y)
723751

724752
assert_array_almost_equal(est_ws.oob_improvement_[:100],

0 commit comments

Comments
 (0)