Skip to content

Commit 36ab692

Browse files
committed
Pre-initialize all trees before dispatching
1 parent 904a526 commit 36ab692

File tree

2 files changed

+13
-20
lines changed

2 files changed

+13
-20
lines changed

sklearn/ensemble/forest.py

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -67,19 +67,11 @@ class calls the ``fit`` method of each sub-estimator on random samples
6767
MAX_INT = np.iinfo(np.int32).max
6868

6969

70-
def _parallel_build_trees(n_trees, forest, X, y,
71-
sample_weight, seeds, verbose):
70+
def _parallel_build_trees(trees, forest, X, y, sample_weight, verbose):
7271
"""Private function used to build a batch of trees within a job."""
73-
trees = []
74-
75-
for i in range(n_trees):
76-
random_state = check_random_state(seeds[i])
72+
for i, tree in enumerate(trees):
7773
if verbose > 1:
78-
print("building tree %d of %d" % (i + 1, n_trees))
79-
seed = random_state.randint(MAX_INT)
80-
81-
tree = forest._make_estimator(append=False)
82-
tree.set_params(random_state=seed)
74+
print("building tree %d of %d" % (i + 1, len(trees)))
8375

8476
if forest.bootstrap:
8577
n_samples = X.shape[0]
@@ -88,6 +80,7 @@ def _parallel_build_trees(n_trees, forest, X, y,
8880
else:
8981
curr_sample_weight = sample_weight.copy()
9082

83+
random_state = check_random_state(tree.random_state)
9184
indices = random_state.randint(0, n_samples, n_samples)
9285
sample_counts = bincount(indices, minlength=n_samples)
9386
curr_sample_weight *= sample_counts
@@ -103,8 +96,6 @@ def _parallel_build_trees(n_trees, forest, X, y,
10396
sample_weight=sample_weight,
10497
check_input=False)
10598

106-
trees.append(tree)
107-
10899
return trees
109100

110101

@@ -264,10 +255,13 @@ def fit(self, X, y, sample_weight=None):
264255
" if bootstrap=True")
265256

266257
# Assign chunk of trees to jobs
267-
n_jobs, n_trees, _ = _partition_estimators(self)
258+
n_jobs, n_trees, starts = _partition_estimators(self)
259+
trees = []
268260

269-
# Precalculate the random states
270-
seeds = [random_state.randint(MAX_INT, size=i) for i in n_trees]
261+
for i in range(self.n_estimators):
262+
tree = self._make_estimator(append=False)
263+
tree.set_params(random_state=random_state.randint(MAX_INT))
264+
trees.append(tree)
271265

272266
# Free allocated memory, if any
273267
self.estimators_ = None
@@ -278,12 +272,11 @@ def fit(self, X, y, sample_weight=None):
278272
all_trees = Parallel(n_jobs=n_jobs, verbose=self.verbose,
279273
backend="threading")(
280274
delayed(_parallel_build_trees)(
281-
n_trees[i],
275+
trees[starts[i]:starts[i + 1]],
282276
self,
283277
X,
284278
y,
285279
sample_weight,
286-
seeds[i],
287280
verbose=self.verbose)
288281
for i in range(n_jobs))
289282

sklearn/ensemble/tests/test_forest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -474,9 +474,9 @@ def test_distribution():
474474
# Single variable with 4 values
475475
X = rng.randint(0, 4, size=(1000, 1))
476476
y = rng.rand(1000)
477-
n_trees = 200
477+
n_trees = 500
478478

479-
clf = ExtraTreesRegressor(n_estimators=n_trees, random_state=1).fit(X, y)
479+
clf = ExtraTreesRegressor(n_estimators=n_trees, random_state=42).fit(X, y)
480480

481481
uniques = defaultdict(int)
482482
for tree in clf.estimators_:

0 commit comments

Comments
 (0)