Skip to content

Commit 8b08c09

Browse files
authored
MNT remove sklearn inheritance from MultiTaskLasso (#80)
1 parent 3d1f524 commit 8b08c09

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

skglm/estimators.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
from sklearn.utils.validation import check_is_fitted
1010
from sklearn.utils import check_array, check_consistent_length
11-
from sklearn.linear_model import MultiTaskLasso as MultiTaskLasso_sklearn
1211
from sklearn.linear_model._base import (
1312
LinearModel, RegressorMixin,
1413
LinearClassifierMixin, SparseCoefMixin, BaseEstimator
@@ -1126,8 +1125,7 @@ def fit(self, X, y):
11261125
# TODO add predict_proba for LinearSVC
11271126

11281127

1129-
# TODO we should no longer inherit from sklearn
1130-
class MultiTaskLasso(MultiTaskLasso_sklearn):
1128+
class MultiTaskLasso(LinearModel, RegressorMixin):
11311129
r"""MultiTaskLasso estimator.
11321130
11331131
The optimization objective for MultiTaskLasso is::
@@ -1139,6 +1137,9 @@ class MultiTaskLasso(MultiTaskLasso_sklearn):
11391137
alpha : float, optional
11401138
Regularization strength (constant that multiplies the L21 penalty).
11411139
1140+
copy_X : bool, optional (default=True)
1141+
If True, X will be copied; else, it may be overwritten.
1142+
11421143
max_iter : int, optional
11431144
The maximum number of iterations (subproblem definitions).
11441145
@@ -1179,12 +1180,14 @@ class MultiTaskLasso(MultiTaskLasso_sklearn):
11791180
Number of subproblems solved by Celer to reach the specified tolerance.
11801181
"""
11811182

1182-
def __init__(self, alpha=1., max_iter=50, max_epochs=50_000, p0=10,
1183+
def __init__(self, alpha=1., copy_X=True, max_iter=50, max_epochs=50_000, p0=10,
11831184
verbose=0, tol=1e-4, fit_intercept=True, warm_start=False,
11841185
ws_strategy="subdiff"):
1185-
super().__init__(
1186-
alpha=alpha, tol=tol,
1187-
fit_intercept=fit_intercept, warm_start=warm_start)
1186+
self.tol = tol
1187+
self.alpha = alpha
1188+
self.copy_X = copy_X
1189+
self.warm_start = warm_start
1190+
self.fit_intercept = fit_intercept
11881191
self.max_iter = max_iter
11891192
self.p0 = p0
11901193
self.ws_strategy = ws_strategy

0 commit comments

Comments
 (0)