@@ -67,19 +67,11 @@ class calls the ``fit`` method of each sub-estimator on random samples
6767MAX_INT = np .iinfo (np .int32 ).max
6868
6969
70- def _parallel_build_trees (n_trees , forest , X , y ,
71- sample_weight , seeds , verbose ):
70+ def _parallel_build_trees (trees , forest , X , y , sample_weight , verbose ):
7271 """Private function used to build a batch of trees within a job."""
73- trees = []
74-
75- for i in range (n_trees ):
76- random_state = check_random_state (seeds [i ])
72+ for i , tree in enumerate (trees ):
7773 if verbose > 1 :
78- print ("building tree %d of %d" % (i + 1 , n_trees ))
79- seed = random_state .randint (MAX_INT )
80-
81- tree = forest ._make_estimator (append = False )
82- tree .set_params (random_state = seed )
74+ print ("building tree %d of %d" % (i + 1 , len (trees )))
8375
8476 if forest .bootstrap :
8577 n_samples = X .shape [0 ]
@@ -88,6 +80,7 @@ def _parallel_build_trees(n_trees, forest, X, y,
8880 else :
8981 curr_sample_weight = sample_weight .copy ()
9082
83+ random_state = check_random_state (tree .random_state )
9184 indices = random_state .randint (0 , n_samples , n_samples )
9285 sample_counts = bincount (indices , minlength = n_samples )
9386 curr_sample_weight *= sample_counts
@@ -103,8 +96,6 @@ def _parallel_build_trees(n_trees, forest, X, y,
10396 sample_weight = sample_weight ,
10497 check_input = False )
10598
106- trees .append (tree )
107-
10899 return trees
109100
110101
@@ -264,10 +255,13 @@ def fit(self, X, y, sample_weight=None):
264255 " if bootstrap=True" )
265256
266257 # Assign chunk of trees to jobs
267- n_jobs , n_trees , _ = _partition_estimators (self )
258+ n_jobs , n_trees , starts = _partition_estimators (self )
259+ trees = []
268260
269- # Precalculate the random states
270- seeds = [random_state .randint (MAX_INT , size = i ) for i in n_trees ]
261+ for i in range (self .n_estimators ):
262+ tree = self ._make_estimator (append = False )
263+ tree .set_params (random_state = random_state .randint (MAX_INT ))
264+ trees .append (tree )
271265
272266 # Free allocated memory, if any
273267 self .estimators_ = None
@@ -278,12 +272,11 @@ def fit(self, X, y, sample_weight=None):
278272 all_trees = Parallel (n_jobs = n_jobs , verbose = self .verbose ,
279273 backend = "threading" )(
280274 delayed (_parallel_build_trees )(
281- n_trees [ i ],
275+ trees [ starts [ i ]: starts [ i + 1 ] ],
282276 self ,
283277 X ,
284278 y ,
285279 sample_weight ,
286- seeds [i ],
287280 verbose = self .verbose )
288281 for i in range (n_jobs ))
289282
0 commit comments