diff --git a/examples/scripts/2_Structural_optimization/2.9_Batched_MACE_FrechetCellFilter_FIRE.py b/examples/scripts/2_Structural_optimization/2.10_Batched_MACE_FrechetCellFilter_FIRE.py similarity index 100% rename from examples/scripts/2_Structural_optimization/2.9_Batched_MACE_FrechetCellFilter_FIRE.py rename to examples/scripts/2_Structural_optimization/2.10_Batched_MACE_FrechetCellFilter_FIRE.py diff --git a/examples/scripts/2_Structural_optimization/2.7_Batched_MACE_FIRE.py b/examples/scripts/2_Structural_optimization/2.7_Batched_MACE_FIRE.py new file mode 100644 index 00000000..e3bf7ed5 --- /dev/null +++ b/examples/scripts/2_Structural_optimization/2.7_Batched_MACE_FIRE.py @@ -0,0 +1,95 @@ +"""Batched MACE FIRE optimizer.""" + +# /// script +# dependencies = [ +# "mace-torch>=0.3.11", +# ] +# /// + +import os + +import numpy as np +import torch +from ase.build import bulk +from mace.calculators.foundations_models import mace_mp + +import torch_sim as ts +from torch_sim.optimizers import fire + + +# Set device and data type +device = "cuda" if torch.cuda.is_available() else "cpu" +dtype = torch.float32 + +# Option 1: Load the raw model from the downloaded model +mace_checkpoint_url = "https://github.com/ACEsuit/mace-mp/releases/download/mace_mpa_0/mace-mpa-0-medium.model" +loaded_model = mace_mp( + model=mace_checkpoint_url, + return_raw_model=True, + default_dtype=dtype, + device=device, +) + +# Option 2: Load from local file (comment out Option 1 to use this) +# MODEL_PATH = "../../../checkpoints/MACE/mace-mpa-0-medium.model" +# loaded_model = torch.load(MODEL_PATH, map_location=device) + +# Number of steps to run +N_steps = 10 if os.getenv("CI") else 500 + +# Set random seed for reproducibility +rng = np.random.default_rng(seed=0) + +# Create diamond cubic Silicon +si_dc = bulk("Si", "diamond", a=5.21, cubic=True).repeat((2, 2, 2)) +si_dc.positions += 0.2 * rng.standard_normal(si_dc.positions.shape) + +# Create FCC Copper +cu_dc = bulk("Cu", "fcc", a=3.85).repeat((2, 2, 2)) +cu_dc.positions += 0.2 * rng.standard_normal(cu_dc.positions.shape) + +# Create BCC Iron +fe_dc = bulk("Fe", "bcc", a=2.95).repeat((2, 2, 2)) +fe_dc.positions += 0.2 * rng.standard_normal(fe_dc.positions.shape) + +# Create a list of our atomic systems +atoms_list = [si_dc, cu_dc, fe_dc] + +# Print structure information +print(f"Silicon atoms: {len(si_dc)}") +print(f"Copper atoms: {len(cu_dc)}") +print(f"Iron atoms: {len(fe_dc)}") +print(f"Total number of structures: {len(atoms_list)}") + +# Create batched model +model = ts.models.MaceModel( + model=loaded_model, + device=device, + compute_forces=True, + compute_stress=True, + dtype=dtype, + enable_cueq=False, +) + +# Convert atoms to state +state = ts.state.atoms_to_state(atoms_list, device=device, dtype=dtype) +# Run initial inference +results = model(state) + +# Initialize unit cell gradient descent optimizer +init_fn, update_fn = fire( + model=model, +) + +state = init_fn(state) + +# Run optimization for a few steps +print("\nRunning FIRE:") +for step in range(N_steps): + if step % 20 == 0: + print(f"Step {step}, Energy: {[energy.item() for energy in state.energy]}") + + state = update_fn(state) + +print(f"Initial energies: {[energy.item() for energy in results['energy']]} eV") +print(f"Final energies: {[energy.item() for energy in state.energy]} eV") diff --git a/examples/scripts/2_Structural_optimization/2.7_Batched_MACE_UnitCellFilter_Gradient_Descent.py b/examples/scripts/2_Structural_optimization/2.8_Batched_MACE_UnitCellFilter_Gradient_Descent.py similarity index 100% rename from examples/scripts/2_Structural_optimization/2.7_Batched_MACE_UnitCellFilter_Gradient_Descent.py rename to examples/scripts/2_Structural_optimization/2.8_Batched_MACE_UnitCellFilter_Gradient_Descent.py diff --git a/examples/scripts/2_Structural_optimization/2.8_Batched_MACE_UnitCellFilter_FIRE.py b/examples/scripts/2_Structural_optimization/2.9_Batched_MACE_UnitCellFilter_FIRE.py similarity index 100% rename from examples/scripts/2_Structural_optimization/2.8_Batched_MACE_UnitCellFilter_FIRE.py rename to examples/scripts/2_Structural_optimization/2.9_Batched_MACE_UnitCellFilter_FIRE.py diff --git a/tests/test_optimizers.py b/tests/test_optimizers.py index ae9b9c97..5547fc78 100644 --- a/tests/test_optimizers.py +++ b/tests/test_optimizers.py @@ -3,6 +3,7 @@ import torch from torch_sim.optimizers import ( + fire, frechet_cell_fire, gradient_descent, unit_cell_fire, @@ -101,6 +102,49 @@ def test_unit_cell_gradient_descent_optimization( assert not torch.allclose(state.cell, initial_state.cell) +def test_fire_optimization( + ar_supercell_sim_state: SimState, lj_model: torch.nn.Module +) -> None: + """Test that the FIRE optimizer actually minimizes energy.""" + # Add some random displacement to positions + perturbed_positions = ( + ar_supercell_sim_state.positions + + torch.randn_like(ar_supercell_sim_state.positions) * 0.1 + ) + + ar_supercell_sim_state.positions = perturbed_positions + initial_state = ar_supercell_sim_state + + # Initialize FIRE optimizer + init_fn, update_fn = fire( + model=lj_model, + dt_max=0.3, + dt_start=0.1, + ) + + state = init_fn(ar_supercell_sim_state) + + # Run optimization for a few steps + energies = [1000, state.energy.item()] + while abs(energies[-2] - energies[-1]) > 1e-6: + state = update_fn(state) + energies.append(state.energy.item()) + + energies = energies[1:] + + # Check that energy decreased + assert energies[-1] < energies[0], ( + f"FIRE optimization should reduce energy " + f"(initial: {energies[0]}, final: {energies[-1]})" + ) + + # Check force convergence + max_force = torch.max(torch.norm(state.forces, dim=1)) + assert max_force < 0.2, f"Forces should be small after optimization (got {max_force})" + + assert not torch.allclose(state.positions, initial_state.positions) + + def test_unit_cell_fire_optimization( ar_supercell_sim_state: SimState, lj_model: torch.nn.Module ) -> None: @@ -197,6 +241,177 @@ def test_unit_cell_frechet_fire_optimization( assert not torch.allclose(state.cell, initial_state.cell) +def test_fire_multi_batch( + ar_supercell_sim_state: SimState, lj_model: torch.nn.Module +) -> None: + """Test FIRE optimization with multiple batches.""" + # Create a multi-batch system by duplicating ar_fcc_state + + generator = torch.Generator(device=ar_supercell_sim_state.device) + + ar_supercell_sim_state_1 = copy.deepcopy(ar_supercell_sim_state) + ar_supercell_sim_state_2 = copy.deepcopy(ar_supercell_sim_state) + + for state in [ar_supercell_sim_state_1, ar_supercell_sim_state_2]: + generator.manual_seed(43) + state.positions += ( + torch.randn( + state.positions.shape, + device=state.device, + generator=generator, + ) + * 0.1 + ) + + multi_state = concatenate_states( + [ar_supercell_sim_state_1, ar_supercell_sim_state_2], + device=ar_supercell_sim_state.device, + ) + + # Initialize FIRE optimizer + init_fn, update_fn = fire( + model=lj_model, + dt_max=0.3, + dt_start=0.1, + ) + + state = init_fn(multi_state) + initial_state = copy.deepcopy(state) + + # Run optimization for a few steps + prev_energy = torch.ones(2, device=state.device, dtype=state.energy.dtype) * 1000 + current_energy = initial_state.energy + step = 0 + while not torch.allclose(current_energy, prev_energy, atol=1e-9): + prev_energy = current_energy + state = update_fn(state) + current_energy = state.energy + + step += 1 + if step > 500: + raise ValueError("Optimization did not converge") + + # check that we actually optimized + assert step > 10 + + # Check that energy decreased for both batches + assert torch.all(state.energy < initial_state.energy), ( + "FIRE optimization should reduce energy for all batches" + ) + + # transfer the energy and force checks to the batched optimizer + max_force = torch.max(torch.norm(state.forces, dim=1)) + assert torch.all(max_force < 0.1), ( + f"Forces should be small after optimization (got {max_force})" + ) + + n_ar_atoms = ar_supercell_sim_state.n_atoms + assert not torch.allclose( + state.positions[:n_ar_atoms], multi_state.positions[:n_ar_atoms] + ) + assert not torch.allclose( + state.positions[n_ar_atoms:], multi_state.positions[n_ar_atoms:] + ) + + # we are evolving identical systems + assert current_energy[0] == current_energy[1] + + +def test_fire_batch_consistency( + ar_supercell_sim_state: SimState, lj_model: torch.nn.Module +) -> None: + """Test batched FIRE optimization is consistent with individual optimizations.""" + generator = torch.Generator(device=ar_supercell_sim_state.device) + + ar_supercell_sim_state_1 = copy.deepcopy(ar_supercell_sim_state) + ar_supercell_sim_state_2 = copy.deepcopy(ar_supercell_sim_state) + + # Add same random perturbation to both states + for state in [ar_supercell_sim_state_1, ar_supercell_sim_state_2]: + generator.manual_seed(43) + state.positions += ( + torch.randn( + state.positions.shape, + device=state.device, + generator=generator, + ) + * 0.1 + ) + + # Optimize each state individually + final_individual_states = [] + total_steps = [] + + def energy_converged(current_energy: float, prev_energy: float) -> bool: + """Check if optimization should continue based on energy convergence.""" + return not torch.allclose(current_energy, prev_energy, atol=1e-6) + + for state in [ar_supercell_sim_state_1, ar_supercell_sim_state_2]: + init_fn, update_fn = fire( + model=lj_model, + dt_max=0.3, + dt_start=0.1, + ) + + state_opt = init_fn(state) + + # Run optimization until convergence + current_energy = state_opt.energy + prev_energy = current_energy + 1 + + step = 0 + while energy_converged(current_energy, prev_energy): + prev_energy = current_energy + state_opt = update_fn(state_opt) + current_energy = state_opt.energy + step += 1 + if step > 1000: + raise ValueError("Optimization did not converge") + + final_individual_states.append(state_opt) + total_steps.append(step) + + # Now optimize both states together in a batch + multi_state = concatenate_states( + [ + copy.deepcopy(ar_supercell_sim_state_1), + copy.deepcopy(ar_supercell_sim_state_2), + ], + device=ar_supercell_sim_state.device, + ) + + init_fn, batch_update_fn = fire( + model=lj_model, + dt_max=0.3, + dt_start=0.1, + ) + + batch_state = init_fn(multi_state) + + # Run optimization until convergence for both batches + current_energies = batch_state.energy.clone() + prev_energies = current_energies + 1 + + step = 0 + while energy_converged(current_energies[0], prev_energies[0]) and energy_converged( + current_energies[1], prev_energies[1] + ): + prev_energies = current_energies.clone() + batch_state = batch_update_fn(batch_state) + current_energies = batch_state.energy.clone() + step += 1 + if step > 1000: + raise ValueError("Optimization did not converge") + + individual_energies = [state.energy.item() for state in final_individual_states] + # Check that final energies from batched optimization match individual optimizations + for step, individual_energy in enumerate(individual_energies): + assert abs(batch_state.energy[step].item() - individual_energy) < 1e-4, ( + f"Energy for batch {step} doesn't match individual optimization: " + f"batch={batch_state.energy[step].item()}, individual={individual_energy}" + ) + + def test_unit_cell_fire_multi_batch( ar_supercell_sim_state: SimState, lj_model: torch.nn.Module ) -> None: @@ -563,3 +778,213 @@ def energy_converged(current_energy: float, prev_energy: float) -> bool: f"Energy for batch {step} doesn't match individual optimization: " f"batch={batch_state.energy[step].item()}, individual={individual_energy}" ) + + +def test_fire_fixed_cell_frechet_consistency( # noqa: C901 + ar_supercell_sim_state: SimState, lj_model: torch.nn.Module +) -> None: + """Test batched Frechet Fixed cell FIRE optimization is + consistent with FIRE (position only) optimizations.""" + generator = torch.Generator(device=ar_supercell_sim_state.device) + + ar_supercell_sim_state_1 = copy.deepcopy(ar_supercell_sim_state) + ar_supercell_sim_state_2 = copy.deepcopy(ar_supercell_sim_state) + + # Add same random perturbation to both states + for state in [ar_supercell_sim_state_1, ar_supercell_sim_state_2]: + generator.manual_seed(43) + state.positions += ( + torch.randn( + state.positions.shape, + device=state.device, + generator=generator, + ) + * 0.1 + ) + + # Optimize each state individually + final_individual_states_frechet = [] + total_steps_frechet = [] + + def energy_converged(current_energy: float, prev_energy: float) -> bool: + """Check if optimization should continue based on energy convergence.""" + return not torch.allclose(current_energy, prev_energy, atol=1e-6) + + for state in [ar_supercell_sim_state_1, ar_supercell_sim_state_2]: + init_fn, update_fn = unit_cell_fire( + model=lj_model, + dt_max=0.3, + dt_start=0.1, + hydrostatic_strain=True, + constant_volume=True, + ) + + state_opt = init_fn(state) + + # Run optimization until convergence + current_energy = state_opt.energy + prev_energy = current_energy + 1 + + step = 0 + while energy_converged(current_energy, prev_energy): + prev_energy = current_energy + state_opt = update_fn(state_opt) + current_energy = state_opt.energy + step += 1 + if step > 1000: + raise ValueError("Optimization did not converge") + + final_individual_states_frechet.append(state_opt) + total_steps_frechet.append(step) + + # Optimize each state individually + final_individual_states_fire = [] + total_steps_fire = [] + + def energy_converged(current_energy: float, prev_energy: float) -> bool: + """Check if optimization should continue based on energy convergence.""" + return not torch.allclose(current_energy, prev_energy, atol=1e-6) + + for state in [ar_supercell_sim_state_1, ar_supercell_sim_state_2]: + init_fn, update_fn = fire( + model=lj_model, + dt_max=0.3, + dt_start=0.1, + ) + + state_opt = init_fn(state) + + # Run optimization until convergence + current_energy = state_opt.energy + prev_energy = current_energy + 1 + + step = 0 + while energy_converged(current_energy, prev_energy): + prev_energy = current_energy + state_opt = update_fn(state_opt) + current_energy = state_opt.energy + step += 1 + if step > 1000: + raise ValueError("Optimization did not converge") + + final_individual_states_fire.append(state_opt) + total_steps_fire.append(step) + + individual_energies_frechet = [ + state.energy.item() for state in final_individual_states_frechet + ] + individual_energies_fire = [ + state.energy.item() for state in final_individual_states_fire + ] + # Check that final energies from fixed cell optimization match + # position only optimizations + for step, energy_frechet in enumerate(individual_energies_frechet): + assert abs(energy_frechet - individual_energies_fire[step]) < 1e-4, ( + f"Energy for batch {step} doesn't match position only optimization: " + f"batch={energy_frechet}, individual={individual_energies_fire[step]}" + ) + + +def test_fire_fixed_cell_unit_cell_consistency( # noqa: C901 + ar_supercell_sim_state: SimState, lj_model: torch.nn.Module +) -> None: + """Test batched Frechet Fixed cell FIRE optimization is + consistent with FIRE (position only) optimizations.""" + generator = torch.Generator(device=ar_supercell_sim_state.device) + + ar_supercell_sim_state_1 = copy.deepcopy(ar_supercell_sim_state) + ar_supercell_sim_state_2 = copy.deepcopy(ar_supercell_sim_state) + + # Add same random perturbation to both states + for state in [ar_supercell_sim_state_1, ar_supercell_sim_state_2]: + generator.manual_seed(43) + state.positions += ( + torch.randn( + state.positions.shape, + device=state.device, + generator=generator, + ) + * 0.1 + ) + + # Optimize each state individually + final_individual_states_unit_cell = [] + total_steps_unit_cell = [] + + def energy_converged(current_energy: float, prev_energy: float) -> bool: + """Check if optimization should continue based on energy convergence.""" + return not torch.allclose(current_energy, prev_energy, atol=1e-6) + + for state in [ar_supercell_sim_state_1, ar_supercell_sim_state_2]: + init_fn, update_fn = unit_cell_fire( + model=lj_model, + dt_max=0.3, + dt_start=0.1, + hydrostatic_strain=True, + constant_volume=True, + ) + + state_opt = init_fn(state) + + # Run optimization until convergence + current_energy = state_opt.energy + prev_energy = current_energy + 1 + + step = 0 + while energy_converged(current_energy, prev_energy): + prev_energy = current_energy + state_opt = update_fn(state_opt) + current_energy = state_opt.energy + step += 1 + if step > 1000: + raise ValueError("Optimization did not converge") + + final_individual_states_unit_cell.append(state_opt) + total_steps_unit_cell.append(step) + + # Optimize each state individually + final_individual_states_fire = [] + total_steps_fire = [] + + def energy_converged(current_energy: float, prev_energy: float) -> bool: + """Check if optimization should continue based on energy convergence.""" + return not torch.allclose(current_energy, prev_energy, atol=1e-6) + + for state in [ar_supercell_sim_state_1, ar_supercell_sim_state_2]: + init_fn, update_fn = fire( + model=lj_model, + dt_max=0.3, + dt_start=0.1, + ) + + state_opt = init_fn(state) + + # Run optimization until convergence + current_energy = state_opt.energy + prev_energy = current_energy + 1 + + step = 0 + while energy_converged(current_energy, prev_energy): + prev_energy = current_energy + state_opt = update_fn(state_opt) + current_energy = state_opt.energy + step += 1 + if step > 1000: + raise ValueError("Optimization did not converge") + + final_individual_states_fire.append(state_opt) + total_steps_fire.append(step) + + individual_energies_unit_cell = [ + state.energy.item() for state in final_individual_states_unit_cell + ] + individual_energies_fire = [ + state.energy.item() for state in final_individual_states_fire + ] + # Check that final energies from fixed cell optimization match + # position only optimizations + for step, energy_unit_cell in enumerate(individual_energies_unit_cell): + assert abs(energy_unit_cell - individual_energies_fire[step]) < 1e-4, ( + f"Energy for batch {step} doesn't match position only optimization: " + f"batch={energy_unit_cell}, individual={individual_energies_fire[step]}" + ) diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index c7a44ea9..7b8e259d 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -26,7 +26,7 @@ @dataclass -class BatchedGDState(SimState): +class GDState(SimState): """State class for batched gradient descent optimization. This class extends SimState to store and track the evolution of system state @@ -53,8 +53,8 @@ def gradient_descent( *, lr: torch.Tensor | float = 0.01, ) -> tuple[ - Callable[[StateDict | SimState], BatchedGDState], - Callable[[BatchedGDState], BatchedGDState], + Callable[[StateDict | SimState], GDState], + Callable[[GDState], GDState], ]: """Initialize a batched gradient descent optimization. @@ -83,7 +83,7 @@ def gradient_descent( def gd_init( state: SimState | StateDict, **kwargs: Any, - ) -> BatchedGDState: + ) -> GDState: """Initialize the batched gradient descent optimization state. Args: @@ -103,7 +103,7 @@ def gd_init( energy = model_output["energy"] forces = model_output["forces"] - return BatchedGDState( + return GDState( positions=state.positions, forces=forces, energy=energy, @@ -114,7 +114,7 @@ def gd_init( batch=state.batch, ) - def gd_step(state: BatchedGDState, lr: torch.Tensor = lr) -> BatchedGDState: + def gd_step(state: GDState, lr: torch.Tensor = lr) -> GDState: """Perform one gradient descent optimization step to update the atomic positions. The cell is not optimized. @@ -123,7 +123,7 @@ def gd_step(state: BatchedGDState, lr: torch.Tensor = lr) -> BatchedGDState: lr: Learning rate(s) to use for this step, overriding the default Returns: - Updated BatchedGDState after one optimization step + Updated GDState after one optimization step """ # Get per-atom learning rates by mapping batch learning rates to atoms if isinstance(lr, float): @@ -147,15 +147,15 @@ def gd_step(state: BatchedGDState, lr: torch.Tensor = lr) -> BatchedGDState: @dataclass -class BatchedUnitCellGDState(BatchedGDState, DeformGradMixin): +class UnitCellGDState(GDState, DeformGradMixin): """State class for batched gradient descent optimization with unit cell. - Extends BatchedGDState to include unit cell optimization parameters and stress + Extends GDState to include unit cell optimization parameters and stress information. This class maintains the state variables needed for simultaneously optimizing atomic positions and unit cell parameters. Attributes: - # Inherited from BatchedGDState + # Inherited from GDState positions (torch.Tensor): Atomic positions with shape [n_atoms, 3] masses (torch.Tensor): Atomic masses with shape [n_atoms] cell (torch.Tensor): Unit cell vectors with shape [n_batches, 3, 3] @@ -203,8 +203,8 @@ def unit_cell_gradient_descent( # noqa: PLR0915, C901 constant_volume: bool = False, scalar_pressure: float = 0.0, ) -> tuple[ - Callable[[SimState | StateDict], BatchedUnitCellGDState], - Callable[[BatchedUnitCellGDState], BatchedUnitCellGDState], + Callable[[SimState | StateDict], UnitCellGDState], + Callable[[UnitCellGDState], UnitCellGDState], ]: """Initialize a batched gradient descent optimization with unit cell parameters. @@ -252,7 +252,7 @@ def gd_init( hydrostatic_strain: bool = hydrostatic_strain, # noqa: FBT001 constant_volume: bool = constant_volume, # noqa: FBT001 scalar_pressure: float = scalar_pressure, - ) -> BatchedUnitCellGDState: + ) -> UnitCellGDState: """Initialize the batched gradient descent optimization state with unit cell. Args: @@ -264,7 +264,7 @@ def gd_init( **kwargs: Additional keyword arguments for state initialization Returns: - Initial BatchedUnitCellGDState with system configuration and forces + Initial UnitCellGDState with system configuration and forces """ if not isinstance(state, SimState): state = SimState(**state) @@ -336,7 +336,7 @@ def gd_init( # Reshape virial for cell forces cell_forces = virial # shape: (n_batches, 3, 3) - return BatchedUnitCellGDState( + return UnitCellGDState( positions=state.positions, forces=forces, energy=energy, @@ -357,10 +357,10 @@ def gd_init( ) def gd_step( - state: BatchedUnitCellGDState, + state: UnitCellGDState, positions_lr: torch.Tensor = positions_lr, cell_lr: torch.Tensor = cell_lr, - ) -> BatchedUnitCellGDState: + ) -> UnitCellGDState: """Perform one gradient descent optimization step with unit cell. Updates both atomic positions and cell parameters based on forces and stress. @@ -371,7 +371,7 @@ def gd_step( cell_lr: Learning rate for unit cell optimization Returns: - Updated BatchedUnitCellGDState after one optimization step + Updated UnitCellGDState after one optimization step """ # Get dimensions n_batches = state.n_batches @@ -446,7 +446,253 @@ def gd_step( @dataclass -class BatchedUnitCellFireState(SimState, DeformGradMixin): +class FireState(SimState): + """State information for batched FIRE optimization. + + This class extends SimState to store and track the system state during FIRE + (Fast Inertial Relaxation Engine) optimization. It maintains the atomic + parameters along with their velocities and forces for structure relaxation using + the FIRE algorithm. + + Attributes: + # Inherited from SimState + positions (torch.Tensor): Atomic positions with shape [n_atoms, 3] + masses (torch.Tensor): Atomic masses with shape [n_atoms] + cell (torch.Tensor): Unit cell vectors with shape [n_batches, 3, 3] + pbc (bool): Whether to use periodic boundary conditions + atomic_numbers (torch.Tensor): Atomic numbers with shape [n_atoms] + batch (torch.Tensor): Batch indices with shape [n_atoms] + + # Atomic quantities + forces (torch.Tensor): Forces on atoms with shape [n_atoms, 3] + velocities (torch.Tensor): Atomic velocities with shape [n_atoms, 3] + energy (torch.Tensor): Energy per batch with shape [n_batches] + + # FIRE optimization parameters + dt (torch.Tensor): Current timestep per batch with shape [n_batches] + alpha (torch.Tensor): Current mixing parameter per batch with shape [n_batches] + n_pos (torch.Tensor): Number of positive power steps per batch with shape + [n_batches] + + Properties: + momenta (torch.Tensor): Atomwise momenta of the system with shape [n_atoms, 3], + calculated as velocities * masses + """ + + # Required attributes not in SimState + forces: torch.Tensor + energy: torch.Tensor + velocities: torch.Tensor + + # FIRE algorithm parameters + dt: torch.Tensor + alpha: torch.Tensor + n_pos: torch.Tensor + + +def fire( + model: torch.nn.Module, + *, + dt_max: float = 1.0, + dt_start: float = 0.1, + n_min: int = 5, + f_inc: float = 1.1, + f_dec: float = 0.5, + alpha_start: float = 0.1, + f_alpha: float = 0.99, +) -> tuple[ + FireState, + Callable[[FireState], FireState], +]: + """Initialize a batched FIRE optimization. + + Creates an optimizer that performs FIRE (Fast Inertial Relaxation Engine) + optimization on atomic positions. + + Args: + model (torch.nn.Module): Model that computes energies, forces, and stress + dt_max (float): Maximum allowed timestep + dt_start (float): Initial timestep + n_min (int): Minimum steps before timestep increase + f_inc (float): Factor for timestep increase when power is positive + f_dec (float): Factor for timestep decrease when power is negative + alpha_start (float): Initial velocity mixing parameter + f_alpha (float): Factor for mixing parameter decrease + + Returns: + tuple: A pair of functions: + - Initialization function that creates a FireState + - Update function that performs one FIRE optimization step + + Notes: + - FIRE is generally more efficient than standard gradient descent for atomic + structure optimization + - The algorithm adaptively adjusts step sizes and mixing parameters based + on the dot product of forces and velocities + """ + device = model.device + dtype = model.dtype + + eps = 1e-8 if dtype == torch.float32 else 1e-16 + + # Setup parameters + params = [dt_max, dt_start, alpha_start, f_inc, f_dec, f_alpha, n_min] + dt_max, dt_start, alpha_start, f_inc, f_dec, f_alpha, n_min = [ + ( + p + if isinstance(p, torch.Tensor) + else torch.tensor(p, device=device, dtype=dtype) + ) + for p in params + ] + + def fire_init( + state: SimState | StateDict, + dt_start: float = dt_start, + alpha_start: float = alpha_start, + ) -> FireState: + """Initialize a batched FIRE optimization state. + + Args: + state: Input state as SimState object or state parameter dict + dt_start: Initial timestep per batch + alpha_start: Initial mixing parameter per batch + + Returns: + FireState with initialized optimization tensors + """ + if not isinstance(state, SimState): + state = SimState(**state) + + # Get dimensions + n_batches = state.n_batches + + # Get initial forces and energy from model + model_output = model(state) + + energy = model_output["energy"] # [n_batches] + forces = model_output["forces"] # [n_total_atoms, 3] + + # Setup parameters + dt_start = torch.full((n_batches,), dt_start, device=device, dtype=dtype) + alpha_start = torch.full((n_batches,), alpha_start, device=device, dtype=dtype) + + n_pos = torch.zeros((n_batches,), device=device, dtype=torch.int32) + + # Create initial state + return FireState( + # Copy SimState attributes + positions=state.positions.clone(), + masses=state.masses.clone(), + cell=state.cell.clone(), + atomic_numbers=state.atomic_numbers.clone(), + batch=state.batch.clone(), + pbc=state.pbc, + # New attributes + velocities=torch.zeros_like(state.positions), + forces=forces, + energy=energy, + # Optimization attributes + dt=dt_start, + alpha=alpha_start, + n_pos=n_pos, + ) + + def fire_step( + state: FireState, + alpha_start: float = alpha_start, + dt_start: float = dt_start, + ) -> FireState: + """Perform one FIRE optimization step for batched atomic systems. + + Implements one step of the Fast Inertial Relaxation Engine (FIRE) algorithm for + optimizing atomic positions in a batched setting. Uses velocity Verlet + integration with adaptive velocity mixing. + + Args: + state: Current optimization state containing atomic parameters + alpha_start: Initial mixing parameter for velocity update + dt_start: Initial timestep for velocity Verlet integration + + Returns: + Updated state after performing one FIRE step + """ + n_batches = state.n_batches + + # Setup parameters + dt_start = torch.full((n_batches,), dt_start, device=device, dtype=dtype) + alpha_start = torch.full((n_batches,), alpha_start, device=device, dtype=dtype) + + # Velocity Verlet first half step (v += 0.5*a*dt) + atom_wise_dt = state.dt[state.batch].unsqueeze(-1) + state.velocities += 0.5 * atom_wise_dt * state.forces / state.masses.unsqueeze(-1) + + # Split positions and forces into atomic and cell components + atomic_positions = state.positions # shape: (n_atoms, 3) + + # Update atomic positions + atomic_positions_new = atomic_positions + atom_wise_dt * state.velocities + + # Update state with new positions and cell + state.positions = atomic_positions_new + + # Get new forces, energy, and stress + results = model(state) + state.energy = results["energy"] + state.forces = results["forces"] + + # Velocity Verlet first half step (v += 0.5*a*dt) + state.velocities += 0.5 * atom_wise_dt * state.forces / state.masses.unsqueeze(-1) + + # Calculate power (F·V) for atoms + atomic_power = (state.forces * state.velocities).sum(dim=1) # [n_atoms] + atomic_power_per_batch = torch.zeros( + n_batches, device=device, dtype=atomic_power.dtype + ) + atomic_power_per_batch.scatter_add_( + dim=0, index=state.batch, src=atomic_power + ) # [n_batches] + + # Calculate power for cell DOFs + batch_power = atomic_power_per_batch + + for batch_idx in range(n_batches): + # FIRE specific updates + if batch_power[batch_idx] > 0: # Power is positive + state.n_pos[batch_idx] += 1 + if state.n_pos[batch_idx] > n_min: + state.dt[batch_idx] = min(state.dt[batch_idx] * f_inc, dt_max) + state.alpha[batch_idx] = state.alpha[batch_idx] * f_alpha + else: # Power is negative + state.n_pos[batch_idx] = 0 + state.dt[batch_idx] = state.dt[batch_idx] * f_dec + state.alpha[batch_idx] = alpha_start[batch_idx] + # Reset velocities for both atoms and cell + state.velocities[state.batch == batch_idx] = 0 + + # Mix velocity and force direction using FIRE for atoms + v_norm = torch.norm(state.velocities, dim=1, keepdim=True) + f_norm = torch.norm(state.forces, dim=1, keepdim=True) + # Avoid division by zero + # mask = f_norm > 1e-10 + # state.velocity = torch.where( + # mask, + # (1.0 - state.alpha) * state.velocity + # + state.alpha * state.forces * v_norm / f_norm, + # state.velocity, + # ) + batch_wise_alpha = state.alpha[state.batch].unsqueeze(-1) + state.velocities = ( + 1.0 - batch_wise_alpha + ) * state.velocities + batch_wise_alpha * state.forces * v_norm / (f_norm + eps) + + return state + + return fire_init, fire_step + + +@dataclass +class UnitCellFireState(SimState, DeformGradMixin): """State information for batched FIRE optimization with unit cell degrees of freedom. @@ -535,8 +781,8 @@ def unit_cell_fire( # noqa: C901, PLR0915 constant_volume: bool = False, scalar_pressure: float = 0.0, ) -> tuple[ - BatchedUnitCellFireState, - Callable[[BatchedUnitCellFireState], BatchedUnitCellFireState], + UnitCellFireState, + Callable[[UnitCellFireState], UnitCellFireState], ]: """Initialize a batched FIRE optimization with unit cell degrees of freedom. @@ -598,7 +844,7 @@ def fire_init( scalar_pressure: float = scalar_pressure, dt_start: float = dt_start, alpha_start: float = alpha_start, - ) -> BatchedUnitCellFireState: + ) -> UnitCellFireState: """Initialize a batched FIRE optimization state with unit cell. Args: @@ -610,7 +856,7 @@ def fire_init( alpha_start: Initial mixing parameter per batch Returns: - BatchedUnitCellFireState with initialized optimization tensors + UnitCellFireState with initialized optimization tensors """ if not isinstance(state, SimState): state = SimState(**state) @@ -678,7 +924,7 @@ def fire_init( n_pos = torch.zeros((n_batches,), device=device, dtype=torch.int32) # Create initial state - return BatchedUnitCellFireState( + return UnitCellFireState( # Copy SimState attributes positions=state.positions.clone(), masses=state.masses.clone(), @@ -708,10 +954,10 @@ def fire_init( ) def fire_step( # noqa: PLR0915 - state: BatchedUnitCellFireState, + state: UnitCellFireState, alpha_start: float = alpha_start, dt_start: float = dt_start, - ) -> BatchedUnitCellFireState: + ) -> UnitCellFireState: """Perform one FIRE optimization step for batched atomic systems with unit cell optimization. @@ -862,7 +1108,7 @@ def fire_step( # noqa: PLR0915 @dataclass -class BatchedFrechetCellFIREState(SimState, DeformGradMixin): +class FrechetCellFIREState(SimState, DeformGradMixin): """State class for batched FIRE optimization with Frechet cell derivatives. This class extends SimState to store and track the system state during FIRE @@ -951,8 +1197,8 @@ def frechet_cell_fire( # noqa: C901, PLR0915 constant_volume: bool = False, scalar_pressure: float = 0.0, ) -> tuple[ - BatchedFrechetCellFIREState, - Callable[[BatchedFrechetCellFIREState], BatchedFrechetCellFIREState], + FrechetCellFIREState, + Callable[[FrechetCellFIREState], FrechetCellFIREState], ]: """Initialize a batched FIRE optimization with Frechet cell parameterization. @@ -980,7 +1226,7 @@ def frechet_cell_fire( # noqa: C901, PLR0915 Returns: tuple: A pair of functions: - - Initialization function that creates a BatchedFrechetCellFIREState + - Initialization function that creates a FrechetCellFIREState - Update function that performs one FIRE step with Frechet derivatives Notes: @@ -1014,7 +1260,7 @@ def fire_init( scalar_pressure: float = scalar_pressure, dt_start: float = dt_start, alpha_start: float = alpha_start, - ) -> BatchedFrechetCellFIREState: + ) -> FrechetCellFIREState: """Initialize a batched FIRE optimization state with Frechet cell parameterization. @@ -1027,7 +1273,7 @@ def fire_init( alpha_start: Initial mixing parameter per batch Returns: - BatchedFrechetCellFIREState with initialized optimization tensors + FrechetCellFIREState with initialized optimization tensors """ if not isinstance(state, SimState): state = SimState(**state) @@ -1107,7 +1353,7 @@ def fire_init( n_pos = torch.zeros((n_batches,), device=device, dtype=torch.int32) # Create initial state - return BatchedFrechetCellFIREState( + return FrechetCellFIREState( # Copy SimState attributes positions=state.positions, masses=state.masses, @@ -1137,10 +1383,10 @@ def fire_init( ) def fire_step( # noqa: PLR0915 - state: BatchedFrechetCellFIREState, + state: FrechetCellFIREState, alpha_start: float = alpha_start, dt_start: float = dt_start, - ) -> BatchedFrechetCellFIREState: + ) -> FrechetCellFIREState: """Perform one FIRE optimization step for batched atomic systems with Frechet cell parameterization. diff --git a/torch_sim/unbatched/unbatched_optimizers.py b/torch_sim/unbatched/unbatched_optimizers.py index 6631b219..d54325d3 100644 --- a/torch_sim/unbatched/unbatched_optimizers.py +++ b/torch_sim/unbatched/unbatched_optimizers.py @@ -270,7 +270,7 @@ def fire_update( alpha_curr = state.alpha n_pos = state.n_pos - state = velocity_verlet(state, dt_curr, model=model) + state = velocity_verlet(state=state, dt=dt_curr, model=model) state.dt = dt_curr state.alpha = alpha_curr