Skip to content

Commit 8062ad9

Browse files
authored
Merge pull request #88 from ACEsuit/aspselect
Post fit selection from asp path
2 parents 391ec74 + ce319ec commit 8062ad9

File tree

2 files changed

+69
-14
lines changed

2 files changed

+69
-14
lines changed

src/asp.jl

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -83,19 +83,28 @@ function solve(solver::ASP, A, y, Aval=A, yval=y)
8383
for p in new_tracer ]
8484
end
8585

86-
xs, in = select_solution(new_post, solver, Aval, yval)
86+
tracer_final = _add_errors(new_post, Aval, yval)
87+
xs, in = asp_select(tracer_final, solver.select)
8788

88-
return Dict( "C" => xs,
89-
"path" => new_post,
90-
"nnzs" => length( (new_post[in][:solution]).nzind) )
89+
return Dict( "C" => xs,
90+
"path" => tracer_final, )
9191
end
9292

9393

94-
function select_solution(tracer, solver, A, y)
95-
if solver.select == :final
94+
function _add_errors(tracer, A, y)
95+
rtN = sqrt(length(y))
96+
return [ ( solution = t.solution, λ = t.λ, σ = t.σ,
97+
rmse = norm(A * t.solution - y) / rtN )
98+
for t in tracer ]
99+
end
100+
101+
asp_select(D::Dict, select) = asp_select(D["path"], select)
102+
103+
function asp_select(tracer, select)
104+
if select == :final
96105
criterion = :final
97106
else
98-
criterion, p = solver.select
107+
criterion, p = select
99108
end
100109

101110
if criterion == :final
@@ -108,12 +117,12 @@ function select_solution(tracer, solver, A, y)
108117
elseif criterion == :bysize
109118
maxind = findfirst(t -> length((t[:solution]).nzind) > p,
110119
tracer) - 1
111-
threshold = 1.0
120+
threshold = 1.0
112121
else
113122
error("Unknown selection criterion: $criterion")
114123
end
115124

116-
errors = [ norm(A * t[:solution] - y) for t in tracer[1:maxind] ]
125+
errors = [ t.rmse for t in tracer[1:maxind] ]
117126
min_error = minimum(errors)
118127
for (i, error) in enumerate(errors)
119128
if error <= threshold * min_error
@@ -140,3 +149,16 @@ function post_asp_tsvd(path, At, yt, Av, yv)
140149

141150
return _post.(path)
142151
end
152+
153+
# TODO: revisit this idea. Maybe we do want to keep this, not as `select`
154+
# but as `solve`. But if we do, then it might be nice to be able to
155+
# extend the path somehow. For now I'm removing it since I don't see
156+
# the immediate need yet. Just calling asp_select is how I would normally
157+
# use this.
158+
#
159+
# function select(tracer, solver, A, y) #can be called by the user to warm-start the selection
160+
# xs, in = select_solution(tracer, solver, A, y)
161+
# return Dict("C" => xs,
162+
# "path" => tracer,
163+
# "nnzs" => length( (tracer[in][:solution]).nzind) )
164+
# end

test/test_asp.jl

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using ACEfit
2-
using LinearAlgebra, Random, Test
2+
using LinearAlgebra, Random, Test
33

44
##
55

@@ -47,7 +47,7 @@ for (select, tolr, tolc) in [ (:final, 10*epsn, 1),
4747
results = ACEfit.solve(solver, A, y)
4848
C = results["C"]
4949
full_path = results["path"]
50-
@show results["nnzs"]
50+
# @show results["nnzs"]
5151
@show norm(A * C - y)
5252
@show norm(C)
5353
@show norm(C - c_ref)
@@ -60,7 +60,7 @@ for (select, tolr, tolc) in [ (:final, 10*epsn, 1),
6060
results = ACEfit.solve(solver, At, yt, Av, yv)
6161
C = results["C"]
6262
full_path = results["path"]
63-
@show results["nnzs"]
63+
# @show results["nnzs"]
6464
@show norm(Av * C - yv)
6565
@show norm(C)
6666
@show norm(C - c_ref)
@@ -91,7 +91,7 @@ for (select, tolr, tolc) in [ (:final, 20*epsn, 1.5),
9191
C_tsvd = results_tsvd["C"]
9292
C = results["C"]
9393

94-
@show results["nnzs"]
94+
# @show results["nnzs"]
9595
@show norm(A * C - y)
9696
@show norm(A * C_tsvd - y)
9797
if norm(A * C_tsvd - y)< norm(A * C - y)
@@ -106,7 +106,7 @@ for (select, tolr, tolc) in [ (:final, 20*epsn, 1.5),
106106
results = ACEfit.solve(solver, At, yt, Av, yv)
107107
C_tsvd = results_tsvd["C"]
108108
C = results["C"]
109-
@show results["nnzs"]
109+
# @show results["nnzs"]
110110
@show norm(A * C - y)
111111
@show norm(A * C_tsvd - y)
112112

@@ -117,3 +117,36 @@ for (select, tolr, tolc) in [ (:final, 20*epsn, 1.5),
117117
end
118118
end
119119

120+
##
121+
122+
# Testing the "select" function
123+
solver_final = ACEfit.ASP(
124+
P = I,
125+
select = :final,
126+
tsvd = false,
127+
nstore = 100,
128+
loglevel = 0
129+
)
130+
131+
results_final = ACEfit.solve(solver_final, At, yt, Av, yv)
132+
tracer_final = results_final["path"]
133+
134+
# Warm-start the solver using the tracer from the final iteration
135+
# select best solution with <= 73 non-zero entries
136+
select = (:bysize, 73)
137+
C_select, _ = ACEfit.asp_select(tracer_final, select)
138+
@test( length(C_select.nzind) <= 73 )
139+
140+
# Check if starting the solver initially with (:bysize, 73) gives the same result
141+
solver_bysize = ACEfit.ASP(
142+
P = I,
143+
select = (:bysize, 73),
144+
tsvd = false,
145+
nstore = 100,
146+
loglevel = 0
147+
)
148+
149+
results_bysize = ACEfit.solve(solver_bysize, At, yt, Av, yv)
150+
@test results_bysize["C"] == C_select # works
151+
152+

0 commit comments

Comments
 (0)