Skip to content

Commit 0910d0d

Browse files
author
Christoph Ortner
committed
asp bugfixes
1 parent 929bf85 commit 0910d0d

File tree

2 files changed

+18
-17
lines changed

2 files changed

+18
-17
lines changed

src/asp.jl

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@ subject to
1313
```
1414
1515
### Constructor Keyword arguments
16+
1617
```julia
17-
ACEfit.ASP(; P = I, select = (:byerror, 1.0),
18+
ACEfit.ASP(; P = I, select = (:byerror, 1.0), tsvd = false, nstore=100,
1819
params...)
1920
```
2021
@@ -42,33 +43,30 @@ solve(solver::ASP, A, y, Aval=A, yval=y)
4243
```
4344
* `A` : `m`-by-`n` design matrix. (required)
4445
* `b` : `m`-vector. (required)
45-
* `Aval = nothing` : `p`-by-`n` validation matrix (only for `validate` mode).
46-
* `bval = nothing` : `p`- validation vector (only for `validate` mode).
46+
* `Aval = nothing` : `p`-by-`n` validation matrix
47+
* `bval = nothing` : `p`- validation vector
4748
4849
If independent `Aval` and `yval` are provided (instead of detaults `A, y`),
4950
then the solver will use this separate validation set instead of the training
5051
set to select the best solution along the model path.
51-
# """
52-
52+
"""
5353
struct ASP
5454
P
5555
select
56-
mode::Symbol
5756
tsvd::Bool
5857
nstore::Integer
5958
params
6059
end
6160

62-
function ASP(; P = I, select, mode=:train, tsvd=false, nstore=100, params...)
63-
return ASP(P, select, mode, tsvd, nstore, params)
61+
function ASP(; P = I, select, tsvd=false, nstore=100, params...)
62+
return ASP(P, select, tsvd, nstore, params)
6463
end
6564

6665
function solve(solver::ASP, A, y, Aval=A, yval=y)
6766
# Apply preconditioning
6867
AP = A / solver.P
6968
AvalP = Aval / solver.P
70-
71-
tracer = asp_homotopy(AP, y; solver.params...)
69+
tracer = asp_homotopy(AP, y; solver.params..., traceFlag = true)
7270

7371
q = length(tracer)
7472
every = max(1, q / solver.nstore)
@@ -89,7 +87,7 @@ function solve(solver::ASP, A, y, Aval=A, yval=y)
8987

9088
return Dict( "C" => xs,
9189
"path" => new_post,
92-
"nnzs" => length( (new_tracer[in][:solution]).nzind) )
90+
"nnzs" => length( (new_post[in][:solution]).nzind) )
9391
end
9492

9593

test/test_asp.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,18 +28,21 @@ Av = A[val_indices,:]
2828
yt = y[train_indices]
2929
yv = y[val_indices]
3030

31+
3132
for (nstore, n1) in [ (20, 21), (100, 101), (200, 165)]
32-
solver = ACEfit.ASP(P=I, select = :final, nstore = nstore, loglevel=0, traceFlag=true)
33+
solver = ACEfit.ASP(; P=I, select = :final, nstore = nstore, loglevel=0)
3334
results = ACEfit.solve(solver, A, y)
3435
@test length(results["path"]) == n1
3536
end
3637

38+
##
39+
3740
for (select, tolr, tolc) in [ (:final, 10*epsn, 1),
3841
( (:byerror,1.3), 10*epsn, 1),
3942
( (:bysize,73), 1, 10) ]
4043
@show select
4144
local solver, results, C
42-
solver = ACEfit.ASP(P=I, select = select, loglevel=0, traceFlag=true)
45+
solver = ACEfit.ASP(P=I, select = select, loglevel=0)
4346
# without validation
4447
results = ACEfit.solve(solver, A, y)
4548
C = results["C"]
@@ -77,11 +80,11 @@ for (select, tolr, tolc) in [ (:final, 20*epsn, 1.5),
7780
( (:bysize,73), 1, 10) ]
7881
@show select
7982
local solver, results, C
80-
solver_tsvd = ACEfit.ASP(P=I, select=select, mode=:train, tsvd=true,
81-
nstore=100, loglevel=0, traceFlag=true)
83+
solver_tsvd = ACEfit.ASP(P=I, select=select, tsvd=true,
84+
nstore=100, loglevel=0)
8285

83-
solver = ACEfit.ASP(P=I, select=select, mode=:train, tsvd=false,
84-
nstore=100, loglevel=0, traceFlag=true)
86+
solver = ACEfit.ASP(P=I, select=select, tsvd=false,
87+
nstore=100, loglevel=0)
8588
# without validation
8689
results_tsvd = ACEfit.solve(solver_tsvd, A, y)
8790
results = ACEfit.solve(solver, A, y)

0 commit comments

Comments
 (0)