Skip to content

Commit 68fc87e

Browse files
authored
Improve MACE calculator to use mh1
self._calculator = MACECalculator( model_paths=model_paths, device=device_str, head=mace_head ) Available heads are: ['matpes_r2scan', 'mp_pbe_refit_add', 'spice_wB97M', 'oc20_usemppbe', 'omol', 'omat_pbe']
1 parent 91675a8 commit 68fc87e

File tree

1 file changed

+148
-40
lines changed

1 file changed

+148
-40
lines changed

amlpa.py

Lines changed: 148 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ def _parse_xyz_header(self, filename: str) -> Tuple[float, List[float]]:
327327
energy = float(energy_match.group(1))
328328
cell = [float(cell_match.group(i)) for i in range(1, 7)]
329329
return energy, cell
330-
330+
331331
def _setup_calculator(self):
332332
"""Setup MACE calculator once and store it"""
333333
if hasattr(self, '_calculator'):
@@ -345,16 +345,28 @@ def _setup_calculator(self):
345345
device_str = ','.join(gpu_devices)
346346

347347
model_paths = self.config.get('model_paths', ['/path/to/your/model'])
348-
self._calculator = MACECalculator(model_paths=model_paths, device=device_str)
348+
349+
# Get head from config, with default
350+
mace_head = self.config.get('mace_head', 'omat_pbe')
351+
352+
self.logger.log(f"Setting up MACE calculator with head: {mace_head}")
353+
354+
self._calculator = MACECalculator(
355+
model_paths=model_paths,
356+
device=device_str,
357+
head=mace_head
358+
)
349359
self.atoms.calc = self._calculator
350360

351-
self.logger.log(f"MACE calculator setup: {device_str}")
361+
self.logger.log(f"MACE calculator setup: device={device_str}, head={mace_head}")
352362

353363
except ImportError:
354364
self.logger.log("MACE not available, using EMT calculator", 'WARNING')
355365
from ase.calculators.emt import EMT
356366
self._calculator = EMT()
357367
self.atoms.calc = self._calculator
368+
369+
358370

359371
def run_single_point(self) -> bool:
360372
"""Run single point energy calculation"""
@@ -535,7 +547,9 @@ def cell_callback():
535547
import traceback
536548
self.logger.log(traceback.format_exc(), 'ERROR')
537549
return False
538-
550+
551+
552+
539553
def analyze_rmsd(self, temperature: float = None) -> Dict[str, Any]:
540554
"""Calculate RMSD between initial and current structure"""
541555
if not self.config.get('run_rmsd', False):
@@ -550,48 +564,91 @@ def analyze_rmsd(self, temperature: float = None) -> Dict[str, Any]:
550564
current_atoms = self.atoms
551565
initial_pos = self.initial_atoms.get_positions()
552566
current_pos = current_atoms.get_positions()
567+
initial_cell = self.initial_atoms.get_cell()
568+
current_cell = current_atoms.get_cell()
553569

554570
if len(initial_pos) != len(current_pos):
555571
self.logger.log(f"Structure size mismatch: {len(initial_pos)} vs {len(current_pos)}", 'ERROR')
556572
return {}
557573

558-
# Handle PBC
574+
# Handle PBC for basic RMSD calculations
559575
if self.atoms.get_pbc().any():
560576
displacement = current_pos - initial_pos
561-
displacement = find_mic(displacement, self.atoms.get_cell())[0]
562-
current_pos = initial_pos + displacement
577+
displacement = find_mic(displacement, current_cell)[0]
578+
current_pos_corrected = initial_pos + displacement
579+
else:
580+
current_pos_corrected = current_pos
581+
582+
# Calculate standard RMSD metrics (unit cell, Cartesian)
583+
rmsd_simple = self._calculate_rmsd_simple(initial_pos, current_pos_corrected)
584+
rmsd_centered = self._calculate_rmsd_centered(initial_pos, current_pos_corrected)
585+
rmsd_aligned = self._calculate_rmsd_aligned(initial_pos, current_pos_corrected)
586+
587+
# Calculate fractional coordinate RMSD (accounts for cell changes)
588+
rmsd_fractional = self._calculate_rmsd_fractional(
589+
initial_pos, current_pos, initial_cell, current_cell
590+
)
591+
592+
# Calculate supercell RMSD (like RMSD15, captures lattice parameter errors)
593+
supercell_dims = self.config.get('rmsd_supercell_dims', [3, 3, 3])
594+
rmsd_supercell = self._calculate_rmsd_supercell(
595+
self.initial_atoms, current_atoms, supercell_dims
596+
)
563597

564-
# Calculate different RMSD metrics
565-
rmsd_simple = self._calculate_rmsd_simple(initial_pos, current_pos)
566-
rmsd_centered = self._calculate_rmsd_centered(initial_pos, current_pos)
567-
rmsd_aligned = self._calculate_rmsd_aligned(initial_pos, current_pos)
598+
# Calculate cell parameter deviations
599+
cell_deviations = self._calculate_cell_deviation(initial_cell, current_cell)
568600

569-
# Per-atom analysis
601+
# Per-atom displacement analysis
570602
displacement = current_pos - initial_pos
571603
if self.atoms.get_pbc().any():
572-
displacement = find_mic(displacement, self.atoms.get_cell())[0]
604+
displacement = find_mic(displacement, current_cell)[0]
573605

574606
per_atom_displacements = np.sqrt(np.sum(displacement**2, axis=1))
575607

576608
analysis = {
609+
# Unit cell Cartesian RMSD (existing metrics)
577610
'rmsd_simple': rmsd_simple,
578611
'rmsd_centered': rmsd_centered,
579612
'rmsd_aligned': rmsd_aligned,
613+
614+
# New metrics that account for cell changes
615+
'rmsd_fractional': rmsd_fractional,
616+
'rmsd_supercell': rmsd_supercell,
617+
'supercell_dims': supercell_dims,
618+
619+
# Cell parameter deviations
620+
'cell_deviations': cell_deviations,
621+
622+
# Per-atom statistics
580623
'n_atoms': len(initial_pos),
581624
'max_displacement': np.max(per_atom_displacements),
582625
'mean_displacement': np.mean(per_atom_displacements),
583626
'std_displacement': np.std(per_atom_displacements),
584627
'per_atom_displacements': per_atom_displacements
585628
}
586629

630+
# Logging
587631
self.logger.log(f"RMSD Analysis:")
588-
self.logger.log(f" Simple RMSD: {rmsd_simple:.4f} Å")
589-
self.logger.log(f" Centered RMSD: {rmsd_centered:.4f} Å")
590-
self.logger.log(f" Aligned RMSD: {rmsd_aligned:.4f} Å")
591-
self.logger.log(f" Max displacement: {analysis['max_displacement']:.4f} Å")
632+
self.logger.log(f" Unit Cell (Cartesian):")
633+
self.logger.log(f" Simple RMSD: {rmsd_simple:.4f} Å")
634+
self.logger.log(f" Centered RMSD: {rmsd_centered:.4f} Å")
635+
self.logger.log(f" Aligned RMSD (Kabsch): {rmsd_aligned:.4f} Å")
636+
self.logger.log(f" Fractional Coordinates:")
637+
self.logger.log(f" RMSD (fractional): {rmsd_fractional:.6f}")
638+
self.logger.log(f" Supercell {supercell_dims}:")
639+
self.logger.log(f" RMSD (supercell): {rmsd_supercell:.4f} Å")
640+
self.logger.log(f" Cell Parameters:")
641+
self.logger.log(f" Δa: {cell_deviations['delta_a']:.4f} Å")
642+
self.logger.log(f" Δb: {cell_deviations['delta_b']:.4f} Å")
643+
self.logger.log(f" Δc: {cell_deviations['delta_c']:.4f} Å")
644+
self.logger.log(f" ΔV: {cell_deviations['volume_change_percent']:.2f}%")
645+
self.logger.log(f" Per-atom Displacement:")
646+
self.logger.log(f" Max: {analysis['max_displacement']:.4f} Å")
647+
self.logger.log(f" Mean: {analysis['mean_displacement']:.4f} Å")
592648

593649
return analysis
594-
650+
651+
595652
def _calculate_rmsd_simple(self, pos1: np.ndarray, pos2: np.ndarray) -> float:
596653
"""Simple RMSD calculation"""
597654
return np.sqrt(np.mean(np.sum((pos1 - pos2)**2, axis=1)))
@@ -601,7 +658,7 @@ def _calculate_rmsd_centered(self, pos1: np.ndarray, pos2: np.ndarray) -> float:
601658
pos1_centered = pos1 - np.mean(pos1, axis=0)
602659
pos2_centered = pos2 - np.mean(pos2, axis=0)
603660
return self._calculate_rmsd_simple(pos1_centered, pos2_centered)
604-
661+
605662
def _calculate_rmsd_aligned(self, pos1: np.ndarray, pos2: np.ndarray) -> float:
606663
"""RMSD after optimal alignment (Kabsch algorithm)"""
607664
try:
@@ -623,7 +680,54 @@ def _calculate_rmsd_aligned(self, pos1: np.ndarray, pos2: np.ndarray) -> float:
623680
except Exception as e:
624681
self.logger.log(f"Alignment failed, using centered RMSD: {e}", 'WARNING')
625682
return self._calculate_rmsd_centered(pos1, pos2)
626-
683+
684+
def _calculate_rmsd_fractional(self, pos1: np.ndarray, pos2: np.ndarray,
685+
cell1: Cell, cell2: Cell) -> float:
686+
"""RMSD in fractional coordinates (accounts for cell changes)"""
687+
# Convert Cartesian positions to fractional coordinates
688+
frac1 = cell1.scaled_positions(pos1)
689+
frac2 = cell2.scaled_positions(pos2)
690+
691+
# Handle PBC wrapping in fractional space
692+
frac_diff = frac2 - frac1
693+
frac_diff = frac_diff - np.round(frac_diff)
694+
695+
return np.sqrt(np.mean(np.sum(frac_diff**2, axis=1)))
696+
697+
def _calculate_rmsd_supercell(self, atoms1: Atoms, atoms2: Atoms,
698+
replicate: List[int] = [3, 3, 3]) -> float:
699+
"""RMSD on replicated supercell (like RMSD15, captures lattice errors)"""
700+
from ase.build import make_supercell
701+
702+
# Replicate both structures
703+
transform = np.diag(replicate)
704+
super1 = make_supercell(atoms1.copy(), transform)
705+
super2 = make_supercell(atoms2.copy(), transform)
706+
707+
pos1 = super1.get_positions()
708+
pos2 = super2.get_positions()
709+
710+
# Use Kabsch alignment on the supercell
711+
return self._calculate_rmsd_aligned(pos1, pos2)
712+
713+
def _calculate_cell_deviation(self, cell1: Cell, cell2: Cell) -> Dict[str, float]:
714+
"""Calculate cell parameter deviations"""
715+
params1 = cell1.cellpar()
716+
params2 = cell2.cellpar()
717+
718+
return {
719+
'delta_a': abs(params1[0] - params2[0]),
720+
'delta_b': abs(params1[1] - params2[1]),
721+
'delta_c': abs(params1[2] - params2[2]),
722+
'delta_alpha': abs(params1[3] - params2[3]),
723+
'delta_beta': abs(params1[4] - params2[4]),
724+
'delta_gamma': abs(params1[5] - params2[5]),
725+
'delta_volume': abs(cell1.volume - cell2.volume),
726+
'volume_change_percent': 100 * abs(cell1.volume - cell2.volume) / cell1.volume if cell1.volume > 0 else 0
727+
}
728+
729+
730+
627731
def calculate_energy_drift_statistics(self, times: np.ndarray, energies: np.ndarray,
628732
temperatures: np.ndarray, start_time_ps: float = 0.0,
629733
timestep_fs: float = 0.5) -> Dict[str, Any]:
@@ -1449,6 +1553,7 @@ def analyze_rdf_from_trajectory(self, temperature: float) -> Dict[str, Dict]:
14491553

14501554
return averaged_rdfs
14511555

1556+
14521557
def plot_rdf_comparison(self) -> None:
14531558
"""Plot RDF comparisons across temperatures"""
14541559
self.logger.log("Plotting RDF comparisons...")
@@ -1629,29 +1734,32 @@ def run_full_analysis(self, xyz_file: str) -> None:
16291734
self.plot_coordination_histogram(label="initial")
16301735

16311736
# MD simulations
1632-
temperatures = self.config.get('temperatures', [300])
1633-
md_steps = self.config.get('md_steps', 10000)
1634-
1635-
for temp in temperatures:
1636-
self.logger.log_section(f"MD ANALYSIS AT {temp}K")
1737+
if self.config.get('md_simu', True): # Default to True for backward compatibility
1738+
temperatures = self.config.get('temperatures', [300])
1739+
md_steps = self.config.get('md_steps', 10000)
16371740

1638-
traj_data = self.run_md_simulation(temp, md_steps)
1639-
1640-
if traj_data:
1641-
# Energy drift (NVE only)
1642-
drift_analysis = self.analyze_energy_drift(temp)
1643-
if drift_analysis:
1644-
self.results[f'energy_drift_{temp}K'] = drift_analysis
1645-
self.plot_energy_drift(temp)
1741+
for temp in temperatures:
1742+
self.logger.log_section(f"MD ANALYSIS AT {temp}K")
16461743

1647-
# RDF analysis
1648-
rdf_results = self.analyze_rdf_from_trajectory(temp)
1649-
if rdf_results:
1650-
self.results[f'rdf_{temp}K'] = rdf_results
1651-
1652-
# Generate comparison plots
1653-
self.plot_rdf_comparison()
1744+
traj_data = self.run_md_simulation(temp, md_steps)
1745+
1746+
if traj_data:
1747+
# Energy drift (NVE only)
1748+
drift_analysis = self.analyze_energy_drift(temp)
1749+
if drift_analysis:
1750+
self.results[f'energy_drift_{temp}K'] = drift_analysis
1751+
self.plot_energy_drift(temp)
1752+
1753+
# RDF analysis
1754+
rdf_results = self.analyze_rdf_from_trajectory(temp)
1755+
if rdf_results:
1756+
self.results[f'rdf_{temp}K'] = rdf_results
16541757

1758+
# Generate comparison plots
1759+
self.plot_rdf_comparison()
1760+
else:
1761+
self.logger.log_section("MD SIMULATIONS SKIPPED")
1762+
self.logger.log("MD simulations disabled in configuration (md_simu: false)")
16551763
# Print cache statistics
16561764
self.logger.log_section("CACHE STATISTICS")
16571765
self.cache.print_cache_summary(self.logger)

0 commit comments

Comments
 (0)