Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/scripts/5_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@

# Print results
print("\nElastic tensor (GPa):")
elastic_tensor_np = elastic_tensor.cpu().numpy()
elastic_tensor_np = elastic_tensor.detach().cpu().numpy()
for row in elastic_tensor_np:
print(" " + " ".join(f"{val:10.4f}" for val in row))

Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/7_others.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@
time = time_steps * correlation_dt * timestep * 1000 # Convert to fs

if vacf_calc.vacf is not None:
vacf_data = vacf_calc.vacf.cpu().numpy()
vacf_data = vacf_calc.vacf.detach().cpu().numpy()
print("\nVACF calculation complete:")
print(f" Number of windows averaged: {vacf_calc._window_count}") # noqa: SLF001
print(f" VACF at t=0: {vacf_data[0]:.4f}")
Expand Down
30 changes: 30 additions & 0 deletions tests/models/test_fairchem.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,3 +288,33 @@ def test_fairchem_charge_spin(charge: float, spin: float) -> None:
# Verify outputs are finite
assert torch.isfinite(result["energy"]).all()
assert torch.isfinite(result["forces"]).all()


# TODO: we should perhaps put something like this inside `validate_model_outputs`
# the question is how we can do this with creating a circular dependency
@pytest.mark.skipif(
get_token() is None, reason="Requires HuggingFace authentication for UMA model access"
)
def test_fairchem_single_step_relax(rattled_si_sim_state: ts.SimState) -> None:
"""Test a single optimization step with FairChemModel.

This verifies that the model works correctly with optimizers, particularly
that it doesn't have issues with the computational graph (e.g., missing
.detach() calls).
"""
model = FairChemModel(model="uma-s-1", task_name="omat", device=DEVICE)
state = rattled_si_sim_state.to(device=DEVICE, dtype=DTYPE)

# Initialize FIRE optimizer
opt_state = ts.fire_init(state, model)
initial_positions = opt_state.positions.clone()
_initial_energy = opt_state.energy.item()

# Run exactly one step
opt_state = ts.fire_step(opt_state, model)

# Verify positions changed
assert not torch.allclose(opt_state.positions, initial_positions)
# Verify energy is still available and finite
assert torch.isfinite(opt_state.energy).all()
assert isinstance(opt_state.energy.item(), float)
2 changes: 1 addition & 1 deletion tests/models/test_orb.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def orbv3_direct_20_omat_calculator() -> ORBCalculator:
def test_cell_to_cellpar(ti_sim_state: SimState) -> None:
assert np.allclose(
ase_c2p(ti_sim_state.row_vector_cell.squeeze()),
cell_to_cellpar(ti_sim_state.row_vector_cell.squeeze(0)).cpu().numpy(),
cell_to_cellpar(ti_sim_state.row_vector_cell.squeeze(0)).detach().cpu().numpy(),
)
assert np.allclose(
ase_c2p(ti_sim_state.row_vector_cell.squeeze(), radians=True),
Expand Down
32 changes: 20 additions & 12 deletions tests/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@ def test_expm_frechet(self):
A = torch.from_numpy(A).to(device=device)
E = torch.from_numpy(E).to(device=device)
observed_expm, observed_frechet = fm.expm_frechet(A, E)
assert_allclose(expected_expm, observed_expm.cpu().numpy(), atol=1e-14)
assert_allclose(expected_frechet, observed_frechet.cpu().numpy(), atol=1e-14)
assert_allclose(expected_expm, observed_expm.detach().cpu().numpy(), atol=1e-14)
assert_allclose(
expected_frechet, observed_frechet.detach().cpu().numpy(), atol=1e-14
)

def test_small_norm_expm_frechet(self):
"""Test matrices with small norms."""
Expand All @@ -41,8 +43,10 @@ def test_small_norm_expm_frechet(self):
A = torch.from_numpy(A).to(device=device, dtype=DTYPE)
E = torch.from_numpy(E).to(device=device, dtype=DTYPE)
observed_expm, observed_frechet = fm.expm_frechet(A, E)
assert_allclose(expected_expm, observed_expm.cpu().numpy(), atol=1e-14)
assert_allclose(expected_frechet, observed_frechet.cpu().numpy(), atol=1e-14)
assert_allclose(expected_expm, observed_expm.detach().cpu().numpy(), atol=1e-14)
assert_allclose(
expected_frechet, observed_frechet.detach().cpu().numpy(), atol=1e-14
)

def test_fuzz(self):
"""Test with a variety of random 3x3 inputs to ensure robustness."""
Expand All @@ -61,8 +65,12 @@ def test_fuzz(self):
A = torch.from_numpy(A).to(device=device, dtype=DTYPE)
E = torch.from_numpy(E).to(device=device, dtype=DTYPE)
observed_expm, observed_frechet = fm.expm_frechet(A, E)
assert_allclose(expected_expm, observed_expm.cpu().numpy(), atol=5e-8)
assert_allclose(expected_frechet, observed_frechet.cpu().numpy(), atol=1e-7)
assert_allclose(
expected_expm, observed_expm.detach().cpu().numpy(), atol=5e-8
)
assert_allclose(
expected_frechet, observed_frechet.detach().cpu().numpy(), atol=1e-7
)

def test_problematic_matrix(self):
"""Test a specific matrix that previously uncovered a bug."""
Expand All @@ -79,8 +87,8 @@ def test_problematic_matrix(self):
blockEnlarge_expm, blockEnlarge_frechet = fm.expm_frechet(
A.unsqueeze(0), E.unsqueeze(0), method="blockEnlarge"
)
assert_allclose(sps_expm, blockEnlarge_expm[0].cpu().numpy())
assert_allclose(sps_frechet, blockEnlarge_frechet[0].cpu().numpy())
assert_allclose(sps_expm, blockEnlarge_expm[0].detach().cpu().numpy())
assert_allclose(sps_frechet, blockEnlarge_frechet[0].detach().cpu().numpy())

def test_medium_matrix(self):
"""Test with a medium-sized matrix to compare performance between methods."""
Expand All @@ -96,8 +104,8 @@ def test_medium_matrix(self):
blockEnlarge_expm, blockEnlarge_frechet = fm.expm_frechet(
A.unsqueeze(0), E.unsqueeze(0), method="blockEnlarge"
)
assert_allclose(sps_expm, blockEnlarge_expm[0].cpu().numpy())
assert_allclose(sps_frechet, blockEnlarge_frechet[0].cpu().numpy())
assert_allclose(sps_expm, blockEnlarge_expm[0].detach().cpu().numpy())
assert_allclose(sps_frechet, blockEnlarge_frechet[0].detach().cpu().numpy())


class TestExpmFrechetTorch:
Expand Down Expand Up @@ -341,7 +349,7 @@ def test_random_float(self):
n = 3
M = torch.randn(n, n, dtype=DTYPE, device=device)
M_logm = fm.matrix_log_33(M)
scipy_logm = scipy.linalg.logm(M.cpu().numpy())
scipy_logm = scipy.linalg.logm(M.detach().cpu().numpy())
torch.testing.assert_close(
M_logm, torch.tensor(scipy_logm, dtype=DTYPE, device=device)
)
Expand All @@ -360,7 +368,7 @@ def test_nearly_degenerate(self):
device=device,
)
M_logm = fm._matrix_log_33(M)
scipy_logm = scipy.linalg.logm(M.cpu().numpy())
scipy_logm = scipy.linalg.logm(M.detach().cpu().numpy())
torch.testing.assert_close(
M_logm, torch.tensor(scipy_logm, dtype=DTYPE, device=device)
)
Expand Down
10 changes: 5 additions & 5 deletions tests/test_quantities.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_unbatched_total_flux(

# Heat flux parts should cancel out
expected = torch.zeros(3, device=flux.device)
assert_allclose(flux.cpu().numpy(), expected.cpu().numpy())
assert_allclose(flux.detach().cpu().numpy(), expected.detach().cpu().numpy())

def test_unbatched_virial_only(
self, mock_simple_system: dict[str, torch.Tensor]
Expand All @@ -73,7 +73,7 @@ def test_unbatched_virial_only(
)

expected = -torch.tensor([1.0, 4.0, 9.0], device=virial.device)
assert_allclose(virial.cpu().numpy(), expected.cpu().numpy())
assert_allclose(virial.detach().cpu().numpy(), expected.detach().cpu().numpy())

def test_batched_calculation(self) -> None:
"""Test heat flux calculation with batched data."""
Expand Down Expand Up @@ -107,7 +107,7 @@ def test_batched_calculation(self) -> None:

# Each batch should cancel heat flux parts
expected = torch.zeros((2, 3), device=DEVICE)
assert_allclose(flux.cpu().numpy(), expected.cpu().numpy())
assert_allclose(flux.detach().cpu().numpy(), expected.detach().cpu().numpy())

def test_centroid_stress(self) -> None:
"""Test heat flux with centroid stress formulation."""
Expand All @@ -130,7 +130,7 @@ def test_centroid_stress(self) -> None:

# Heatflux should be [-1,-1,-1]
expected = torch.full((3,), -1.0, device=DEVICE)
assert_allclose(flux.cpu().numpy(), expected.cpu().numpy())
assert_allclose(flux.detach().cpu().numpy(), expected.detach().cpu().numpy())

def test_momenta_input(self) -> None:
"""Test heat flux calculation using momenta instead."""
Expand All @@ -149,7 +149,7 @@ def test_momenta_input(self) -> None:

# Heat flux terms should cancel out
expected = torch.zeros(3, device=DEVICE)
assert_allclose(flux.cpu().numpy(), expected.cpu().numpy())
assert_allclose(flux.detach().cpu().numpy(), expected.detach().cpu().numpy())


@pytest.fixture
Expand Down
4 changes: 2 additions & 2 deletions tests/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,7 @@ def test_state_set_cell(ti_sim_state: SimState) -> None:
)
ase_atoms = ti_sim_state.to_atoms()[0]
ti_sim_state.set_cell(new_cell, scale_atoms=True)
ase_atoms.set_cell(new_cell[0].T.cpu().numpy(), scale_atoms=True)
ase_atoms.set_cell(new_cell[0].T.detach().cpu().numpy(), scale_atoms=True)
assert torch.allclose(
ti_sim_state.positions.cpu(), torch.from_numpy(ase_atoms.positions)
)
Expand All @@ -715,7 +715,7 @@ def test_state_set_cell(ti_sim_state: SimState) -> None:
new_cell = M @ ti_sim_state.cell
ase_atoms = ti_sim_state.to_atoms()[0]
ti_sim_state.set_cell(new_cell, scale_atoms=True)
ase_atoms.set_cell(new_cell[0].T.cpu().numpy(), scale_atoms=True)
ase_atoms.set_cell(new_cell[0].T.detach().cpu().numpy(), scale_atoms=True)
assert torch.allclose(
ti_sim_state.positions.cpu(), torch.from_numpy(ase_atoms.positions)
)
Expand Down
2 changes: 1 addition & 1 deletion tests/workflows/test_a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ def test_get_subcells_with_composition_restrictions(

# Check that all candidates match the requested compositions
for ids, _, _ in candidates:
subcell_species = [species[int(i)] for i in ids.cpu().numpy()]
subcell_species = [species[int(i)] for i in ids.detach().cpu().numpy()]
comp = Composition("".join(subcell_species)).reduced_formula
assert comp in restrict_to_compositions, (
f"Found composition {comp} not in {restrict_to_compositions}"
Expand Down
8 changes: 7 additions & 1 deletion torch_sim/integrators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,13 @@

import torch_sim as ts

from .md import MDState, calculate_momenta, momentum_step, position_step, velocity_verlet
from .md import (
MDState,
calculate_momenta,
momentum_step,
position_step,
velocity_verlet_step,
)
from .npt import (
NPTLangevinState,
NPTNoseHooverState,
Expand Down
57 changes: 36 additions & 21 deletions torch_sim/integrators/md.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Core molecular dynamics state and operations."""

import warnings
from collections.abc import Callable
from dataclasses import dataclass

Expand Down Expand Up @@ -120,7 +121,8 @@ def calculate_momenta(
if seed is not None:
generator.manual_seed(seed)

if isinstance(kT, torch.Tensor) and len(kT.shape) > 0:
kT = torch.as_tensor(kT, device=device, dtype=dtype)
if kT.ndim > 0:
# kT is a tensor with shape (n_systems,)
kT = kT[system_idx]

Expand Down Expand Up @@ -166,6 +168,7 @@ def momentum_step[T: MDState](state: T, dt: float | torch.Tensor) -> T:
MDState: Updated state with new momenta after force application

"""
dt = torch.as_tensor(dt, device=state.device, dtype=state.dtype)
new_momenta = state.momenta + state.forces * dt
state.set_constrained_momenta(new_momenta)
return state
Expand All @@ -186,12 +189,15 @@ def position_step[T: MDState](state: T, dt: float | torch.Tensor) -> T:
MDState: Updated state with new positions after propagation

"""
dt = torch.as_tensor(dt, device=state.device, dtype=state.dtype)
new_positions = state.positions + state.velocities * dt
state.set_constrained_positions(new_positions)
return state


def velocity_verlet[T: MDState](state: T, dt: torch.Tensor, model: ModelInterface) -> T:
def velocity_verlet_step[T: MDState](
state: T, dt: float | torch.Tensor, model: ModelInterface
) -> T:
"""Perform one complete velocity Verlet integration step.

This function implements the velocity Verlet algorithm, which provides
Expand All @@ -215,6 +221,7 @@ def velocity_verlet[T: MDState](state: T, dt: torch.Tensor, model: ModelInterfac
- Conserves energy in the absence of numerical errors
- Handles periodic boundary conditions if enabled in state
"""
dt = torch.as_tensor(dt, device=state.device, dtype=state.dtype)
dt_2 = dt / 2
state = momentum_step(state, dt_2)
state = position_step(state, dt)
Expand All @@ -226,6 +233,18 @@ def velocity_verlet[T: MDState](state: T, dt: torch.Tensor, model: ModelInterfac
return momentum_step(state, dt_2)


def velocity_verlet[T: MDState](
state: T, dt: float | torch.Tensor, model: ModelInterface
) -> T:
"""Deprecated alias for velocity_verlet_step."""
warnings.warn(
"velocity_verlet is deprecated. Use velocity_verlet_step instead.",
DeprecationWarning,
stacklevel=2,
)
return velocity_verlet_step(state=state, dt=dt, model=model)


@dataclass
class NoseHooverChain:
"""State information for a Nose-Hoover chain thermostat.
Expand Down Expand Up @@ -317,11 +336,11 @@ class NoseHooverChainFns:
@dcite("10.2183/pjab.69.161")
@dcite("10.1016/0375-9601(90)90092-3")
def construct_nose_hoover_chain( # noqa: C901 PLR0915
dt: torch.Tensor,
dt: float | torch.Tensor,
chain_length: int,
chain_steps: int,
sy_steps: int,
tau: torch.Tensor,
tau: float | torch.Tensor,
) -> NoseHooverChainFns:
"""Creates functions to simulate a Nose-Hoover Chain thermostat.

Expand Down Expand Up @@ -378,16 +397,14 @@ def init_fn(
p_xi = torch.zeros((n_systems, chain_length), dtype=dtype, device=device)

# Broadcast tau to match n_systems
if isinstance(tau, torch.Tensor):
tau_batched = tau.expand(n_systems) if tau.dim() == 0 else tau
else:
tau_batched = torch.full((n_systems,), tau, dtype=dtype, device=device)
tau_batched = torch.as_tensor(tau, device=device, dtype=dtype)
if tau_batched.ndim == 0:
tau_batched = tau_batched.expand(n_systems)

# Ensure kT has proper batch dimension
if isinstance(kT, torch.Tensor):
kT_batched = kT.expand(n_systems) if kT.dim() == 0 else kT
else:
kT_batched = torch.full((n_systems,), kT, dtype=dtype, device=device)
kT_batched = torch.as_tensor(kT, device=device, dtype=dtype)
if kT_batched.ndim == 0:
kT_batched = kT_batched.expand(n_systems)

Q = (
kT_batched.unsqueeze(-1)
Expand Down Expand Up @@ -433,10 +450,9 @@ def substep_fn(
M = chain_length - 1

# Ensure kT has proper batch dimension
if isinstance(kT, torch.Tensor):
kT_batched = kT.expand(KE.shape[0]) if kT.dim() == 0 else kT
else:
kT_batched = torch.full_like(KE, kT)
kT_batched = torch.as_tensor(kT, device=KE.device, dtype=KE.dtype)
if kT_batched.ndim == 0:
kT_batched = kT_batched.expand(KE.shape[0])

# Update chain momenta backwards
if M > 0:
Expand Down Expand Up @@ -505,7 +521,7 @@ def half_step_chain_fn(
return P, state

def update_chain_mass_fn(
chain_state: NoseHooverChain, kT: torch.Tensor
chain_state: NoseHooverChain, kT: float | torch.Tensor
) -> NoseHooverChain:
"""Update chain masses to maintain target oscillation period.

Expand All @@ -523,10 +539,9 @@ def update_chain_mass_fn(
n_systems = chain_state.kinetic_energy.shape[0]

# Ensure kT has proper batch dimension
if isinstance(kT, torch.Tensor):
kT_batched = kT.expand(n_systems) if kT.dim() == 0 else kT
else:
kT_batched = torch.full((n_systems,), kT, dtype=dtype, device=device)
kT_batched = torch.as_tensor(kT, device=device, dtype=dtype)
if kT_batched.ndim == 0:
kT_batched = kT_batched.expand(n_systems)

Q = (
kT_batched.unsqueeze(-1)
Expand Down
Loading
Loading