Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ To then relax those structures with FIRE is just a few more lines.
relaxed_state = ts.optimize(
system=final_state,
model=mace_model,
optimizer=ts.OptimFlavor.fire,
optimizer=ts.Optimizer.fire,
autobatcher=True,
init_kwargs=dict(cell_filter=ts.CellFilter.frechet),
)
Expand Down
16 changes: 8 additions & 8 deletions examples/scripts/4_High_level_api/4.1_high_level_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
final_state = ts.integrate(
system=si_atoms,
model=lj_model,
integrator=ts.MdFlavor.nvt_langevin,
integrator=ts.Integrator.nvt_langevin,
n_steps=100 if SMOKE_TEST else 1000,
temperature=2000,
timestep=0.002,
Expand Down Expand Up @@ -62,7 +62,7 @@
final_state = ts.integrate(
system=si_atoms,
model=lj_model,
integrator=ts.MdFlavor.nvt_langevin,
integrator=ts.Integrator.nvt_langevin,
n_steps=100 if SMOKE_TEST else 1000,
temperature=2000,
timestep=0.002,
Expand Down Expand Up @@ -102,7 +102,7 @@
final_state = ts.integrate(
system=si_atoms,
model=mace_model,
integrator=ts.MdFlavor.nvt_langevin,
integrator=ts.Integrator.nvt_langevin,
n_steps=100 if SMOKE_TEST else 1000,
temperature=2000,
timestep=0.002,
Expand All @@ -120,7 +120,7 @@
final_state = ts.integrate(
system=[si_atoms, fe_atoms, si_atoms_supercell, fe_atoms_supercell],
model=mace_model,
integrator=ts.MdFlavor.nvt_langevin,
integrator=ts.Integrator.nvt_langevin,
n_steps=100 if SMOKE_TEST else 1000,
temperature=2000,
timestep=0.002,
Expand All @@ -142,7 +142,7 @@
final_state = ts.integrate(
system=systems,
model=mace_model,
integrator=ts.MdFlavor.nvt_langevin,
integrator=ts.Integrator.nvt_langevin,
n_steps=100 if SMOKE_TEST else 1000,
temperature=2000,
timestep=0.002,
Expand All @@ -159,7 +159,7 @@
final_state = ts.optimize(
system=systems,
model=mace_model,
optimizer=ts.OptimFlavor.fire,
optimizer=ts.Optimizer.fire,
max_steps=10 if SMOKE_TEST else 1000,
init_kwargs=dict(cell_filter=ts.CellFilter.unit),
)
Expand All @@ -171,7 +171,7 @@
final_state = ts.optimize(
system=systems,
model=mace_model,
optimizer=ts.OptimFlavor.fire,
optimizer=ts.Optimizer.fire,
convergence_fn=lambda state, last_energy: last_energy - state.energy
< 1e-6 * MetalUnits.energy,
max_steps=10 if SMOKE_TEST else 1000,
Expand All @@ -195,7 +195,7 @@
final_state = ts.integrate(
system=structure,
model=lj_model,
integrator=ts.MdFlavor.nvt_langevin,
integrator=ts.Integrator.nvt_langevin,
n_steps=100 if SMOKE_TEST else 1000,
temperature=2000,
timestep=0.002,
Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/6_Phonons/6.1_Phonons_MACE.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def get_labels_qpts(ph: Phonopy, n_points: int = 101) -> tuple[list[str], list[b
final_state = ts.optimize(
system=struct,
model=model,
optimizer=ts.OptimFlavor.fire,
optimizer=ts.Optimizer.fire,
max_steps=max_steps,
init_kwargs=dict(
cell_filter=ts.CellFilter.frechet, constant_volume=True, hydrostatic_strain=True
Expand Down
4 changes: 2 additions & 2 deletions examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def get_relaxed_structure(
final_state = ts.optimize(
system=struct,
model=model,
optimizer=ts.OptimFlavor.fire,
optimizer=ts.Optimizer.fire,
max_steps=max_steps,
convergence_fn=converge_max_force,
trajectory_reporter=reporter,
Expand Down Expand Up @@ -118,7 +118,7 @@ def get_qha_structures(
scaled_state = ts.optimize(
system=scaled_structs,
model=model,
optimizer=ts.OptimFlavor.fire,
optimizer=ts.Optimizer.fire,
max_steps=Nmax,
convergence_fn=ts.runners.generate_force_convergence_fn(force_tol=fmax),
autobatcher=use_autobatcher,
Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/6_Phonons/6.3_Conductivity_MACE.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def print_relax_info(trajectory_file: str, device: torch.device) -> None:
final_state = ts.optimize(
system=struct,
model=model,
optimizer=ts.OptimFlavor.fire,
optimizer=ts.Optimizer.fire,
max_steps=max_steps,
convergence_fn=converge_max_force,
trajectory_reporter=reporter,
Expand Down
43 changes: 26 additions & 17 deletions examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,14 +125,14 @@
def run_optimization_ts( # noqa: PLR0915
*,
initial_state: SimState | OptimState,
ts_md_flavor: Literal["vv_fire", "ase_fire"],
ts_fire_flavor: Literal["vv_fire", "ase_fire"],
ts_use_frechet: bool,
force_tol: float,
max_iterations_ts: int,
) -> tuple[torch.Tensor, OptimState | None]:
"""Runs torch-sim optimization and returns convergence steps and final state."""
print(
f"\n--- Running torch-sim optimization: flavor={ts_md_flavor}, "
f"\n--- Running torch-sim optimization: flavor={ts_fire_flavor}, "
f"frechet_cell_opt={ts_use_frechet}, force_tol={force_tol} ---"
)
start_time = time.perf_counter()
Expand Down Expand Up @@ -242,7 +242,7 @@ def run_optimization_ts( # noqa: PLR0915

end_time = time.perf_counter()
print(
f"Finished torch-sim ({ts_md_flavor}, frechet={ts_use_frechet}) in "
f"Finished torch-sim ({ts_fire_flavor}, frechet={ts_use_frechet}) in "
f"{end_time - start_time:.2f} seconds."
)
return convergence_steps, final_state_concatenated
Expand Down Expand Up @@ -377,7 +377,10 @@ def run_optimization_ase( # noqa: C901, PLR0915

if not all_positions: # If all optimizations failed early
print("Warning: No successful ASE structures to form OptimState.")
return torch.tensor(convergence_steps_list, dtype=torch.long, device=DEVICE), None
return (
torch.tensor(convergence_steps_list, dtype=torch.long, device=DEVICE),
None,
)

# Concatenate all parts
concatenated_positions = torch.cat(all_positions, dim=0)
Expand Down Expand Up @@ -429,25 +432,25 @@ def run_optimization_ase( # noqa: C901, PLR0915
{
"name": "torch-sim VV-FIRE (PosOnly)",
"type": "torch-sim",
"ts_md_flavor": "vv_fire",
"ts_fire_flavor": "vv_fire",
"ts_use_frechet": False,
},
{
"name": "torch-sim ASE-FIRE (PosOnly)",
"type": "torch-sim",
"ts_md_flavor": "ase_fire",
"ts_fire_flavor": "ase_fire",
"ts_use_frechet": False,
},
{
"name": "torch-sim VV-FIRE (Frechet Cell)",
"type": "torch-sim",
"ts_md_flavor": "vv_fire",
"ts_fire_flavor": "vv_fire",
"ts_use_frechet": True,
},
{
"name": "torch-sim ASE-FIRE (Frechet Cell)",
"type": "torch-sim",
"ts_md_flavor": "ase_fire",
"ts_fire_flavor": "ase_fire",
"ts_use_frechet": True,
},
{
Expand Down Expand Up @@ -475,19 +478,19 @@ class ResultData(TypedDict):
for config_run in configs_to_run:
print(f"\n\nStarting configuration: {config_run['name']}")
optimizer_type_val = config_run["type"]
ts_md_flavor_val = config_run.get("ts_md_flavor")
ts_fire_flavor_val = config_run.get("ts_fire_flavor")
ts_use_frechet_val = config_run.get("ts_use_frechet", False)
ase_use_frechet_filter_val = config_run.get("ase_use_frechet_filter", False)

steps: torch.Tensor | None = None
final_state_opt: OptimState | None = None

if optimizer_type_val == "torch-sim":
if ts_md_flavor_val is None:
raise ValueError(f"{ts_md_flavor_val=} must be provided for torch-sim")
if ts_fire_flavor_val is None:
raise ValueError(f"{ts_fire_flavor_val=} must be provided for torch-sim")
steps, final_state_opt = run_optimization_ts(
initial_state=state.clone(),
ts_md_flavor=ts_md_flavor_val,
ts_fire_flavor=ts_fire_flavor_val,
ts_use_frechet=ts_use_frechet_val,
force_tol=force_tol,
max_iterations_ts=max_iterations,
Expand Down Expand Up @@ -641,10 +644,12 @@ class ResultData(TypedDict):
x=structure_names,
y=steps_data_fig1[:, i],
text=[
"NC"
if all_results[plot_methods_fig1[i]]["steps"].cpu().numpy()[bar_idx] == -1
and not np.isnan(steps_data_fig1[bar_idx, i])
else ""
(
"NC"
if all_results[plot_methods_fig1[i]]["steps"].cpu().numpy()[bar_idx] == -1
and not np.isnan(steps_data_fig1[bar_idx, i])
else ""
)
for bar_idx in range(num_structures_plot)
],
textposition="inside",
Expand Down Expand Up @@ -827,7 +832,11 @@ class ResultData(TypedDict):
baseline_ase_pos_only,
"TS ASE PosOnly vs ASE Native",
),
("torch-sim VV-FIRE (PosOnly)", baseline_ase_pos_only, "TS VV PosOnly vs ASE Native"),
(
"torch-sim VV-FIRE (PosOnly)",
baseline_ase_pos_only,
"TS VV PosOnly vs ASE Native",
),
(
"torch-sim ASE-FIRE (Frechet Cell)",
baseline_ase_frechet,
Expand Down
20 changes: 10 additions & 10 deletions examples/tutorials/high_level_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
final_state = ts.integrate(
system=cu_atoms, # Input atomic system
model=lj_model, # Energy/force model
integrator=ts.MdFlavor.nvt_langevin, # Integrator to use
integrator=ts.Integrator.nvt_langevin, # Integrator to use
n_steps=n_steps, # Number of MD steps
temperature=2000, # Target temperature (K)
timestep=0.002, # Integration timestep (ps)
Expand All @@ -98,7 +98,7 @@
final_state = ts.integrate(
system=cu_atoms,
model=lj_model,
integrator=ts.MdFlavor.nvt_langevin,
integrator=ts.Integrator.nvt_langevin,
n_steps=n_steps,
temperature=2000,
timestep=0.002,
Expand Down Expand Up @@ -151,7 +151,7 @@
final_state = ts.integrate(
system=cu_atoms,
model=lj_model,
integrator=ts.MdFlavor.nvt_langevin,
integrator=ts.Integrator.nvt_langevin,
n_steps=n_steps,
temperature=2000,
timestep=0.002,
Expand Down Expand Up @@ -209,7 +209,7 @@
final_state = ts.integrate(
system=cu_atoms,
model=mace_model,
integrator=ts.MdFlavor.nvt_langevin,
integrator=ts.Integrator.nvt_langevin,
n_steps=n_steps,
temperature=2000,
timestep=0.002,
Expand Down Expand Up @@ -245,7 +245,7 @@
final_state = ts.integrate(
system=systems, # List of systems to simulate
model=mace_model, # Single model for all systems
integrator=ts.MdFlavor.nvt_langevin,
integrator=ts.Integrator.nvt_langevin,
n_steps=n_steps,
temperature=2000,
timestep=0.002,
Expand Down Expand Up @@ -277,7 +277,7 @@
final_state = ts.integrate(
system=systems,
model=mace_model,
integrator=ts.MdFlavor.nvt_langevin,
integrator=ts.Integrator.nvt_langevin,
n_steps=n_steps,
temperature=2000,
timestep=0.002,
Expand Down Expand Up @@ -327,7 +327,7 @@
final_state = ts.integrate(
system=systems,
model=mace_model,
integrator=ts.MdFlavor.nvt_langevin,
integrator=ts.Integrator.nvt_langevin,
n_steps=n_steps,
temperature=2000,
timestep=0.002,
Expand Down Expand Up @@ -356,7 +356,7 @@
final_state = ts.optimize(
system=systems,
model=mace_model,
optimizer=ts.OptimFlavor.fire,
optimizer=ts.Optimizer.fire,
init_kwargs=dict(cell_filter=ts.CellFilter.unit),
)

Expand Down Expand Up @@ -404,7 +404,7 @@ def default_energy_convergence(state, last_energy):
final_state = ts.optimize(
system=systems,
model=mace_model,
optimizer=ts.OptimFlavor.fire,
optimizer=ts.Optimizer.fire,
convergence_fn=force_convergence_fn, # Custom convergence function
init_kwargs=dict(cell_filter=ts.CellFilter.unit),
)
Expand Down Expand Up @@ -481,7 +481,7 @@ def default_energy_convergence(state, last_energy):
final_state = ts.integrate(
system=structure,
model=lj_model,
integrator=ts.MdFlavor.nvt_langevin,
integrator=ts.Integrator.nvt_langevin,
n_steps=n_steps,
temperature=2000,
timestep=0.002,
Expand Down
2 changes: 1 addition & 1 deletion examples/tutorials/low_level_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@
# %% [markdown]
"""
You can set the optimizer-specific arguments in the `optimize` function
optimizer=ts.OptimFlavor.fire, cell_filter=ts.CellFilter.unit. Fixed
optimizer=ts.Optimizer.fire, cell_filter=ts.CellFilter.unit. Fixed
parameters can usually be passed to the `init_fn` and parameters that vary over
the course of the simulation can be passed to the step_fn`.
"""
Expand Down
4 changes: 2 additions & 2 deletions examples/tutorials/metatomic_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
equilibrated_state = ts.integrate(
system=atoms,
model=model,
integrator=ts.MdFlavor.nvt_langevin,
integrator=ts.Integrator.nvt_langevin,
n_steps=100,
temperature=300, # K
timestep=0.001, # ps
Expand All @@ -54,7 +54,7 @@
final_state = ts.integrate(
system=equilibrated_state,
model=model,
integrator=ts.MdFlavor.nve,
integrator=ts.Integrator.nve,
n_steps=100,
temperature=300, # K
timestep=0.001, # ps
Expand Down
2 changes: 1 addition & 1 deletion examples/tutorials/using_graphpes_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
final_state = ts.integrate(
system=atoms,
model=ts_model,
integrator=ts.MdFlavor.nvt_langevin,
integrator=ts.Integrator.nvt_langevin,
n_steps=50,
temperature=300,
timestep=0.001,
Expand Down
Loading