@@ -31,11 +31,15 @@ def gram_lasso(X, y, alpha, max_iter, tol, w_init=None, weights=None, check_freq
31
31
lipschitz = np .zeros (n_features , dtype = X .dtype )
32
32
for j in range (n_features ):
33
33
lipschitz [j ] = (X [:, j ] ** 2 ).sum () / len (y )
34
- w = w_init if w_init is not None else np .zeros (n_features )
34
+ w = w_init .copy () if w_init is not None else np .zeros (n_features )
35
+ z = w_init .copy () if w_init is not None else np .zeros (n_features )
36
+ beta_0 = beta_1 = 1
35
37
weights = weights if weights is not None else np .ones (n_features )
36
38
# CD
37
39
for n_iter in range (max_iter ):
38
- cd_epoch (X , G , grads , w , alpha , lipschitz , weights )
40
+ beta_1 = (1 + np .sqrt (1 + 4 * beta_0 ** 2 )) / 2
41
+ cd_epoch (X , G , grads , w , z , alpha , beta_1 , beta_0 , lipschitz , weights )
42
+ beta_0 = beta_1
39
43
if n_iter % check_freq == 0 :
40
44
p_obj = primal (alpha , y , X , w , weights )
41
45
if p_obj_prev - p_obj < tol :
@@ -58,7 +62,7 @@ def gram_group_lasso(X, y, alpha, groups, max_iter, tol, w_init=None, weights=No
58
62
for g in range (n_groups ):
59
63
X_g = X [:, grp_indices [grp_ptr [g ]:grp_ptr [g + 1 ]]]
60
64
lipschitz [g ] = norm (X_g , ord = 2 ) ** 2 / len (y )
61
- w = w_init if w_init is not None else np .zeros (n_features )
65
+ w = w_init . copy () if w_init is not None else np .zeros (n_features )
62
66
weights = weights if weights is not None else np .ones (n_groups )
63
67
# BCD
64
68
for n_iter in range (max_iter ):
@@ -74,15 +78,17 @@ def gram_group_lasso(X, y, alpha, groups, max_iter, tol, w_init=None, weights=No
74
78
75
79
76
80
@njit
77
- def cd_epoch (X , G , grads , w , alpha , lipschitz , weights ):
81
+ def cd_epoch (X , G , grads , w , z , alpha , beta_1 , beta_0 , lipschitz , weights ):
78
82
n_features = X .shape [1 ]
79
83
for j in range (n_features ):
80
84
if lipschitz [j ] == 0. or weights [j ] == np .inf :
81
85
continue
82
86
old_w_j = w [j ]
83
- w [j ] = ST (w [j ] + grads [j ] / lipschitz [j ], alpha / lipschitz [j ] * weights [j ])
84
- if old_w_j != w [j ]:
85
- grads += G [j , :] * (old_w_j - w [j ]) / len (X )
87
+ old_z_j = z [j ]
88
+ w [j ] = ST (z [j ] + grads [j ] / lipschitz [j ], alpha / lipschitz [j ] * weights [j ])
89
+ z [j ] = w [j ] + ((beta_0 - 1 ) / beta_1 ) * (w [j ] - old_w_j )
90
+ if old_z_j != z [j ]:
91
+ grads += G [j , :] * (old_z_j - z [j ]) / len (X )
86
92
87
93
88
94
@njit
0 commit comments