Skip to content

Commit bb9da84

Browse files
authored
Merge pull request #78 from tinatorabi/main
ASP
2 parents d96fa8d + b706e79 commit bb9da84

File tree

3 files changed

+128
-5
lines changed

3 files changed

+128
-5
lines changed

Project.toml

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ authors = ["William C Witt <[email protected]>, Christoph Ortner <christophor
44
version = "0.2.1"
55

66
[deps]
7+
ActiveSetPursuit = "d86c1dc8-ba26-4c98-b330-3a8efc174d20"
78
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
89
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
910
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -21,25 +22,25 @@ MLJScikitLearnInterface = "5ae90465-5518-4432-b9d2-8a1def2f0cab"
2122
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
2223

2324
[extensions]
24-
ACEfit_PythonCall_ext = "PythonCall"
25-
ACEfit_MLJLinearModels_ext = [ "MLJ", "MLJLinearModels" ]
25+
ACEfit_MLJLinearModels_ext = ["MLJ", "MLJLinearModels"]
2626
ACEfit_MLJScikitLearnInterface_ext = ["MLJScikitLearnInterface", "PythonCall", "MLJ"]
27+
ACEfit_PythonCall_ext = "PythonCall"
2728

2829
[compat]
29-
julia = "1.9"
3030
IterativeSolvers = "0.9.2"
31+
LowRankApprox = "0.5.3"
3132
MLJ = "0.19"
3233
MLJLinearModels = "0.9"
3334
MLJScikitLearnInterface = "0.7"
34-
LowRankApprox = "0.5.3"
3535
Optim = "1.7"
3636
ParallelDataTransfer = "0.5.0"
3737
ProgressMeter = "1.7"
3838
PythonCall = "0.9"
3939
StaticArrays = "1.5"
40+
julia = "1.9"
4041

4142
[extras]
4243
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4344

4445
[targets]
45-
test = ["Test", ]
46+
test = ["Test"]

src/solvers.jl

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ using LowRankApprox: pqrfact
33
using IterativeSolvers
44
using .BayesianLinear
55
using LinearAlgebra: SVD, svd
6+
using ActiveSetPursuit
67

78
@doc raw"""
89
`struct QR` : linear least squares solver, using standard QR factorisation;
@@ -195,3 +196,95 @@ function solve(solver::TruncatedSVD, A, y)
195196
return Dict{String, Any}("C" => solver.P \ θP)
196197
end
197198

199+
200+
@doc raw"""
201+
`struct ASP` : Active Set Pursuit sparse solver
202+
solves the following optimization problem using the homotopy approach:
203+
204+
```math
205+
\max_{y} \left( b^T y - \frac{1}{2} λ y^T y \right)
206+
```
207+
subject to
208+
209+
```math
210+
\|A^T y\|_{\infty} \leq 1.
211+
```
212+
213+
* Input
214+
* `A` : `m`-by-`n` explicit matrix or linear operator.
215+
* `b` : `m`-vector.
216+
217+
* Solver parameters
218+
* `min_lambda` : Minimum value for `λ`. Defaults to zero if not provided.
219+
* `loglevel` : Logging level.
220+
* `itnMax` : Maximum number of iterations.
221+
* `actMax` : Maximum number of active constraints.
222+
223+
Constructor
224+
```julia
225+
ACEfit.ASP(; P = I, select, params)
226+
```
227+
where
228+
- `P` : right-preconditioner / tychonov operator
229+
- `select`: Selection mode for the final solution.
230+
- `(:byerror, th)`: Selects the smallest active set fit within a factor `th` of the smallest fit error.
231+
- `(:final, nothing)`: Returns the final iterate.
232+
- `params`: The solver parameters, passed as named arguments.
233+
"""
234+
struct ASP
235+
P::Any
236+
select::Tuple
237+
params::NamedTuple
238+
end
239+
240+
function ASP(; P = I, select, params...)
241+
params_tuple = NamedTuple(params)
242+
return ASP(P, select, params_tuple)
243+
end
244+
245+
function solve(solver::ASP, A, y)
246+
# Apply preconditioning
247+
AP = A / solver.P
248+
249+
tracer = asp_homotopy(AP, y; solver.params[1]...)
250+
251+
new_tracer = Vector{NamedTuple{(:solution, :λ), Tuple{Any, Any}}}(undef, length(tracer))
252+
253+
for i in 1:length(tracer)
254+
new_tracer[i] = (solution = solver.P \ tracer[i][1], λ = tracer[i][2])
255+
end
256+
257+
# Select the final solution based on the criterion
258+
xs, in = select_solution(new_tracer, solver, A, y)
259+
260+
println("done.")
261+
return Dict("C" => xs, "path" => new_tracer, "nnzs" => length((tracer[in][1]).nzind) )
262+
end
263+
264+
function select_solution(tracer, solver, A, y)
265+
criterion, threshold = solver.select
266+
267+
if criterion == :final
268+
return tracer[end][1], length(tracer)
269+
270+
elseif criterion == :byerror
271+
errors = [norm(A * t[1] - y) for t in tracer]
272+
min_error = minimum(errors)
273+
274+
# Find the solution with the smallest error within the threshold
275+
for (i, error) in enumerate(errors)
276+
if error <= threshold * min_error
277+
return tracer[i][1], i
278+
end
279+
end
280+
elseif criterion == :bysize
281+
for i in 1:length(tracer)
282+
if length((tracer[i][1]).nzind) == threshold
283+
return tracer[i][1], i
284+
end
285+
end
286+
else
287+
@error("Unknown selection criterion: $criterion")
288+
end
289+
end
290+

test/test_linearsolvers.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,3 +111,32 @@ C = results["C"]
111111
@show norm(C)
112112
@show norm(C - c_ref)
113113

114+
@info(" ... ASP_homotopy selected by error")
115+
solver = ACEfit.ASP(P = P, select = (:byerror,1.5), params = (loglevel=0, traceFlag=true))
116+
results = ACEfit.solve(solver, A, y)
117+
C = results["C"]
118+
full_path = results["path"]
119+
@show results["nnzs"]
120+
@show norm(A * C - y)
121+
@show norm(C)
122+
@show norm(C - c_ref)
123+
124+
@info(" ... ASP_homotopy selected by size")
125+
solver = ACEfit.ASP(P = P, select = (:bysize,50), params = (loglevel=0, traceFlag=true))
126+
results = ACEfit.solve(solver, A, y)
127+
C = results["C"]
128+
full_path = results["path"]
129+
@show results["nnzs"]
130+
@show norm(A * C - y)
131+
@show norm(C)
132+
@show norm(C - c_ref)
133+
134+
@info(" ... ASP_homotopy final solution")
135+
solver = ACEfit.ASP(P = P, select = (:final,nothing), params = (loglevel=0, traceFlag=true))
136+
results = ACEfit.solve(solver, A, y)
137+
C = results["C"]
138+
full_path = results["path"]
139+
@show results["nnzs"]
140+
@show norm(A * C - y)
141+
@show norm(C)
142+
@show norm(C - c_ref)

0 commit comments

Comments
 (0)