Skip to content

Commit 7773adb

Browse files
committed
referenced botnet with Ezero
1 parent c43788b commit 7773adb

File tree

3 files changed

+11
-4
lines changed

3 files changed

+11
-4
lines changed

botnet/modules/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from .loss import EnergyForcesLoss, ACELoss, WeightedEnergyForcesLoss
88
from .models import (BodyOrderedModel, ScaleShiftBodyOrderedModel, SingleReadoutModel,
99
ScaleShiftNonLinearBodyOrderedModel, ScaleShiftSingleReadoutModel,
10-
ScaleShiftNonLinearSingleReadoutModel)
10+
ScaleShiftNonLinearSingleReadoutModel, NonLinearBodyOrderedModel)
1111
from .radial import BesselBasis, PolynomialCutoff
1212
from .utils import compute_mean_std_atomic_inter_energy, compute_mean_rms_energy_forces, compute_avg_num_neighbors
1313

@@ -30,6 +30,6 @@
3030
'AtomicEnergiesBlock', 'RadialEmbeddingBlock', 'LinearReadoutBlock', 'SimpleInteractionBlock', 'PolynomialCutoff',
3131
'AgnosticNoScNonlinearInteractionBlock','BesselBasis', 'EnergyForcesLoss', 'ACELoss', 'WeightedEnergyForcesLoss',
3232
'interaction_classes', 'InteractionBlock','BodyOrderedModel', 'ScaleShiftBodyOrderedModel', 'SingleReadoutModel',
33-
'ScaleShiftSingleReadoutModel','ScaleShiftNonLinearSingleReadoutModel', 'compute_mean_std_atomic_inter_energy',
34-
'compute_avg_num_neighbors',
33+
'ScaleShiftSingleReadoutModel','ScaleShiftNonLinearSingleReadoutModel', 'NonLinearBodyOrderedModel',
34+
'ScaleShiftNonLinearBodyOrderedModel','compute_mean_std_atomic_inter_energy', 'compute_avg_num_neighbors',
3535
]

botnet/tools/arg_parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def build_default_arg_parser() -> argparse.ArgumentParser:
3030
default='body_ordered',
3131
choices=[
3232
'body_ordered', 'scale_shift', 'single_readout', 'scale_shift_non_linear',
33-
'scale_shift_single_readout', 'scale_shift_non_linear_single_readout',
33+
'scale_shift_single_readout', 'scale_shift_non_linear_single_readout', 'body_ordered_non_linear'
3434
])
3535
parser.add_argument('--r_max', help='distance cutoff (in Ang)', type=float, default=4.0)
3636
parser.add_argument('--num_radial_basis', help='number of radial basis functions', type=int, default=8)

scripts/run_train.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,13 @@ def main() -> None:
181181
atomic_inter_scale=std,
182182
atomic_inter_shift=mean,
183183
)
184+
elif args.model == 'body_ordered_non_linear':
185+
model = modules.NonLinearBodyOrderedModel(
186+
**model_config,
187+
gate=gate_dict[args.gate],
188+
interaction_cls_first = modules.interaction_classes[args.interaction_first],
189+
MLP_irreps=o3.Irreps(args.MLP_irreps),
190+
)
184191
elif args.model == 'scale_shift_non_linear_single_readout':
185192
mean, std = modules.scaling_classes[args.scaling](train_loader, atomic_energies)
186193
model = modules.ScaleShiftNonLinearSingleReadoutModel(

0 commit comments

Comments
 (0)