Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,7 @@ jobs:
- name: Install torch_sim
run: |
uv pip install "torch>2" --index-url https://download.pytorch.org/whl/cpu --system
# always use numpy>=2 with Python 3.13
if [ "${{ matrix.version.python }}" = "3.13" ]; then
uv pip install -e ".[test]" "numpy>=2" --resolution=${{ matrix.version.resolution }} --system
else
uv pip install -e ".[test]" --resolution=${{ matrix.version.resolution }} --system
fi
uv pip install -e ".[test]" --resolution=${{ matrix.version.resolution }} --system

- name: Run core tests
run: |
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ requires-python = ">=3.12"
dependencies = [
"h5py>=3.12.1",
"nvalchemi-toolkit-ops>=0.2.0",
"numpy>=1.26,<3",
"numpy>=1.26,<3; python_version < '3.13'",
"numpy>=2,<3; python_version >= '3.13'",
"tables>=3.10.2,<3.11",
"torch>=2",
"tqdm>=4.67",
Expand Down
234 changes: 234 additions & 0 deletions tests/test_extras.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
import pytest
import torch

import torch_sim as ts


DEVICE = torch.device("cpu")
DTYPE = torch.float64


class TestExtras:
def test_system_extras_construction(self):
"""Extras can be passed at construction time."""
field = torch.randn(1, 3)
state = ts.SimState(
positions=torch.zeros(2, 3),
masses=torch.ones(2),
cell=torch.eye(3).unsqueeze(0),
pbc=True,
atomic_numbers=torch.tensor([1, 1], dtype=torch.int),
_system_extras={"external_E_field": field},
)
assert torch.equal(state.external_E_field, field)

def test_atom_extras_construction(self):
"""Per-atom extras work at construction time."""
tags = torch.tensor([1.0, 2.0])
state = ts.SimState(
positions=torch.zeros(2, 3),
masses=torch.ones(2),
cell=torch.eye(3).unsqueeze(0),
pbc=True,
atomic_numbers=torch.tensor([1, 1], dtype=torch.int),
_atom_extras={"tags": tags},
)
assert torch.equal(state.tags, tags)

def test_getattr_missing_raises_attribute_error(self, cu_sim_state: ts.SimState):
with pytest.raises(AttributeError, match="nonexistent_key"):
_ = cu_sim_state.nonexistent_key

def test_set_extras(self, cu_sim_state: ts.SimState):
field = torch.randn(cu_sim_state.n_systems, 3, device=cu_sim_state.device)
cu_sim_state.set_extras("E", field, scope="per-system")
assert torch.equal(cu_sim_state.E, field)

def test_set_extras_bad_shape(self, cu_sim_state: ts.SimState):
bad = torch.randn(cu_sim_state.n_systems + 5, 3)
with pytest.raises(ValueError, match="leading dim must be n_systems"):
cu_sim_state.set_extras("bad", bad, scope="per-system")

def test_clone_preserves_extras(self, cu_sim_state: ts.SimState):
field = torch.randn(cu_sim_state.n_systems, 3, device=cu_sim_state.device)
cu_sim_state.set_extras("E", field, scope="per-system")
cloned = cu_sim_state.clone()
assert torch.equal(cloned.E, field)
# verify independence
cloned.system_extras["E"].zero_()
assert not torch.equal(cu_sim_state.E, cloned.E)

def test_split_preserves_extras(self, si_double_sim_state: ts.SimState):
field = torch.randn(
si_double_sim_state.n_systems, 3, device=si_double_sim_state.device
)
si_double_sim_state.set_extras("H", field, scope="per-system")
splits = si_double_sim_state.split()
for i, s in enumerate(splits):
assert torch.equal(s.H, field[i : i + 1])

def test_getitem_preserves_extras(self, si_double_sim_state: ts.SimState):
field = torch.randn(
si_double_sim_state.n_systems, 3, device=si_double_sim_state.device
)
si_double_sim_state.set_extras("E", field, scope="per-system")
sub = si_double_sim_state[[0]]
assert torch.equal(sub.E, field[0:1])

def test_concatenate_preserves_extras(self, cu_sim_state: ts.SimState):
s1 = cu_sim_state.clone()
s2 = cu_sim_state.clone()
f1 = torch.randn(s1.n_systems, 3, device=s1.device)
f2 = torch.randn(s2.n_systems, 3, device=s2.device)
s1.set_extras("E", f1, scope="per-system")
s2.set_extras("E", f2, scope="per-system")
merged = ts.concatenate_states([s1, s2])
assert torch.equal(merged.E, torch.cat([f1, f2], dim=0))

def test_to_device_moves_extras(self, cu_sim_state: ts.SimState):
field = torch.randn(cu_sim_state.n_systems, 3, device=cu_sim_state.device)
cu_sim_state.set_extras("E", field, scope="per-system")
moved = cu_sim_state.to(device=cu_sim_state.device)
assert moved.E.device == cu_sim_state.device

def test_pop_preserves_extras(self, si_double_sim_state: ts.SimState):
field = torch.randn(
si_double_sim_state.n_systems, 3, device=si_double_sim_state.device
)
si_double_sim_state.set_extras("E", field, scope="per-system")
popped = si_double_sim_state.pop(0)
assert popped[0].E.shape[0] == 1

def test_has_extras(self, cu_sim_state: ts.SimState):
assert not cu_sim_state.has_extras("E")
cu_sim_state.set_extras(
"E",
torch.zeros(cu_sim_state.n_systems, 3, device=cu_sim_state.device),
scope="per-system",
)
assert cu_sim_state.has_extras("E")

def test_post_init_validation_rejects_bad_shape(self):
with pytest.raises(ValueError, match="leading dim must be n_systems"):
ts.SimState(
positions=torch.zeros(2, 3),
masses=torch.ones(2),
cell=torch.eye(3).unsqueeze(0),
pbc=True,
atomic_numbers=torch.tensor([1, 1], dtype=torch.int),
_system_extras={"bad": torch.randn(5, 3)},
)

def test_from_state_preserves_extras(self, cu_sim_state: ts.SimState):
field = torch.randn(cu_sim_state.n_systems, 3, device=cu_sim_state.device)
cu_sim_state.set_extras("E", field, scope="per-system")
new = ts.SimState.from_state(cu_sim_state)
assert torch.equal(new.E, field)

def test_extras_cannot_shadow_declared_fields(self, cu_sim_state: ts.SimState):
# set_extras should raise if attempting to shadow
with pytest.raises(ValueError, match="shadows an existing attribute"):
cu_sim_state.set_extras(
"cell", torch.zeros(cu_sim_state.n_systems, 3), scope="per-system"
)

def test_construction_extras_cannot_shadow(self):
# Post-init validation should also catch shadowing during construction
with pytest.raises(ValueError, match="shadows an existing attribute"):
ts.SimState(
positions=torch.zeros(2, 3),
masses=torch.ones(2),
cell=torch.eye(3).unsqueeze(0),
pbc=True,
atomic_numbers=torch.tensor([1, 1], dtype=torch.int),
_system_extras={"cell": torch.zeros(1, 3)},
)

# store_model_extras
def test_store_model_extras_canonical_keys_not_stored(
self, si_double_sim_state: ts.SimState
):
"""Canonical keys (energy, forces, stress) must not land in extras."""
state = si_double_sim_state.clone()
state.store_model_extras(
{
"energy": torch.randn(state.n_systems),
"forces": torch.randn(state.n_atoms, 3),
"stress": torch.randn(state.n_systems, 3, 3),
}
)
assert not state._system_extras # noqa: SLF001
assert not state._atom_extras # noqa: SLF001

def test_store_model_extras_per_system(self, si_double_sim_state: ts.SimState):
"""Tensors with leading dim == n_systems go into system_extras."""
state = si_double_sim_state.clone()
dipole = torch.randn(state.n_systems, 3)
state.store_model_extras(
{"energy": torch.randn(state.n_systems), "dipole": dipole}
)
assert torch.equal(state.dipole, dipole)

def test_store_model_extras_per_atom(self, si_double_sim_state: ts.SimState):
"""Tensors with leading dim == n_atoms go into atom_extras."""
state = si_double_sim_state.clone()
charges = torch.randn(state.n_atoms)
density = torch.randn(state.n_atoms, 8)
state.store_model_extras(
{
"energy": torch.randn(state.n_systems),
"charges": charges,
"density_coefficients": density,
}
)
assert torch.equal(state.charges, charges)
assert state.density_coefficients.shape == (state.n_atoms, 8)

def test_store_model_extras_skips_scalars(self, si_double_sim_state: ts.SimState):
"""0-d tensors and non-Tensor values are silently ignored."""
state = si_double_sim_state.clone()
state.store_model_extras(
{
"scalar": torch.tensor(3.14),
"string": "not a tensor",
}
)
assert not state.has_extras("scalar")
assert not state.has_extras("string")


def test_system_extras_atoms_roundtrip():
state = ts.SimState(
positions=torch.zeros(2, 3),
masses=torch.ones(2),
cell=torch.eye(3).unsqueeze(0),
pbc=True,
atomic_numbers=torch.tensor([1, 1], dtype=torch.int),
_system_extras={"external_E_field": torch.tensor([[1.0, 0.0, 0.0]])},
)
atoms_list = state.to_atoms()
assert "external_E_field" in atoms_list[0].info
restored = ts.io.atoms_to_state(
atoms_list,
system_extras_keys=["external_E_field"],
)
assert torch.allclose(restored.external_E_field, state.external_E_field)


def test_atom_extras_atoms_roundtrip():
tags = torch.tensor([1.0, 2.0])
state = ts.SimState(
positions=torch.zeros(2, 3),
masses=torch.ones(2),
cell=torch.eye(3).unsqueeze(0),
pbc=True,
atomic_numbers=torch.tensor([1, 1], dtype=torch.int),
_atom_extras={"tags": tags},
)
atoms_list = state.to_atoms()
assert "tags" in atoms_list[0].arrays
restored = ts.io.atoms_to_state(
atoms_list,
atom_extras_keys=["tags"],
)
assert torch.allclose(restored.tags, state.tags)
84 changes: 84 additions & 0 deletions tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest
import torch
from ase import Atoms
from ase.build import molecule
from phonopy.structure.atoms import PhonopyAtoms
from pymatgen.core import Structure

Expand Down Expand Up @@ -88,6 +89,69 @@ def test_multiple_atoms_to_state(si_atoms: Atoms) -> None:
)


@pytest.mark.parametrize(
("charge", "spin", "expected_charge", "expected_spin"),
[
(1.0, 1.0, 1.0, 1.0), # Non-zero charge and spin
(0.0, 0.0, 0.0, 0.0), # Explicit zero charge and spin
(None, None, 0.0, 0.0), # No charge/spin set, defaults to zero
],
)
def test_atoms_to_state_with_charge_spin(
charge: float | None,
spin: float | None,
expected_charge: float,
expected_spin: float,
) -> None:
"""Test conversion from ASE Atoms with charge and spin to state tensors."""
mol = molecule("H2O")
if charge is not None:
mol.info["charge"] = charge
if spin is not None:
mol.info["spin"] = spin

state = ts.io.atoms_to_state([mol], DEVICE, DTYPE)

# Check basic properties
assert isinstance(state, SimState)
assert state.charge is not None
assert state.spin is not None
assert state.charge.shape == (1,)
assert state.spin.shape == (1,)
assert state.charge[0].item() == expected_charge
assert state.spin[0].item() == expected_spin


def test_multiple_atoms_to_state_with_charge_spin() -> None:
"""Test conversion from multiple ASE Atoms with different charge/spin values."""
mol1 = molecule("H2O")
mol1.info["charge"] = 1.0
mol1.info["spin"] = 1.0

mol2 = molecule("CH4")
mol2.info["charge"] = -1.0
mol2.info["spin"] = 0.0

mol3 = molecule("NH3")
mol3.info["charge"] = 0.0
mol3.info["spin"] = 2.0

state = ts.io.atoms_to_state([mol1, mol2, mol3], DEVICE, DTYPE)

# Check basic properties
assert isinstance(state, SimState)
assert state.charge is not None
assert state.spin is not None
assert state.charge.shape == (3,)
assert state.spin.shape == (3,)
assert state.charge[0].item() == 1.0
assert state.charge[1].item() == -1.0
assert state.charge[2].item() == 0.0
assert state.spin[0].item() == 1.0
assert state.spin[1].item() == 0.0
assert state.spin[2].item() == 2.0


def test_state_to_structure(ar_supercell_sim_state: SimState) -> None:
"""Test conversion from state tensors to list of pymatgen Structure."""
structures = ts.io.state_to_structures(ar_supercell_sim_state)
Expand All @@ -114,6 +178,23 @@ def test_state_to_atoms(ar_supercell_sim_state: SimState) -> None:
assert len(atoms[0]) == 32


def test_state_to_atoms_with_charge_spin() -> None:
"""Test conversion from state with charge/spin to ASE Atoms preserves charge/spin."""
mol = molecule("H2O")
mol.info["charge"] = 1.0
mol.info["spin"] = 1.0

state = ts.io.atoms_to_state([mol], DEVICE, DTYPE)
atoms = ts.io.state_to_atoms(state)

assert len(atoms) == 1
assert isinstance(atoms[0], Atoms)
assert "charge" in atoms[0].info
assert "spin" in atoms[0].info
assert atoms[0].info["charge"] == 1
assert atoms[0].info["spin"] == 1


def test_state_to_multiple_atoms(ar_double_sim_state: SimState) -> None:
"""Test conversion from state tensors to list of ASE Atoms."""
atoms = ts.io.state_to_atoms(ar_double_sim_state)
Expand Down Expand Up @@ -253,6 +334,9 @@ def test_state_round_trip(
# since both use their own isotope masses based on species,
# not the ones in the state
assert torch.allclose(sim_state.masses, round_trip_state.masses)
# Check charge/spin round trip
assert torch.allclose(sim_state.charge, round_trip_state.charge)
assert torch.allclose(sim_state.spin, round_trip_state.spin)


def test_state_to_atoms_importerror(monkeypatch: pytest.MonkeyPatch) -> None:
Expand Down
1 change: 1 addition & 0 deletions torch_sim/integrators/md.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ def velocity_verlet[T: MDState](state: T, dt: torch.Tensor, model: ModelInterfac

state.energy = model_output["energy"]
state.forces = model_output["forces"]
state.store_model_extras(model_output)
return momentum_step(state, dt_2)


Expand Down
Loading
Loading