@@ -48,35 +48,45 @@ solve(solver::ASP, A, y, Aval=A, yval=y)
4848If independent `Aval` and `yval` are provided (instead of detaults `A, y`),
4949then the solver will use this separate validation set instead of the training
5050set to select the best solution along the model path.
51- """
51+ # """
52+
5253struct ASP
5354 P
5455 select
56+ mode:: Symbol
57+ tsvd:: Bool
58+ nstore:: Integer
5559 params
5660end
5761
58- function ASP (; P = I, select, mode= :train , params... )
59- return ASP (P, select, params)
62+ function ASP (; P = I, select, mode= :train , tsvd = false , nstore = 100 , params... )
63+ return ASP (P, select, mode, tsvd, nstore, params)
6064end
6165
6266function solve (solver:: ASP , A, y, Aval= A, yval= y)
6367 # Apply preconditioning
6468 AP = A / solver. P
6569
6670 tracer = asp_homotopy (AP, y; solver. params... )
67- q = length (tracer)
68- new_tracer = Vector {NamedTuple{(:solution, :λ), Tuple{Any, Any}}} (undef, q)
6971
70- for i in 1 : q
71- new_tracer[i] = (solution = solver. P \ tracer[i][1 ], λ = tracer[i][2 ])
72+ q = length (tracer)
73+ every = max (1 , q ÷ solver. nstore)
74+ istore = unique ([1 : every: q; q])
75+ new_tracer = [ (solution = solver. P \ tracer[i][1 ], λ = tracer[i][2 ], σ = 0.0 )
76+ for i in istore ]
77+
78+ if solver. tsvd # Post-processing if tsvd is true
79+ post = post_asp_tsvd (new_tracer, A, y, Aval, yval)
80+ new_post = [ (solution = p. θ, λ = p. λ, σ = p. σ) for p in post ]
81+ else
82+ new_post = new_tracer
7283 end
7384
74- xs, in = select_solution (new_tracer , solver, Aval, yval)
85+ xs, in = select_solution (new_post , solver, Aval, yval)
7586
76- # println("done.")
77- return Dict ( " C" => xs,
78- " path" => new_tracer,
79- " nnzs" => length ((new_tracer[in][:solution ]). nzind) )
87+ return Dict ( " C" => xs,
88+ " path" => new_post,
89+ " nnzs" => length ( (new_tracer[in][:solution ]). nzind) )
8090end
8191
8292
@@ -114,44 +124,56 @@ function select_solution(tracer, solver, A, y)
114124end
115125
116126
117- #=
118- function select_smart(tracer, solver, Aval, yval)
119127
120- best_metric = Inf
121- best_iteration = 0
122- validation_metric = 0
123- q = length(tracer)
124- errors = [norm(Aval * t[:solution] - yval) for t in tracer]
125- nnzss = [(t[:solution]).nzind for t in tracer]
126- best_iteration = argmin(errors)
127- validation_metric = errors[best_iteration]
128- validation_end = norm(Aval * tracer[end][:solution] - yval)
128+ using SparseArrays
129129
130- if validation_end < validation_metric #make sure to check the last one too in case q<<100
131- best_iteration = q
132- end
130+ function solve_tsvd (At, yt, Av, yv)
131+ Ut, Σt, Vt = svd (At); zt = Ut' * yt
132+ Qv, Rv = qr (Av); zv = Matrix (Qv)' * yv
133+ @assert issorted (Σt, rev= true )
133134
134- criterion, threshold = solver.select
135-
136- if criterion == :val
137- return tracer[best_iteration][:solution], best_iteration
138-
139- elseif criterion == :byerror
140- for (i, error) in enumerate(errors)
141- if error <= threshold * validation_metric
142- return tracer[i][:solution], i
143- end
144- end
135+ Rv_Vt = Rv * Vt
145136
146- elseif criterion == :bysize
147- first_index = findfirst(sublist -> threshold in sublist, nnzss)
148- relevant_errors = errors[1:first_index - 1]
149- min_error = minimum(relevant_errors)
150- min_error_index = findfirst(==(min_error), relevant_errors)
151- return tracer[min_error_index][:solution], min_error_index
137+ θv = zeros (size (Av, 2 ))
138+ θv[1 ] = zt[1 ] / Σt[1 ]
139+ rv = Rv_Vt[:, 1 ] * θv[1 ] - zv
152140
153- else
154- @error("Unknown selection criterion: $criterion")
155- end
141+ tsvd_errs = Float64[]
142+ push! (tsvd_errs, norm (rv))
143+
144+ for k = 2 : length (Σt)
145+ θv[k] = zt[k] / Σt[k]
146+ rv += Rv_Vt[:, k] * θv[k]
147+ push! (tsvd_errs, norm (rv))
148+ end
149+
150+ imin = argmin (tsvd_errs)
151+ θv[imin+ 1 : end ] .= 0
152+ return Vt * θv, Σt[imin]
153+ end
154+
155+ function post_asp_tsvd (path, At, yt, Av, yv)
156+ Qt, Rt = qr (At); zt = Matrix (Qt)' * yt
157+ Qv, Rv = qr (Av); zv = Matrix (Qv)' * yv
158+
159+ function _post (θλ)
160+ (θ, λ) = θλ
161+ if isempty (θ. nzind); return (θ = θ, λ = λ, σ = Inf ); end
162+ inz = θ. nzind
163+ θ1, σ = solve_tsvd (Rt[:, inz], zt, Rv[:, inz], zv)
164+ θ2 = copy (θ); θ2[inz] .= θ1
165+ return (θ = θ2, λ = λ, σ = σ)
166+ end
167+
168+ return _post .(path)
169+
170+ # post = []
171+ # for (θ, λ) in path
172+ # if isempty(θ.nzind); push!(post, (θ = θ, λ = λ, σ = Inf)); continue; end
173+ # inz = θ.nzind
174+ # θ1, σ = solve_tsvd(Rt[:, inz], zt, Rv[:, inz], zv)
175+ # θ2 = copy(θ); θ2[inz] .= θ1
176+ # push!(post, (θ = θ2, λ = λ, σ = σ))
177+ # end
178+ # return identity.(post)
156179end
157- =#
0 commit comments