@@ -569,24 +569,22 @@ def update(self, batch, output): # pylint: disable=arguments-differ
569569 self .pred_energies_per_atom .append (output ["energy" ] / atoms_per_config )
570570
571571 self .n_energy += filter_nonzero_weight (
572- batch , self .ref_energies , batch .weight , batch .energy_weight , "config"
572+ batch , self .ref_energies , batch .weight , batch .energy_weight ,
573573 )
574574 filter_nonzero_weight (
575- batch , self .pred_energies , batch .weight , batch .energy_weight , "config"
575+ batch , self .pred_energies , batch .weight , batch .energy_weight ,
576576 )
577577 filter_nonzero_weight (
578578 batch ,
579579 self .ref_energies_per_atom ,
580580 batch .weight ,
581581 batch .energy_weight ,
582- "config" ,
583582 )
584583 filter_nonzero_weight (
585584 batch ,
586585 self .pred_energies_per_atom ,
587586 batch .weight ,
588587 batch .energy_weight ,
589- "config" ,
590588 )
591589
592590 if output .get ("interaction_energy" ) is not None and batch .energy is not None :
@@ -608,28 +606,24 @@ def update(self, batch, output): # pylint: disable=arguments-differ
608606 self .ref_interaction_energies ,
609607 batch .weight ,
610608 batch .energy_weight ,
611- "config" ,
612609 )
613610 filter_nonzero_weight (
614611 batch ,
615612 self .pred_interaction_energies ,
616613 batch .weight ,
617614 batch .energy_weight ,
618- "config" ,
619615 )
620616 filter_nonzero_weight (
621617 batch ,
622618 self .ref_interaction_energies_per_atom ,
623619 batch .weight ,
624620 batch .energy_weight ,
625- "config" ,
626621 )
627622 filter_nonzero_weight (
628623 batch ,
629624 self .pred_interaction_energies_per_atom ,
630625 batch .weight ,
631626 batch .energy_weight ,
632- "config" ,
633627 )
634628
635629 # Forces
@@ -638,10 +632,10 @@ def update(self, batch, output): # pylint: disable=arguments-differ
638632 self .pred_forces .append (output ["forces" ])
639633
640634 self .n_forces += filter_nonzero_weight (
641- batch , self .ref_forces , batch .weight , batch .forces_weight , "atom"
635+ batch , self .ref_forces , batch .weight , batch .forces_weight , spread_atoms = True ,
642636 )
643637 filter_nonzero_weight (
644- batch , self .pred_forces , batch .weight , batch .forces_weight , "atom"
638+ batch , self .pred_forces , batch .weight , batch .forces_weight , spread_atoms = True ,
645639 )
646640
647641 # Stress
@@ -650,10 +644,10 @@ def update(self, batch, output): # pylint: disable=arguments-differ
650644 self .pred_stress .append (output ["stress" ])
651645
652646 self .n_stress += filter_nonzero_weight (
653- batch , self .ref_stress , batch .weight , batch .stress_weight , "config"
647+ batch , self .ref_stress , batch .weight , batch .stress_weight ,
654648 )
655649 filter_nonzero_weight (
656- batch , self .pred_stress , batch .weight , batch .stress_weight , "config"
650+ batch , self .pred_stress , batch .weight , batch .stress_weight ,
657651 )
658652
659653 # Virials
@@ -666,24 +660,22 @@ def update(self, batch, output): # pylint: disable=arguments-differ
666660 self .pred_virials_per_atom .append (output ["virials" ] / atoms_per_config_3d )
667661
668662 self .n_virials += filter_nonzero_weight (
669- batch , self .ref_virials , batch .weight , batch .virials_weight , "config"
663+ batch , self .ref_virials , batch .weight , batch .virials_weight ,
670664 )
671665 filter_nonzero_weight (
672- batch , self .pred_virials , batch .weight , batch .virials_weight , "config"
666+ batch , self .pred_virials , batch .weight , batch .virials_weight ,
673667 )
674668 filter_nonzero_weight (
675669 batch ,
676670 self .ref_virials_per_atom ,
677671 batch .weight ,
678672 batch .virials_weight ,
679- "config" ,
680673 )
681674 filter_nonzero_weight (
682675 batch ,
683676 self .pred_virials_per_atom ,
684677 batch .weight ,
685678 batch .virials_weight ,
686- "config" ,
687679 )
688680
689681 # Dipole
@@ -705,14 +697,14 @@ def update(self, batch, output): # pylint: disable=arguments-differ
705697 self .ref_dipole_per_atom ,
706698 batch .weight ,
707699 batch .dipole_weight ,
708- "config" ,
700+ spread_quantity_vector = False ,
709701 )
710702 filter_nonzero_weight (
711703 batch ,
712704 self .pred_dipole_per_atom ,
713705 batch .weight ,
714706 batch .dipole_weight ,
715- "config" ,
707+ spread_quantity_vector = False ,
716708 )
717709
718710 def _process_data (self , ref_list , pred_list ):
0 commit comments