Skip to content

Commit 6b84ab7

Browse files
authored
added modes to ASP
1 parent daff721 commit 6b84ab7

File tree

1 file changed

+31
-28
lines changed

1 file changed

+31
-28
lines changed

test/test_linearsolvers.jl

Lines changed: 31 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11

22
using ACEfit
33
using LinearAlgebra
4+
using Random
45
using PythonCall
56

67
@info("Test Solver on overdetermined system")
@@ -111,32 +112,34 @@ C = results["C"]
111112
@show norm(C)
112113
@show norm(C - c_ref)
113114

114-
@info(" ... ASP_homotopy selected by error")
115-
solver = ACEfit.ASP(P = P, select = (:byerror,1.5), params = (loglevel=0, traceFlag=true))
116-
results = ACEfit.solve(solver, A, y)
117-
C = results["C"]
118-
full_path = results["path"]
119-
@show results["nnzs"]
120-
@show norm(A * C - y)
121-
@show norm(C)
122-
@show norm(C - c_ref)
123-
124-
@info(" ... ASP_homotopy selected by size")
125-
solver = ACEfit.ASP(P = P, select = (:bysize,50), params = (loglevel=0, traceFlag=true))
126-
results = ACEfit.solve(solver, A, y)
127-
C = results["C"]
128-
full_path = results["path"]
129-
@show results["nnzs"]
130-
@show norm(A * C - y)
131-
@show norm(C)
132-
@show norm(C - c_ref)
133115

134-
@info(" ... ASP_homotopy final solution")
135-
solver = ACEfit.ASP(P = P, select = (:final,nothing), params = (loglevel=0, traceFlag=true))
136-
results = ACEfit.solve(solver, A, y)
137-
C = results["C"]
138-
full_path = results["path"]
139-
@show results["nnzs"]
140-
@show norm(A * C - y)
141-
@show norm(C)
142-
@show norm(C - c_ref)
116+
@info(" ... ASP")
117+
shuffled_indices = shuffle(1:length(y))
118+
train_indices = shuffled_indices[1:round(Int, 0.85 * length(y))]
119+
val_indices = shuffled_indices[round(Int, 0.85 * length(y)) + 1:end]
120+
At = A[train_indices,:]
121+
Av = A[val_indices,:]
122+
yt = y[train_indices]
123+
yv = y[val_indices]
124+
125+
for (sel, mod) in [((:final,nothing),:basic ),( (:byerror,1.3),:basic ),((:bysize,73),:basic )
126+
,((:val,nothing),:smart ),((:byerror,1.3),:smart ),((:bysize,73),:smart )]
127+
solver = ACEfit.ASP(P=I, select= sel, mode=mod ,params = (loglevel=0, traceFlag=true))
128+
if mod == :basic
129+
results = ACEfit.solve(solver, A, y)
130+
C = results["C"]
131+
full_path = results["path"]
132+
@show results["nnzs"]
133+
@show norm(A * C - y)
134+
@show norm(C)
135+
@show norm(C - c_ref)
136+
elseif mod == :smart
137+
results = ACEfit.solve(solver, At, yt, Av, yv)
138+
C = results["C"]
139+
full_path = results["path"]
140+
@show results["nnzs"]
141+
@show norm(Av * C - yv)
142+
@show norm(C)
143+
@show norm(C - c_ref)
144+
end
145+
end

0 commit comments

Comments
 (0)