Skip to content

Commit 19a5a6f

Browse files
mathurinmPABannierBadr-MOUFAD
authored
API transform solver functions into classes (#63)
Co-authored-by: Pierre-Antoine Bannier <[email protected]> Co-authored-by: Badr MOUFAD <[email protected]>
1 parent b21482c commit 19a5a6f

19 files changed

+1381
-1641
lines changed

examples/plot_logreg_various_penalties.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,14 @@
3838
clf_enet = GeneralizedLinearEstimator(
3939
Logistic(),
4040
L1_plus_L2(alpha, l1_ratio),
41-
)
41+
)
4242
y_pred_enet = clf_enet.fit(X_train, y_train).predict(X_test)
4343
f1_score_enet = f1_score(y_test, y_pred_enet)
4444

4545
clf_mcp = GeneralizedLinearEstimator(
4646
Logistic(),
4747
MCPenalty(alpha, gamma),
48-
)
48+
)
4949
y_pred_mcp = clf_mcp.fit(X_train, y_train).predict(X_test)
5050
f1_score_mcp = f1_score(y_test, y_pred_mcp)
5151

examples/plot_sparse_recovery.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from sklearn.metrics import f1_score, mean_squared_error
1717

1818
from skglm.utils import make_correlated_data
19-
from skglm.solvers import cd_solver_path
19+
from skglm.solvers import AndersonCD
2020
from skglm.datafits import Quadratic
2121
from skglm.utils import compiled_clone
2222
from skglm.penalties import L1, MCPenalty, L0_5, L2_3, SCAD
@@ -69,11 +69,13 @@
6969
l0 = {}
7070
mse_ref = mean_squared_error(np.zeros_like(y_test), y_test)
7171

72+
solver = AndersonCD(ws_strategy="fixpoint", fit_intercept=False)
73+
7274
for idx, estimator in enumerate(penalties.keys()):
7375
print(f'Running {estimator}...')
74-
estimator_path = cd_solver_path(
76+
estimator_path = solver.path(
7577
X, y, compiled_clone(datafit), compiled_clone(penalties[estimator]),
76-
alphas=alphas, ws_strategy="fixpoint")
78+
alphas=alphas)
7779

7880
f1_temp = np.zeros(n_alphas)
7981
prediction_error_temp = np.zeros(n_alphas)

0 commit comments

Comments
 (0)