Skip to content

ENH Add FISTA solver #91

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 31 commits into from
Oct 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
0868b0f
POC FISTA
PABannier Oct 12, 2022
8584299
CLN
PABannier Oct 14, 2022
c82e32e
changed obj_freq from 100 to 10
PABannier Oct 14, 2022
4940a0d
WIP Lipschitz
PABannier Oct 14, 2022
e47c68a
ADD global lipschitz constants
PABannier Oct 14, 2022
3635f24
FISTA with global lipschitz
PABannier Oct 14, 2022
4880112
writing tests
PABannier Oct 14, 2022
46a9a76
better tests
PABannier Oct 14, 2022
9f0653a
support sparse matrices
PABannier Oct 14, 2022
fe159be
fix mistake
PABannier Oct 14, 2022
8e74e8a
RM toy_fista
PABannier Oct 14, 2022
a24ed9c
green
PABannier Oct 14, 2022
4362c2c
mv `_prox_vec` to utils
PABannier Oct 16, 2022
2665d5d
rm `opt_freq`
PABannier Oct 16, 2022
2e408bc
fix tests
PABannier Oct 16, 2022
8524cf7
Update skglm/solvers/fista.py
PABannier Oct 16, 2022
dd658f8
huber comment
PABannier Oct 16, 2022
7c9fbe1
Merge branch 'fista' of https://github.com/PABannier/skglm into fista
PABannier Oct 16, 2022
cbc5418
WIP
PABannier Oct 16, 2022
b6c664c
Merge branch 'main' of https://github.com/scikit-learn-contrib/skglm …
Badr-MOUFAD Oct 20, 2022
e76dfb1
implement power method
Badr-MOUFAD Oct 20, 2022
2a4bce3
private ``prox_vec``
Badr-MOUFAD Oct 20, 2022
cd39a62
random init in power method && default args
Badr-MOUFAD Oct 21, 2022
0e4d42a
use power method for ``global_lipschitz``
Badr-MOUFAD Oct 21, 2022
2bbc8f5
fix && refactor unittest
Badr-MOUFAD Oct 21, 2022
ed3686a
add docs for tol and max_iter && clean ups
Badr-MOUFAD Oct 21, 2022
aa15c46
remove square form spectral norm
Badr-MOUFAD Oct 21, 2022
27b918d
refactor ``_prox_vec`` function
Badr-MOUFAD Oct 21, 2022
9d8e3c0
fix bug segmentation fault
Badr-MOUFAD Oct 21, 2022
e5ce21b
add Fista to docs && fix unittest
Badr-MOUFAD Oct 21, 2022
5d2dbaf
cosmetic changes
mathurinm Oct 22, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ Solvers
:toctree: generated/

AndersonCD
FISTA
GramCD
GroupBCD
MultiTaskBCD
Expand Down
60 changes: 52 additions & 8 deletions skglm/datafits/single_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from numba import float64

from skglm.datafits.base import BaseDatafit
from skglm.utils import spectral_norm


class Quadratic(BaseDatafit):
Expand All @@ -22,6 +23,10 @@ class Quadratic(BaseDatafit):
The coordinatewise gradient Lipschitz constants. Equal to
norm(X, axis=0) ** 2 / n_samples.

global_lipschitz : float
Global Lipschitz constant. Equal to
norm(X, ord=2) ** 2 / n_samples.

Note
----
The class is jit compiled at fit time using Numba compiler.
Expand All @@ -35,6 +40,7 @@ def get_spec(self):
spec = (
('Xty', float64[:]),
('lipschitz', float64[:]),
('global_lipschitz', float64),
)
return spec

Expand All @@ -44,14 +50,18 @@ def params_to_dict(self):
def initialize(self, X, y):
self.Xty = X.T @ y
n_features = X.shape[1]
self.global_lipschitz = norm(X, ord=2) ** 2 / len(y)
self.lipschitz = np.zeros(n_features, dtype=X.dtype)
for j in range(n_features):
self.lipschitz[j] = (X[:, j] ** 2).sum() / len(y)

def initialize_sparse(
self, X_data, X_indptr, X_indices, y):
def initialize_sparse(self, X_data, X_indptr, X_indices, y):
n_features = len(X_indptr) - 1
self.Xty = np.zeros(n_features, dtype=X_data.dtype)

self.global_lipschitz = spectral_norm(X_data, X_indptr, X_indices, len(y)) ** 2
self.global_lipschitz /= len(y)

self.lipschitz = np.zeros(n_features, dtype=X_data.dtype)
for j in range(n_features):
nrm2 = 0.
Expand Down Expand Up @@ -111,6 +121,10 @@ class Logistic(BaseDatafit):
The coordinatewise gradient Lipschitz constants. Equal to
norm(X, axis=0) ** 2 / (4 * n_samples).

global_lipschitz : float
Global Lipschitz constant. Equal to
norm(X, ord=2) ** 2 / (4 * n_samples).

Note
----
The class is jit compiled at fit time using Numba compiler.
Expand All @@ -123,6 +137,7 @@ def __init__(self):
def get_spec(self):
spec = (
('lipschitz', float64[:]),
('global_lipschitz', float64),
)
return spec

Expand All @@ -140,9 +155,14 @@ def raw_hessian(self, y, Xw):

def initialize(self, X, y):
self.lipschitz = (X ** 2).sum(axis=0) / (len(y) * 4)
self.global_lipschitz = norm(X, ord=2) ** 2 / (len(y) * 4)

def initialize_sparse(self, X_data, X_indptr, X_indices, y):
n_features = len(X_indptr) - 1

self.global_lipschitz = spectral_norm(X_data, X_indptr, X_indices, len(y)) ** 2
self.global_lipschitz /= 4 * len(y)

self.lipschitz = np.zeros(n_features, dtype=X_data.dtype)
for j in range(n_features):
Xj = X_data[X_indptr[j]:X_indptr[j+1]]
Expand Down Expand Up @@ -187,6 +207,11 @@ class QuadraticSVC(BaseDatafit):
----------
lipschitz : array, shape (n_features,)
The coordinatewise gradient Lipschitz constants.
Equal to norm(yXT, axis=0) ** 2.

global_lipschitz : float
Global Lipschitz constant. Equal to
norm(yXT, ord=2) ** 2.

Note
----
Expand All @@ -200,6 +225,7 @@ def __init__(self):
def get_spec(self):
spec = (
('lipschitz', float64[:]),
('global_lipschitz', float64),
)
return spec

Expand All @@ -209,12 +235,16 @@ def params_to_dict(self):
def initialize(self, yXT, y):
n_features = yXT.shape[1]
self.lipschitz = np.zeros(n_features, dtype=yXT.dtype)
self.global_lipschitz = norm(yXT, ord=2) ** 2
for j in range(n_features):
self.lipschitz[j] = norm(yXT[:, j]) ** 2

def initialize_sparse(
self, yXT_data, yXT_indptr, yXT_indices, y):
def initialize_sparse(self, yXT_data, yXT_indptr, yXT_indices, y):
n_features = len(yXT_indptr) - 1

self.global_lipschitz = spectral_norm(
yXT_data, yXT_indptr, yXT_indices, max(yXT_indices)+1) ** 2

self.lipschitz = np.zeros(n_features, dtype=yXT_data.dtype)
for j in range(n_features):
nrm2 = 0.
Expand Down Expand Up @@ -264,8 +294,16 @@ class Huber(BaseDatafit):

Attributes
----------
delta : float
Threshold hyperparameter.

lipschitz : array, shape (n_features,)
The coordinatewise gradient Lipschitz constants.
The coordinatewise gradient Lipschitz constants. Equal to
norm(X, axis=0) ** 2 / n_samples.

global_lipschitz : float
Global Lipschitz constant. Equal to
norm(X, ord=2) ** 2 / n_samples.

Note
----
Expand All @@ -279,7 +317,8 @@ def __init__(self, delta):
def get_spec(self):
spec = (
('delta', float64),
('lipschitz', float64[:])
('lipschitz', float64[:]),
('global_lipschitz', float64),
)
return spec

Expand All @@ -289,12 +328,17 @@ def params_to_dict(self):
def initialize(self, X, y):
n_features = X.shape[1]
self.lipschitz = np.zeros(n_features, dtype=X.dtype)
self.global_lipschitz = 0.
for j in range(n_features):
self.lipschitz[j] = (X[:, j] ** 2).sum() / len(y)
self.global_lipschitz += (X[:, j] ** 2).sum() / len(y)

def initialize_sparse(
self, X_data, X_indptr, X_indices, y):
def initialize_sparse(self, X_data, X_indptr, X_indices, y):
n_features = len(X_indptr) - 1

self.global_lipschitz = spectral_norm(X_data, X_indptr, X_indices, len(y)) ** 2
self.global_lipschitz /= len(y)

self.lipschitz = np.zeros(n_features, dtype=X_data.dtype)
for j in range(n_features):
nrm2 = 0.
Expand Down
3 changes: 2 additions & 1 deletion skglm/solvers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from .anderson_cd import AndersonCD
from .base import BaseSolver
from .fista import FISTA
from .gram_cd import GramCD
from .group_bcd import GroupBCD
from .multitask_bcd import MultiTaskBCD
from .prox_newton import ProxNewton


__all__ = [AndersonCD, BaseSolver, GramCD, GroupBCD, MultiTaskBCD, ProxNewton]
__all__ = [AndersonCD, BaseSolver, FISTA, GramCD, GroupBCD, MultiTaskBCD, ProxNewton]
82 changes: 82 additions & 0 deletions skglm/solvers/fista.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import numpy as np
from scipy.sparse import issparse
from skglm.solvers.base import BaseSolver
from skglm.solvers.common import construct_grad, construct_grad_sparse
from skglm.utils import _prox_vec


class FISTA(BaseSolver):
r"""ISTA solver with Nesterov acceleration (FISTA).

Attributes
----------
max_iter : int, default 100
Maximum number of iterations.

tol : float, default 1e-4
Tolerance for convergence.

verbose : bool, default False
Amount of verbosity. 0/False is silent.

References
----------
.. [1] Beck, A. and Teboulle M.
"A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse
problems", 2009, SIAM J. Imaging Sci.
https://epubs.siam.org/doi/10.1137/080716542
"""

def __init__(self, max_iter=100, tol=1e-4, verbose=0):
self.max_iter = max_iter
self.tol = tol
self.verbose = verbose

def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
p_objs_out = []
n_samples, n_features = X.shape
all_features = np.arange(n_features)
t_new = 1.

w = w_init.copy() if w_init is not None else np.zeros(n_features)
z = w_init.copy() if w_init is not None else np.zeros(n_features)
Xw = Xw_init.copy() if Xw_init is not None else np.zeros(n_samples)

if hasattr(datafit, "global_lipschitz"):
lipschitz = datafit.global_lipschitz
else:
# TODO: OR line search
raise Exception("Line search is not yet implemented for FISTA solver.")

for n_iter in range(self.max_iter):
t_old = t_new
t_new = (1 + np.sqrt(1 + 4 * t_old ** 2)) / 2
w_old = w.copy()
if issparse(X):
grad = construct_grad_sparse(
X.data, X.indptr, X.indices, y, z, X @ z, datafit, all_features)
else:
grad = construct_grad(X, y, z, X @ z, datafit, all_features)

step = 1 / lipschitz
z -= step * grad
w = _prox_vec(w, z, penalty, step)
Xw = X @ w
z = w + (t_old - 1.) / t_new * (w - w_old)

opt = penalty.subdiff_distance(w, grad, all_features)
stop_crit = np.max(opt)

p_obj = datafit.value(y, w, Xw) + penalty.value(w)
p_objs_out.append(p_obj)
if self.verbose:
print(
f"Iteration {n_iter+1}: {p_obj:.10f}, "
f"stopping crit: {stop_crit:.2e}"
)

if stop_crit < self.tol:
if self.verbose:
print(f"Stopping criterion max violation: {stop_crit:.2e}")
break
return w, np.array(p_objs_out), stop_crit
69 changes: 69 additions & 0 deletions skglm/tests/test_fista.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import pytest

import numpy as np
from numpy.linalg import norm

import scipy.sparse
import scipy.sparse.linalg
from scipy.sparse import csc_matrix, issparse

from skglm.penalties import L1, IndicatorBox
from skglm.solvers import FISTA, AndersonCD
from skglm.datafits import Quadratic, Logistic, QuadraticSVC
from skglm.utils import make_correlated_data, compiled_clone, spectral_norm


random_state = 113
n_samples, n_features = 50, 60

rng = np.random.RandomState(random_state)
X, y, _ = make_correlated_data(n_samples, n_features, random_state=rng)
rng.seed(random_state)
X_sparse = csc_matrix(X * np.random.binomial(1, 0.5, X.shape))
y_classif = np.sign(y)

alpha_max = norm(X.T @ y, ord=np.inf) / len(y)
alpha = alpha_max / 10

tol = 1e-10


@pytest.mark.parametrize("X", [X, X_sparse])
@pytest.mark.parametrize("Datafit, Penalty", [
(Quadratic, L1),
(Logistic, L1),
(QuadraticSVC, IndicatorBox),
])
def test_fista_solver(X, Datafit, Penalty):
_y = y if isinstance(Datafit, Quadratic) else y_classif
datafit = compiled_clone(Datafit())
_init = y @ X.T if isinstance(Datafit, QuadraticSVC) else X
if issparse(X):
datafit.initialize_sparse(_init.data, _init.indptr, _init.indices, _y)
else:
datafit.initialize(_init, _y)
penalty = compiled_clone(Penalty(alpha))

solver = FISTA(max_iter=1000, tol=tol)
w_fista = solver.solve(X, _y, datafit, penalty)[0]

solver_cd = AndersonCD(tol=tol, fit_intercept=False)
w_cd = solver_cd.solve(X, _y, datafit, penalty)[0]

np.testing.assert_allclose(w_fista, w_cd, atol=1e-7)


def test_spectral_norm():
n_samples, n_features = 50, 60
A_sparse = scipy.sparse.random(n_samples, n_features, density=0.7, format='csc',
random_state=random_state)

A_bundles = (A_sparse.data, A_sparse.indptr, A_sparse.indices)
spectral_norm_our = spectral_norm(*A_bundles, n_samples=len(y))
spectral_norm_sp = scipy.sparse.linalg.svds(A_sparse, k=1)[1]

np.testing.assert_allclose(spectral_norm_our, spectral_norm_sp)


if __name__ == '__main__':
pass
Loading