Skip to content

Commit b577790

Browse files
committed
Update filter_nonzero_weight calls in visualise_train.py to correct API
1 parent c53584c commit b577790

File tree

2 files changed

+15
-23
lines changed

2 files changed

+15
-23
lines changed

mace/cli/visualise_train.py

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -569,24 +569,22 @@ def update(self, batch, output): # pylint: disable=arguments-differ
569569
self.pred_energies_per_atom.append(output["energy"] / atoms_per_config)
570570

571571
self.n_energy += filter_nonzero_weight(
572-
batch, self.ref_energies, batch.weight, batch.energy_weight, "config"
572+
batch, self.ref_energies, batch.weight, batch.energy_weight,
573573
)
574574
filter_nonzero_weight(
575-
batch, self.pred_energies, batch.weight, batch.energy_weight, "config"
575+
batch, self.pred_energies, batch.weight, batch.energy_weight,
576576
)
577577
filter_nonzero_weight(
578578
batch,
579579
self.ref_energies_per_atom,
580580
batch.weight,
581581
batch.energy_weight,
582-
"config",
583582
)
584583
filter_nonzero_weight(
585584
batch,
586585
self.pred_energies_per_atom,
587586
batch.weight,
588587
batch.energy_weight,
589-
"config",
590588
)
591589

592590
if output.get("interaction_energy") is not None and batch.energy is not None:
@@ -608,28 +606,24 @@ def update(self, batch, output): # pylint: disable=arguments-differ
608606
self.ref_interaction_energies,
609607
batch.weight,
610608
batch.energy_weight,
611-
"config",
612609
)
613610
filter_nonzero_weight(
614611
batch,
615612
self.pred_interaction_energies,
616613
batch.weight,
617614
batch.energy_weight,
618-
"config",
619615
)
620616
filter_nonzero_weight(
621617
batch,
622618
self.ref_interaction_energies_per_atom,
623619
batch.weight,
624620
batch.energy_weight,
625-
"config",
626621
)
627622
filter_nonzero_weight(
628623
batch,
629624
self.pred_interaction_energies_per_atom,
630625
batch.weight,
631626
batch.energy_weight,
632-
"config",
633627
)
634628

635629
# Forces
@@ -638,10 +632,10 @@ def update(self, batch, output): # pylint: disable=arguments-differ
638632
self.pred_forces.append(output["forces"])
639633

640634
self.n_forces += filter_nonzero_weight(
641-
batch, self.ref_forces, batch.weight, batch.forces_weight, "atom"
635+
batch, self.ref_forces, batch.weight, batch.forces_weight, spread_atoms=True,
642636
)
643637
filter_nonzero_weight(
644-
batch, self.pred_forces, batch.weight, batch.forces_weight, "atom"
638+
batch, self.pred_forces, batch.weight, batch.forces_weight, spread_atoms=True,
645639
)
646640

647641
# Stress
@@ -650,10 +644,10 @@ def update(self, batch, output): # pylint: disable=arguments-differ
650644
self.pred_stress.append(output["stress"])
651645

652646
self.n_stress += filter_nonzero_weight(
653-
batch, self.ref_stress, batch.weight, batch.stress_weight, "config"
647+
batch, self.ref_stress, batch.weight, batch.stress_weight,
654648
)
655649
filter_nonzero_weight(
656-
batch, self.pred_stress, batch.weight, batch.stress_weight, "config"
650+
batch, self.pred_stress, batch.weight, batch.stress_weight,
657651
)
658652

659653
# Virials
@@ -666,24 +660,22 @@ def update(self, batch, output): # pylint: disable=arguments-differ
666660
self.pred_virials_per_atom.append(output["virials"] / atoms_per_config_3d)
667661

668662
self.n_virials += filter_nonzero_weight(
669-
batch, self.ref_virials, batch.weight, batch.virials_weight, "config"
663+
batch, self.ref_virials, batch.weight, batch.virials_weight,
670664
)
671665
filter_nonzero_weight(
672-
batch, self.pred_virials, batch.weight, batch.virials_weight, "config"
666+
batch, self.pred_virials, batch.weight, batch.virials_weight,
673667
)
674668
filter_nonzero_weight(
675669
batch,
676670
self.ref_virials_per_atom,
677671
batch.weight,
678672
batch.virials_weight,
679-
"config",
680673
)
681674
filter_nonzero_weight(
682675
batch,
683676
self.pred_virials_per_atom,
684677
batch.weight,
685678
batch.virials_weight,
686-
"config",
687679
)
688680

689681
# Dipole
@@ -705,14 +697,14 @@ def update(self, batch, output): # pylint: disable=arguments-differ
705697
self.ref_dipole_per_atom,
706698
batch.weight,
707699
batch.dipole_weight,
708-
"config",
700+
spread_quantity_vector=False,
709701
)
710702
filter_nonzero_weight(
711703
batch,
712704
self.pred_dipole_per_atom,
713705
batch.weight,
714706
batch.dipole_weight,
715-
"config",
707+
spread_quantity_vector=False,
716708
)
717709

718710
def _process_data(self, ref_list, pred_list):

mace/tools/train.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -611,7 +611,7 @@ def update(self, batch, output): # pylint: disable=arguments-differ
611611
)
612612
self.E_computed += filter_nonzero_weight(
613613
batch, self.delta_es, batch.weight, batch.energy_weight
614-
) # DEBUG , label="delta_es")
614+
)
615615
if output.get("forces") is not None and batch.forces is not None:
616616
self.fs.append(batch.forces)
617617
self.delta_fs.append(batch.forces - output["forces"])
@@ -621,12 +621,12 @@ def update(self, batch, output): # pylint: disable=arguments-differ
621621
batch.weight,
622622
batch.forces_weight,
623623
spread_atoms=True,
624-
) # DEBUG , label="delta_fs")
624+
)
625625
if output.get("stress") is not None and batch.stress is not None:
626626
self.delta_stress.append(batch.stress - output["stress"])
627627
self.stress_computed += filter_nonzero_weight(
628628
batch, self.delta_stress, batch.weight, batch.stress_weight
629-
) # DEBUG , label="delta_stress")
629+
)
630630
if output.get("virials") is not None and batch.virials is not None:
631631
self.delta_virials.append(batch.virials - output["virials"])
632632
self.delta_virials_per_atom.append(
@@ -635,7 +635,7 @@ def update(self, batch, output): # pylint: disable=arguments-differ
635635
)
636636
self.virials_computed += filter_nonzero_weight(
637637
batch, self.delta_virials, batch.weight, batch.virials_weight
638-
) # DEBUG , label="delta_virials")
638+
)
639639
if output.get("dipole") is not None and batch.dipole is not None:
640640
self.mus.append(batch.dipole)
641641
self.delta_mus.append(batch.dipole - output["dipole"])
@@ -649,7 +649,7 @@ def update(self, batch, output): # pylint: disable=arguments-differ
649649
batch.weight,
650650
batch.dipole_weight,
651651
spread_quantity_vector=False,
652-
) # DEBUG , label="delta_mus")
652+
)
653653
if (
654654
output.get("polarizability") is not None
655655
and batch.polarizability is not None

0 commit comments

Comments
 (0)