Skip to content

Commit 3b80cf6

Browse files
mathurinmKlopfeQB3Badr-MOUFADPABannier
authored
ENH fit intercept inside cd_solver (#55)
Co-authored-by: Klopfe <[email protected]> Co-authored-by: QB3 <[email protected]> Co-authored-by: Badr MOUFAD <[email protected]> Co-authored-by: Pierre-Antoine Bannier <[email protected]>
1 parent 92e1266 commit 3b80cf6

13 files changed

+362
-162
lines changed

examples/plot_lasso_vs_weighted.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@
3434
alpha_max = np.max(np.abs(X.T @ y)) / len(y)
3535
alpha = alpha_max / 10
3636
las = Lasso(alpha=alpha, fit_intercept=False).fit(X, y)
37-
wei = WeightedLasso(alpha=alpha, weights=norm(X, axis=0)).fit(X, y)
37+
wei = WeightedLasso(
38+
alpha=alpha, weights=norm(X, axis=0), fit_intercept=False).fit(X, y)
3839

3940

4041
fig, axarr = plt.subplots(1, 3, sharey=True, figsize=(10, 2.4))

skglm/datafits/group.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,3 +68,6 @@ def gradient_g(self, X, y, w, Xw, g):
6868

6969
def gradient_scalar(self, X, y, w, Xw, j):
7070
return X[:, j] @ (Xw - y) / len(y)
71+
72+
def intercept_update_step(self, y, Xw):
73+
return np.mean(Xw - y)

skglm/datafits/multi_task.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,3 +91,6 @@ def full_grad_sparse(self, X_data, X_indptr, X_indices, Y, XW):
9191
XjTXW[t] += X_data[i] * XW[X_indices[i], t]
9292
grad[j, :] = (XjTXW - self.XtY[j, :]) / n_samples
9393
return grad
94+
95+
def intercept_update_step(self, Y, XW):
96+
return np.sum(XW - Y, axis=0) / len(Y)

skglm/datafits/single_task.py

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@ def full_grad_sparse(
8787
grad[j] = (XjTXw - self.Xty[j]) / n_samples
8888
return grad
8989

90+
def intercept_update_step(self, y, Xw):
91+
return np.mean(Xw - y)
92+
9093

9194
@njit
9295
def sigmoid(x):
@@ -169,6 +172,9 @@ def gradient_scalar_sparse(self, X_data, X_indptr, X_indices, y, Xw, j):
169172
grad -= X_data[i] * y[idx_i] * sigmoid(- y[idx_i] * Xw[idx_i])
170173
return grad / len(Xw)
171174

175+
def intercept_update_step(self, y, Xw):
176+
return np.mean(- y * sigmoid(- y * Xw)) / 4
177+
172178

173179
class QuadraticSVC(BaseDatafit):
174180
"""A Quadratic SVC datafit used for classification tasks.
@@ -300,32 +306,32 @@ def value(self, y, w, Xw):
300306
n_samples = len(y)
301307
res = 0.
302308
for i in range(n_samples):
303-
tmp = abs(y[i] - Xw[i])
304-
if tmp < self.delta:
305-
res += 0.5 * tmp ** 2
309+
residual = abs(y[i] - Xw[i])
310+
if residual < self.delta:
311+
res += 0.5 * residual ** 2
306312
else:
307-
res += self.delta * tmp - 0.5 * self.delta ** 2
313+
res += self.delta * residual - 0.5 * self.delta ** 2
308314
return res / n_samples
309315

310316
def gradient_scalar(self, X, y, w, Xw, j):
311317
n_samples = len(y)
312318
grad_j = 0.
313319
for i in range(n_samples):
314-
tmp = y[i] - Xw[i]
315-
if abs(tmp) < self.delta:
316-
grad_j += - X[i, j] * tmp
320+
residual = y[i] - Xw[i]
321+
if abs(residual) < self.delta:
322+
grad_j += - X[i, j] * residual
317323
else:
318-
grad_j += - X[i, j] * np.sign(tmp) * self.delta
324+
grad_j += - X[i, j] * np.sign(residual) * self.delta
319325
return grad_j / n_samples
320326

321327
def gradient_scalar_sparse(self, X_data, X_indptr, X_indices, y, Xw, j):
322328
grad_j = 0.
323329
for i in range(X_indptr[j], X_indptr[j + 1]):
324-
tmp = y[X_indices[i]] - Xw[X_indices[i]]
325-
if np.abs(tmp) < self.delta:
326-
grad_j += - X_data[i] * tmp
330+
residual = y[X_indices[i]] - Xw[X_indices[i]]
331+
if np.abs(residual) < self.delta:
332+
grad_j += - X_data[i] * residual
327333
else:
328-
grad_j += - X_data[i] * np.sign(tmp) * self.delta
334+
grad_j += - X_data[i] * np.sign(residual) * self.delta
329335
return grad_j / len(Xw)
330336

331337
def full_grad_sparse(
@@ -336,10 +342,21 @@ def full_grad_sparse(
336342
for j in range(n_features):
337343
grad_j = 0.
338344
for i in range(X_indptr[j], X_indptr[j + 1]):
339-
tmp = y[X_indices[i]] - Xw[X_indices[i]]
340-
if np.abs(tmp) < self.delta:
341-
grad_j += - X_data[i] * tmp
345+
residual = y[X_indices[i]] - Xw[X_indices[i]]
346+
if np.abs(residual) < self.delta:
347+
grad_j += - X_data[i] * residual
342348
else:
343-
grad_j += - X_data[i] * np.sign(tmp) * self.delta
349+
grad_j += - X_data[i] * np.sign(residual) * self.delta
344350
grad[j] = grad_j / n_samples
345351
return grad
352+
353+
def intercept_update_step(self, y, Xw):
354+
n_samples = len(y)
355+
update = 0.
356+
for i in range(n_samples):
357+
residual = y[i] - Xw[i]
358+
if abs(residual) < self.delta:
359+
update -= residual
360+
else:
361+
update -= np.sign(residual) * self.delta
362+
return update / n_samples

skglm/estimators.py

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from sklearn.utils import check_array, check_consistent_length
1010
from sklearn.linear_model import MultiTaskLasso as MultiTaskLasso_sklearn
1111
from sklearn.linear_model._base import (
12-
_preprocess_data, LinearModel, RegressorMixin,
12+
LinearModel, RegressorMixin,
1313
LinearClassifierMixin, SparseCoefMixin, BaseEstimator
1414
)
1515
from sklearn.utils.extmath import softmax
@@ -98,6 +98,8 @@ def _glm_fit(X, y, model, datafit, penalty):
9898
else:
9999
X_ = X
100100

101+
n_samples, n_features = X_.shape
102+
101103
penalty_jit = compiled_clone(penalty)
102104
datafit_jit = compiled_clone(datafit, to_float32=X.dtype == np.float32)
103105
if issparse(X):
@@ -112,22 +114,24 @@ def _glm_fit(X, y, model, datafit, penalty):
112114
w = model.coef_[0, :].copy()
113115
else:
114116
w = model.coef_.copy()
115-
Xw = X_ @ w
117+
if model.fit_intercept:
118+
w = np.hstack([w, model.intercept_])
119+
Xw = X_ @ w[:w.shape[0] - model.fit_intercept] + model.fit_intercept * w[-1]
116120
else:
117121
# TODO this should be solver.get_init() do delegate the work
118122
if y.ndim == 1:
119-
w = np.zeros(X_.shape[1], dtype=X_.dtype)
120-
Xw = np.zeros(X_.shape[0], dtype=X_.dtype)
123+
w = np.zeros(n_features + model.fit_intercept, dtype=X_.dtype)
124+
Xw = np.zeros(n_samples, dtype=X_.dtype)
121125
else: # multitask
122-
w = np.zeros((X_.shape[1], y.shape[1]), dtype=X_.dtype)
126+
w = np.zeros((n_features + model.fit_intercept, y.shape[1]), dtype=X_.dtype)
123127
Xw = np.zeros(y.shape, dtype=X_.dtype)
124128

125129
# check consistency of weights for WeightedL1
126130
if isinstance(penalty, WeightedL1):
127-
if len(penalty.weights) != X.shape[1]:
131+
if len(penalty.weights) != n_features:
128132
raise ValueError(
129-
"The size of the WeightedL1 penalty weights should be n_features, \
130-
expected %i, got %i" % (X_.shape[1], len(penalty.weights)))
133+
"The size of the WeightedL1 penalty weights should be n_features, "
134+
"expected %i, got %i." % (X_.shape[1], len(penalty.weights)))
131135

132136
if is_classif:
133137
solver = cd_solver # TODO to be be replaced by an instance of BaseSolver
@@ -141,15 +145,19 @@ def _glm_fit(X, y, model, datafit, penalty):
141145
coefs, p_obj, kkt = solver(
142146
X_, y, datafit_jit, penalty_jit, w, Xw, max_iter=model.max_iter,
143147
max_epochs=model.max_epochs, p0=model.p0,
144-
tol=model.tol, # ws_strategy=model.ws_strategy,
148+
tol=model.tol, fit_intercept=model.fit_intercept,
145149
verbose=model.verbose)
150+
model.coef_, model.stop_crit_ = coefs[:n_features], kkt
151+
if y.ndim == 1:
152+
model.intercept_ = coefs[-1] if model.fit_intercept else 0.
153+
else:
154+
model.intercept_ = coefs[-1, :] if model.fit_intercept else np.zeros(
155+
y.shape[1])
146156

147-
model.coef_, model.stop_crit_ = coefs, kkt
148157
model.n_iter_ = len(p_obj)
149-
model.intercept_ = 0.
150158

151159
if is_classif and n_classes_ <= 2:
152-
model.coef_ = coefs[np.newaxis, :]
160+
model.coef_ = coefs[np.newaxis, :n_features]
153161
if isinstance(datafit, QuadraticSVC):
154162
if is_sparse:
155163
primal_coef = ((yXT).multiply(model.coef_[0, :])).T
@@ -1212,6 +1220,7 @@ def fit(self, X, y):
12121220
# TODO add predict_proba for LinearSVC
12131221

12141222

1223+
# TODO we should no longer inherit from sklearn
12151224
class MultiTaskLasso(MultiTaskLasso_sklearn):
12161225
r"""MultiTaskLasso estimator.
12171226
@@ -1291,7 +1300,6 @@ def fit(self, X, Y):
12911300
self :
12921301
The fitted estimator.
12931302
"""
1294-
# TODO check if we could just patch `bcd_solver_path` as we do in Lasso case.
12951303
# Below is copied from sklearn, with path replaced by our path.
12961304
# Need to validate separately here.
12971305
# We can't pass multi_output=True because that would allow y to be csr.
@@ -1312,9 +1320,10 @@ def fit(self, X, Y):
13121320
raise ValueError("X and Y have inconsistent dimensions (%d != %d)"
13131321
% (n_samples, Y.shape[0]))
13141322

1315-
X, Y, X_offset, Y_offset, X_scale = _preprocess_data(
1316-
X, Y, self.fit_intercept, copy=False)
1323+
# X, Y, X_offset, Y_offset, X_scale = _preprocess_data(
1324+
# X, Y, self.fit_intercept, copy=False)
13171325

1326+
# TODO handle and test warm start for MTL
13181327
if not self.warm_start or not hasattr(self, "coef_"):
13191328
self.coef_ = None
13201329

@@ -1324,9 +1333,10 @@ def fit(self, X, Y):
13241333
max_epochs=self.max_epochs, p0=self.p0, verbose=self.verbose,
13251334
tol=self.tol)
13261335

1327-
self.coef_, self.dual_gap_ = coefs[..., 0], kkt[-1]
1336+
self.coef_ = coefs[:, :X.shape[1], 0]
1337+
self.intercept_ = self.fit_intercept * coefs[:, -1, 0]
1338+
self.stopping_crit = kkt[-1]
13281339
self.n_iter_ = len(kkt)
1329-
self._set_intercept(X_offset, Y_offset, X_scale)
13301340

13311341
return self
13321342

@@ -1368,4 +1378,5 @@ def path(self, X, Y, alphas, coef_init=None, **params):
13681378
penalty = compiled_clone(self.penalty)
13691379

13701380
return multitask_bcd_solver_path(X, Y, datafit, penalty, alphas=alphas,
1371-
coef_init=coef_init, **params)
1381+
coef_init=coef_init,
1382+
fit_intercept=self.fit_intercept, tol=self.tol)

0 commit comments

Comments
 (0)