Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 40 additions & 18 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,14 +127,10 @@ def ti_sim_state(device: torch.device, dtype: torch.dtype) -> SimState:
def tio2_sim_state(device: torch.device, dtype: torch.dtype) -> SimState:
"""Create crystalline TiO2 using ASE."""
a, c = 4.60, 2.96
symbols = ["Ti", "O", "O"]
basis = [
(0.5, 0.5, 0), # Ti
(0.695679, 0.695679, 0.5), # O
]
basis = [("Ti", 0.5, 0.5, 0), ("O", 0.695679, 0.695679, 0.5)]
atoms = crystal(
symbols,
basis=basis,
symbols=[b[0] for b in basis],
basis=[b[1:] for b in basis],
spacegroup=136, # P4_2/mnm
cellpar=[a, a, c, 90, 90, 90],
)
Expand All @@ -145,13 +141,10 @@ def tio2_sim_state(device: torch.device, dtype: torch.dtype) -> SimState:
def ga_sim_state(device: torch.device, dtype: torch.dtype) -> SimState:
"""Create crystalline Ga using ASE."""
a, b, c = 4.43, 7.60, 4.56
symbols = ["Ga"]
basis = [
(0, 0.344304, 0.415401), # Ga
]
basis = [("Ga", 0, 0.344304, 0.415401)]
atoms = crystal(
symbols,
basis=basis,
symbols=[b[0] for b in basis],
basis=[b[1:] for b in basis],
spacegroup=64, # Cmce
cellpar=[a, b, c, 90, 90, 90],
)
Expand All @@ -163,14 +156,13 @@ def niti_sim_state(device: torch.device, dtype: torch.dtype) -> SimState:
"""Create crystalline NiTi using ASE."""
a, b, c = 2.89, 3.97, 4.83
alpha, beta, gamma = 90.00, 105.23, 90.00
symbols = ["Ni", "Ti"]
basis = [
(0.369548, 0.25, 0.217074), # Ni
(0.076622, 0.25, 0.671102), # Ti
("Ni", 0.369548, 0.25, 0.217074),
("Ti", 0.076622, 0.25, 0.671102),
]
atoms = crystal(
symbols,
basis=basis,
symbols=[b[0] for b in basis],
basis=[b[1:] for b in basis],
spacegroup=11,
cellpar=[a, b, c, alpha, beta, gamma],
)
Expand Down Expand Up @@ -215,6 +207,36 @@ def rattled_sio2_sim_state(
return sim_state


@pytest.fixture
def casio3_sim_state(device: torch.device, dtype: torch.dtype) -> SimState:
a, b, c = 7.9258, 7.3202, 7.0653
alpha, beta, gamma = 90.055, 95.217, 103.426
basis = [
("Ca", 0.19831, 0.42266, 0.76060),
("Ca", 0.20241, 0.92919, 0.76401),
("Ca", 0.50333, 0.75040, 0.52691),
("Si", 0.1851, 0.3875, 0.2684),
("Si", 0.1849, 0.9542, 0.2691),
("Si", 0.3973, 0.7236, 0.0561),
("O", 0.3034, 0.4616, 0.4628),
("O", 0.3014, 0.9385, 0.4641),
("O", 0.5705, 0.7688, 0.1988),
("O", 0.9832, 0.3739, 0.2655),
("O", 0.9819, 0.8677, 0.2648),
("O", 0.4018, 0.7266, 0.8296),
("O", 0.2183, 0.1785, 0.2254),
("O", 0.2713, 0.8704, 0.0938),
("O", 0.2735, 0.5126, 0.0931),
]
atoms = crystal(
symbols=[b[0] for b in basis],
basis=[b[1:] for b in basis],
spacegroup=2,
cellpar=[a, b, c, alpha, beta, gamma],
)
return ts.io.atoms_to_state(atoms, device, dtype)


@pytest.fixture
def benzene_sim_state(
benzene_atoms: Any, device: torch.device, dtype: torch.dtype
Expand Down
1 change: 1 addition & 0 deletions tests/models/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
"rattled_sio2_sim_state",
"ar_supercell_sim_state",
"fe_supercell_sim_state",
"casio3_sim_state",
"benzene_sim_state",
)

Expand Down
51 changes: 34 additions & 17 deletions tests/test_integrators.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Any

import pytest
import torch

from torch_sim.integrators import (
Expand Down Expand Up @@ -317,13 +318,23 @@ def test_nve(ar_double_sim_state: SimState, lj_model: LennardJonesModel):
assert torch.allclose(energies_tensor[:, 1], energies_tensor[0, 1], atol=1e-4)


@pytest.mark.parametrize(
"sim_state_fixture_name", ["casio3_sim_state", "ar_supercell_sim_state"]
)
def test_compare_single_vs_batched_integrators(
ar_supercell_sim_state: SimState, lj_model: Any
sim_state_fixture_name: str, request: pytest.FixtureRequest, lj_model: Any
) -> None:
"""Test that single and batched integrators give the same results."""
"""Test NVE single vs batched for a tilted cell to verify PBC wrapping.

NOTE: added triclinic cell after https://github.com/Radical-AI/torch-sim/issues/171.
Although the addition doesn't fail if we do not add the changes suggested in issue.
"""
sim_state = request.getfixturevalue(sim_state_fixture_name)
n_steps = 100

initial_states = {
"single": ar_supercell_sim_state,
"batched": concatenate_states([ar_supercell_sim_state, ar_supercell_sim_state]),
"single": sim_state,
"batched": concatenate_states([sim_state, sim_state]),
}

final_states = {}
Expand All @@ -333,25 +344,31 @@ def test_compare_single_vs_batched_integrators(
dt = torch.tensor(0.001) # Small timestep for stability

nve_init, nve_update = nve(model=lj_model, dt=dt, kT=kT)
state = nve_init(state=state, seed=42)
state.momenta = torch.zeros_like(state.momenta)
# Initialize momenta (even if zero) and get forces
state = nve_init(state=state, seed=42) # kT is ignored if momenta are set below
# Ensure momenta start at zero AFTER init which might randomize them based on kT
state.momenta = torch.zeros_like(state.momenta) # Start from rest

for _step in range(100):
for _step in range(n_steps):
state = nve_update(state=state, dt=dt)

final_states[state_name] = state

# Check energy conservation
ar_single_state = final_states["single"]
ar_batched_state_0 = final_states["batched"][0]
ar_batched_state_1 = final_states["batched"][1]

for final_state in [ar_batched_state_0, ar_batched_state_1]:
assert torch.allclose(ar_single_state.positions, final_state.positions)
assert torch.allclose(ar_single_state.momenta, final_state.momenta)
assert torch.allclose(ar_single_state.forces, final_state.forces)
assert torch.allclose(ar_single_state.masses, final_state.masses)
assert torch.allclose(ar_single_state.cell, final_state.cell)
single_state = final_states["single"]
batched_state_0 = final_states["batched"][0]
batched_state_1 = final_states["batched"][1]

# Compare single state results with each part of the batched state
for final_state in [batched_state_0, batched_state_1]:
# Check positions first - most likely to fail with incorrect PBC
torch.testing.assert_close(single_state.positions, final_state.positions)
# Check other state components
torch.testing.assert_close(single_state.momenta, final_state.momenta)
torch.testing.assert_close(single_state.forces, final_state.forces)
torch.testing.assert_close(single_state.masses, final_state.masses)
torch.testing.assert_close(single_state.cell, final_state.cell)
torch.testing.assert_close(single_state.energy, final_state.energy)


def test_compute_cell_force_atoms_per_batch():
Expand Down
67 changes: 36 additions & 31 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,30 +94,34 @@ def test_pbc_wrap_general_orthorhombic() -> None:
assert torch.allclose(wrapped, expected)


def test_pbc_wrap_general_triclinic() -> None:
"""Test periodic boundary wrapping with triclinic cell.

Tests wrapping in a non-orthogonal cell where lattice vectors have
off-diagonal components (tilt factors). This verifies the general
matrix transformation approach works for arbitrary cell shapes.
"""
# Triclinic cell with tilt
lattice = torch.tensor(
[
[2.0, 0.5, 0.0], # a vector with b-tilt
[0.0, 2.0, 0.0], # b vector
[0.0, 0.3, 2.0], # c vector with b-tilt
]
)

# Position outside triclinic box
positions = torch.tensor([[2.5, 2.5, 2.5]])

# Correct expected wrapped position for this triclinic cell
expected = torch.tensor([[2.0, 0.5, 0.2]])

wrapped = tst.pbc_wrap_general(positions, lattice)
assert torch.allclose(wrapped, expected, atol=1e-6)
@pytest.mark.parametrize(
("cell", "shift"),
[
# Cubic cell, integer shift [1, 1, 1]
(torch.eye(3, dtype=torch.float64) * 2.0, [1, 1, 1]),
# Triclinic cell, integer shift [1, 1, 1]
(([[2.0, 0.0, 0.0], [0.5, 2.0, 0.0], [0.0, 0.3, 2.0]]), [1, 1, 1]),
# Triclinic cell, integer shift [-1, 2, 0]
(([[2.0, 0.5, 0.0], [0.0, 2.0, 0.0], [0.0, 0.3, 2.0]]), [-1, 2, 0]),
# triclinic, all negative shift
(([[2.0, 0.5, 0.0], [0.0, 2.0, 0.0], [0.0, 0.3, 2.0]]), [-2, -1, -3]),
# cubic, large mixed shift
(torch.eye(3, dtype=torch.float64) * 2.0, [5, 0, -10]),
# highly tilted cell
(([[1.3, 0.9, 0.8], [0.0, 1.0, 0.9], [0.0, 0.0, 1.0]]), [1, -2, 3]),
# Left-handed cell
(([[2.0, 0.0, 0.0], [0.0, -2.0, 0.0], [0.0, 0.0, 2.0]]), [1, 1, 1]),
],
)
def test_pbc_wrap_general_param(cell: torch.Tensor, shift: torch.Tensor) -> None:
"""Test periodic boundary wrapping for various cells and integer shifts."""
cell = torch.as_tensor(cell, dtype=torch.float64)
shift = torch.as_tensor(shift, dtype=torch.float64)
base_frac = torch.tensor([[0.25, 0.5, 0.75]], dtype=torch.float64)
base_cart = base_frac @ cell.T
shifted_cart = base_cart + (shift @ cell.T)
wrapped = tst.pbc_wrap_general(shifted_cart, cell)
torch.testing.assert_close(wrapped, base_cart, rtol=1e-6, atol=1e-6)


def test_pbc_wrap_general_edge_case() -> None:
Expand Down Expand Up @@ -277,35 +281,36 @@ def test_pbc_wrap_batched_orthorhombic(si_double_sim_state: SimState) -> None:

def test_pbc_wrap_batched_triclinic(device: torch.device) -> None:
"""Test batched periodic boundary wrapping with triclinic cell."""
# Create two triclinic cells with different tilt factors
# Define cell matrices (M_row convention)
cell1 = torch.tensor(
[
[2.0, 0.5, 0.0], # a vector with b-tilt
[0.0, 2.0, 0.0], # b vector
[0.0, 0.3, 2.0], # c vector with b-tilt
],
dtype=torch.float64,
device=device,
)

cell2 = torch.tensor(
[
[2.0, 0.0, 0.5], # a vector with c-tilt
[0.3, 2.0, 0.0], # b vector with a-tilt
[0.0, 0.0, 2.0], # c vector
],
dtype=torch.float64,
device=device,
)
cell = torch.stack([cell1, cell2])

# Create positions for two atoms, one in each batch
# Define positions (r_row convention)
positions = torch.tensor(
[
[2.5, 2.5, 2.5], # First atom, outside batch 0's cell
[2.7, 2.7, 2.7], # Second atom, outside batch 1's cell
[2.5, 2.5, 2.5], # Atom 0 (batch 0)
[2.7, 2.7, 2.7], # Atom 1 (batch 1)
],
dtype=torch.float64,
device=device,
)

# Create batch indices
batch = torch.tensor([0, 1], device=device)

# Stack the cells for batched processing
Expand Down
8 changes: 2 additions & 6 deletions torch_sim/integrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,7 @@ def position_step(state: MDState, dt: torch.Tensor) -> MDState:

if state.pbc:
# Split positions and cells by batch
new_positions = pbc_wrap_batched(
new_positions, state.cell.swapaxes(1, 2), state.batch
)
new_positions = pbc_wrap_batched(new_positions, state.cell, state.batch)

state.positions = new_positions
return state
Expand Down Expand Up @@ -1027,9 +1025,7 @@ def langevin_position_step(

# Apply periodic boundary conditions if needed
if state.pbc:
state.positions = pbc_wrap_batched(
state.positions, state.cell.swapaxes(1, 2), state.batch
)
state.positions = pbc_wrap_batched(state.positions, state.cell, state.batch)

return state

Expand Down
37 changes: 10 additions & 27 deletions torch_sim/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,14 @@ def pbc_wrap_general(
This implementation follows the general matrix-based approach for
periodic boundary conditions in arbitrary triclinic cells:
1. Transform positions to fractional coordinates using B = A^(-1)
2. Wrap fractional coordinates to [0,1) using f - floor(f)
2. Wrap fractional coordinates to [0,1) using modulo
3. Transform back to real space using A

Args:
positions (torch.Tensor): Tensor of shape (..., d)
containing particle positions in real space.
lattice_vectors (torch.Tensor): Tensor of shape (d, d)
containing lattice vectors as columns (A matrix in the equations).
lattice_vectors (torch.Tensor): Tensor of shape (d, d) containing
lattice vectors as columns (A matrix in the equations).

Returns:
torch.Tensor: Tensor of wrapped positions in real space with
Expand All @@ -124,23 +124,13 @@ def pbc_wrap_general(
if positions.shape[-1] != lattice_vectors.shape[0]:
raise ValueError("Position dimensionality must match lattice vectors.")

# Compute B = A^(-1) to transform to fractional coordinates
B = torch.linalg.inv(lattice_vectors)

# Transform to fractional coordinates: f = Br
frac_coords = positions @ B.T

# Wrap to reference cell [0,1) using f - floor(f)
wrapped_frac = frac_coords - torch.floor(frac_coords)
frac_coords = positions @ torch.linalg.inv(lattice_vectors).T

# Handle edge case of positions exactly on upper boundary
wrapped_frac = torch.where(
torch.isclose(wrapped_frac, torch.ones_like(wrapped_frac)),
torch.zeros_like(wrapped_frac),
wrapped_frac,
)
# Wrap to reference cell [0,1) using modulo
wrapped_frac = frac_coords % 1.0

# Transform back to real space: t = Ag
# Transform back to real space: r_row_wrapped = wrapped_f_row @ M_row
return wrapped_frac @ lattice_vectors.T


Expand All @@ -157,7 +147,7 @@ def pbc_wrap_batched(
positions (torch.Tensor): Tensor of shape (n_atoms, 3) containing
particle positions in real space.
cell (torch.Tensor): Tensor of shape (n_batches, 3, 3) containing
lattice vectors for each batch.
lattice vectors as column vectors.
batch (torch.Tensor): Tensor of shape (n_atoms,) containing batch
indices for each atom.

Expand Down Expand Up @@ -191,15 +181,8 @@ def pbc_wrap_batched(
# For each atom, multiply its position by its batch's inverse cell matrix
frac_coords = torch.bmm(B_per_atom, positions.unsqueeze(2)).squeeze(2)

# Wrap to reference cell [0,1) using f - floor(f)
wrapped_frac = frac_coords - torch.floor(frac_coords)

# Handle edge case of positions exactly on upper boundary
wrapped_frac = torch.where(
torch.isclose(wrapped_frac, torch.ones_like(wrapped_frac)),
torch.zeros_like(wrapped_frac),
wrapped_frac,
)
# Wrap to reference cell [0,1) using modulo
wrapped_frac = frac_coords % 1.0

# Transform back to real space: r = A·f
# Get the cell for each atom based on its batch index
Expand Down
Loading