@@ -327,7 +327,7 @@ def _parse_xyz_header(self, filename: str) -> Tuple[float, List[float]]:
327327 energy = float (energy_match .group (1 ))
328328 cell = [float (cell_match .group (i )) for i in range (1 , 7 )]
329329 return energy , cell
330-
330+
331331 def _setup_calculator (self ):
332332 """Setup MACE calculator once and store it"""
333333 if hasattr (self , '_calculator' ):
@@ -345,16 +345,28 @@ def _setup_calculator(self):
345345 device_str = ',' .join (gpu_devices )
346346
347347 model_paths = self .config .get ('model_paths' , ['/path/to/your/model' ])
348- self ._calculator = MACECalculator (model_paths = model_paths , device = device_str )
348+
349+ # Get head from config, with default
350+ mace_head = self .config .get ('mace_head' , 'omat_pbe' )
351+
352+ self .logger .log (f"Setting up MACE calculator with head: { mace_head } " )
353+
354+ self ._calculator = MACECalculator (
355+ model_paths = model_paths ,
356+ device = device_str ,
357+ head = mace_head
358+ )
349359 self .atoms .calc = self ._calculator
350360
351- self .logger .log (f"MACE calculator setup: { device_str } " )
361+ self .logger .log (f"MACE calculator setup: device= { device_str } , head= { mace_head } " )
352362
353363 except ImportError :
354364 self .logger .log ("MACE not available, using EMT calculator" , 'WARNING' )
355365 from ase .calculators .emt import EMT
356366 self ._calculator = EMT ()
357367 self .atoms .calc = self ._calculator
368+
369+
358370
359371 def run_single_point (self ) -> bool :
360372 """Run single point energy calculation"""
@@ -535,7 +547,9 @@ def cell_callback():
535547 import traceback
536548 self .logger .log (traceback .format_exc (), 'ERROR' )
537549 return False
538-
550+
551+
552+
539553 def analyze_rmsd (self , temperature : float = None ) -> Dict [str , Any ]:
540554 """Calculate RMSD between initial and current structure"""
541555 if not self .config .get ('run_rmsd' , False ):
@@ -550,48 +564,91 @@ def analyze_rmsd(self, temperature: float = None) -> Dict[str, Any]:
550564 current_atoms = self .atoms
551565 initial_pos = self .initial_atoms .get_positions ()
552566 current_pos = current_atoms .get_positions ()
567+ initial_cell = self .initial_atoms .get_cell ()
568+ current_cell = current_atoms .get_cell ()
553569
554570 if len (initial_pos ) != len (current_pos ):
555571 self .logger .log (f"Structure size mismatch: { len (initial_pos )} vs { len (current_pos )} " , 'ERROR' )
556572 return {}
557573
558- # Handle PBC
574+ # Handle PBC for basic RMSD calculations
559575 if self .atoms .get_pbc ().any ():
560576 displacement = current_pos - initial_pos
561- displacement = find_mic (displacement , self .atoms .get_cell ())[0 ]
562- current_pos = initial_pos + displacement
577+ displacement = find_mic (displacement , current_cell )[0 ]
578+ current_pos_corrected = initial_pos + displacement
579+ else :
580+ current_pos_corrected = current_pos
581+
582+ # Calculate standard RMSD metrics (unit cell, Cartesian)
583+ rmsd_simple = self ._calculate_rmsd_simple (initial_pos , current_pos_corrected )
584+ rmsd_centered = self ._calculate_rmsd_centered (initial_pos , current_pos_corrected )
585+ rmsd_aligned = self ._calculate_rmsd_aligned (initial_pos , current_pos_corrected )
586+
587+ # Calculate fractional coordinate RMSD (accounts for cell changes)
588+ rmsd_fractional = self ._calculate_rmsd_fractional (
589+ initial_pos , current_pos , initial_cell , current_cell
590+ )
591+
592+ # Calculate supercell RMSD (like RMSD15, captures lattice parameter errors)
593+ supercell_dims = self .config .get ('rmsd_supercell_dims' , [3 , 3 , 3 ])
594+ rmsd_supercell = self ._calculate_rmsd_supercell (
595+ self .initial_atoms , current_atoms , supercell_dims
596+ )
563597
564- # Calculate different RMSD metrics
565- rmsd_simple = self ._calculate_rmsd_simple (initial_pos , current_pos )
566- rmsd_centered = self ._calculate_rmsd_centered (initial_pos , current_pos )
567- rmsd_aligned = self ._calculate_rmsd_aligned (initial_pos , current_pos )
598+ # Calculate cell parameter deviations
599+ cell_deviations = self ._calculate_cell_deviation (initial_cell , current_cell )
568600
569- # Per-atom analysis
601+ # Per-atom displacement analysis
570602 displacement = current_pos - initial_pos
571603 if self .atoms .get_pbc ().any ():
572- displacement = find_mic (displacement , self . atoms . get_cell () )[0 ]
604+ displacement = find_mic (displacement , current_cell )[0 ]
573605
574606 per_atom_displacements = np .sqrt (np .sum (displacement ** 2 , axis = 1 ))
575607
576608 analysis = {
609+ # Unit cell Cartesian RMSD (existing metrics)
577610 'rmsd_simple' : rmsd_simple ,
578611 'rmsd_centered' : rmsd_centered ,
579612 'rmsd_aligned' : rmsd_aligned ,
613+
614+ # New metrics that account for cell changes
615+ 'rmsd_fractional' : rmsd_fractional ,
616+ 'rmsd_supercell' : rmsd_supercell ,
617+ 'supercell_dims' : supercell_dims ,
618+
619+ # Cell parameter deviations
620+ 'cell_deviations' : cell_deviations ,
621+
622+ # Per-atom statistics
580623 'n_atoms' : len (initial_pos ),
581624 'max_displacement' : np .max (per_atom_displacements ),
582625 'mean_displacement' : np .mean (per_atom_displacements ),
583626 'std_displacement' : np .std (per_atom_displacements ),
584627 'per_atom_displacements' : per_atom_displacements
585628 }
586629
630+ # Logging
587631 self .logger .log (f"RMSD Analysis:" )
588- self .logger .log (f" Simple RMSD: { rmsd_simple :.4f} Å" )
589- self .logger .log (f" Centered RMSD: { rmsd_centered :.4f} Å" )
590- self .logger .log (f" Aligned RMSD: { rmsd_aligned :.4f} Å" )
591- self .logger .log (f" Max displacement: { analysis ['max_displacement' ]:.4f} Å" )
632+ self .logger .log (f" Unit Cell (Cartesian):" )
633+ self .logger .log (f" Simple RMSD: { rmsd_simple :.4f} Å" )
634+ self .logger .log (f" Centered RMSD: { rmsd_centered :.4f} Å" )
635+ self .logger .log (f" Aligned RMSD (Kabsch): { rmsd_aligned :.4f} Å" )
636+ self .logger .log (f" Fractional Coordinates:" )
637+ self .logger .log (f" RMSD (fractional): { rmsd_fractional :.6f} " )
638+ self .logger .log (f" Supercell { supercell_dims } :" )
639+ self .logger .log (f" RMSD (supercell): { rmsd_supercell :.4f} Å" )
640+ self .logger .log (f" Cell Parameters:" )
641+ self .logger .log (f" Δa: { cell_deviations ['delta_a' ]:.4f} Å" )
642+ self .logger .log (f" Δb: { cell_deviations ['delta_b' ]:.4f} Å" )
643+ self .logger .log (f" Δc: { cell_deviations ['delta_c' ]:.4f} Å" )
644+ self .logger .log (f" ΔV: { cell_deviations ['volume_change_percent' ]:.2f} %" )
645+ self .logger .log (f" Per-atom Displacement:" )
646+ self .logger .log (f" Max: { analysis ['max_displacement' ]:.4f} Å" )
647+ self .logger .log (f" Mean: { analysis ['mean_displacement' ]:.4f} Å" )
592648
593649 return analysis
594-
650+
651+
595652 def _calculate_rmsd_simple (self , pos1 : np .ndarray , pos2 : np .ndarray ) -> float :
596653 """Simple RMSD calculation"""
597654 return np .sqrt (np .mean (np .sum ((pos1 - pos2 )** 2 , axis = 1 )))
@@ -601,7 +658,7 @@ def _calculate_rmsd_centered(self, pos1: np.ndarray, pos2: np.ndarray) -> float:
601658 pos1_centered = pos1 - np .mean (pos1 , axis = 0 )
602659 pos2_centered = pos2 - np .mean (pos2 , axis = 0 )
603660 return self ._calculate_rmsd_simple (pos1_centered , pos2_centered )
604-
661+
605662 def _calculate_rmsd_aligned (self , pos1 : np .ndarray , pos2 : np .ndarray ) -> float :
606663 """RMSD after optimal alignment (Kabsch algorithm)"""
607664 try :
@@ -623,7 +680,54 @@ def _calculate_rmsd_aligned(self, pos1: np.ndarray, pos2: np.ndarray) -> float:
623680 except Exception as e :
624681 self .logger .log (f"Alignment failed, using centered RMSD: { e } " , 'WARNING' )
625682 return self ._calculate_rmsd_centered (pos1 , pos2 )
626-
683+
684+ def _calculate_rmsd_fractional (self , pos1 : np .ndarray , pos2 : np .ndarray ,
685+ cell1 : Cell , cell2 : Cell ) -> float :
686+ """RMSD in fractional coordinates (accounts for cell changes)"""
687+ # Convert Cartesian positions to fractional coordinates
688+ frac1 = cell1 .scaled_positions (pos1 )
689+ frac2 = cell2 .scaled_positions (pos2 )
690+
691+ # Handle PBC wrapping in fractional space
692+ frac_diff = frac2 - frac1
693+ frac_diff = frac_diff - np .round (frac_diff )
694+
695+ return np .sqrt (np .mean (np .sum (frac_diff ** 2 , axis = 1 )))
696+
697+ def _calculate_rmsd_supercell (self , atoms1 : Atoms , atoms2 : Atoms ,
698+ replicate : List [int ] = [3 , 3 , 3 ]) -> float :
699+ """RMSD on replicated supercell (like RMSD15, captures lattice errors)"""
700+ from ase .build import make_supercell
701+
702+ # Replicate both structures
703+ transform = np .diag (replicate )
704+ super1 = make_supercell (atoms1 .copy (), transform )
705+ super2 = make_supercell (atoms2 .copy (), transform )
706+
707+ pos1 = super1 .get_positions ()
708+ pos2 = super2 .get_positions ()
709+
710+ # Use Kabsch alignment on the supercell
711+ return self ._calculate_rmsd_aligned (pos1 , pos2 )
712+
713+ def _calculate_cell_deviation (self , cell1 : Cell , cell2 : Cell ) -> Dict [str , float ]:
714+ """Calculate cell parameter deviations"""
715+ params1 = cell1 .cellpar ()
716+ params2 = cell2 .cellpar ()
717+
718+ return {
719+ 'delta_a' : abs (params1 [0 ] - params2 [0 ]),
720+ 'delta_b' : abs (params1 [1 ] - params2 [1 ]),
721+ 'delta_c' : abs (params1 [2 ] - params2 [2 ]),
722+ 'delta_alpha' : abs (params1 [3 ] - params2 [3 ]),
723+ 'delta_beta' : abs (params1 [4 ] - params2 [4 ]),
724+ 'delta_gamma' : abs (params1 [5 ] - params2 [5 ]),
725+ 'delta_volume' : abs (cell1 .volume - cell2 .volume ),
726+ 'volume_change_percent' : 100 * abs (cell1 .volume - cell2 .volume ) / cell1 .volume if cell1 .volume > 0 else 0
727+ }
728+
729+
730+
627731 def calculate_energy_drift_statistics (self , times : np .ndarray , energies : np .ndarray ,
628732 temperatures : np .ndarray , start_time_ps : float = 0.0 ,
629733 timestep_fs : float = 0.5 ) -> Dict [str , Any ]:
@@ -1449,6 +1553,7 @@ def analyze_rdf_from_trajectory(self, temperature: float) -> Dict[str, Dict]:
14491553
14501554 return averaged_rdfs
14511555
1556+
14521557 def plot_rdf_comparison (self ) -> None :
14531558 """Plot RDF comparisons across temperatures"""
14541559 self .logger .log ("Plotting RDF comparisons..." )
@@ -1629,29 +1734,32 @@ def run_full_analysis(self, xyz_file: str) -> None:
16291734 self .plot_coordination_histogram (label = "initial" )
16301735
16311736 # MD simulations
1632- temperatures = self .config .get ('temperatures' , [300 ])
1633- md_steps = self .config .get ('md_steps' , 10000 )
1634-
1635- for temp in temperatures :
1636- self .logger .log_section (f"MD ANALYSIS AT { temp } K" )
1737+ if self .config .get ('md_simu' , True ): # Default to True for backward compatibility
1738+ temperatures = self .config .get ('temperatures' , [300 ])
1739+ md_steps = self .config .get ('md_steps' , 10000 )
16371740
1638- traj_data = self .run_md_simulation (temp , md_steps )
1639-
1640- if traj_data :
1641- # Energy drift (NVE only)
1642- drift_analysis = self .analyze_energy_drift (temp )
1643- if drift_analysis :
1644- self .results [f'energy_drift_{ temp } K' ] = drift_analysis
1645- self .plot_energy_drift (temp )
1741+ for temp in temperatures :
1742+ self .logger .log_section (f"MD ANALYSIS AT { temp } K" )
16461743
1647- # RDF analysis
1648- rdf_results = self .analyze_rdf_from_trajectory (temp )
1649- if rdf_results :
1650- self .results [f'rdf_{ temp } K' ] = rdf_results
1651-
1652- # Generate comparison plots
1653- self .plot_rdf_comparison ()
1744+ traj_data = self .run_md_simulation (temp , md_steps )
1745+
1746+ if traj_data :
1747+ # Energy drift (NVE only)
1748+ drift_analysis = self .analyze_energy_drift (temp )
1749+ if drift_analysis :
1750+ self .results [f'energy_drift_{ temp } K' ] = drift_analysis
1751+ self .plot_energy_drift (temp )
1752+
1753+ # RDF analysis
1754+ rdf_results = self .analyze_rdf_from_trajectory (temp )
1755+ if rdf_results :
1756+ self .results [f'rdf_{ temp } K' ] = rdf_results
16541757
1758+ # Generate comparison plots
1759+ self .plot_rdf_comparison ()
1760+ else :
1761+ self .logger .log_section ("MD SIMULATIONS SKIPPED" )
1762+ self .logger .log ("MD simulations disabled in configuration (md_simu: false)" )
16551763 # Print cache statistics
16561764 self .logger .log_section ("CACHE STATISTICS" )
16571765 self .cache .print_cache_summary (self .logger )
0 commit comments