From 7846d27292626fc88f0b1009ecbe74192b3c3fbd Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Fri, 30 Jan 2026 21:08:05 -0500 Subject: [PATCH 1/7] fea: test ase to torchsim spin/charge io. Test models don't crash with spins and charges --- tests/test_io.py | 84 +++++++++++++++++++++++++++++++++++ torch_sim/models/interface.py | 48 +++++++++++++++++++- 2 files changed, 131 insertions(+), 1 deletion(-) diff --git a/tests/test_io.py b/tests/test_io.py index a2c25ab4e..737517b8c 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -5,6 +5,7 @@ import pytest import torch from ase import Atoms +from ase.build import molecule from phonopy.structure.atoms import PhonopyAtoms from pymatgen.core import Structure @@ -88,6 +89,69 @@ def test_multiple_atoms_to_state(si_atoms: Atoms) -> None: ) +@pytest.mark.parametrize( + ("charge", "spin", "expected_charge", "expected_spin"), + [ + (1.0, 1.0, 1.0, 1.0), # Non-zero charge and spin + (0.0, 0.0, 0.0, 0.0), # Explicit zero charge and spin + (None, None, 0.0, 0.0), # No charge/spin set, defaults to zero + ], +) +def test_atoms_to_state_with_charge_spin( + charge: float | None, + spin: float | None, + expected_charge: float, + expected_spin: float, +) -> None: + """Test conversion from ASE Atoms with charge and spin to state tensors.""" + mol = molecule("H2O") + if charge is not None: + mol.info["charge"] = charge + if spin is not None: + mol.info["spin"] = spin + + state = ts.io.atoms_to_state([mol], DEVICE, DTYPE) + + # Check basic properties + assert isinstance(state, SimState) + assert state.charge is not None + assert state.spin is not None + assert state.charge.shape == (1,) + assert state.spin.shape == (1,) + assert state.charge[0].item() == expected_charge + assert state.spin[0].item() == expected_spin + + +def test_multiple_atoms_to_state_with_charge_spin() -> None: + """Test conversion from multiple ASE Atoms with different charge/spin values.""" + mol1 = molecule("H2O") + mol1.info["charge"] = 1.0 + mol1.info["spin"] = 1.0 + + mol2 = molecule("CH4") + mol2.info["charge"] = -1.0 + mol2.info["spin"] = 0.0 + + mol3 = molecule("NH3") + mol3.info["charge"] = 0.0 + mol3.info["spin"] = 2.0 + + state = ts.io.atoms_to_state([mol1, mol2, mol3], DEVICE, DTYPE) + + # Check basic properties + assert isinstance(state, SimState) + assert state.charge is not None + assert state.spin is not None + assert state.charge.shape == (3,) + assert state.spin.shape == (3,) + assert state.charge[0].item() == 1.0 + assert state.charge[1].item() == -1.0 + assert state.charge[2].item() == 0.0 + assert state.spin[0].item() == 1.0 + assert state.spin[1].item() == 0.0 + assert state.spin[2].item() == 2.0 + + def test_state_to_structure(ar_supercell_sim_state: SimState) -> None: """Test conversion from state tensors to list of pymatgen Structure.""" structures = ts.io.state_to_structures(ar_supercell_sim_state) @@ -114,6 +178,23 @@ def test_state_to_atoms(ar_supercell_sim_state: SimState) -> None: assert len(atoms[0]) == 32 +def test_state_to_atoms_with_charge_spin() -> None: + """Test conversion from state with charge/spin to ASE Atoms preserves charge/spin.""" + mol = molecule("H2O") + mol.info["charge"] = 1.0 + mol.info["spin"] = 1.0 + + state = ts.io.atoms_to_state([mol], DEVICE, DTYPE) + atoms = ts.io.state_to_atoms(state) + + assert len(atoms) == 1 + assert isinstance(atoms[0], Atoms) + assert "charge" in atoms[0].info + assert "spin" in atoms[0].info + assert atoms[0].info["charge"] == 1 + assert atoms[0].info["spin"] == 1 + + def test_state_to_multiple_atoms(ar_double_sim_state: SimState) -> None: """Test conversion from state tensors to list of ASE Atoms.""" atoms = ts.io.state_to_atoms(ar_double_sim_state) @@ -253,6 +334,9 @@ def test_state_round_trip( # since both use their own isotope masses based on species, # not the ones in the state assert torch.allclose(sim_state.masses, round_trip_state.masses) + # Check charge/spin round trip + assert torch.allclose(sim_state.charge, round_trip_state.charge) + assert torch.allclose(sim_state.spin, round_trip_state.spin) def test_state_to_atoms_importerror(monkeypatch: pytest.MonkeyPatch) -> None: diff --git a/torch_sim/models/interface.py b/torch_sim/models/interface.py index 5c6a243af..dbcd4448f 100644 --- a/torch_sim/models/interface.py +++ b/torch_sim/models/interface.py @@ -200,7 +200,7 @@ def validate_model_outputs( # noqa: C901, PLR0915 This validator creates small test systems (silicon and iron) for validation. It tests both single and multi-batch processing capabilities. """ - from ase.build import bulk + from ase.build import bulk, molecule for attr in ("dtype", "device", "compute_stress", "compute_forces"): if not hasattr(model, attr): @@ -229,6 +229,8 @@ def validate_model_outputs( # noqa: C901, PLR0915 og_cell = sim_state.cell.clone() og_system_idx = sim_state.system_idx.clone() og_atomic_nums = sim_state.atomic_numbers.clone() + og_charge = sim_state.charge.clone() + og_spin = sim_state.spin.clone() model_output = model.forward(sim_state) @@ -241,6 +243,10 @@ def validate_model_outputs( # noqa: C901, PLR0915 raise ValueError(f"{og_system_idx=} != {sim_state.system_idx=}") if not torch.allclose(og_atomic_nums, sim_state.atomic_numbers): raise ValueError(f"{og_atomic_nums=} != {sim_state.atomic_numbers=}") + if not torch.allclose(og_charge, sim_state.charge): + raise ValueError(f"{og_charge=} != {sim_state.charge=}") + if not torch.allclose(og_spin, sim_state.spin): + raise ValueError(f"{og_spin=} != {sim_state.spin=}") # assert model output has the correct keys if "energy" not in model_output: @@ -300,3 +306,43 @@ def validate_model_outputs( # noqa: C901, PLR0915 raise ValueError(f"{fe_model_output['forces'].shape=} != (12, 3)") if stress_computed and fe_model_output["stress"].shape != (1, 3, 3): raise ValueError(f"{fe_model_output['stress'].shape=} != (1, 3, 3)") + + # Test that models can handle non-zero charge and spin + benzene_atoms = molecule("C6H6") + benzene_atoms.info["charge"] = 1.0 + benzene_atoms.info["spin"] = 1.0 + charged_state = ts.io.atoms_to_state([benzene_atoms], device, dtype) + + # Ensure state has charge/spin before testing model + if charged_state.charge is None or charged_state.spin is None: + raise ValueError( + "atoms_to_state did not extract charge/spin. " + "Cannot test model charge/spin handling." + ) + + # Test that model can handle charge/spin without crashing + og_charged_charge = charged_state.charge.clone() + og_charged_spin = charged_state.spin.clone() + try: + charged_output = model.forward(charged_state) + except Exception as e: + raise ValueError( + "Model failed to handle non-zero charge/spin. " + "Models must be able to process states with charge and spin values. " + ) from e + + # Verify model didn't mutate charge/spin + if not torch.allclose(og_charged_charge, charged_state.charge): + raise ValueError( + f"Model mutated charge: {og_charged_charge=} != {charged_state.charge=}" + ) + if not torch.allclose(og_charged_spin, charged_state.spin): + raise ValueError( + f"Model mutated spin: {og_charged_spin=} != {charged_state.spin=}" + ) + # Verify output shape is still correct + if charged_output["energy"].shape != (1,): + raise ValueError( + f"energy shape incorrect with charge/spin: " + f"{charged_output['energy'].shape=} != (1,)" + ) From dcff78f5f180b538e9f16b6309d3ee8d3395d7d7 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Sat, 7 Feb 2026 08:35:45 -0500 Subject: [PATCH 2/7] wip --- torch_sim/models/fairchem_legacy.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch_sim/models/fairchem_legacy.py b/torch_sim/models/fairchem_legacy.py index 5aa966b9d..91b506fa0 100644 --- a/torch_sim/models/fairchem_legacy.py +++ b/torch_sim/models/fairchem_legacy.py @@ -424,6 +424,8 @@ def forward( # noqa: C901 fixed=fixed[c - n : c].clone(), natoms=n, pbc=state.pbc, + charge=state.charge[i].clone(), + spin=state.spin[i].clone(), ) ) self.data_object = Batch.from_data_list(data_list) From 8bb856bb89417cc89d426acd492716ecc2a82ea1 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Sat, 7 Feb 2026 12:45:01 -0500 Subject: [PATCH 3/7] Fairchem v1 should have a planned deprecation date --- torch_sim/models/fairchem_legacy.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torch_sim/models/fairchem_legacy.py b/torch_sim/models/fairchem_legacy.py index 91b506fa0..56ec6e48b 100644 --- a/torch_sim/models/fairchem_legacy.py +++ b/torch_sim/models/fairchem_legacy.py @@ -416,6 +416,9 @@ def forward( # noqa: C901 for i, (n, c) in enumerate( zip(natoms, torch.cumsum(natoms, dim=0), strict=False) ): + # NOTE: Legacy FairChem models (v1) do not support charge/spin, + # so we don't pass these fields to the Data object. + # The model will simply ignore charge/spin and treat all systems as neutral. data_list.append( Data( pos=state.positions[c - n : c].clone(), @@ -424,8 +427,6 @@ def forward( # noqa: C901 fixed=fixed[c - n : c].clone(), natoms=n, pbc=state.pbc, - charge=state.charge[i].clone(), - spin=state.spin[i].clone(), ) ) self.data_object = Batch.from_data_list(data_list) From c4929a2c4e08573d5b3461359269b53f70e62d2c Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Mon, 23 Feb 2026 12:49:20 -0500 Subject: [PATCH 4/7] wip: work out how to have other things as properties --- tests/test_extras.py | 198 +++++++++++++++++++++++++++++++++++++++++++ torch_sim/io.py | 34 ++++++++ torch_sim/state.py | 147 +++++++++++++++++++++++++++++++- 3 files changed, 378 insertions(+), 1 deletion(-) create mode 100644 tests/test_extras.py diff --git a/tests/test_extras.py b/tests/test_extras.py new file mode 100644 index 000000000..b638fe63d --- /dev/null +++ b/tests/test_extras.py @@ -0,0 +1,198 @@ +import pytest +import torch + +import torch_sim as ts + + +DEVICE = torch.device("cpu") +DTYPE = torch.float64 + + +class TestExtras: + def test_system_extras_construction(self): + """Extras can be passed at construction time.""" + field = torch.randn(1, 3) + state = ts.SimState( + positions=torch.zeros(2, 3), + masses=torch.ones(2), + cell=torch.eye(3).unsqueeze(0), + pbc=True, + atomic_numbers=torch.tensor([1, 1], dtype=torch.int), + _system_extras={"external_E_field": field}, + ) + assert torch.equal(state.external_E_field, field) + + def test_atom_extras_construction(self): + """Per-atom extras work at construction time.""" + tags = torch.tensor([1.0, 2.0]) + state = ts.SimState( + positions=torch.zeros(2, 3), + masses=torch.ones(2), + cell=torch.eye(3).unsqueeze(0), + pbc=True, + atomic_numbers=torch.tensor([1, 1], dtype=torch.int), + _atom_extras={"tags": tags}, + ) + assert torch.equal(state.tags, tags) + + def test_getattr_missing_returns_none_and_warns(self, cu_sim_state: ts.SimState): + with pytest.warns( + UserWarning, match="Accessing optional extra 'nonexistent_key'" + ): + val = cu_sim_state.nonexistent_key + assert val is None + + def test_set_extras(self, cu_sim_state: ts.SimState): + field = torch.randn(cu_sim_state.n_systems, 3, device=cu_sim_state.device) + cu_sim_state.set_extras("E", field, scope="per-system") + assert torch.equal(cu_sim_state.E, field) + + def test_set_extras_bad_shape(self, cu_sim_state: ts.SimState): + bad = torch.randn(cu_sim_state.n_systems + 5, 3) + with pytest.raises(ValueError, match="leading dim must be n_systems"): + cu_sim_state.set_extras("bad", bad, scope="per-system") + + def test_clone_preserves_extras(self, cu_sim_state: ts.SimState): + field = torch.randn(cu_sim_state.n_systems, 3, device=cu_sim_state.device) + cu_sim_state.set_extras("E", field, scope="per-system") + cloned = cu_sim_state.clone() + assert torch.equal(cloned.E, field) + # verify independence + cloned.system_extras["E"].zero_() + assert not torch.equal(cu_sim_state.E, cloned.E) + + def test_split_preserves_extras(self, si_double_sim_state: ts.SimState): + field = torch.randn( + si_double_sim_state.n_systems, 3, device=si_double_sim_state.device + ) + si_double_sim_state.set_extras("H", field, scope="per-system") + splits = si_double_sim_state.split() + for i, s in enumerate(splits): + assert torch.equal(s.H, field[i : i + 1]) + + def test_getitem_preserves_extras(self, si_double_sim_state: ts.SimState): + field = torch.randn( + si_double_sim_state.n_systems, 3, device=si_double_sim_state.device + ) + si_double_sim_state.set_extras("E", field, scope="per-system") + sub = si_double_sim_state[[0]] + assert torch.equal(sub.E, field[0:1]) + + def test_concatenate_preserves_extras(self, cu_sim_state: ts.SimState): + s1 = cu_sim_state.clone() + s2 = cu_sim_state.clone() + f1 = torch.randn(s1.n_systems, 3, device=s1.device) + f2 = torch.randn(s2.n_systems, 3, device=s2.device) + s1.set_extras("E", f1, scope="per-system") + s2.set_extras("E", f2, scope="per-system") + merged = ts.concatenate_states([s1, s2]) + assert torch.equal(merged.E, torch.cat([f1, f2], dim=0)) + + def test_to_device_moves_extras(self, cu_sim_state: ts.SimState): + field = torch.randn(cu_sim_state.n_systems, 3, device=cu_sim_state.device) + cu_sim_state.set_extras("E", field, scope="per-system") + moved = cu_sim_state.to(device=cu_sim_state.device) + assert moved.E.device == cu_sim_state.device + + def test_pop_preserves_extras(self, si_double_sim_state: ts.SimState): + field = torch.randn( + si_double_sim_state.n_systems, 3, device=si_double_sim_state.device + ) + si_double_sim_state.set_extras("E", field, scope="per-system") + popped = si_double_sim_state.pop(0) + assert popped[0].E.shape[0] == 1 + + def test_has_extras(self, cu_sim_state: ts.SimState): + assert not cu_sim_state.has_extras("E") + cu_sim_state.set_extras( + "E", + torch.zeros(cu_sim_state.n_systems, 3, device=cu_sim_state.device), + scope="per-system", + ) + assert cu_sim_state.has_extras("E") + + def test_post_init_validation_rejects_bad_shape(self): + with pytest.raises(ValueError, match="leading dim must be n_systems"): + ts.SimState( + positions=torch.zeros(2, 3), + masses=torch.ones(2), + cell=torch.eye(3).unsqueeze(0), + pbc=True, + atomic_numbers=torch.tensor([1, 1], dtype=torch.int), + _system_extras={"bad": torch.randn(5, 3)}, + ) + + def test_from_state_preserves_extras(self, cu_sim_state: ts.SimState): + field = torch.randn(cu_sim_state.n_systems, 3, device=cu_sim_state.device) + cu_sim_state.set_extras("E", field, scope="per-system") + new = ts.SimState.from_state(cu_sim_state) + assert torch.equal(new.E, field) + + def test_extras_cannot_shadow_declared_fields(self, cu_sim_state: ts.SimState): + # set_extras should raise if attempting to shadow + with pytest.raises(ValueError, match="shadows an existing attribute"): + cu_sim_state.set_extras( + "cell", torch.zeros(cu_sim_state.n_systems, 3), scope="per-system" + ) + + def test_construction_extras_cannot_shadow(self): + # Post-init validation should also catch shadowing during construction + with pytest.raises(ValueError, match="shadows an existing attribute"): + ts.SimState( + positions=torch.zeros(2, 3), + masses=torch.ones(2), + cell=torch.eye(3).unsqueeze(0), + pbc=True, + atomic_numbers=torch.tensor([1, 1], dtype=torch.int), + _system_extras={"cell": torch.zeros(1, 3)}, + ) + + def test_optional_extras_return_none_and_warn(self, cu_sim_state: ts.SimState): + with pytest.warns( + UserWarning, match="Accessing optional extra 'external_E_field'" + ): + val = cu_sim_state.external_E_field + assert val is None + + def test_arbitrary_extras_return_none_and_warn(self, cu_sim_state: ts.SimState): + # Even unknown fields now return None + warning for extensibility + with pytest.warns(UserWarning, match="Accessing optional extra 'random_field'"): + val = cu_sim_state.random_field + assert val is None + + +def test_system_extras_atoms_roundtrip(): + state = ts.SimState( + positions=torch.zeros(2, 3), + masses=torch.ones(2), + cell=torch.eye(3).unsqueeze(0), + pbc=True, + atomic_numbers=torch.tensor([1, 1], dtype=torch.int), + _system_extras={"external_E_field": torch.tensor([[1.0, 0.0, 0.0]])}, + ) + atoms_list = state.to_atoms() + assert "external_E_field" in atoms_list[0].info + restored = ts.io.atoms_to_state( + atoms_list, + system_extras_keys=["external_E_field"], + ) + assert torch.allclose(restored.external_E_field, state.external_E_field) + + +def test_atom_extras_atoms_roundtrip(): + tags = torch.tensor([1.0, 2.0]) + state = ts.SimState( + positions=torch.zeros(2, 3), + masses=torch.ones(2), + cell=torch.eye(3).unsqueeze(0), + pbc=True, + atomic_numbers=torch.tensor([1, 1], dtype=torch.int), + _atom_extras={"tags": tags}, + ) + atoms_list = state.to_atoms() + assert "tags" in atoms_list[0].arrays + restored = ts.io.atoms_to_state( + atoms_list, + atom_extras_keys=["tags"], + ) + assert torch.allclose(restored.tags, state.tags) diff --git a/torch_sim/io.py b/torch_sim/io.py index 835446ec0..ad0af3b1e 100644 --- a/torch_sim/io.py +++ b/torch_sim/io.py @@ -86,6 +86,14 @@ def state_to_atoms(state: "ts.SimState") -> list["Atoms"]: if spin is not None: atoms.info["spin"] = int(spin[sys_idx].item()) + # Write system extras to atoms.info + for key, val in state.system_extras.items(): + atoms.info[key] = val[sys_idx].detach().cpu().numpy() + + # Write atom extras to atoms.arrays + for key, val in state.atom_extras.items(): + atoms.arrays[key] = val[mask].detach().cpu().numpy() + atoms_list.append(atoms) return atoms_list @@ -218,6 +226,8 @@ def atoms_to_state( atoms: "Atoms | list[Atoms]", device: torch.device | None = None, dtype: torch.dtype | None = None, + system_extras_keys: list[str] | None = None, + atom_extras_keys: list[str] | None = None, ) -> "ts.SimState": """Convert an ASE Atoms object or list of Atoms objects to a SimState. @@ -226,6 +236,10 @@ def atoms_to_state( device (torch.device): Device to create tensors on dtype (torch.dtype): Data type for tensors (typically torch.float32 or torch.float64) + system_extras_keys (list[str]): Optional list of keys to read from atoms.info + into _system_extras + atom_extras_keys (list[str]): Optional list of keys to read from atoms.arrays + into _atom_extras Returns: SimState: TorchSim SimState object. @@ -279,6 +293,24 @@ def atoms_to_state( [at.info.get("spin", 0.0) for at in atoms_list], dtype=dtype, device=device ) + _system_extras: dict[str, torch.Tensor] = {} + if system_extras_keys: + for key in system_extras_keys: + vals = [at.info.get(key) for at in atoms_list] + if all(v is not None for v in vals): + _system_extras[key] = torch.tensor( + np.stack(vals), dtype=dtype, device=device + ) + + _atom_extras: dict[str, torch.Tensor] = {} + if atom_extras_keys: + for key in atom_extras_keys: + arrays = [at.arrays.get(key) for at in atoms_list] + if all(a is not None for a in arrays): + _atom_extras[key] = torch.tensor( + np.concatenate(arrays), dtype=dtype, device=device + ) + return ts.SimState( positions=positions, masses=masses, @@ -288,6 +320,8 @@ def atoms_to_state( system_idx=system_idx, charge=charge, spin=spin, + _system_extras=_system_extras, + _atom_extras=_atom_extras, ) diff --git a/torch_sim/state.py b/torch_sim/state.py index bd0dd1f2a..421376790 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -7,6 +7,7 @@ import copy import importlib import typing +import warnings from collections import defaultdict from collections.abc import Generator, Sequence from dataclasses import dataclass, field @@ -92,6 +93,8 @@ class SimState: spin: torch.Tensor | None = field(default=None) system_idx: torch.Tensor | None = field(default=None) _constraints: list["Constraint"] = field(default_factory=lambda: []) # noqa: PIE807 + _system_extras: dict[str, torch.Tensor] = field(default_factory=dict) + _atom_extras: dict[str, torch.Tensor] = field(default_factory=dict) if TYPE_CHECKING: @@ -183,6 +186,29 @@ def __post_init__(self) -> None: # noqa: C901 if len(set(devices.values())) > 1: raise ValueError("All tensors must be on the same device") + # Validate extras shapes and prevent shadowing + all_attrs = self._get_all_attributes() + for key, val in self._system_extras.items(): + if key in all_attrs or hasattr(type(self), key): + raise ValueError(f"System extra '{key}' shadows an existing attribute") + if not isinstance(val, torch.Tensor): + raise TypeError(f"System extra '{key}' must be a torch.Tensor") + if val.shape[0] != n_systems: + raise ValueError( + f"System extra '{key}' leading dim must be " + f"n_systems={n_systems}, got {val.shape[0]}" + ) + for key, val in self._atom_extras.items(): + if key in all_attrs or hasattr(type(self), key): + raise ValueError(f"Atom extra '{key}' shadows an existing attribute") + if not isinstance(val, torch.Tensor): + raise TypeError(f"Atom extra '{key}' must be a torch.Tensor") + if val.shape[0] != self.n_atoms: + raise ValueError( + f"Atom extra '{key}' leading dim must be " + f"n_atoms={self.n_atoms}, got {val.shape[0]}" + ) + @classmethod def _get_all_attributes(cls) -> set[str]: """Get all attributes of the SimState.""" @@ -190,8 +216,76 @@ def _get_all_attributes(cls) -> set[str]: cls._atom_attributes | cls._system_attributes | cls._global_attributes - | {"_constraints"} + | {"_constraints", "_system_extras", "_atom_extras"} + ) + + def __getattr__(self, name: str) -> Any: + """Allow attribute-style access to extras dict entries.""" + # Guard: don't look up private attrs in extras (avoids recursion during init) + if name.startswith("_"): + raise AttributeError(name) + for extras_attr in ("_system_extras", "_atom_extras"): + try: + extras = object.__getattribute__(self, extras_attr) + except AttributeError: + continue + if name in extras: + return extras[name] + + # Any public attribute that's not found is treated as a missing extra + # to ensure that new models can define their own optional extras without + # modifying the core state definition. + warnings.warn( + f"Accessing optional extra '{name}' which is not set on this state. " + "Returning None.", + category=UserWarning, + stacklevel=2, ) + return None + + @property + def system_extras(self) -> dict[str, torch.Tensor]: + """Get the system extras.""" + return self._system_extras + + @property + def atom_extras(self) -> dict[str, torch.Tensor]: + """Get the atom extras.""" + return self._atom_extras + + def set_extras( + self, + key: str, + value: torch.Tensor, + scope: Literal["per-system", "per-atom"], + ) -> None: + """Set an extras tensor with explicit scope and shape validation.""" + if key in self._get_all_attributes() or hasattr(type(self), key): + raise ValueError( + f"Cannot set extra '{key}' because it shadows an existing attribute" + ) + if not isinstance(value, torch.Tensor): + raise TypeError(f"Extras value must be a torch.Tensor, got {type(value)}") + if scope == "per-system": + if value.shape[0] != self.n_systems: + raise ValueError( + f"System extras {key} leading dim must be " + f"n_systems={self.n_systems}, got {value.shape[0]}" + ) + self._system_extras[key] = value + elif scope == "per-atom": + if value.shape[0] != self.n_atoms: + raise ValueError( + f"Atom extras {key} leading dim must be " + f"n_atoms={self.n_atoms}, got {value.shape[0]}" + ) + self._atom_extras[key] = value + else: + raise ValueError(f"scope must be 'per-system' or 'per-atom', got {scope!r}") + + def has_extras(self, key: str) -> bool: + """Check if an extras key exists.""" + return key in self._system_extras or key in self._atom_extras @property def wrap_positions(self) -> torch.Tensor: @@ -731,11 +825,25 @@ def _state_to_device[T: SimState]( if isinstance(attr_value, torch.Tensor): attrs[attr_name] = attr_value.to(device=device) + for extras_key in ("_system_extras", "_atom_extras"): + if extras_key in attrs and isinstance(attrs[extras_key], dict): + attrs[extras_key] = { + k: v.to(device=device) for k, v in attrs[extras_key].items() + } + if dtype is not None: attrs["positions"] = attrs["positions"].to(dtype=dtype) attrs["masses"] = attrs["masses"].to(dtype=dtype) attrs["cell"] = attrs["cell"].to(dtype=dtype) attrs["atomic_numbers"] = attrs["atomic_numbers"].to(dtype=torch.int) + + # Update floating point extras to new dtype + for extras_key in ("_system_extras", "_atom_extras"): + if extras_key in attrs and isinstance(attrs[extras_key], dict): + attrs[extras_key] = { + k: v.to(dtype=dtype) if v.is_floating_point() else v + for k, v in attrs[extras_key].items() + } return type(state)(**attrs) @@ -823,6 +931,13 @@ def _filter_attrs_by_index( val[system_indices] if isinstance(val, torch.Tensor) else val ) + filtered_attrs["_system_extras"] = { + key: val[system_indices] for key, val in state.system_extras.items() + } + filtered_attrs["_atom_extras"] = { + key: val[atom_indices] for key, val in state.atom_extras.items() + } + return filtered_attrs @@ -855,6 +970,14 @@ def _split_state[T: SimState](state: T) -> list[T]: global_attrs = dict(get_attrs_for_scope(state, "global")) + split_system_extras: dict[str, list[torch.Tensor]] = {} + for key, val in state._system_extras.items(): # noqa: SLF001 + split_system_extras[key] = list(torch.split(val, 1, dim=0)) + + split_atom_extras: dict[str, list[torch.Tensor]] = {} + for key, val in state._atom_extras.items(): # noqa: SLF001 + split_atom_extras[key] = list(torch.split(val, system_sizes, dim=0)) + # Create a state for each system states: list[T] = [] n_systems = len(system_sizes) @@ -881,6 +1004,12 @@ def _split_state[T: SimState](state: T) -> list[T]: **per_system_dict, # Add the global attributes **global_attrs, + "_system_extras": { + key: split_system_extras[key][sys_idx] for key in split_system_extras + }, + "_atom_extras": { + key: split_atom_extras[key][sys_idx] for key in split_atom_extras + }, } atom_idx = torch.arange(cumsum_atoms[sys_idx], cumsum_atoms[sys_idx + 1]) @@ -1025,6 +1154,8 @@ def concatenate_states[T: SimState]( # noqa: C901, PLR0915 # Pre-allocate lists for tensors to concatenate per_atom_tensors = defaultdict(list) per_system_tensors = defaultdict(list) + system_extras_tensors: dict[str, list[torch.Tensor]] = defaultdict(list) + atom_extras_tensors: dict[str, list[torch.Tensor]] = defaultdict(list) new_system_indices = [] system_offset = 0 num_atoms_per_state = [] @@ -1046,6 +1177,12 @@ def concatenate_states[T: SimState]( # noqa: C901, PLR0915 for prop, val in get_attrs_for_scope(state, "per-system"): per_system_tensors[prop].append(val) + # Collect extras + for key, val in state.system_extras.items(): + system_extras_tensors[key].append(val) + for key, val in state.atom_extras.items(): + atom_extras_tensors[key].append(val) + # Update system indices num_systems = state.n_systems new_indices = state.system_idx + system_offset @@ -1116,6 +1253,14 @@ def concatenate_states[T: SimState]( # noqa: C901, PLR0915 # Concatenate system indices concatenated["system_idx"] = torch.cat(new_system_indices) + # Concatenate extras + concatenated["_system_extras"] = { + key: torch.cat(tensors, dim=0) for key, tensors in system_extras_tensors.items() + } + concatenated["_atom_extras"] = { + key: torch.cat(tensors, dim=0) for key, tensors in atom_extras_tensors.items() + } + # Merge constraints constraint_lists = [state.constraints for state in states] num_systems_per_state = [state.n_systems for state in states] From 005ecfe865e6bdf72ab03d379b95b4065c1f0411 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Mon, 23 Feb 2026 13:30:48 -0500 Subject: [PATCH 5/7] maint: cleaner handling of numpy dep for py313 --- .github/workflows/test.yml | 7 +------ pyproject.toml | 3 ++- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f88a7c86a..b9623d7d2 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -39,12 +39,7 @@ jobs: - name: Install torch_sim run: | uv pip install "torch>2" --index-url https://download.pytorch.org/whl/cpu --system - # always use numpy>=2 with Python 3.13 - if [ "${{ matrix.version.python }}" = "3.13" ]; then - uv pip install -e ".[test]" "numpy>=2" --resolution=${{ matrix.version.resolution }} --system - else - uv pip install -e ".[test]" --resolution=${{ matrix.version.resolution }} --system - fi + uv pip install -e ".[test]" --resolution=${{ matrix.version.resolution }} --system - name: Run core tests run: | diff --git a/pyproject.toml b/pyproject.toml index 1280e577a..56d57b235 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,8 @@ requires-python = ">=3.12" dependencies = [ "h5py>=3.12.1", "nvalchemi-toolkit-ops>=0.2.0", - "numpy>=1.26,<3", + "numpy>=1.26,<3; python_version < '3.13'", + "numpy>=2,<3; python_version >= '3.13'", "tables>=3.10.2,<3.11", "torch>=2", "tqdm>=4.67", From b307d26afbf14c279f7ce1b31d08eb01986bf8dc Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Mon, 23 Feb 2026 15:46:29 -0500 Subject: [PATCH 6/7] fix getattr issue. this works but would be a typing nightmare --- tests/test_extras.py | 22 +++------------------- torch_sim/state.py | 14 ++++---------- 2 files changed, 7 insertions(+), 29 deletions(-) diff --git a/tests/test_extras.py b/tests/test_extras.py index b638fe63d..88a6ebe39 100644 --- a/tests/test_extras.py +++ b/tests/test_extras.py @@ -35,12 +35,9 @@ def test_atom_extras_construction(self): ) assert torch.equal(state.tags, tags) - def test_getattr_missing_returns_none_and_warns(self, cu_sim_state: ts.SimState): - with pytest.warns( - UserWarning, match="Accessing optional extra 'nonexistent_key'" - ): - val = cu_sim_state.nonexistent_key - assert val is None + def test_getattr_missing_raises_attribute_error(self, cu_sim_state: ts.SimState): + with pytest.raises(AttributeError, match="nonexistent_key"): + _ = cu_sim_state.nonexistent_key def test_set_extras(self, cu_sim_state: ts.SimState): field = torch.randn(cu_sim_state.n_systems, 3, device=cu_sim_state.device) @@ -147,19 +144,6 @@ def test_construction_extras_cannot_shadow(self): _system_extras={"cell": torch.zeros(1, 3)}, ) - def test_optional_extras_return_none_and_warn(self, cu_sim_state: ts.SimState): - with pytest.warns( - UserWarning, match="Accessing optional extra 'external_E_field'" - ): - val = cu_sim_state.external_E_field - assert val is None - - def test_arbitrary_extras_return_none_and_warn(self, cu_sim_state: ts.SimState): - # Even unknown fields now return None + warning for extensibility - with pytest.warns(UserWarning, match="Accessing optional extra 'random_field'"): - val = cu_sim_state.random_field - assert val is None - def test_system_extras_atoms_roundtrip(): state = ts.SimState( diff --git a/torch_sim/state.py b/torch_sim/state.py index 421376790..b24e69a1b 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -7,7 +7,6 @@ import copy import importlib import typing -import warnings from collections import defaultdict from collections.abc import Generator, Sequence from dataclasses import dataclass, field @@ -232,16 +231,11 @@ def __getattr__(self, name: str) -> Any: if name in extras: return extras[name] - # Any public attribute that's not found is treated as a missing extra - # to ensure that new models can define their own optional extras without - # modifying the core state definition. - warnings.warn( - f"Accessing optional extra '{name}' which is not set on this state. " - "Returning None.", - category=UserWarning, - stacklevel=2, + # Raise AttributeError so that Python's getattr(obj, name, default), + # hasattr(obj, name), and other descriptor-protocol machinery work correctly. + raise AttributeError( + f"'{type(self).__name__}' has no attribute or extra '{name}'" ) - return None @property def system_extras(self) -> dict[str, torch.Tensor]: From 6df30681e3018c9cf7e05aa9842903c5ac1c2863 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Mon, 23 Feb 2026 21:23:51 -0500 Subject: [PATCH 7/7] fea: generalize ts interface to predict anything. --- tests/test_extras.py | 52 ++++++++++++++++++++++++ torch_sim/integrators/md.py | 1 + torch_sim/integrators/npt.py | 19 +++++++-- torch_sim/integrators/nve.py | 5 ++- torch_sim/integrators/nvt.py | 13 ++++-- torch_sim/models/mace.py | 7 ++++ torch_sim/monte_carlo.py | 5 ++- torch_sim/optimizers/bfgs.py | 6 ++- torch_sim/optimizers/fire.py | 7 +++- torch_sim/optimizers/gradient_descent.py | 3 ++ torch_sim/optimizers/lbfgs.py | 6 ++- torch_sim/runners.py | 21 +++++++++- torch_sim/state.py | 34 ++++++++++++++++ 13 files changed, 166 insertions(+), 13 deletions(-) diff --git a/tests/test_extras.py b/tests/test_extras.py index 88a6ebe39..e40738463 100644 --- a/tests/test_extras.py +++ b/tests/test_extras.py @@ -144,6 +144,58 @@ def test_construction_extras_cannot_shadow(self): _system_extras={"cell": torch.zeros(1, 3)}, ) + # store_model_extras + def test_store_model_extras_canonical_keys_not_stored( + self, si_double_sim_state: ts.SimState + ): + """Canonical keys (energy, forces, stress) must not land in extras.""" + state = si_double_sim_state.clone() + state.store_model_extras( + { + "energy": torch.randn(state.n_systems), + "forces": torch.randn(state.n_atoms, 3), + "stress": torch.randn(state.n_systems, 3, 3), + } + ) + assert not state._system_extras # noqa: SLF001 + assert not state._atom_extras # noqa: SLF001 + + def test_store_model_extras_per_system(self, si_double_sim_state: ts.SimState): + """Tensors with leading dim == n_systems go into system_extras.""" + state = si_double_sim_state.clone() + dipole = torch.randn(state.n_systems, 3) + state.store_model_extras( + {"energy": torch.randn(state.n_systems), "dipole": dipole} + ) + assert torch.equal(state.dipole, dipole) + + def test_store_model_extras_per_atom(self, si_double_sim_state: ts.SimState): + """Tensors with leading dim == n_atoms go into atom_extras.""" + state = si_double_sim_state.clone() + charges = torch.randn(state.n_atoms) + density = torch.randn(state.n_atoms, 8) + state.store_model_extras( + { + "energy": torch.randn(state.n_systems), + "charges": charges, + "density_coefficients": density, + } + ) + assert torch.equal(state.charges, charges) + assert state.density_coefficients.shape == (state.n_atoms, 8) + + def test_store_model_extras_skips_scalars(self, si_double_sim_state: ts.SimState): + """0-d tensors and non-Tensor values are silently ignored.""" + state = si_double_sim_state.clone() + state.store_model_extras( + { + "scalar": torch.tensor(3.14), + "string": "not a tensor", + } + ) + assert not state.has_extras("scalar") + assert not state.has_extras("string") + def test_system_extras_atoms_roundtrip(): state = ts.SimState( diff --git a/torch_sim/integrators/md.py b/torch_sim/integrators/md.py index 3185b6cdf..ac5f31200 100644 --- a/torch_sim/integrators/md.py +++ b/torch_sim/integrators/md.py @@ -223,6 +223,7 @@ def velocity_verlet[T: MDState](state: T, dt: torch.Tensor, model: ModelInterfac state.energy = model_output["energy"] state.forces = model_output["forces"] + state.store_model_extras(model_output) return momentum_step(state, dt_2) diff --git a/torch_sim/integrators/npt.py b/torch_sim/integrators/npt.py index a709c4c2a..e2ea0576a 100644 --- a/torch_sim/integrators/npt.py +++ b/torch_sim/integrators/npt.py @@ -613,7 +613,7 @@ def npt_langevin_init( ) # Create the initial state - return NPTLangevinState.from_state( + npt_state = NPTLangevinState.from_state( state, momenta=momenta, energy=model_output["energy"], @@ -627,6 +627,8 @@ def npt_langevin_init( cell_masses=cell_masses, cell_alpha=cell_alpha, ) + npt_state.store_model_extras(model_output) + return npt_state @dcite("10.1063/1.4901303") @@ -688,6 +690,7 @@ def npt_langevin_step( model_output = model(state) state.forces = model_output["forces"] state.stress = model_output["stress"] + state.store_model_extras(model_output) # Store initial values for integration forces = state.forces @@ -723,6 +726,7 @@ def npt_langevin_step( state.energy = model_output["energy"] state.forces = model_output["forces"] state.stress = model_output["stress"] + state.store_model_extras(model_output) # Compute updated pressure force F_p_n_new = _compute_cell_force( @@ -1264,6 +1268,7 @@ def _npt_nose_hoover_inner_step( state.set_constrained_momenta(momenta) state.forces = model_output["forces"] state.energy = model_output["energy"] + state.store_model_extras(model_output) state.cell_position = cell_position state.cell_momentum = cell_momentum state.cell_mass = cell_mass @@ -1417,7 +1422,7 @@ def npt_nose_hoover_init( ) # Create initial state - return NPTNoseHooverState.from_state( + npt_state = NPTNoseHooverState.from_state( state, momenta=momenta, energy=energy, @@ -1432,6 +1437,8 @@ def npt_nose_hoover_init( barostat_fns=barostat_fns, thermostat_fns=thermostat_fns, ) + npt_state.store_model_extras(model_output) + return npt_state @dcite("10.1080/00268979600100761") @@ -2028,6 +2035,7 @@ def npt_crescale_anisotropic_step( state.forces = model_output["forces"] state.energy = model_output["energy"] state.stress = model_output["stress"] + state.store_model_extras(model_output) # Final momentum step state = momentum_step(state, dt / 2) @@ -2098,6 +2106,7 @@ def npt_crescale_independent_lengths_step( state.forces = model_output["forces"] state.energy = model_output["energy"] state.stress = model_output["stress"] + state.store_model_extras(model_output) # Final momentum step state = momentum_step(state, dt / 2) @@ -2169,6 +2178,7 @@ def npt_crescale_average_anisotropic_step( state.forces = model_output["forces"] state.energy = model_output["energy"] state.stress = model_output["stress"] + state.store_model_extras(model_output) # Final momentum step state = momentum_step(state, dt / 2) @@ -2241,6 +2251,7 @@ def npt_crescale_isotropic_step( state.forces = model_output["forces"] state.energy = model_output["energy"] state.stress = model_output["stress"] + state.store_model_extras(model_output) # Final momentum step state = momentum_step(state, dt / 2) @@ -2314,7 +2325,7 @@ def npt_crescale_init( ) # Create the initial state - return NPTCRescaleState.from_state( + npt_state = NPTCRescaleState.from_state( state, momenta=momenta, energy=model_output["energy"], @@ -2323,3 +2334,5 @@ def npt_crescale_init( tau_p=tau_p, isothermal_compressibility=isothermal_compressibility, ) + npt_state.store_model_extras(model_output) + return npt_state diff --git a/torch_sim/integrators/nve.py b/torch_sim/integrators/nve.py index 4880cfac6..938e8d60c 100644 --- a/torch_sim/integrators/nve.py +++ b/torch_sim/integrators/nve.py @@ -57,12 +57,14 @@ def nve_init( calculate_momenta(state.positions, state.masses, state.system_idx, kT, seed), ) - return MDState.from_state( + md_state = MDState.from_state( state, momenta=momenta, energy=model_output["energy"], forces=model_output["forces"], ) + md_state.store_model_extras(model_output) + return md_state def nve_step( @@ -99,5 +101,6 @@ def nve_step( model_output = model(state) state.energy = model_output["energy"] state.forces = model_output["forces"] + state.store_model_extras(model_output) return momentum_step(state, dt / 2) diff --git a/torch_sim/integrators/nvt.py b/torch_sim/integrators/nvt.py index 957ad89f9..71b6f4604 100644 --- a/torch_sim/integrators/nvt.py +++ b/torch_sim/integrators/nvt.py @@ -119,12 +119,14 @@ def nvt_langevin_init( "momenta", calculate_momenta(state.positions, state.masses, state.system_idx, kT, seed), ) - return MDState.from_state( + md_state = MDState.from_state( state, momenta=momenta, energy=model_output["energy"], forces=model_output["forces"], ) + md_state.store_model_extras(model_output) + return md_state @dcite("10.1098/rspa.2016.0138") @@ -188,6 +190,7 @@ def nvt_langevin_step( model_output = model(state) state.energy = model_output["energy"] state.forces = model_output["forces"] + state.store_model_extras(model_output) return momentum_step(state, dt / 2) @@ -312,7 +315,7 @@ def nvt_nose_hoover_init( ) # n_atoms * n_dimensions # Initialize state - return NVTNoseHooverState.from_state( + nh_state = NVTNoseHooverState.from_state( state, momenta=momenta, energy=model_output["energy"], @@ -321,6 +324,8 @@ def nvt_nose_hoover_init( chain=chain_fns.initialize(dof_per_system, KE, kT), _chain_fns=chain_fns, ) + nh_state.store_model_extras(model_output) + return nh_state @dcite("10.1080/00268979600100761") @@ -595,12 +600,14 @@ def nvt_vrescale_init( calculate_momenta(state.positions, state.masses, state.system_idx, kT, seed), ) - return NVTVRescaleState.from_state( + vr_state = NVTVRescaleState.from_state( state, momenta=momenta, energy=model_output["energy"], forces=model_output["forces"], ) + vr_state.store_model_extras(model_output) + return vr_state @dcite("10.1063/1.2408420") diff --git a/torch_sim/models/mace.py b/torch_sim/models/mace.py index 678405ef8..790cb2633 100644 --- a/torch_sim/models/mace.py +++ b/torch_sim/models/mace.py @@ -368,6 +368,13 @@ def forward( # noqa: C901 if stress is not None: results["stress"] = stress.detach() + # Propagate additional model outputs (e.g. dipole, charges, etc.) + for key, val in out.items(): + if key not in ("energy", "forces", "stress") and isinstance( + val, torch.Tensor + ): + results[key] = val.detach() + return results diff --git a/torch_sim/monte_carlo.py b/torch_sim/monte_carlo.py index fc88d7f59..13312d11b 100644 --- a/torch_sim/monte_carlo.py +++ b/torch_sim/monte_carlo.py @@ -205,7 +205,7 @@ def swap_mc_init( """ model_output = model(state) - return SwapMCState( + mc_state = SwapMCState( positions=state.positions, masses=state.masses, cell=state.cell, @@ -216,6 +216,8 @@ def swap_mc_init( last_permutation=torch.arange(state.n_atoms, device=state.device), _constraints=state.constraints, ) + mc_state.store_model_extras(model_output) + return mc_state def swap_mc_step( @@ -275,5 +277,6 @@ def swap_mc_step( state.energy = torch.where(accepted, energies_new, energies_old) state.last_permutation = permutation[reverse_rejected_swaps].clone() + state.store_model_extras(model_output) return state diff --git a/torch_sim/optimizers/bfgs.py b/torch_sim/optimizers/bfgs.py index 458f90979..981748e56 100644 --- a/torch_sim/optimizers/bfgs.py +++ b/torch_sim/optimizers/bfgs.py @@ -206,6 +206,7 @@ def bfgs_init( cell_state.prev_cell_positions = cell_state.cell_positions.clone() # [S, 3, 3] cell_state.prev_cell_forces = cell_state.cell_forces.clone() # [S, 3, 3] + cell_state.store_model_extras(model_output) return cell_state # Position-only Hessian: 3*global_max_atoms x 3*global_max_atoms @@ -237,7 +238,9 @@ def bfgs_init( "_constraints": state.constraints, # preserve constraints } - return BFGSState(**common_args) + bfgs_state = BFGSState(**common_args) + bfgs_state.store_model_extras(model_output) + return bfgs_state def bfgs_step( # noqa: C901, PLR0915 @@ -541,6 +544,7 @@ def bfgs_step( # noqa: C901, PLR0915 state.energy = model_output["energy"] # [S] if "stress" in model_output: state.stress = model_output["stress"] # [S, 3, 3] + state.store_model_extras(model_output) # Update cell forces for next step # Update cell forces for cell state: [S, 3, 3] diff --git a/torch_sim/optimizers/fire.py b/torch_sim/optimizers/fire.py index 49ab9f2cf..e3665c961 100644 --- a/torch_sim/optimizers/fire.py +++ b/torch_sim/optimizers/fire.py @@ -99,9 +99,12 @@ def fire_init( cell_state.cell_forces.shape, torch.nan, **tensor_args ) + cell_state.store_model_extras(model_output) return cell_state # Create regular FireState without cell optimization - return FireState.from_state(state, **fire_attrs) + fire_state = FireState.from_state(state, **fire_attrs) + fire_state.store_model_extras(model_output) + return fire_state def fire_step( @@ -214,6 +217,7 @@ def _vv_fire_step[T: "FireState | CellFireState"]( # noqa: PLR0915 state.energy = model_output["energy"] if "stress" in model_output: state.stress = model_output["stress"] + state.store_model_extras(model_output) # Update cell forces if isinstance(state, CellFireState): @@ -463,6 +467,7 @@ def _ase_fire_step[T: "FireState | CellFireState"]( # noqa: C901, PLR0915 state.energy = model_output["energy"] if "stress" in model_output: state.stress = model_output["stress"] + state.store_model_extras(model_output) # Update cell forces if isinstance(state, CellFireState): diff --git a/torch_sim/optimizers/gradient_descent.py b/torch_sim/optimizers/gradient_descent.py index 23a51a0ed..80d3f2c42 100644 --- a/torch_sim/optimizers/gradient_descent.py +++ b/torch_sim/optimizers/gradient_descent.py @@ -57,6 +57,8 @@ def gradient_descent_init( "stress": stress, } + state.store_model_extras(model_output) + if cell_filter is not None: # Create cell optimization state cell_filter_funcs = init_fn, _step_fn = ts.get_cell_filter(cell_filter) optim_attrs["reference_cell"] = state.cell.clone() @@ -115,6 +117,7 @@ def gradient_descent_step( state.energy = model_output["energy"] if "stress" in model_output: state.stress = model_output["stress"] + state.store_model_extras(model_output) # Update cell forces if isinstance(state, CellOptimState): diff --git a/torch_sim/optimizers/lbfgs.py b/torch_sim/optimizers/lbfgs.py index 53a7e0bdc..fb648309d 100644 --- a/torch_sim/optimizers/lbfgs.py +++ b/torch_sim/optimizers/lbfgs.py @@ -265,9 +265,12 @@ def lbfgs_init( cell_state.prev_cell_positions = cell_state.cell_positions.clone() # [S, 3, 3] cell_state.prev_cell_forces = cell_state.cell_forces.clone() # [S, 3, 3] + cell_state.store_model_extras(model_output) return cell_state - return LBFGSState(**common_args) + lbfgs_state = LBFGSState(**common_args) + lbfgs_state.store_model_extras(model_output) + return lbfgs_state def lbfgs_step( # noqa: PLR0915, C901 @@ -529,6 +532,7 @@ def lbfgs_step( # noqa: PLR0915, C901 new_forces = model_output["forces"] # [N, 3] new_energy = model_output["energy"] # [S] new_stress = model_output.get("stress") # [S, 3, 3] or None + state.store_model_extras(model_output) # Update cell forces for next step: [S, 3, 3] if is_cell_state: diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 8f2772894..4c3567d40 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -22,7 +22,7 @@ from torch_sim.integrators.md import MDState from torch_sim.models.interface import ModelInterface from torch_sim.optimizers import OPTIM_REGISTRY, FireState, Optimizer, OptimState -from torch_sim.state import SimState +from torch_sim.state import _CANONICAL_MODEL_KEYS, SimState from torch_sim.trajectory import TrajectoryReporter from torch_sim.typing import StateLike from torch_sim.units import UnitSystem @@ -719,7 +719,7 @@ def optimize[T: OptimState]( # noqa: C901, PLR0915 return state # type: ignore[return-value] -def static( +def static( # noqa: C901 system: StateLike, model: ModelInterface, *, @@ -821,8 +821,25 @@ class StaticState(SimState): else torch.full_like(sub_state.cell, fill_value=float("nan")) ), ) + static_state.store_model_extras(model_outputs) props = trajectory_reporter.report(static_state, 0, model=model) + + # Merge extra model outputs into per-system property dicts + # TODO: this should be cleaner? + extra_keys = {k for k in model_outputs if k not in _CANONICAL_MODEL_KEYS} + if extra_keys: + for sys_idx, sys_props in enumerate(props): + for key in extra_keys: + val = model_outputs[key] + if not isinstance(val, torch.Tensor) or val.ndim == 0: + continue + if val.shape[0] == static_state.n_atoms: + mask = static_state.system_idx == sys_idx + sys_props[key] = val[mask] + elif val.shape[0] == static_state.n_systems: + sys_props[key] = val[sys_idx : sys_idx + 1] + all_props.extend(props) if tqdm_pbar: diff --git a/torch_sim/state.py b/torch_sim/state.py index b24e69a1b..5c1148b29 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -26,6 +26,10 @@ from torch_sim.constraints import Constraint, merge_constraints, validate_constraints +# Canonical model output keys that are handled explicitly by integrators/runners +_CANONICAL_MODEL_KEYS = frozenset({"energy", "forces", "stress"}) + + @dataclass class SimState: """State representation for atomistic systems with batched operations support. @@ -281,6 +285,36 @@ def has_extras(self, key: str) -> bool: """Check if an extras key exists.""" return key in self._system_extras or key in self._atom_extras + def store_model_extras(self, model_output: dict[str, torch.Tensor]) -> None: + """Store non-canonical model outputs into state extras (in-place). + + Any key in *model_output* that is not in ``{"energy", "forces", "stress"}`` + is classified by its leading dimension: + + * ``n_atoms`` → stored in ``_atom_extras`` + * ``n_systems`` → stored in ``_system_extras`` + * otherwise → skipped (ambiguity or scalar) + + When ``n_atoms == n_systems`` (single-atom system), the tensor is stored as + per-atom by convention. + + Args: + model_output: Full dict returned by ``model.forward()``. + """ + n_atoms = self.n_atoms + n_systems = self.n_systems + + for key, val in model_output.items(): + if key in _CANONICAL_MODEL_KEYS: + continue + if not isinstance(val, torch.Tensor) or val.ndim == 0: + continue + leading = val.shape[0] + if leading == n_atoms: + self._atom_extras[key] = val + elif leading == n_systems: + self._system_extras[key] = val + @property def wrap_positions(self) -> torch.Tensor: """Atomic positions wrapped according to periodic boundary conditions if pbc=True,