Skip to content

Commit ace95f2

Browse files
committed
Allow no energy data for MACEField.
1 parent e2178d3 commit ace95f2

File tree

2 files changed

+10
-0
lines changed

2 files changed

+10
-0
lines changed

mace/cli/run_train.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,7 @@ def run(args) -> None:
309309
args.pseudolabel_replay
310310
and args.multiheads_finetuning
311311
and head_config.head_name == "pt_head"
312+
or args.model == "MACEField"
312313
),
313314
)
314315
head_config.collections = SubsetCollection(

mace/tools/scripts_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -515,6 +515,15 @@ def get_atomic_energies(E0s, train_collection, z_table) -> dict:
515515
atomic_energies_dict = data.compute_average_E0s(
516516
train_collection, z_table
517517
)
518+
atomic_energies_dict = {
519+
key:value
520+
if
521+
not np.isnan(value)
522+
else
523+
0.0
524+
for
525+
key, value in atomic_energies_dict.items()
526+
}
518527
except Exception as e:
519528
raise RuntimeError(
520529
f"Could not compute average E0s if no training xyz given, error {e} occured"

0 commit comments

Comments
 (0)