Skip to content

Commit e3aa03a

Browse files
author
Christoph Ortner
committed
Merge branch 'main' into tutorial_08
2 parents 5a906b6 + a720285 commit e3aa03a

28 files changed

+456
-270
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e"
1313
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1414
ChunkSplitters = "ae650224-84b6-46f8-82ea-d812ca08434e"
1515
DynamicPolynomials = "7c1d4256-1411-5781-91ec-d7bc3513ac07"
16+
EmpiricalPotentials = "38527215-9240-4c91-a638-d4250620c9e2"
1617
EquivariantModels = "73ee3e68-46fd-466f-9c56-451dc0291ebc"
1718
ExtXYZ = "352459e4-ddd7-4360-8937-99dcb397b478"
1819
Folds = "41a02a25-b8f0-4f67-bc48-60067656b558"

docs/src/tutorials/basic_julia_workflow.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,3 +85,20 @@ end
8585

8686
# Finally, we delete the model to clean up.
8787
rm("TiAl_model.json")
88+
89+
# ### Fast Evaluator
90+
#
91+
# `ACEpotentials.jl` provides an experimental "fast evaluator". This tries to
92+
# merge some of the operations in the full model resulting in a "slimmer" and
93+
# usually faster evaluator. In some cases the performance gain can be multiple
94+
# factors up to an order of magnitude. This is particularly important when
95+
# using a parameter estimation solver that sparsifies. In that case, the
96+
# performance gain can be significant.
97+
#
98+
# To construct the fast evaluator, simply use
99+
# ```julia
100+
# fpot = fast_evaluator(model)
101+
# ```
102+
# An optional keyword argument `aa_static = true` can be used to optimize the
103+
# n-correlation layer for very small models (at most a few hundred parameters).
104+
# For larger models this leads to a stack overflow.

examples/zuobench/zuo_asp.jl

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
using Distributed, Random, SparseArrays
33
addprocs(10, exeflags="--project=$(Base.active_project())")
44
@everywhere using ACEpotentials, PrettyTables
5+
using ACEpotentials.Models: fast_evaluator
56

67
##
78

@@ -33,32 +34,30 @@ At, yt, Wt = ACEpotentials.assemble(train_data, model)
3334
Av, yv, Wv = ACEpotentials.assemble(val_data, model)
3435

3536
@info("Compute ASP Path")
36-
solver = ACEfit.ASP(; P = P, select = :final, tsvd = true,
37-
actMax = 1000, traceFlag=true )
37+
solver = ACEfit.ASP(; P = P, select = :final, tsvd = true, actMax = 1000 )
3838
asp_result = ACEfit.solve(solver, Wt .* At, Wt .* yt, Wv .* Av, Wv .* yv)
3939

4040
##
4141

4242
@info("Pick solutions for 100, 300, 1000 parameters, compute errors")
4343

44-
@show length(asp_result["path"])
45-
path = asp_result["path"]
46-
nnzs = [ nnz(p.solution) for p in path ]
47-
I1000 = length(nnzs)
48-
I300 = findfirst(nnzs .>= 300)
49-
I100 = findfirst(nnzs .>= 100)
50-
51-
model_1000 = deepcopy(model)
52-
set_parameters!(model_1000, path[I1000].solution)
53-
model_300 = deepcopy(model)
54-
set_parameters!(model_300, path[I300].solution)
55-
model_100 = deepcopy(model)
56-
set_parameters!(model_100, path[I100].solution)
57-
58-
err_100 = ACEpotentials.linear_errors(test_data, model_100)
59-
err_300 = ACEpotentials.linear_errors(test_data, model_300)
60-
err_1000 = ACEpotentials.linear_errors(test_data, model_1000)
61-
44+
# select models from the model path
45+
model_1000 = set_parameters!( deepcopy(model),
46+
ACEfit.asp_select(asp_result, :final)[1])
47+
model_300 = set_parameters!( deepcopy(model),
48+
ACEfit.asp_select(asp_result, (:bysize, 300))[1])
49+
model_100 = set_parameters!( deepcopy(model),
50+
ACEfit.asp_select(asp_result, (:bysize, 100))[1])
51+
52+
# generate sparsified, faster evaluators
53+
pot_1000 = fast_evaluator(model_1000; aa_static = false) # static can cause stack overflow
54+
pot_300 = fast_evaluator(model_300; aa_static = true)
55+
pot_100 = fast_evaluator(model_100; aa_static = true)
56+
57+
@info("Evaluate errors on the test set")
58+
err_100 = ACEpotentials.linear_errors(test_data, pot_100)
59+
err_300 = ACEpotentials.linear_errors(test_data, pot_300)
60+
err_1000 = ACEpotentials.linear_errors(test_data, pot_1000)
6261

6362
##
6463

src/ACEpotentials.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,18 @@ include("ace1_compat.jl")
1515
# Fitting
1616
include("atoms_data.jl")
1717
include("fit_model.jl")
18+
include("repulsion_restraint.jl")
1819

1920
# Data
2021
include("example_data.jl")
2122

2223
# Misc
23-
# TODO: all of this just needs to be moved from JuLIP to AtomsBase
2424
include("analysis/dataset_analysis.jl")
2525
include("analysis/potential_analysis.jl")
2626
include("descriptor.jl")
2727

2828

2929
# TODO: to be completely rewritten
30-
# include("io.jl")
3130
# include("export.jl")
3231

3332
# Experimental
@@ -41,19 +40,23 @@ import ACEpotentials.ACE1compat: ace1_model
4140
import ACEpotentials.Models: algebraic_smoothness_prior,
4241
exp_smoothness_prior,
4342
gaussian_smoothness_prior,
44-
set_parameters!
43+
set_parameters!,
44+
fast_evaluator,
45+
@committee,
46+
set_committee!
4547
import JSON
4648

4749
export ace1_model,
4850
length_basis,
4951
algebraic_smoothness_prior,
5052
exp_smoothness_prior,
5153
gaussian_smoothness_prior,
52-
set_parameters!
53-
54+
set_parameters!,
55+
fast_evaluator,
56+
@committee,
57+
set_committee!
5458

5559
include("json_interface.jl")
5660

5761

58-
5962
end

src/ace1_compat.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ const _kw_defaults = Dict(:elements => nothing,
3737
:pair_envelope => (:r, 2),
3838
#
3939
:Eref => missing,
40+
:ZBL => false,
4041
#
4142
:variable_cutoffs => false,
4243
)
@@ -300,8 +301,12 @@ function _pair_basis(kwargs)
300301
# here we use a similar convention, just need to convert to ace1-style
301302
envelope = kwargs[:pair_envelope]
302303
if envelope isa Tuple && envelope[1] == :r
304+
if kwargs[:ZBL]
305+
@warn("""It is not recommended to combine the ZBL reference potential
306+
with a repulsive pair basis. Use `pair_envelope = (:x, 0, q)` instead.""")
307+
end
303308
envelope = (:r_ace1, envelope[2])
304-
end
309+
end
305310

306311
pair_basis = ace_learnable_Rnlrzz(; spec = pair_spec,
307312
maxq = maxq,
@@ -322,6 +327,11 @@ end
322327

323328
function ace1_model(; kwargs...)
324329

330+
# change the default for the envelope if ZBL is used
331+
if haskey(kwargs, :ZBL) && kwargs[:ZBL] && !haskey(kwargs, :envelope)
332+
kwargs = (; pair_envelope = (:x, 0, 2), kwargs...)
333+
end
334+
325335
model_spec = Dict{Symbol, Any}(:model_name => "ACE1", kwargs...)
326336

327337
kwargs = _clean_args(kwargs)
@@ -372,9 +382,11 @@ function ace1_model(; kwargs...)
372382
E0s = Dict([ key => val * u"eV" for (key, val) in Eref]...)
373383
end
374384

385+
375386
model = Models.ace_model(; elements=elements,
376387
order = cor_order,
377388
Ytype = :spherical,
389+
ZBL = kwargs[:ZBL],
378390
E0s = E0s,
379391
rbasis = rbasis,
380392
pair_basis = pairbasis,

src/atoms_data.jl

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ function _getfuzzy(coll, key)
3333
end
3434

3535
_issimilarkey(k1, k2) = lowercase(String(k1)) == lowercase(String(k2))
36+
_issimilarkey(k1::Nothing, k2) = false
37+
_issimilarkey(k1, k2::Nothing) = false
38+
_issimilarkey(k1::Nothing, k2::Nothing) = false
3639

3740
function _find_similar_key(coll, key)
3841
for k in keys(coll)
@@ -43,38 +46,38 @@ function _find_similar_key(coll, key)
4346
return nothing
4447
end
4548

46-
function _find_similar_key(sys::ExtXYZ.Atoms, key)
47-
for k in keys(sys.system_data)
49+
function _find_similar_key(sys::AbstractSystem, key)
50+
for k in keys(sys)
4851
if _issimilarkey(k, key)
4952
return k
5053
end
5154
end
52-
for k in keys(sys.atom_data)
55+
for k in atomkeys(sys)
5356
if _issimilarkey(k, key)
5457
return k
5558
end
5659
end
5760
return nothing
5861
end
5962

60-
function _get_data_fuzzy(sys::ExtXYZ.Atoms, key)
63+
function _get_data_fuzzy(sys::AbstractSystem, key)
6164
k = _find_similar_key(sys, key)
6265
if k == nothing
6366
error("Couldn't find $key or similar in collection with keys $(keys(sys))")
6467
end
65-
if haskey(sys.system_data, k)
66-
return sys.system_data[k]
68+
if haskey(sys, k)
69+
return sys[k]
6770
end
68-
return sys.atom_data[k]
71+
return sys[:, k]
6972
end
7073

7174
_has_similar_key(coll, key) = (_find_similar_key(coll, key) != nothing)
7275

73-
function _get_data(sys::ExtXYZ.Atoms, key)
74-
if haskey(sys.system_data, key)
75-
return sys.system_data[key]
76-
elseif haskey(sys.atom_data, key)
77-
return sys.atom_data[key]
76+
function _get_data(sys::AbstractSystem, key)
77+
if haskey(sys, key)
78+
return sys[key]
79+
elseif hasatomkey(sys, key)
80+
return sys[:, key]
7881
else
7982
error("Couldn't find $key in System")
8083
end

src/fit_model.jl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,15 @@ function acefit!(raw_data::AbstractArray{<: AbstractSystem}, model;
110110
end
111111

112112
if repulsion_restraint
113-
error("Repulsion restraint is currently not implemented")
113+
restraint_data = _rep_dimer_data_atomsbase(
114+
model;
115+
weight = restraint_weight,
116+
energy_key = Symbol(energy_key),
117+
kwargs...
118+
)
119+
append!(data, restraint_data)
120+
# return nothing
121+
# error("Repulsion restraint is currently not implemented")
114122
# if eltype(data) == AtomsData
115123
# append!(data, _rep_dimer_data(model; weight = restraint_weight))
116124
# else
@@ -190,11 +198,11 @@ function linear_errors(raw_data::AbstractArray{<: AbstractSystem}, model;
190198
virial_key = "virial",
191199
weights = default_weights(),
192200
verbose = true,
193-
return_efv = false
201+
return_efv = false,
194202
)
195203
data = [ AtomsData(at; energy_key = energy_key, force_key=force_key,
196204
virial_key = virial_key, weights = weights,
197-
v_ref = _get_Vref(model))
205+
v_ref = nothing)
198206
for at in raw_data ]
199207
return linear_errors(data, model; verbose=verbose, return_efv = return_efv)
200208
end

src/models/ace.jl

Lines changed: 9 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -92,35 +92,16 @@ function _make_idx_AA_spec(AA_spec, A_spec)
9292
return AA_spec_idx
9393
end
9494

95-
function _make_Vref_E0s(rbasis, E0s::Nothing)
96-
NZ = _get_nz(rbasis)
97-
return _make_Vref_E0s(rbasis, [ _i2z(rbasis, i) => 0.0 for i = 1:NZ ] )
98-
end
99-
100-
# E0s can be anything with (key, value) pairs
101-
function _make_Vref_E0s(rbasis, E0s)
102-
_convert_E0s(E0s::Union{Dict, NamedTuple}) = E0s
103-
_convert_E0s(E0s::Union{AbstractVector, Tuple}) = Dict(E0s...)
104-
_convert_E0s(E0s) = error("E0s must be nothing, a NamedTuple, Dict or list of pairs")
105-
106-
NZ = _get_nz(rbasis)
107-
V0 = OneBody(_convert_E0s(E0s))
108-
if length(V0.E0) != NZ
109-
error("E0s must have the right number of elements")
110-
end
111-
112-
return V0
113-
end
114-
11595

11696
function _generate_ace_model(rbasis, Ytype::Symbol, AA_spec::AbstractVector,
97+
Vref,
11798
level = TotalDegree(),
11899
pair_basis = nothing,
119-
E0s = nothing,
120-
Vref = _make_Vref_E0s(rbasis, E0s), )
100+
)
121101

122-
# storing E0s with unit
123-
model_meta = Dict{String, Any}("E0s" => deepcopy(E0s))
102+
# # storing E0s with unit
103+
# model_meta = Dict{String, Any}("E0s" => deepcopy(E0s))
104+
model_meta = Dict{String, Any}()
124105

125106
# generate the coupling coefficients
126107
cgen = EquivariantModels.Rot3DCoeffs_real(0)
@@ -170,8 +151,8 @@ end
170151
# since it is implicitly already encoded in AA_spec. We need a
171152
# function `auto_level` that generates level automagically from AA_spec.
172153
function ace_model(rbasis, Ytype, AA_spec::AbstractVector, level,
173-
pair_basis, E0s = nothing)
174-
return _generate_ace_model(rbasis, Ytype, AA_spec, level, pair_basis, E0s)
154+
pair_basis, Vref)
155+
return _generate_ace_model(rbasis, Ytype, AA_spec, Vref, level, pair_basis)
175156
end
176157

177158
# NOTE : a nicer convenience constructor is also provided in `ace_heuristics.jl`
@@ -320,9 +301,8 @@ function evaluate(model::ACEModel,
320301
val += dot(Apair, (@view ps.Wpair[:, i_z0]))
321302
end
322303
# -------------------
323-
# TODO - Vref : assume it is a OneBody
324-
@assert model.Vref isa OneBody
325-
val += model.Vref.E0[Z0]
304+
# Vref
305+
val += eval_site(model.Vref, Rs, Zs, Z0)
326306
# -------------------
327307

328308
end # @no_escape

0 commit comments

Comments
 (0)