Skip to content

Commit f892f48

Browse files
author
Christoph Ortner
committed
add svd with validation set
1 parent 1b58c75 commit f892f48

File tree

3 files changed

+58
-46
lines changed

3 files changed

+58
-46
lines changed

src/asp.jl

Lines changed: 7 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -66,20 +66,23 @@ end
6666
function solve(solver::ASP, A, y, Aval=A, yval=y)
6767
# Apply preconditioning
6868
AP = A / solver.P
69+
AvalP = Aval / solver.P
6970

7071
tracer = asp_homotopy(AP, y; solver.params...)
7172

7273
q = length(tracer)
7374
every = max(1, q ÷ solver.nstore)
7475
istore = unique([1:every:q; q])
75-
new_tracer = [ (solution = solver.P \ tracer[i][1], λ = tracer[i][2], σ = 0.0 )
76+
new_tracer = [ (solution = tracer[i][1], λ = tracer[i][2], σ = 0.0 )
7677
for i in istore ]
7778

7879
if solver.tsvd # Post-processing if tsvd is true
79-
post = post_asp_tsvd(new_tracer, A, y, Aval, yval)
80-
new_post = [ (solution = p.θ, λ = p.λ, σ = p.σ) for p in post ]
80+
post = post_asp_tsvd(new_tracer, AP, y, AvalP, yval)
81+
new_post = [ (solution = solver.P \ p.θ, λ = p.λ, σ = p.σ)
82+
for p in post ]
8183
else
82-
new_post = new_tracer
84+
new_post = [ (solution = solver.P \ p.solution, λ = p.λ, σ = 0.0)
85+
for p in new_tracer ]
8386
end
8487

8588
xs, in = select_solution(new_post, solver, Aval, yval)
@@ -124,34 +127,6 @@ function select_solution(tracer, solver, A, y)
124127
end
125128

126129

127-
128-
using SparseArrays
129-
130-
function solve_tsvd(At, yt, Av, yv)
131-
Ut, Σt, Vt = svd(At); zt = Ut' * yt
132-
Qv, Rv = qr(Av); zv = Matrix(Qv)' * yv
133-
@assert issorted(Σt, rev=true)
134-
135-
Rv_Vt = Rv * Vt
136-
137-
θv = zeros(size(Av, 2))
138-
θv[1] = zt[1] / Σt[1]
139-
rv = Rv_Vt[:, 1] * θv[1] - zv
140-
141-
tsvd_errs = Float64[]
142-
push!(tsvd_errs, norm(rv))
143-
144-
for k = 2:length(Σt)
145-
θv[k] = zt[k] / Σt[k]
146-
rv += Rv_Vt[:, k] * θv[k]
147-
push!(tsvd_errs, norm(rv))
148-
end
149-
150-
imin = argmin(tsvd_errs)
151-
θv[imin+1:end] .= 0
152-
return Vt * θv, Σt[imin]
153-
end
154-
155130
function post_asp_tsvd(path, At, yt, Av, yv)
156131
Qt, Rt = qr(At); zt = Matrix(Qt)' * yt
157132
Qv, Rv = qr(Av); zv = Matrix(Qv)' * yv
@@ -166,14 +141,4 @@ function post_asp_tsvd(path, At, yt, Av, yv)
166141
end
167142

168143
return _post.(path)
169-
170-
# post = []
171-
# for (θ, λ) in path
172-
# if isempty(θ.nzind); push!(post, (θ = θ, λ = λ, σ = Inf)); continue; end
173-
# inz = θ.nzind
174-
# θ1, σ = solve_tsvd(Rt[:, inz], zt, Rv[:, inz], zv)
175-
# θ2 = copy(θ); θ2[inz] .= θ1
176-
# push!(post, (θ = θ2, λ = λ, σ = σ))
177-
# end
178-
# return identity.(post)
179144
end

src/solvers.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,3 +196,37 @@ function solve(solver::TruncatedSVD, A, y)
196196
return Dict{String, Any}("C" => solver.P \ θP)
197197
end
198198

199+
200+
# ------------ Truncated SVD with tol specified by validation set ------------
201+
202+
function solve_tsvd(At, yt, Av, yv)
203+
Ut, Σt, Vt = svd(At); zt = Ut' * yt
204+
Qv, Rv = qr(Av); zv = Matrix(Qv)' * yv
205+
@assert issorted(Σt, rev=true)
206+
207+
Rv_Vt = Rv * Vt
208+
209+
θv = zeros(size(Av, 2))
210+
θv[1] = zt[1] / Σt[1]
211+
rv = Rv_Vt[:, 1] * θv[1] - zv
212+
213+
tsvd_errs = Float64[]
214+
push!(tsvd_errs, norm(rv))
215+
216+
for k = 2:length(Σt)
217+
θv[k] = zt[k] / Σt[k]
218+
rv += Rv_Vt[:, k] * θv[k]
219+
push!(tsvd_errs, norm(rv))
220+
end
221+
222+
imin = argmin(tsvd_errs)
223+
θv[imin+1:end] .= 0
224+
return Vt * θv, Σt[imin]
225+
end
226+
227+
228+
function solve(solver::TruncatedSVD, At, yt, Av, yv)
229+
# make a function barrier because solver.P is not inferred
230+
θ, σ = solve_tsvd(At / solver.P, yt, Av / solver.P, yv)
231+
return Dict{String, Any}("C" => solver.P \ θ, "σ" => σ)
232+
end

test/test_linearsolvers.jl

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11

2-
using ACEfit
3-
using LinearAlgebra, Random, Test
4-
using Random
5-
using PythonCall
2+
using ACEfit, LinearAlgebra, Random, Test, PythonCall
63

74
##
85

@@ -168,3 +165,19 @@ C = results["C"]
168165
@test norm(A * C - y) < 10 * epsn
169166
@test norm(C - c_ref) < 1
170167

168+
169+
##
170+
171+
@info("Truncated SVD with validation")
172+
solver = ACEfit.TruncatedSVD(; rtol = 0.0)
173+
At = A[1:8000, :]
174+
yt = y[1:8000]
175+
Av = A[8001:end, :]
176+
yv = y[8001:end]
177+
results_v = ACEfit.solve(solver, At, yt, Av, yv)
178+
@show err_v = norm(Av * results_v["C"] - yv)
179+
@show err = norm(Av * results["C"] - yv)
180+
@test err_v <= err
181+
@show norm(results_v["C"] - c_ref)
182+
@show norm(results["C"] - c_ref)
183+
@test norm(results_v["C"] - c_ref) < 1e-2

0 commit comments

Comments
 (0)