|
2 | 2 | using Distributed, Random, SparseArrays |
3 | 3 | addprocs(10, exeflags="--project=$(Base.active_project())") |
4 | 4 | @everywhere using ACEpotentials, PrettyTables |
| 5 | +using ACEpotentials.Models: fast_evaluator |
5 | 6 |
|
6 | 7 | ## |
7 | 8 |
|
@@ -33,32 +34,30 @@ At, yt, Wt = ACEpotentials.assemble(train_data, model) |
33 | 34 | Av, yv, Wv = ACEpotentials.assemble(val_data, model) |
34 | 35 |
|
35 | 36 | @info("Compute ASP Path") |
36 | | -solver = ACEfit.ASP(; P = P, select = :final, tsvd = true, |
37 | | - actMax = 1000, traceFlag=true ) |
| 37 | +solver = ACEfit.ASP(; P = P, select = :final, tsvd = true, actMax = 1000 ) |
38 | 38 | asp_result = ACEfit.solve(solver, Wt .* At, Wt .* yt, Wv .* Av, Wv .* yv) |
39 | 39 |
|
40 | 40 | ## |
41 | 41 |
|
42 | 42 | @info("Pick solutions for 100, 300, 1000 parameters, compute errors") |
43 | 43 |
|
44 | | -@show length(asp_result["path"]) |
45 | | -path = asp_result["path"] |
46 | | -nnzs = [ nnz(p.solution) for p in path ] |
47 | | -I1000 = length(nnzs) |
48 | | -I300 = findfirst(nnzs .>= 300) |
49 | | -I100 = findfirst(nnzs .>= 100) |
50 | | - |
51 | | -model_1000 = deepcopy(model) |
52 | | -set_parameters!(model_1000, path[I1000].solution) |
53 | | -model_300 = deepcopy(model) |
54 | | -set_parameters!(model_300, path[I300].solution) |
55 | | -model_100 = deepcopy(model) |
56 | | -set_parameters!(model_100, path[I100].solution) |
57 | | - |
58 | | -err_100 = ACEpotentials.linear_errors(test_data, model_100) |
59 | | -err_300 = ACEpotentials.linear_errors(test_data, model_300) |
60 | | -err_1000 = ACEpotentials.linear_errors(test_data, model_1000) |
61 | | - |
| 44 | +# select models from the model path |
| 45 | +model_1000 = set_parameters!( deepcopy(model), |
| 46 | + ACEfit.asp_select(asp_result, :final)[1]) |
| 47 | +model_300 = set_parameters!( deepcopy(model), |
| 48 | + ACEfit.asp_select(asp_result, (:bysize, 300))[1]) |
| 49 | +model_100 = set_parameters!( deepcopy(model), |
| 50 | + ACEfit.asp_select(asp_result, (:bysize, 100))[1]) |
| 51 | + |
| 52 | +# generate sparsified, faster evaluators |
| 53 | +pot_1000 = fast_evaluator(model_1000; aa_static = false) # static can cause stack overflow |
| 54 | +pot_300 = fast_evaluator(model_300; aa_static = true) |
| 55 | +pot_100 = fast_evaluator(model_100; aa_static = true) |
| 56 | + |
| 57 | +@info("Evaluate errors on the test set") |
| 58 | +err_100 = ACEpotentials.linear_errors(test_data, pot_100) |
| 59 | +err_300 = ACEpotentials.linear_errors(test_data, pot_300) |
| 60 | +err_1000 = ACEpotentials.linear_errors(test_data, pot_1000) |
62 | 61 |
|
63 | 62 | ## |
64 | 63 |
|
|
0 commit comments