Skip to content

Commit 83dbee5

Browse files
author
Christoph Ortner
committed
cleanup, activate asp tests
1 parent 59723c5 commit 83dbee5

File tree

3 files changed

+27
-66
lines changed

3 files changed

+27
-66
lines changed

src/asp.jl

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -69,24 +69,24 @@ function solve(solver::ASP, A, y, Aval=A, yval=y)
6969

7070
tracer = asp_homotopy(AP, y; solver.params...)
7171

72-
q = length(tracer)
73-
every = max(1, q ÷ solver.nstore)
74-
new_tracer = Vector{NamedTuple{(:solution, :λ), Tuple{Any, Any}}}(undef, q)
75-
new_tracer = [(solution = solver.P \ tracer[i][1], λ = tracer[i][2]) for i in [1:every:q; q]]
72+
q = length(tracer)
73+
every = max(1, q ÷ solver.nstore)
74+
istore = unique([1:every:q; q])
75+
new_tracer = [ (solution = solver.P \ tracer[i][1], λ = tracer[i][2], σ = 0.0 )
76+
for i in istore ]
7677

7778
if solver.tsvd # Post-processing if tsvd is true
7879
post = post_asp_tsvd(new_tracer, A, y, Aval, yval)
79-
new_post = [(solution = p.θ, λ = p.λ) for p in post]
80+
new_post = [ (solution = p.θ, λ = p.λ, σ = p.σ) for p in post ]
8081
else
8182
new_post = new_tracer
8283
end
8384

8485
xs, in = select_solution(new_post, solver, Aval, yval)
8586

86-
# println("done.")
8787
return Dict( "C" => xs,
8888
"path" => new_post,
89-
"nnzs" => length((new_tracer[in][:solution]).nzind) )
89+
"nnzs" => length( (new_tracer[in][:solution]).nzind) )
9090
end
9191

9292

@@ -156,13 +156,24 @@ function post_asp_tsvd(path, At, yt, Av, yv)
156156
Qt, Rt = qr(At); zt = Matrix(Qt)' * yt
157157
Qv, Rv = qr(Av); zv = Matrix(Qv)' * yv
158158

159-
post = []
160-
for (θ, λ) in path
161-
if isempty.nzind); push!(post,= θ, λ = λ, σ = Inf)); continue; end
159+
function _post(θλ)
160+
(θ, λ) = θλ
161+
if isempty.nzind); return= θ, λ = λ, σ = Inf); end
162162
inz = θ.nzind
163163
θ1, σ = solve_tsvd(Rt[:, inz], zt, Rv[:, inz], zv)
164164
θ2 = copy(θ); θ2[inz] .= θ1
165-
push!(post, (θ = θ2, λ = λ, σ = σ))
166-
end
167-
return identity.(post)
165+
return= θ2, λ = λ, σ = σ)
166+
end
167+
168+
return _post.(path)
169+
170+
# post = []
171+
# for (θ, λ) in path
172+
# if isempty(θ.nzind); push!(post, (θ = θ, λ = λ, σ = Inf)); continue; end
173+
# inz = θ.nzind
174+
# θ1, σ = solve_tsvd(Rt[:, inz], zt, Rv[:, inz], zv)
175+
# θ2 = copy(θ); θ2[inz] .= θ1
176+
# push!(post, (θ = θ2, λ = λ, σ = σ))
177+
# end
178+
# return identity.(post)
168179
end

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,7 @@ using Test
88

99
@testset "Linear Solvers" begin include("test_linearsolvers.jl") end
1010

11+
@testset "ASP" begin include("test_asp.jl") end
12+
1113
@testset "MLJ Solvers" begin include("test_mlj.jl") end
1214
end

test/test_asp.jl

Lines changed: 1 addition & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ using Random
88

99
Random.seed!(1234)
1010
Nobs = 10_000
11-
Nfeat = 300
11+
Nfeat = 100
1212
A1 = randn(Nobs, Nfeat) / sqrt(Nobs)
1313
U, S1, V = svd(A1)
1414
S = 1e-4 .+ ((S1 .- S1[end]) / (S1[1] - S1[end])).^2
@@ -109,55 +109,3 @@ for (select, tolr, tolc) in [ (:final, 20*epsn, 1.5),
109109
end
110110
end
111111

112-
113-
##
114-
115-
# Experimental Implementation of tsvd postprocessing
116-
117-
118-
# using SparseArrays
119-
120-
# function solve_tsvd(At, yt, Av, yv)
121-
# Ut, Σt, Vt = svd(At); zt = Ut' * yt
122-
# Qv, Rv = qr(Av); zv = Matrix(Qv)' * yv
123-
# @assert issorted(Σt, rev=true)
124-
125-
# Rv_Vt = Rv * Vt
126-
127-
# θv = zeros(size(Av, 2))
128-
# θv[1] = zt[1] / Σt[1]
129-
# rv = Rv_Vt[:, 1] * θv[1] - zv
130-
131-
# tsvd_errs = Float64[]
132-
# push!(tsvd_errs, norm(rv))
133-
134-
# for k = 2:length(Σt)
135-
# θv[k] = zt[k] / Σt[k]
136-
# rv += Rv_Vt[:, k] * θv[k]
137-
# push!(tsvd_errs, norm(rv))
138-
# end
139-
140-
# imin = argmin(tsvd_errs)
141-
# θv[imin+1:end] .= 0
142-
# return Vt * θv, Σt[imin]
143-
# end
144-
145-
# function post_asp_tsvd(path, At, yt, Av, yv)
146-
# Qt, Rt = qr(At); zt = Matrix(Qt)' * yt
147-
# Qv, Rv = qr(Av); zv = Matrix(Qv)' * yv
148-
149-
# post = []
150-
# for (θ, λ) in path
151-
# if isempty(θ.nzind); push!(post, (θ = θ, λ = λ, σ = Inf)); continue; end
152-
# inz = θ.nzind
153-
# θ1, σ = solve_tsvd(Rt[:, inz], zt, Rv[:, inz], zv)
154-
# θ2 = copy(θ); θ2[inz] .= θ1
155-
# push!(post, (θ = θ2, λ = λ, σ = σ))
156-
# end
157-
# return identity.(post)
158-
# end
159-
160-
# solver = ACEfit.ASP(P=I, select = :final, loglevel=0, traceFlag=true)
161-
# result = ACEfit.solve(solver, At, yt);
162-
# post = post_asp_tsvd(result["path"], At, yt, Av, yv);
163-

0 commit comments

Comments
 (0)