Skip to content

Commit a7db68b

Browse files
committed
WIP FISTA
1 parent 51b4cfe commit a7db68b

File tree

1 file changed

+13
-7
lines changed

1 file changed

+13
-7
lines changed

skglm/solvers/gram.py

+13-7
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,15 @@ def gram_lasso(X, y, alpha, max_iter, tol, w_init=None, weights=None, check_freq
3131
lipschitz = np.zeros(n_features, dtype=X.dtype)
3232
for j in range(n_features):
3333
lipschitz[j] = (X[:, j] ** 2).sum() / len(y)
34-
w = w_init if w_init is not None else np.zeros(n_features)
34+
w = w_init.copy() if w_init is not None else np.zeros(n_features)
35+
z = w_init.copy() if w_init is not None else np.zeros(n_features)
36+
beta_0 = beta_1 = 1
3537
weights = weights if weights is not None else np.ones(n_features)
3638
# CD
3739
for n_iter in range(max_iter):
38-
cd_epoch(X, G, grads, w, alpha, lipschitz, weights)
40+
beta_1 = (1 + np.sqrt(1 + 4 * beta_0 ** 2)) / 2
41+
cd_epoch(X, G, grads, w, z, alpha, beta_1, beta_0, lipschitz, weights)
42+
beta_0 = beta_1
3943
if n_iter % check_freq == 0:
4044
p_obj = primal(alpha, y, X, w, weights)
4145
if p_obj_prev - p_obj < tol:
@@ -58,7 +62,7 @@ def gram_group_lasso(X, y, alpha, groups, max_iter, tol, w_init=None, weights=No
5862
for g in range(n_groups):
5963
X_g = X[:, grp_indices[grp_ptr[g]:grp_ptr[g + 1]]]
6064
lipschitz[g] = norm(X_g, ord=2) ** 2 / len(y)
61-
w = w_init if w_init is not None else np.zeros(n_features)
65+
w = w_init.copy() if w_init is not None else np.zeros(n_features)
6266
weights = weights if weights is not None else np.ones(n_groups)
6367
# BCD
6468
for n_iter in range(max_iter):
@@ -74,15 +78,17 @@ def gram_group_lasso(X, y, alpha, groups, max_iter, tol, w_init=None, weights=No
7478

7579

7680
@njit
77-
def cd_epoch(X, G, grads, w, alpha, lipschitz, weights):
81+
def cd_epoch(X, G, grads, w, z, alpha, beta_1, beta_0, lipschitz, weights):
7882
n_features = X.shape[1]
7983
for j in range(n_features):
8084
if lipschitz[j] == 0. or weights[j] == np.inf:
8185
continue
8286
old_w_j = w[j]
83-
w[j] = ST(w[j] + grads[j] / lipschitz[j], alpha / lipschitz[j] * weights[j])
84-
if old_w_j != w[j]:
85-
grads += G[j, :] * (old_w_j - w[j]) / len(X)
87+
old_z_j = z[j]
88+
w[j] = ST(z[j] + grads[j] / lipschitz[j], alpha / lipschitz[j] * weights[j])
89+
z[j] = w[j] + ((beta_0 - 1) / beta_1) * (w[j] - old_w_j)
90+
if old_z_j != z[j]:
91+
grads += G[j, :] * (old_z_j - z[j]) / len(X)
8692

8793

8894
@njit

0 commit comments

Comments
 (0)