Skip to content

Commit bddf71d

Browse files
authored
Merge pull request ACEsuit#269 from ACEsuit/co/fixes
Small fixes
2 parents ec34a73 + 8161809 commit bddf71d

File tree

3 files changed

+33
-9
lines changed

3 files changed

+33
-9
lines changed

src/ace1_compat.jl

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -155,9 +155,19 @@ function _get_all_rcut(kwargs; _rcut = kwargs[:rcut])
155155
end
156156

157157

158-
function _rin0cuts_rcut(zlist, cutoffs::Dict)
158+
function _rin0cuts_rcut(zlist, cutoffs::Dict, kwargs = nothing)
159+
function _get_r0(zi, zj)
160+
if kwargs == nothing
161+
return DefaultHypers.bond_len(zi, zj)
162+
elseif kwargs[:r0] == :bondlen
163+
return DefaultHypers.bond_len(zi, zj)
164+
elseif kwargs[:r0] isa Number
165+
return kwargs[:r0]
166+
end
167+
error("Cannot determine r0($zi, $zj) from the arguments provided.")
168+
end
159169
function rin0cut(zi, zj)
160-
r0 = DefaultHypers.bond_len(zi, zj)
170+
r0 = _get_r0(zi, zj)
161171
rin, rcut = cutoffs[zi, zj]
162172
return (rin = rin, r0 = r0, rcut = rcut)
163173
end
@@ -166,18 +176,23 @@ function _rin0cuts_rcut(zlist, cutoffs::Dict)
166176
end
167177

168178

169-
function _ace1_rin0cuts(kwargs; rcutkey = :rcut)
179+
function _ace1_rin0cuts(kwargs; rcutkey = :rcut, rinkey = :rin)
170180
elements = _get_elements(kwargs)
171181
rcut = _get_all_rcut(kwargs; _rcut = kwargs[rcutkey])
182+
if kwargs[:rin] isa Number
183+
rin = kwargs[:rin]
184+
else
185+
error("Cannot read rin; please provide a number of file an issue if a more general mechanism is needed.")
186+
end
172187
if rcut isa Number
173-
cutoffs = Dict([ (s1, s2) => (0.0, rcut) for s1 in elements, s2 in elements]...)
188+
cutoffs = Dict([ (s1, s2) => (rin, rcut) for s1 in elements, s2 in elements]...)
174189
else
175-
cutoffs = Dict([ (s1, s2) => (0.0, rcut[(s1, s2)]) for s1 in elements, s2 in elements]...)
190+
cutoffs = Dict([ (s1, s2) => (rin, rcut[(s1, s2)]) for s1 in elements, s2 in elements]...)
176191
end
177192
# rcut = maximum(values(rcut)) # multitransform wants a single cutoff.
178193

179194
# construct the rin0cut structures
180-
rin0cuts = _rin0cuts_rcut(elements, cutoffs)
195+
rin0cuts = _rin0cuts_rcut(elements, cutoffs, kwargs)
181196
end
182197

183198

src/fit_model.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@ export acefit!, assemble, compute_errors
1414

1515
_get_Vref(model::ACEPotential) = model.model.Vref
1616

17-
__set_params!(model::ACEPotential, coeffs) = ACEpotentials.Models.set_parameters!(model, coeffs)
18-
1917
default_weights() = Dict("default"=>Dict("E"=>30.0, "F"=>1.0, "V"=>1.0))
2018

2119
function _make_prior(model::ACEpotentials.Models.ACEPotential, smoothness, P)
@@ -163,7 +161,7 @@ function acefit!(raw_data::AbstractArray{<: AbstractSystem}, model;
163161
coeffs = P \ result["C"]
164162

165163
# dispatch setting of parameters
166-
__set_params!(model, coeffs)
164+
ACEpotentials.Models.set_linear_parameters!(model, coeffs)
167165

168166
if haskey(result, "committee")
169167
co_coeffs = result["committee"]

src/models/calculators.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,17 @@ function set_parameters!(V::ACEPotential, θ::AbstractVector)
5454
return set_parameters!(V, ps)
5555
end
5656

57+
function set_linear_parameters!(V::ACEPotential{<: ACEModel}, θ::AbstractVector)
58+
ps = V.ps
59+
ps1 = (WB = ps.WB, Wpair = ps.Wpair,)
60+
ps1_vec, _restruct = destructure(ps1)
61+
ps2 = _restruct(θ)
62+
ps3 = deepcopy(ps)
63+
ps3.WB[:] = ps2.WB
64+
ps3.Wpair[:] = ps2.Wpair
65+
return set_parameters!(V, ps3)
66+
end
67+
5768
# ---------------------------------------------------------------
5869
# AtomsCalculatorsUtilities / SitePotential based implementation
5970
#

0 commit comments

Comments
 (0)