Skip to content

Commit 7a595b8

Browse files
committed
COSMIT: simplify parallel code in multiclass
1 parent d0feab0 commit 7a595b8

File tree

1 file changed

+10
-27
lines changed

1 file changed

+10
-27
lines changed

sklearn/multiclass.py

Lines changed: 10 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -75,25 +75,16 @@ def _check_estimator(estimator):
7575
"decision_function or predict_proba!")
7676

7777

78-
def fit_ovr(estimator, X, y, n_jobs):
78+
def fit_ovr(estimator, X, y, n_jobs=1):
7979
"""Fit a one-vs-the-rest strategy."""
8080
_check_estimator(estimator)
8181

8282
lb = LabelBinarizer()
8383
Y = lb.fit_transform(y)
84-
classes = []
85-
for i in range(Y.shape[1]):
86-
classes.append(["not %s" % i, i])
87-
88-
if n_jobs == 1:
89-
estimators = [_fit_binary(estimator, X, Y[:, i],
90-
classes=classes[i])
91-
for i in range(Y.shape[1])]
92-
else:
93-
estimators = Parallel(n_jobs=n_jobs)(
94-
delayed(_fit_binary)(estimator, X, Y[:, i],
95-
classes=classes[i])
96-
for i in range(Y.shape[1]))
84+
85+
estimators = Parallel(n_jobs=n_jobs)(
86+
delayed(_fit_binary)(estimator, X, Y[:, i], classes=["not %s" % i, i])
87+
for i in range(Y.shape[1]))
9788
return estimators, lb
9889

9990

@@ -296,15 +287,11 @@ def _fit_ovo_binary(estimator, X, y, i, j):
296287
return _fit_binary(estimator, X[ind[cond]], y, classes=[i, j])
297288

298289

299-
def fit_ovo(estimator, X, y, n_jobs):
290+
def fit_ovo(estimator, X, y, n_jobs=1):
300291
"""Fit a one-vs-one strategy."""
301292
classes = np.unique(y)
302293
n_classes = classes.shape[0]
303-
if n_jobs == 1:
304-
estimators = [_fit_ovo_binary(estimator, X, y, classes[i], classes[j])
305-
for i in range(n_classes) for j in range(i + 1, n_classes)]
306-
else:
307-
estimators = Parallel(n_jobs=n_jobs)(
294+
estimators = Parallel(n_jobs=n_jobs)(
308295
delayed(_fit_ovo_binary)(estimator, X, y, classes[i], classes[j])
309296
for i in range(n_classes) for j in range(i + 1, n_classes))
310297

@@ -455,13 +442,9 @@ def fit_ecoc(estimator, X, y, code_size=1.5, random_state=None, n_jobs=1):
455442
Y = np.array([code_book[cls_idx[y[i]]] for i in xrange(X.shape[0])],
456443
dtype=np.int)
457444

458-
if n_jobs == 1:
459-
estimators = [_fit_binary(estimator, X, Y[:, i])
460-
for i in range(Y.shape[1])]
461-
else:
462-
estimators = Parallel(n_jobs=n_jobs)(
463-
delayed(_fit_binary)(estimator, X, Y[:, i])
464-
for i in range(Y.shape[1]))
445+
estimators = Parallel(n_jobs=n_jobs)(
446+
delayed(_fit_binary)(estimator, X, Y[:, i])
447+
for i in range(Y.shape[1]))
465448

466449
return estimators, classes, code_book
467450

0 commit comments

Comments
 (0)