Skip to content

Commit 054a2ab

Browse files
authored
Merge pull request #89 from ACEsuit/asp
ASP and OMP with D&C QR
2 parents daa09c6 + a2de647 commit 054a2ab

File tree

3 files changed

+230
-131
lines changed

3 files changed

+230
-131
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ ACEfit_MLJScikitLearnInterface_ext = ["MLJScikitLearnInterface", "PythonCall", "
2828
ACEfit_PythonCall_ext = "PythonCall"
2929

3030
[compat]
31-
ActiveSetPursuit = "0.0.2"
31+
ActiveSetPursuit = "0.0.4"
3232
IterativeSolvers = "0.9.2"
3333
LowRankApprox = "0.5.3"
3434
MLJ = "0.19"

src/asp.jl

Lines changed: 117 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
using Distributed
2+
using LinearAlgebra
13

24
@doc raw"""
35
@@ -66,7 +68,18 @@ function solve(solver::ASP, A, y, Aval=A, yval=y)
6668
# Apply preconditioning
6769
AP = A / solver.P
6870
AvalP = Aval / solver.P
69-
tracer = asp_homotopy(AP, y; solver.params..., traceFlag = true)
71+
72+
F = qr!(AP)
73+
m, n = size(AP)
74+
if m < n
75+
error("ASP requires m >= n, but got m = $m, n = $n")
76+
end
77+
Qtb = F.Q' * y
78+
Qtb1 = Qtb[1:n]
79+
80+
tracer = asp_homotopy(F.R, Qtb1; solver.params..., traceFlag = true)
81+
82+
# tracer = asp_homotopy(AP, y; solver.params..., traceFlag = true)
7083

7184
q = length(tracer)
7285
every = max(1, q / solver.nstore)
@@ -75,7 +88,7 @@ function solve(solver::ASP, A, y, Aval=A, yval=y)
7588
for i in istore ]
7689

7790
if solver.tsvd # Post-processing if tsvd is true
78-
post = post_asp_tsvd(new_tracer, AP, y, AvalP, yval)
91+
post = post_asp_tsvd(new_tracer, F.R, Qtb1, AvalP, yval)
7992
new_post = [ (solution = solver.P \ p.θ, λ = p.λ, σ = p.σ)
8093
for p in post ]
8194
else
@@ -162,3 +175,105 @@ end
162175
# "path" => tracer,
163176
# "nnzs" => length( (tracer[in][:solution]).nzind) )
164177
# end
178+
179+
180+
181+
@doc raw"""
182+
183+
`OMP` : Orthogonal Matching Pursuit solver
184+
185+
Solves the lasso optimization problem.
186+
```math
187+
\max_{y} \left( b^T y - \frac{1}{2} λ y^T y \right)
188+
```
189+
subject to
190+
```math
191+
\|A^T y\|_{\infty} \leq 1.
192+
```
193+
194+
### Constructor Keyword arguments
195+
196+
```julia
197+
ACEfit.ASP(; P = I, select = (:byerror, 1.0), tsvd = false, nstore=100,
198+
params...)
199+
```
200+
201+
* `select` : Selection criterion for the final solution (required)
202+
* `:final` : final solution (largest computed basis)
203+
* `(:byerror, q)` : solution with error within `q` times the minimum error
204+
along the path; if training error is used and `q == 1.0`, then this is
205+
equivalent to to `:final`.
206+
* `(:bysize, n)` : best solution with at most `n` non-zero features; if
207+
training error is used, then it will be the solution with exactly `n`
208+
non-zero features.
209+
* `P = I` : prior / regularizer (optional)
210+
211+
The remaining kwarguments to `ASP` are parameters for the ASP homotopy solver.
212+
213+
* `actMax` : Maximum number of active constraints.
214+
* `min_lambda` : Minimum value for `λ`. (defaults to 0)
215+
* `loglevel` : Logging level.
216+
* `itnMax` : Maximum number of iterations.
217+
218+
### Extended syntax for `solve`
219+
220+
```julia
221+
solve(solver::ASP, A, y, Aval=A, yval=y)
222+
```
223+
* `A` : `m`-by-`n` design matrix. (required)
224+
* `b` : `m`-vector. (required)
225+
* `Aval = nothing` : `p`-by-`n` validation matrix
226+
* `bval = nothing` : `p`- validation vector
227+
228+
If independent `Aval` and `yval` are provided (instead of detaults `A, y`),
229+
then the solver will use this separate validation set instead of the training
230+
set to select the best solution along the model path.
231+
"""
232+
struct OMP
233+
P
234+
select
235+
tsvd::Bool
236+
nstore::Integer
237+
params
238+
end
239+
240+
function OMP(; P = I, select, tsvd=false, nstore=100, params...)
241+
return OMP(P, select, tsvd, nstore, params)
242+
end
243+
244+
function solve(solver::OMP, A, y, Aval=A, yval=y)
245+
# Apply preconditioning
246+
AP = A / solver.P
247+
AvalP = Aval / solver.P
248+
249+
F = qr!(AP)
250+
m, n = size(AP)
251+
if m < n
252+
error("OMP requires m >= n, but got m = $m, n = $n")
253+
end
254+
Qtb = F.Q' * y
255+
Qtb1 = Qtb[1:n]
256+
257+
tracer = asp_omp(F.R, Qtb1, 0.0; traceFlag=true, loglevel=0, solver.params...)
258+
259+
q = length(tracer)
260+
every = max(1, q / solver.nstore)
261+
istore = unique(round.(Int, [1:every:q; q]))
262+
new_tracer = [ (solution = tracer[i][1], λ = tracer[i][2], σ = 0.0 )
263+
for i in istore ]
264+
265+
if solver.tsvd # Post-processing if tsvd is true
266+
post = post_asp_tsvd(new_tracer, F.R, Qtb1, AvalP, yval)
267+
new_post = [ (solution = solver.P \ p.θ, λ = p.λ, σ = p.σ)
268+
for p in post ]
269+
else
270+
new_post = [ (solution = solver.P \ p.solution, λ = p.λ, σ = 0.0)
271+
for p in new_tracer ]
272+
end
273+
274+
tracer_final = _add_errors(new_post, Aval, yval)
275+
xs, in = asp_select(tracer_final, solver.select)
276+
277+
return Dict( "C" => xs,
278+
"path" => tracer_final, )
279+
end

0 commit comments

Comments
 (0)