55import numpy as np
66import warnings
77
8- from sklearn .utils .testing import assert_equal
9- from sklearn .utils .testing import assert_array_equal
8+ from sklearn import datasets
9+ from sklearn .base import clone
10+ from sklearn .ensemble import GradientBoostingClassifier
11+ from sklearn .ensemble import GradientBoostingRegressor
12+ from sklearn .ensemble .gradient_boosting import ZeroEstimator
13+ from sklearn .metrics import mean_squared_error
14+ from sklearn .utils import check_random_state , tosequence
15+ from sklearn .utils .testing import assert_almost_equal
1016from sklearn .utils .testing import assert_array_almost_equal
17+ from sklearn .utils .testing import assert_array_equal
18+ from sklearn .utils .testing import assert_equal
19+ from sklearn .utils .testing import assert_greater
1120from sklearn .utils .testing import assert_raises
1221from sklearn .utils .testing import assert_true
13- from sklearn .utils .testing import assert_almost_equal
14- from sklearn .utils .testing import assert_greater
1522from sklearn .utils .testing import assert_warns
16-
17-
18- from sklearn .metrics import mean_squared_error
19- from sklearn .utils import check_random_state , tosequence
2023from sklearn .utils .validation import DataConversionWarning
2124
22- from sklearn .ensemble import GradientBoostingClassifier
23- from sklearn .ensemble import GradientBoostingRegressor
24- from sklearn .ensemble .gradient_boosting import ZeroEstimator
25-
26- from sklearn import datasets
2725
2826# toy sample
2927X = [[- 2 , - 1 ], [- 1 , - 1 ], [- 1 , - 2 ], [1 , 1 ], [1 , 2 ], [2 , 1 ]]
@@ -637,6 +635,7 @@ def test_warm_start():
637635
638636 est_ws = cls (n_estimators = 100 , max_depth = 1 , warm_start = True )
639637 est_ws .fit (X , y )
638+ est_ws .set_params (n_estimators = 200 )
640639 est_ws .fit (X , y )
641640
642641 assert_array_almost_equal (est_ws .predict (X ), est .predict (X ))
@@ -651,7 +650,7 @@ def test_warm_start_n_estimators():
651650
652651 est_ws = cls (n_estimators = 100 , max_depth = 1 , warm_start = True )
653652 est_ws .fit (X , y )
654- est_ws .set_params (n_estimators = 200 )
653+ est_ws .set_params (n_estimators = 300 )
655654 est_ws .fit (X , y )
656655
657656 assert_array_almost_equal (est_ws .predict (X ), est .predict (X ))
@@ -663,12 +662,13 @@ def test_warm_start_max_depth():
663662 for cls in [GradientBoostingRegressor , GradientBoostingClassifier ]:
664663 est = cls (n_estimators = 100 , max_depth = 1 , warm_start = True )
665664 est .fit (X , y )
666- est .set_params (n_estimators = 10 , max_depth = 2 )
665+ est .set_params (n_estimators = 110 , max_depth = 2 )
667666 est .fit (X , y )
668667
669668 # last 10 trees have different depth
670669 assert est .estimators_ [0 , 0 ].max_depth == 1
671- assert est .estimators_ [- 1 , 0 ].max_depth == 2
670+ for i in range (1 , 11 ):
671+ assert est .estimators_ [- i , 0 ].max_depth == 2
672672
673673
674674def test_warm_start_clear ():
@@ -696,16 +696,43 @@ def test_warm_start_zero_n_estimators():
696696 assert_raises (ValueError , est .fit , X , y )
697697
698698
699+ def test_warm_start_smaller_n_estimators ():
700+ """Test if warm start with smaller n_estimators raises error """
701+ X , y = datasets .make_hastie_10_2 (n_samples = 100 , random_state = 1 )
702+ for cls in [GradientBoostingRegressor , GradientBoostingClassifier ]:
703+ est = cls (n_estimators = 100 , max_depth = 1 , warm_start = True )
704+ est .fit (X , y )
705+ est .set_params (n_estimators = 99 )
706+ assert_raises (ValueError , est .fit , X , y )
707+
708+
709+ def test_warm_start_equal_n_estimators ():
710+ """Test if warm start with equal n_estimators does nothing """
711+ X , y = datasets .make_hastie_10_2 (n_samples = 100 , random_state = 1 )
712+ for cls in [GradientBoostingRegressor , GradientBoostingClassifier ]:
713+ est = cls (n_estimators = 100 , max_depth = 1 )
714+ est .fit (X , y )
715+
716+ est2 = clone (est )
717+ est2 .set_params (n_estimators = est .n_estimators , warm_start = True )
718+ est2 .fit (X , y )
719+
720+ assert_array_almost_equal (est2 .predict (X ), est .predict (X ))
721+
722+
699723def test_warm_start_oob_switch ():
700724 """Test if oob can be turned on during warm start. """
701725 X , y = datasets .make_hastie_10_2 (n_samples = 100 , random_state = 1 )
702726 for cls in [GradientBoostingRegressor , GradientBoostingClassifier ]:
703727 est = cls (n_estimators = 100 , max_depth = 1 , warm_start = True )
704728 est .fit (X , y )
705- est .set_params (n_estimators = 10 , subsample = 0.5 )
729+ est .set_params (n_estimators = 110 , subsample = 0.5 )
706730 est .fit (X , y )
707731
708- assert_array_equal (est .oob_improvement_ [:10 ], np .zeros (10 ))
732+ assert_array_equal (est .oob_improvement_ [:100 ], np .zeros (100 ))
733+ # the last 10 are not zeros
734+ assert_array_equal (est .oob_improvement_ [- 10 :] == 0.0 ,
735+ np .zeros (10 , dtype = np .bool ))
709736
710737
711738def test_warm_start_oob ():
@@ -719,6 +746,7 @@ def test_warm_start_oob():
719746 est_ws = cls (n_estimators = 100 , max_depth = 1 , subsample = 0.5 ,
720747 random_state = 1 , warm_start = True )
721748 est_ws .fit (X , y )
749+ est_ws .set_params (n_estimators = 200 )
722750 est_ws .fit (X , y )
723751
724752 assert_array_almost_equal (est_ws .oob_improvement_ [:100 ],
0 commit comments