|
| 1 | +using ACEfit |
| 2 | +using LinearAlgebra, Random, Test |
| 3 | +using Random |
| 4 | + |
| 5 | +## |
| 6 | + |
| 7 | +@info("Test Solver on overdetermined system") |
| 8 | + |
| 9 | +Random.seed!(1234) |
| 10 | +Nobs = 10_000 |
| 11 | +Nfeat = 300 |
| 12 | +A1 = randn(Nobs, Nfeat) / sqrt(Nobs) |
| 13 | +U, S1, V = svd(A1) |
| 14 | +S = 1e-4 .+ ((S1 .- S1[end]) / (S1[1] - S1[end])).^2 |
| 15 | +A = U * Diagonal(S) * V' |
| 16 | +c_ref = randn(Nfeat) |
| 17 | +epsn = 1e-5 |
| 18 | +y = A * c_ref + epsn * randn(Nobs) / sqrt(Nobs) |
| 19 | +P = Diagonal(1.0 .+ rand(Nfeat)) |
| 20 | + |
| 21 | +## |
| 22 | + |
| 23 | +@info(" ... ASP") |
| 24 | +shuffled_indices = shuffle(1:length(y)) |
| 25 | +train_indices = shuffled_indices[1:round(Int, 0.85 * length(y))] |
| 26 | +val_indices = shuffled_indices[round(Int, 0.85 * length(y)) + 1:end] |
| 27 | +At = A[train_indices,:] |
| 28 | +Av = A[val_indices,:] |
| 29 | +yt = y[train_indices] |
| 30 | +yv = y[val_indices] |
| 31 | + |
| 32 | +for (select, tolr, tolc) in [ (:final, 10*epsn, 1), |
| 33 | + ( (:byerror,1.3), 10*epsn, 1), |
| 34 | + ( (:bysize,73), 1, 10) ] |
| 35 | + @show select |
| 36 | + local solver, results, C |
| 37 | + solver = ACEfit.ASP(P=I, select = select, loglevel=0, traceFlag=true) |
| 38 | + # without validation |
| 39 | + results = ACEfit.solve(solver, A, y) |
| 40 | + C = results["C"] |
| 41 | + full_path = results["path"] |
| 42 | + @show results["nnzs"] |
| 43 | + @show norm(A * C - y) |
| 44 | + @show norm(C) |
| 45 | + @show norm(C - c_ref) |
| 46 | + |
| 47 | + @test norm(A * C - y) < tolr |
| 48 | + @test norm(C - c_ref) < tolc |
| 49 | + |
| 50 | + |
| 51 | + # with validation |
| 52 | + results = ACEfit.solve(solver, At, yt, Av, yv) |
| 53 | + C = results["C"] |
| 54 | + full_path = results["path"] |
| 55 | + @show results["nnzs"] |
| 56 | + @show norm(Av * C - yv) |
| 57 | + @show norm(C) |
| 58 | + @show norm(C - c_ref) |
| 59 | + |
| 60 | + @test norm(Av * C - yv) < tolr |
| 61 | + @test norm(C - c_ref) < tolc |
| 62 | +end |
| 63 | + |
| 64 | +## |
| 65 | + |
| 66 | +# Experimental Implementation of tsvd postprocessing |
| 67 | + |
| 68 | + |
| 69 | +using SparseArrays |
| 70 | + |
| 71 | +function solve_tsvd(At, yt, Av, yv) |
| 72 | + Ut, Σt, Vt = svd(At); zt = Ut' * yt |
| 73 | + Qv, Rv = qr(Av); zv = Matrix(Qv)' * yv |
| 74 | + @assert issorted(Σt, rev=true) |
| 75 | + |
| 76 | + Rv_Vt = Rv * Vt |
| 77 | + |
| 78 | + θv = zeros(size(Av, 2)) |
| 79 | + θv[1] = zt[1] / Σt[1] |
| 80 | + rv = Rv_Vt[:, 1] * θv[1] - zv |
| 81 | + |
| 82 | + tsvd_errs = Float64[] |
| 83 | + push!(tsvd_errs, norm(rv)) |
| 84 | + |
| 85 | + for k = 2:length(Σt) |
| 86 | + θv[k] = zt[k] / Σt[k] |
| 87 | + rv += Rv_Vt[:, k] * θv[k] |
| 88 | + push!(tsvd_errs, norm(rv)) |
| 89 | + end |
| 90 | + |
| 91 | + imin = argmin(tsvd_errs) |
| 92 | + θv[imin+1:end] .= 0 |
| 93 | + return Vt * θv, Σt[imin] |
| 94 | +end |
| 95 | + |
| 96 | +function post_asp_tsvd(path, At, yt, Av, yv) |
| 97 | + Qt, Rt = qr(At); zt = Matrix(Qt)' * yt |
| 98 | + Qv, Rv = qr(Av); zv = Matrix(Qv)' * yv |
| 99 | + |
| 100 | + post = [] |
| 101 | + for (θ, λ) in path |
| 102 | + if isempty(θ.nzind); push!(post, (θ = θ, λ = λ, σ = Inf)); continue; end |
| 103 | + inz = θ.nzind |
| 104 | + θ1, σ = solve_tsvd(Rt[:, inz], zt, Rv[:, inz], zv) |
| 105 | + θ2 = copy(θ); θ2[inz] .= θ1 |
| 106 | + push!(post, (θ = θ2, λ = λ, σ = σ)) |
| 107 | + end |
| 108 | + return identity.(post) |
| 109 | +end |
| 110 | + |
| 111 | +solver = ACEfit.ASP(P=I, select = :final, loglevel=0, traceFlag=true) |
| 112 | +result = ACEfit.solve(solver, At, yt); |
| 113 | +post = post_asp_tsvd(result["path"], At, yt, Av, yv); |
0 commit comments