Skip to content

Commit a64d8d8

Browse files
author
Christoph Ortner
committed
experimental tsvd postprocessing
1 parent e0fea2e commit a64d8d8

File tree

2 files changed

+113
-42
lines changed

2 files changed

+113
-42
lines changed

test/test_asp.jl

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
using ACEfit
2+
using LinearAlgebra, Random, Test
3+
using Random
4+
5+
##
6+
7+
@info("Test Solver on overdetermined system")
8+
9+
Random.seed!(1234)
10+
Nobs = 10_000
11+
Nfeat = 300
12+
A1 = randn(Nobs, Nfeat) / sqrt(Nobs)
13+
U, S1, V = svd(A1)
14+
S = 1e-4 .+ ((S1 .- S1[end]) / (S1[1] - S1[end])).^2
15+
A = U * Diagonal(S) * V'
16+
c_ref = randn(Nfeat)
17+
epsn = 1e-5
18+
y = A * c_ref + epsn * randn(Nobs) / sqrt(Nobs)
19+
P = Diagonal(1.0 .+ rand(Nfeat))
20+
21+
##
22+
23+
@info(" ... ASP")
24+
shuffled_indices = shuffle(1:length(y))
25+
train_indices = shuffled_indices[1:round(Int, 0.85 * length(y))]
26+
val_indices = shuffled_indices[round(Int, 0.85 * length(y)) + 1:end]
27+
At = A[train_indices,:]
28+
Av = A[val_indices,:]
29+
yt = y[train_indices]
30+
yv = y[val_indices]
31+
32+
for (select, tolr, tolc) in [ (:final, 10*epsn, 1),
33+
( (:byerror,1.3), 10*epsn, 1),
34+
( (:bysize,73), 1, 10) ]
35+
@show select
36+
local solver, results, C
37+
solver = ACEfit.ASP(P=I, select = select, loglevel=0, traceFlag=true)
38+
# without validation
39+
results = ACEfit.solve(solver, A, y)
40+
C = results["C"]
41+
full_path = results["path"]
42+
@show results["nnzs"]
43+
@show norm(A * C - y)
44+
@show norm(C)
45+
@show norm(C - c_ref)
46+
47+
@test norm(A * C - y) < tolr
48+
@test norm(C - c_ref) < tolc
49+
50+
51+
# with validation
52+
results = ACEfit.solve(solver, At, yt, Av, yv)
53+
C = results["C"]
54+
full_path = results["path"]
55+
@show results["nnzs"]
56+
@show norm(Av * C - yv)
57+
@show norm(C)
58+
@show norm(C - c_ref)
59+
60+
@test norm(Av * C - yv) < tolr
61+
@test norm(C - c_ref) < tolc
62+
end
63+
64+
##
65+
66+
# Experimental Implementation of tsvd postprocessing
67+
68+
69+
using SparseArrays
70+
71+
function solve_tsvd(At, yt, Av, yv)
72+
Ut, Σt, Vt = svd(At); zt = Ut' * yt
73+
Qv, Rv = qr(Av); zv = Matrix(Qv)' * yv
74+
@assert issorted(Σt, rev=true)
75+
76+
Rv_Vt = Rv * Vt
77+
78+
θv = zeros(size(Av, 2))
79+
θv[1] = zt[1] / Σt[1]
80+
rv = Rv_Vt[:, 1] * θv[1] - zv
81+
82+
tsvd_errs = Float64[]
83+
push!(tsvd_errs, norm(rv))
84+
85+
for k = 2:length(Σt)
86+
θv[k] = zt[k] / Σt[k]
87+
rv += Rv_Vt[:, k] * θv[k]
88+
push!(tsvd_errs, norm(rv))
89+
end
90+
91+
imin = argmin(tsvd_errs)
92+
θv[imin+1:end] .= 0
93+
return Vt * θv, Σt[imin]
94+
end
95+
96+
function post_asp_tsvd(path, At, yt, Av, yv)
97+
Qt, Rt = qr(At); zt = Matrix(Qt)' * yt
98+
Qv, Rv = qr(Av); zv = Matrix(Qv)' * yv
99+
100+
post = []
101+
for (θ, λ) in path
102+
if isempty.nzind); push!(post, (θ = θ, λ = λ, σ = Inf)); continue; end
103+
inz = θ.nzind
104+
θ1, σ = solve_tsvd(Rt[:, inz], zt, Rv[:, inz], zv)
105+
θ2 = copy(θ); θ2[inz] .= θ1
106+
push!(post, (θ = θ2, λ = λ, σ = σ))
107+
end
108+
return identity.(post)
109+
end
110+
111+
solver = ACEfit.ASP(P=I, select = :final, loglevel=0, traceFlag=true)
112+
result = ACEfit.solve(solver, At, yt);
113+
post = post_asp_tsvd(result["path"], At, yt, Av, yv);

test/test_linearsolvers.jl

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -168,45 +168,3 @@ C = results["C"]
168168
@test norm(A * C - y) < 10 * epsn
169169
@test norm(C - c_ref) < 1
170170

171-
##
172-
173-
@info(" ... ASP")
174-
shuffled_indices = shuffle(1:length(y))
175-
train_indices = shuffled_indices[1:round(Int, 0.85 * length(y))]
176-
val_indices = shuffled_indices[round(Int, 0.85 * length(y)) + 1:end]
177-
At = A[train_indices,:]
178-
Av = A[val_indices,:]
179-
yt = y[train_indices]
180-
yv = y[val_indices]
181-
182-
for (select, tolr, tolc) in [ (:final, 10*epsn, 1),
183-
( (:byerror,1.3), 10*epsn, 1),
184-
( (:bysize,73), 1, 10) ]
185-
@show select
186-
local solver, results, C
187-
solver = ACEfit.ASP(P=I, select = select, loglevel=0, traceFlag=true)
188-
# without validation
189-
results = ACEfit.solve(solver, A, y)
190-
C = results["C"]
191-
full_path = results["path"]
192-
@show results["nnzs"]
193-
@show norm(A * C - y)
194-
@show norm(C)
195-
@show norm(C - c_ref)
196-
197-
@test norm(A * C - y) < tolr
198-
@test norm(C - c_ref) < tolc
199-
200-
201-
# with validation
202-
results = ACEfit.solve(solver, At, yt, Av, yv)
203-
C = results["C"]
204-
full_path = results["path"]
205-
@show results["nnzs"]
206-
@show norm(Av * C - yv)
207-
@show norm(C)
208-
@show norm(C - c_ref)
209-
210-
@test norm(Av * C - yv) < tolr
211-
@test norm(C - c_ref) < tolc
212-
end

0 commit comments

Comments
 (0)