Skip to content

Commit b021f6a

Browse files
authored
Merge pull request #83 from ACEsuit/asp_svd
TSVD Postprocessing of ASP
2 parents e0fea2e + 83dbee5 commit b021f6a

File tree

5 files changed

+183
-89
lines changed

5 files changed

+183
-89
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ Optim = "429524aa-4258-5aef-a3af-852621145aeb"
1313
ParallelDataTransfer = "2dcacdae-9679-587a-88bb-8b444fb7085b"
1414
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
1515
SharedArrays = "1a1011a3-84de-559e-8e89-a11a2f7dc383"
16+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1617
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1718

1819
[weakdeps]

src/asp.jl

Lines changed: 69 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -48,35 +48,45 @@ solve(solver::ASP, A, y, Aval=A, yval=y)
4848
If independent `Aval` and `yval` are provided (instead of detaults `A, y`),
4949
then the solver will use this separate validation set instead of the training
5050
set to select the best solution along the model path.
51-
"""
51+
# """
52+
5253
struct ASP
5354
P
5455
select
56+
mode::Symbol
57+
tsvd::Bool
58+
nstore::Integer
5559
params
5660
end
5761

58-
function ASP(; P = I, select, mode=:train, params...)
59-
return ASP(P, select, params)
62+
function ASP(; P = I, select, mode=:train, tsvd=false, nstore=100, params...)
63+
return ASP(P, select, mode, tsvd, nstore, params)
6064
end
6165

6266
function solve(solver::ASP, A, y, Aval=A, yval=y)
6367
# Apply preconditioning
6468
AP = A / solver.P
6569

6670
tracer = asp_homotopy(AP, y; solver.params...)
67-
q = length(tracer)
68-
new_tracer = Vector{NamedTuple{(:solution, :λ), Tuple{Any, Any}}}(undef, q)
6971

70-
for i in 1:q
71-
new_tracer[i] = (solution = solver.P \ tracer[i][1], λ = tracer[i][2])
72+
q = length(tracer)
73+
every = max(1, q ÷ solver.nstore)
74+
istore = unique([1:every:q; q])
75+
new_tracer = [ (solution = solver.P \ tracer[i][1], λ = tracer[i][2], σ = 0.0 )
76+
for i in istore ]
77+
78+
if solver.tsvd # Post-processing if tsvd is true
79+
post = post_asp_tsvd(new_tracer, A, y, Aval, yval)
80+
new_post = [ (solution = p.θ, λ = p.λ, σ = p.σ) for p in post ]
81+
else
82+
new_post = new_tracer
7283
end
7384

74-
xs, in = select_solution(new_tracer, solver, Aval, yval)
85+
xs, in = select_solution(new_post, solver, Aval, yval)
7586

76-
# println("done.")
77-
return Dict( "C" => xs,
78-
"path" => new_tracer,
79-
"nnzs" => length((new_tracer[in][:solution]).nzind) )
87+
return Dict( "C" => xs,
88+
"path" => new_post,
89+
"nnzs" => length( (new_tracer[in][:solution]).nzind) )
8090
end
8191

8292

@@ -114,44 +124,56 @@ function select_solution(tracer, solver, A, y)
114124
end
115125

116126

117-
#=
118-
function select_smart(tracer, solver, Aval, yval)
119127

120-
best_metric = Inf
121-
best_iteration = 0
122-
validation_metric = 0
123-
q = length(tracer)
124-
errors = [norm(Aval * t[:solution] - yval) for t in tracer]
125-
nnzss = [(t[:solution]).nzind for t in tracer]
126-
best_iteration = argmin(errors)
127-
validation_metric = errors[best_iteration]
128-
validation_end = norm(Aval * tracer[end][:solution] - yval)
128+
using SparseArrays
129129

130-
if validation_end < validation_metric #make sure to check the last one too in case q<<100
131-
best_iteration = q
132-
end
130+
function solve_tsvd(At, yt, Av, yv)
131+
Ut, Σt, Vt = svd(At); zt = Ut' * yt
132+
Qv, Rv = qr(Av); zv = Matrix(Qv)' * yv
133+
@assert issorted(Σt, rev=true)
133134

134-
criterion, threshold = solver.select
135-
136-
if criterion == :val
137-
return tracer[best_iteration][:solution], best_iteration
138-
139-
elseif criterion == :byerror
140-
for (i, error) in enumerate(errors)
141-
if error <= threshold * validation_metric
142-
return tracer[i][:solution], i
143-
end
144-
end
135+
Rv_Vt = Rv * Vt
145136

146-
elseif criterion == :bysize
147-
first_index = findfirst(sublist -> threshold in sublist, nnzss)
148-
relevant_errors = errors[1:first_index - 1]
149-
min_error = minimum(relevant_errors)
150-
min_error_index = findfirst(==(min_error), relevant_errors)
151-
return tracer[min_error_index][:solution], min_error_index
137+
θv = zeros(size(Av, 2))
138+
θv[1] = zt[1] / Σt[1]
139+
rv = Rv_Vt[:, 1] * θv[1] - zv
152140

153-
else
154-
@error("Unknown selection criterion: $criterion")
155-
end
141+
tsvd_errs = Float64[]
142+
push!(tsvd_errs, norm(rv))
143+
144+
for k = 2:length(Σt)
145+
θv[k] = zt[k] / Σt[k]
146+
rv += Rv_Vt[:, k] * θv[k]
147+
push!(tsvd_errs, norm(rv))
148+
end
149+
150+
imin = argmin(tsvd_errs)
151+
θv[imin+1:end] .= 0
152+
return Vt * θv, Σt[imin]
153+
end
154+
155+
function post_asp_tsvd(path, At, yt, Av, yv)
156+
Qt, Rt = qr(At); zt = Matrix(Qt)' * yt
157+
Qv, Rv = qr(Av); zv = Matrix(Qv)' * yv
158+
159+
function _post(θλ)
160+
(θ, λ) = θλ
161+
if isempty.nzind); return= θ, λ = λ, σ = Inf); end
162+
inz = θ.nzind
163+
θ1, σ = solve_tsvd(Rt[:, inz], zt, Rv[:, inz], zv)
164+
θ2 = copy(θ); θ2[inz] .= θ1
165+
return= θ2, λ = λ, σ = σ)
166+
end
167+
168+
return _post.(path)
169+
170+
# post = []
171+
# for (θ, λ) in path
172+
# if isempty(θ.nzind); push!(post, (θ = θ, λ = λ, σ = Inf)); continue; end
173+
# inz = θ.nzind
174+
# θ1, σ = solve_tsvd(Rt[:, inz], zt, Rv[:, inz], zv)
175+
# θ2 = copy(θ); θ2[inz] .= θ1
176+
# push!(post, (θ = θ2, λ = λ, σ = σ))
177+
# end
178+
# return identity.(post)
156179
end
157-
=#

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,7 @@ using Test
88

99
@testset "Linear Solvers" begin include("test_linearsolvers.jl") end
1010

11+
@testset "ASP" begin include("test_asp.jl") end
12+
1113
@testset "MLJ Solvers" begin include("test_mlj.jl") end
1214
end

test/test_asp.jl

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
using ACEfit
2+
using LinearAlgebra, Random, Test
3+
using Random
4+
5+
##
6+
7+
@info("Test Solver on overdetermined system")
8+
9+
Random.seed!(1234)
10+
Nobs = 10_000
11+
Nfeat = 100
12+
A1 = randn(Nobs, Nfeat) / sqrt(Nobs)
13+
U, S1, V = svd(A1)
14+
S = 1e-4 .+ ((S1 .- S1[end]) / (S1[1] - S1[end])).^2
15+
A = U * Diagonal(S) * V'
16+
c_ref = randn(Nfeat)
17+
epsn = 1e-5
18+
y = A * c_ref + epsn * randn(Nobs) / sqrt(Nobs)
19+
P = Diagonal(1.0 .+ rand(Nfeat))
20+
21+
##
22+
23+
@info(" ... ASP")
24+
shuffled_indices = shuffle(1:length(y))
25+
train_indices = shuffled_indices[1:round(Int, 0.85 * length(y))]
26+
val_indices = shuffled_indices[round(Int, 0.85 * length(y)) + 1:end]
27+
At = A[train_indices,:]
28+
Av = A[val_indices,:]
29+
yt = y[train_indices]
30+
yv = y[val_indices]
31+
32+
for (select, tolr, tolc) in [ (:final, 10*epsn, 1),
33+
( (:byerror,1.3), 10*epsn, 1),
34+
( (:bysize,73), 1, 10) ]
35+
@show select
36+
local solver, results, C
37+
solver = ACEfit.ASP(P=I, select = select, loglevel=0, traceFlag=true)
38+
# without validation
39+
results = ACEfit.solve(solver, A, y)
40+
C = results["C"]
41+
full_path = results["path"]
42+
@show results["nnzs"]
43+
@show norm(A * C - y)
44+
@show norm(C)
45+
@show norm(C - c_ref)
46+
47+
@test norm(A * C - y) < tolr
48+
@test norm(C - c_ref) < tolc
49+
50+
51+
# with validation
52+
results = ACEfit.solve(solver, At, yt, Av, yv)
53+
C = results["C"]
54+
full_path = results["path"]
55+
@show results["nnzs"]
56+
@show norm(Av * C - yv)
57+
@show norm(C)
58+
@show norm(C - c_ref)
59+
60+
@test norm(Av * C - yv) < tolr
61+
@test norm(C - c_ref) < tolc
62+
end
63+
64+
##
65+
66+
67+
# I didn't wanna add more tsvd tests to yours so I just wrote this one
68+
# I only wanted to naïvely demonstrate that tsvd actually does make a difference! :)
69+
70+
for (select, tolr, tolc) in [ (:final, 20*epsn, 1.5),
71+
( (:byerror,1.3), 20*epsn, 1.5),
72+
( (:bysize,73), 1, 10) ]
73+
@show select
74+
local solver, results, C
75+
solver_tsvd = ACEfit.ASP(P=I, select=select, mode=:train, tsvd=true,
76+
nstore=100, loglevel=0, traceFlag=true)
77+
78+
solver = ACEfit.ASP(P=I, select=select, mode=:train, tsvd=false,
79+
nstore=100, loglevel=0, traceFlag=true)
80+
# without validation
81+
results_tsvd = ACEfit.solve(solver_tsvd, A, y)
82+
results = ACEfit.solve(solver, A, y)
83+
C_tsvd = results_tsvd["C"]
84+
C = results["C"]
85+
86+
@show results["nnzs"]
87+
@show norm(A * C - y)
88+
@show norm(A * C_tsvd - y)
89+
if norm(A * C_tsvd - y)< norm(A * C - y)
90+
@info "tsvd made improvements!"
91+
else
92+
@warn "tsvd did NOT make any improvements!"
93+
end
94+
95+
96+
# with validation
97+
results_tsvd = ACEfit.solve(solver_tsvd, At, yt, Av, yv)
98+
results = ACEfit.solve(solver, At, yt, Av, yv)
99+
C_tsvd = results_tsvd["C"]
100+
C = results["C"]
101+
@show results["nnzs"]
102+
@show norm(A * C - y)
103+
@show norm(A * C_tsvd - y)
104+
105+
if norm(A * C_tsvd - y)< norm(A * C - y)
106+
@info "tsvd made improvements!"
107+
else
108+
@warn "tsvd did NOT make any improvements!"
109+
end
110+
end
111+

test/test_linearsolvers.jl

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -168,45 +168,3 @@ C = results["C"]
168168
@test norm(A * C - y) < 10 * epsn
169169
@test norm(C - c_ref) < 1
170170

171-
##
172-
173-
@info(" ... ASP")
174-
shuffled_indices = shuffle(1:length(y))
175-
train_indices = shuffled_indices[1:round(Int, 0.85 * length(y))]
176-
val_indices = shuffled_indices[round(Int, 0.85 * length(y)) + 1:end]
177-
At = A[train_indices,:]
178-
Av = A[val_indices,:]
179-
yt = y[train_indices]
180-
yv = y[val_indices]
181-
182-
for (select, tolr, tolc) in [ (:final, 10*epsn, 1),
183-
( (:byerror,1.3), 10*epsn, 1),
184-
( (:bysize,73), 1, 10) ]
185-
@show select
186-
local solver, results, C
187-
solver = ACEfit.ASP(P=I, select = select, loglevel=0, traceFlag=true)
188-
# without validation
189-
results = ACEfit.solve(solver, A, y)
190-
C = results["C"]
191-
full_path = results["path"]
192-
@show results["nnzs"]
193-
@show norm(A * C - y)
194-
@show norm(C)
195-
@show norm(C - c_ref)
196-
197-
@test norm(A * C - y) < tolr
198-
@test norm(C - c_ref) < tolc
199-
200-
201-
# with validation
202-
results = ACEfit.solve(solver, At, yt, Av, yv)
203-
C = results["C"]
204-
full_path = results["path"]
205-
@show results["nnzs"]
206-
@show norm(Av * C - yv)
207-
@show norm(C)
208-
@show norm(C - c_ref)
209-
210-
@test norm(Av * C - yv) < tolr
211-
@test norm(C - c_ref) < tolc
212-
end

0 commit comments

Comments
 (0)