8
8
9
9
from sklearn .utils .validation import check_is_fitted
10
10
from sklearn .utils import check_array , check_consistent_length
11
- from sklearn .linear_model import MultiTaskLasso as MultiTaskLasso_sklearn
12
11
from sklearn .linear_model ._base import (
13
12
LinearModel , RegressorMixin ,
14
13
LinearClassifierMixin , SparseCoefMixin , BaseEstimator
@@ -1126,8 +1125,7 @@ def fit(self, X, y):
1126
1125
# TODO add predict_proba for LinearSVC
1127
1126
1128
1127
1129
- # TODO we should no longer inherit from sklearn
1130
- class MultiTaskLasso (MultiTaskLasso_sklearn ):
1128
+ class MultiTaskLasso (LinearModel , RegressorMixin ):
1131
1129
r"""MultiTaskLasso estimator.
1132
1130
1133
1131
The optimization objective for MultiTaskLasso is::
@@ -1139,6 +1137,9 @@ class MultiTaskLasso(MultiTaskLasso_sklearn):
1139
1137
alpha : float, optional
1140
1138
Regularization strength (constant that multiplies the L21 penalty).
1141
1139
1140
+ copy_X : bool, optional (default=True)
1141
+ If True, X will be copied; else, it may be overwritten.
1142
+
1142
1143
max_iter : int, optional
1143
1144
The maximum number of iterations (subproblem definitions).
1144
1145
@@ -1179,12 +1180,14 @@ class MultiTaskLasso(MultiTaskLasso_sklearn):
1179
1180
Number of subproblems solved by Celer to reach the specified tolerance.
1180
1181
"""
1181
1182
1182
- def __init__ (self , alpha = 1. , max_iter = 50 , max_epochs = 50_000 , p0 = 10 ,
1183
+ def __init__ (self , alpha = 1. , copy_X = True , max_iter = 50 , max_epochs = 50_000 , p0 = 10 ,
1183
1184
verbose = 0 , tol = 1e-4 , fit_intercept = True , warm_start = False ,
1184
1185
ws_strategy = "subdiff" ):
1185
- super ().__init__ (
1186
- alpha = alpha , tol = tol ,
1187
- fit_intercept = fit_intercept , warm_start = warm_start )
1186
+ self .tol = tol
1187
+ self .alpha = alpha
1188
+ self .copy_X = copy_X
1189
+ self .warm_start = warm_start
1190
+ self .fit_intercept = fit_intercept
1188
1191
self .max_iter = max_iter
1189
1192
self .p0 = p0
1190
1193
self .ws_strategy = ws_strategy
0 commit comments