From 55e55c9a3ed557df6a096d68ba9be1fd8faa35f5 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Mon, 23 Feb 2026 22:29:49 -0500 Subject: [PATCH 1/5] maint: address #325 -- speed not benchmarked. --- torch_sim/integrators/md.py | 44 ++++--- torch_sim/integrators/npt.py | 160 ++++++++++++----------- torch_sim/integrators/nve.py | 5 +- torch_sim/integrators/nvt.py | 40 +++--- torch_sim/models/morse.py | 16 +-- torch_sim/models/soft_sphere.py | 16 +-- torch_sim/optimizers/bfgs.py | 26 ++-- torch_sim/optimizers/cell_filters.py | 23 ++-- torch_sim/optimizers/fire.py | 31 +++-- torch_sim/optimizers/gradient_descent.py | 5 +- torch_sim/optimizers/lbfgs.py | 17 ++- torch_sim/runners.py | 31 +---- 12 files changed, 214 insertions(+), 200 deletions(-) diff --git a/torch_sim/integrators/md.py b/torch_sim/integrators/md.py index 3185b6cdf..604316b11 100644 --- a/torch_sim/integrators/md.py +++ b/torch_sim/integrators/md.py @@ -120,7 +120,8 @@ def calculate_momenta( if seed is not None: generator.manual_seed(seed) - if isinstance(kT, torch.Tensor) and len(kT.shape) > 0: + kT = torch.as_tensor(kT, device=device, dtype=dtype) + if kT.ndim > 0: # kT is a tensor with shape (n_systems,) kT = kT[system_idx] @@ -166,6 +167,7 @@ def momentum_step[T: MDState](state: T, dt: float | torch.Tensor) -> T: MDState: Updated state with new momenta after force application """ + dt = torch.as_tensor(dt, device=state.device, dtype=state.dtype) new_momenta = state.momenta + state.forces * dt state.set_constrained_momenta(new_momenta) return state @@ -186,12 +188,15 @@ def position_step[T: MDState](state: T, dt: float | torch.Tensor) -> T: MDState: Updated state with new positions after propagation """ + dt = torch.as_tensor(dt, device=state.device, dtype=state.dtype) new_positions = state.positions + state.velocities * dt state.set_constrained_positions(new_positions) return state -def velocity_verlet[T: MDState](state: T, dt: torch.Tensor, model: ModelInterface) -> T: +def velocity_verlet[T: MDState]( + state: T, dt: float | torch.Tensor, model: ModelInterface +) -> T: """Perform one complete velocity Verlet integration step. This function implements the velocity Verlet algorithm, which provides @@ -215,6 +220,7 @@ def velocity_verlet[T: MDState](state: T, dt: torch.Tensor, model: ModelInterfac - Conserves energy in the absence of numerical errors - Handles periodic boundary conditions if enabled in state """ + dt = torch.as_tensor(dt, device=state.device, dtype=state.dtype) dt_2 = dt / 2 state = momentum_step(state, dt_2) state = position_step(state, dt) @@ -317,11 +323,11 @@ class NoseHooverChainFns: @dcite("10.2183/pjab.69.161") @dcite("10.1016/0375-9601(90)90092-3") def construct_nose_hoover_chain( # noqa: C901 PLR0915 - dt: torch.Tensor, + dt: float | torch.Tensor, chain_length: int, chain_steps: int, sy_steps: int, - tau: torch.Tensor, + tau: float | torch.Tensor, ) -> NoseHooverChainFns: """Creates functions to simulate a Nose-Hoover Chain thermostat. @@ -378,16 +384,14 @@ def init_fn( p_xi = torch.zeros((n_systems, chain_length), dtype=dtype, device=device) # Broadcast tau to match n_systems - if isinstance(tau, torch.Tensor): - tau_batched = tau.expand(n_systems) if tau.dim() == 0 else tau - else: - tau_batched = torch.full((n_systems,), tau, dtype=dtype, device=device) + tau_batched = torch.as_tensor(tau, device=device, dtype=dtype) + if tau_batched.ndim == 0: + tau_batched = tau_batched.expand(n_systems) # Ensure kT has proper batch dimension - if isinstance(kT, torch.Tensor): - kT_batched = kT.expand(n_systems) if kT.dim() == 0 else kT - else: - kT_batched = torch.full((n_systems,), kT, dtype=dtype, device=device) + kT_batched = torch.as_tensor(kT, device=device, dtype=dtype) + if kT_batched.ndim == 0: + kT_batched = kT_batched.expand(n_systems) Q = ( kT_batched.unsqueeze(-1) @@ -433,10 +437,9 @@ def substep_fn( M = chain_length - 1 # Ensure kT has proper batch dimension - if isinstance(kT, torch.Tensor): - kT_batched = kT.expand(KE.shape[0]) if kT.dim() == 0 else kT - else: - kT_batched = torch.full_like(KE, kT) + kT_batched = torch.as_tensor(kT, device=KE.device, dtype=KE.dtype) + if kT_batched.ndim == 0: + kT_batched = kT_batched.expand(KE.shape[0]) # Update chain momenta backwards if M > 0: @@ -505,7 +508,7 @@ def half_step_chain_fn( return P, state def update_chain_mass_fn( - chain_state: NoseHooverChain, kT: torch.Tensor + chain_state: NoseHooverChain, kT: float | torch.Tensor ) -> NoseHooverChain: """Update chain masses to maintain target oscillation period. @@ -523,10 +526,9 @@ def update_chain_mass_fn( n_systems = chain_state.kinetic_energy.shape[0] # Ensure kT has proper batch dimension - if isinstance(kT, torch.Tensor): - kT_batched = kT.expand(n_systems) if kT.dim() == 0 else kT - else: - kT_batched = torch.full((n_systems,), kT, dtype=dtype, device=device) + kT_batched = torch.as_tensor(kT, device=device, dtype=dtype) + if kT_batched.ndim == 0: + kT_batched = kT_batched.expand(n_systems) Q = ( kT_batched.unsqueeze(-1) diff --git a/torch_sim/integrators/npt.py b/torch_sim/integrators/npt.py index a709c4c2a..a4b859f75 100644 --- a/torch_sim/integrators/npt.py +++ b/torch_sim/integrators/npt.py @@ -437,8 +437,8 @@ def _npt_langevin_velocity_step( def _compute_cell_force( state: NPTLangevinState, - external_pressure: torch.Tensor, - kT: torch.Tensor, + external_pressure: float | torch.Tensor, + kT: float | torch.Tensor, ) -> torch.Tensor: """Compute forces on the cell for NPT dynamics. @@ -458,14 +458,12 @@ def _compute_cell_force( torch.Tensor: Force acting on the cell [n_systems, n_dim, n_dim] """ # Convert external_pressure to tensor if it's not already one - if not isinstance(external_pressure, torch.Tensor): - external_pressure = torch.tensor( - external_pressure, device=state.device, dtype=state.dtype - ) + external_pressure = torch.as_tensor( + external_pressure, device=state.device, dtype=state.dtype + ) # Convert kT to tensor if it's not already one - if not isinstance(kT, torch.Tensor): - kT = torch.tensor(kT, device=state.device, dtype=state.dtype) + kT = torch.as_tensor(kT, device=state.device, dtype=state.dtype) # Get current volumes for each batch volumes = torch.linalg.det(state.cell) # shape: (n_systems,) @@ -634,9 +632,9 @@ def npt_langevin_step( state: NPTLangevinState, model: ModelInterface, *, - dt: torch.Tensor, - kT: torch.Tensor, - external_pressure: torch.Tensor, + dt: float | torch.Tensor, + kT: float | torch.Tensor, + external_pressure: float | torch.Tensor, ) -> NPTLangevinState: """Perform one complete NPT Langevin dynamics integration step. @@ -667,14 +665,10 @@ def npt_langevin_step( device, dtype = model.device, model.dtype # Convert any scalar parameters to tensors with batch dimension if needed - if isinstance(state.alpha, float): - state.alpha = torch.tensor(state.alpha, device=device, dtype=dtype) - if isinstance(kT, float): - kT = torch.tensor(kT, device=device, dtype=dtype) - if isinstance(state.cell_alpha, float): - state.cell_alpha = torch.tensor(state.cell_alpha, device=device, dtype=dtype) - if isinstance(dt, float): - dt = torch.tensor(dt, device=device, dtype=dtype) + state.alpha = torch.as_tensor(state.alpha, device=device, dtype=dtype) + kT = torch.as_tensor(kT, device=device, dtype=dtype) + state.cell_alpha = torch.as_tensor(state.cell_alpha, device=device, dtype=dtype) + dt = torch.as_tensor(dt, device=device, dtype=dtype) # Make sure parameters have batch dimension if they're scalars batch_kT = kT.expand(state.n_systems) if kT.ndim == 0 else kT @@ -925,8 +919,7 @@ def _npt_nose_hoover_update_cell_mass( _n_particles, dim = state.positions.shape # Convert kT to tensor if it's not already one - if not isinstance(kT, torch.Tensor): - kT = torch.tensor(kT, device=device, dtype=dtype) + kT = torch.as_tensor(kT, device=device, dtype=dtype) # Handle both scalar and batched kT kT_system = kT.expand(state.n_systems) if kT.ndim == 0 else kT @@ -1274,13 +1267,13 @@ def npt_nose_hoover_init( state: SimState | StateDict, model: ModelInterface, *, - kT: torch.Tensor, - dt: torch.Tensor, + kT: float | torch.Tensor, + dt: float | torch.Tensor, chain_length: int = 3, chain_steps: int = 2, sy_steps: int = 3, - t_tau: torch.Tensor | None = None, - b_tau: torch.Tensor | None = None, + t_tau: float | torch.Tensor | None = None, + b_tau: float | torch.Tensor | None = None, seed: int | None = None, **kwargs: Any, ) -> NPTNoseHooverState: @@ -1327,15 +1320,15 @@ def npt_nose_hoover_init( - All cell properties are properly initialized with batch dimensions """ device, dtype = model.device, model.dtype + dt = torch.as_tensor(dt, device=device, dtype=dtype) + kT = torch.as_tensor(kT, device=device, dtype=dtype) # Initialize the NPT Nose-Hoover state # Thermostat relaxation time - if t_tau is None: - t_tau = 100 * dt + t_tau = torch.as_tensor(t_tau or 100 * dt, device=device, dtype=dtype) # Barostat relaxation time - if b_tau is None: - b_tau = 1000 * dt + b_tau = torch.as_tensor(b_tau or 1000 * dt, device=device, dtype=dtype) # Setup thermostats with appropriate timescales barostat_fns = construct_nose_hoover_chain( @@ -1358,8 +1351,7 @@ def npt_nose_hoover_init( cell_momentum = torch.zeros(n_systems, 1, device=device, dtype=dtype) # Convert kT to tensor if it's not already one - if not isinstance(kT, torch.Tensor): - kT = torch.tensor(kT, device=device, dtype=dtype) + kT = torch.as_tensor(kT, device=device, dtype=dtype) # Handle both scalar and batched kT kT_system = kT.expand(n_systems) if kT.ndim == 0 else kT @@ -1439,9 +1431,9 @@ def npt_nose_hoover_step( state: NPTNoseHooverState, model: ModelInterface, *, - dt: torch.Tensor, - kT: torch.Tensor, - external_pressure: torch.Tensor, + dt: float | torch.Tensor, + kT: float | torch.Tensor, + external_pressure: float | torch.Tensor, ) -> NPTNoseHooverState: """Perform a complete NPT integration step with Nose-Hoover chain thermostats. If the center of mass motion is removed initially, it remains removed throughout @@ -1465,6 +1457,9 @@ def npt_nose_hoover_step( NPTNoseHooverState: Updated state after complete integration step """ device, dtype = model.device, model.dtype + dt = torch.as_tensor(dt, device=device, dtype=dtype) + kT = torch.as_tensor(kT, device=device, dtype=dtype) + external_pressure = torch.as_tensor(external_pressure, device=device, dtype=dtype) # Unpack state variables for clarity barostat = state.barostat @@ -1972,10 +1967,10 @@ def npt_crescale_anisotropic_step( state: NPTCRescaleState, model: ModelInterface, *, - dt: torch.Tensor, - kT: torch.Tensor, - external_pressure: torch.Tensor, - tau: torch.Tensor | None = None, + dt: float | torch.Tensor, + kT: float | torch.Tensor, + external_pressure: float | torch.Tensor, + tau: float | torch.Tensor | None = None, ) -> NPTCRescaleState: """Perform one NPT integration step with cell rescaling barostat. @@ -2013,9 +2008,14 @@ def npt_crescale_anisotropic_step( Returns: NPTCRescaleState: Updated state after one integration step """ + device, dtype = model.device, model.dtype + dt = torch.as_tensor(dt, device=device, dtype=dtype) + kT = torch.as_tensor(kT, device=device, dtype=dtype) + external_pressure = torch.as_tensor(external_pressure, device=device, dtype=dtype) + # Note: would probably be better to have tau in NVTCRescaleState - if tau is None: - tau = 100 * dt + tau = torch.as_tensor(tau or 100 * dt, device=device, dtype=dtype) + state = _vrescale_update(state, tau, kT, dt / 2) state = momentum_step(state, dt / 2) @@ -2042,10 +2042,10 @@ def npt_crescale_independent_lengths_step( state: NPTCRescaleState, model: ModelInterface, *, - dt: torch.Tensor, - kT: torch.Tensor, - external_pressure: torch.Tensor, - tau: torch.Tensor | None = None, + dt: float | torch.Tensor, + kT: float | torch.Tensor, + external_pressure: float | torch.Tensor, + tau: float | torch.Tensor | None = None, ) -> NPTCRescaleState: """Perform one NPT integration step with cell rescaling barostat. @@ -2083,9 +2083,14 @@ def npt_crescale_independent_lengths_step( Returns: NPTCRescaleState: Updated state after one integration step """ + device, dtype = model.device, model.dtype + dt = torch.as_tensor(dt, device=device, dtype=dtype) + kT = torch.as_tensor(kT, device=device, dtype=dtype) + external_pressure = torch.as_tensor(external_pressure, device=device, dtype=dtype) + # Note: would probably be better to have tau in NVTCRescaleState - if tau is None: - tau = 100 * dt + tau = torch.as_tensor(tau or 100 * dt, device=device, dtype=dtype) + state = _vrescale_update(state, tau, kT, dt / 2) state = momentum_step(state, dt / 2) @@ -2112,10 +2117,10 @@ def npt_crescale_average_anisotropic_step( state: NPTCRescaleState, model: ModelInterface, *, - dt: torch.Tensor, - kT: torch.Tensor, - external_pressure: torch.Tensor, - tau: torch.Tensor | None = None, + dt: float | torch.Tensor, + kT: float | torch.Tensor, + external_pressure: float | torch.Tensor, + tau: float | torch.Tensor | None = None, ) -> NPTCRescaleState: """Perform one NPT integration step with cell rescaling barostat. @@ -2154,9 +2159,14 @@ def npt_crescale_average_anisotropic_step( Returns: NPTCRescaleState: Updated state after one integration step """ + device, dtype = model.device, model.dtype + dt = torch.as_tensor(dt, device=device, dtype=dtype) + kT = torch.as_tensor(kT, device=device, dtype=dtype) + external_pressure = torch.as_tensor(external_pressure, device=device, dtype=dtype) + # Note: would probably be better to have tau in NVTCRescaleState - if tau is None: - tau = 100 * dt + tau = torch.as_tensor(tau or 100 * dt, device=device, dtype=dtype) + state = _vrescale_update(state, tau, kT, dt / 2) state = momentum_step(state, dt / 2) @@ -2182,10 +2192,10 @@ def npt_crescale_isotropic_step( state: NPTCRescaleState, model: ModelInterface, *, - dt: torch.Tensor, - kT: torch.Tensor, - external_pressure: torch.Tensor, - tau: torch.Tensor | None = None, + dt: float | torch.Tensor, + kT: float | torch.Tensor, + external_pressure: float | torch.Tensor, + tau: float | torch.Tensor | None = None, ) -> NPTCRescaleState: """Perform one NPT integration step with cell rescaling barostat. @@ -2226,9 +2236,14 @@ def npt_crescale_isotropic_step( Returns: NPTCRescaleState: Updated state after one integration step """ + device, dtype = model.device, model.dtype + dt = torch.as_tensor(dt, device=device, dtype=dtype) + kT = torch.as_tensor(kT, device=device, dtype=dtype) + external_pressure = torch.as_tensor(external_pressure, device=device, dtype=dtype) + # Note: would probably be better to have tau in NVTCRescaleState - if tau is None: - tau = 100 * dt + tau = torch.as_tensor(tau or 100 * dt, device=device, dtype=dtype) + state = _vrescale_update(state, tau, kT, dt / 2) state = momentum_step(state, dt / 2) @@ -2253,10 +2268,10 @@ def npt_crescale_init( state: SimState | StateDict, model: ModelInterface, *, - kT: torch.Tensor, - dt: torch.Tensor, - tau_p: torch.Tensor | None = None, - isothermal_compressibility: torch.Tensor | None = None, + kT: float | torch.Tensor, + dt: float | torch.Tensor, + tau_p: float | torch.Tensor | None = None, + isothermal_compressibility: float | torch.Tensor | None = None, seed: int | None = None, ) -> NPTCRescaleState: """Initialize the NPT cell rescaling state. @@ -2280,25 +2295,22 @@ def npt_crescale_init( """ device, dtype = model.device, model.dtype - # Set default values if not provided - if tau_p is None: - tau_p = 5000 * dt # 5ps for dt=1fs - if isothermal_compressibility is None: - isothermal_compressibility = 1e-1 # (eV/A^3)^-1 - # Convert all parameters to tensors with correct device and dtype - tau_p = torch.as_tensor(tau_p, device=device, dtype=dtype) + dt = torch.as_tensor(dt, device=device, dtype=dtype) + kT = torch.as_tensor(kT, device=device, dtype=dtype) + + # Set default values if not provided + tau_p = torch.as_tensor(tau_p or 5000 * dt, device=device, dtype=dtype) # (eV/A^3)^-1 isothermal_compressibility = torch.as_tensor( - isothermal_compressibility, device=device, dtype=dtype + isothermal_compressibility or 1e-1, + device=device, + dtype=dtype, # (eV/A^3)^-1 ) + if tau_p.ndim == 0: tau_p = tau_p.expand(state.n_systems) if isothermal_compressibility.ndim == 0: isothermal_compressibility = isothermal_compressibility.expand(state.n_systems) - if isinstance(dt, float): - dt = torch.tensor(dt, device=device, dtype=dtype) - if isinstance(kT, float): - kT = torch.tensor(kT, device=device, dtype=dtype) if not isinstance(state, SimState): state = SimState(**state) diff --git a/torch_sim/integrators/nve.py b/torch_sim/integrators/nve.py index 4880cfac6..f49de51d4 100644 --- a/torch_sim/integrators/nve.py +++ b/torch_sim/integrators/nve.py @@ -19,7 +19,7 @@ def nve_init( state: SimState | StateDict, model: ModelInterface, *, - kT: torch.Tensor, + kT: float | torch.Tensor, seed: int | None = None, **_kwargs: Any, ) -> MDState: @@ -66,7 +66,7 @@ def nve_init( def nve_step( - state: MDState, model: ModelInterface, *, dt: torch.Tensor, **_kwargs: Any + state: MDState, model: ModelInterface, *, dt: float | torch.Tensor, **_kwargs: Any ) -> MDState: """Perform one complete NVE (microcanonical) integration step. @@ -93,6 +93,7 @@ def nve_step( - Handles periodic boundary conditions if enabled in state - Symplectic integrator preserving phase space volume """ + dt = torch.as_tensor(dt, device=state.device, dtype=state.dtype) state = momentum_step(state, dt / 2) state = position_step(state, dt) diff --git a/torch_sim/integrators/nvt.py b/torch_sim/integrators/nvt.py index 957ad89f9..76264a44b 100644 --- a/torch_sim/integrators/nvt.py +++ b/torch_sim/integrators/nvt.py @@ -171,14 +171,9 @@ def nvt_langevin_step( """ device, dtype = model.device, model.dtype - if gamma is None: - gamma = 1 / (100 * dt) - - if isinstance(gamma, float): - gamma = torch.tensor(gamma, device=device, dtype=dtype) - - if isinstance(dt, float): - dt = torch.tensor(dt, device=device, dtype=dtype) + dt = torch.as_tensor(dt, device=device, dtype=dtype) + kT = torch.as_tensor(kT, device=device, dtype=dtype) + gamma = torch.as_tensor(gamma or 1 / (100 * dt), device=device, dtype=dtype) state = momentum_step(state, dt / 2) state = position_step(state, dt / 2) @@ -246,9 +241,9 @@ def nvt_nose_hoover_init( state: SimState | StateDict, model: ModelInterface, *, - kT: torch.Tensor, - dt: torch.Tensor, - tau: torch.Tensor | None = None, + kT: float | torch.Tensor, + dt: float | torch.Tensor, + tau: float | torch.Tensor | None = None, chain_length: int = 3, chain_steps: int = 3, sy_steps: int = 3, @@ -284,8 +279,10 @@ def nvt_nose_hoover_init( - Chain variables evolve to maintain target temperature - Time-reversible when integrated with appropriate algorithms """ - if tau is None: # Set default tau if not provided - tau = dt * 100.0 + device, dtype = model.device, model.dtype + dt = torch.as_tensor(dt, device=device, dtype=dtype) + kT = torch.as_tensor(kT, device=device, dtype=dtype) + tau = torch.as_tensor(tau or dt * 100.0, device=device, dtype=dtype) # Create thermostat functions chain_fns = construct_nose_hoover_chain(dt, chain_length, chain_steps, sy_steps, tau) @@ -328,8 +325,8 @@ def nvt_nose_hoover_step( state: NVTNoseHooverState, model: ModelInterface, *, - dt: torch.Tensor, - kT: torch.Tensor, + dt: float | torch.Tensor, + kT: float | torch.Tensor, ) -> NVTNoseHooverState: """Perform one complete Nose-Hoover chain integration step. @@ -356,6 +353,10 @@ def nvt_nose_hoover_step( 4. Update chain kinetic energy 5. Second half-step of chain evolution """ + device, dtype = model.device, model.dtype + dt = torch.as_tensor(dt, device=device, dtype=dtype) + kT = torch.as_tensor(kT, device=device, dtype=dtype) + # Get chain functions from state chain_fns = state._chain_fns # noqa: SLF001 chain = state.chain @@ -649,12 +650,9 @@ def nvt_vrescale_step( if tau is None: tau = 100 * dt - if isinstance(tau, float): - tau = torch.tensor(tau, device=device, dtype=dtype) - if isinstance(dt, float): - dt = torch.tensor(dt, device=device, dtype=dtype) - if isinstance(kT, float): - kT = torch.tensor(kT, device=device, dtype=dtype) + tau = torch.as_tensor(tau, device=device, dtype=dtype) + dt = torch.as_tensor(dt, device=device, dtype=dtype) + kT = torch.as_tensor(kT, device=device, dtype=dtype) # Apply V-Rescale rescaling state = _vrescale_update(state, tau, kT, dt) diff --git a/torch_sim/models/morse.py b/torch_sim/models/morse.py index 78c62f3f3..feb9534f2 100644 --- a/torch_sim/models/morse.py +++ b/torch_sim/models/morse.py @@ -152,9 +152,9 @@ class MorseModel(ModelInterface): def __init__( self, - sigma: float = 1.0, - epsilon: float = 5.0, - alpha: float = 5.0, + sigma: float | torch.Tensor = 1.0, + epsilon: float | torch.Tensor = 5.0, + alpha: float | torch.Tensor = 5.0, device: torch.device | None = None, dtype: torch.dtype = torch.float32, *, # Force keyword-only arguments @@ -163,7 +163,7 @@ def __init__( per_atom_energies: bool = False, per_atom_stresses: bool = False, use_neighbor_list: bool = True, - cutoff: float | None = None, + cutoff: float | torch.Tensor | None = None, ) -> None: """Initialize the Morse potential calculator. @@ -219,12 +219,12 @@ def __init__( self._per_atom_stresses = per_atom_stresses self.use_neighbor_list = use_neighbor_list # Convert parameters to tensors - self.sigma = torch.tensor(sigma, dtype=self.dtype, device=self.device) - self.cutoff = torch.tensor( + self.sigma = torch.as_tensor(sigma, dtype=self.dtype, device=self.device) + self.cutoff = torch.as_tensor( cutoff or 2.5 * sigma, dtype=self.dtype, device=self.device ) - self.epsilon = torch.tensor(epsilon, dtype=self.dtype, device=self.device) - self.alpha = torch.tensor(alpha, dtype=self.dtype, device=self.device) + self.epsilon = torch.as_tensor(epsilon, dtype=self.dtype, device=self.device) + self.alpha = torch.as_tensor(alpha, dtype=self.dtype, device=self.device) def unbatched_forward( self, state: ts.SimState | StateDict diff --git a/torch_sim/models/soft_sphere.py b/torch_sim/models/soft_sphere.py index 60d647829..db1f8ec42 100644 --- a/torch_sim/models/soft_sphere.py +++ b/torch_sim/models/soft_sphere.py @@ -181,9 +181,9 @@ class SoftSphereModel(ModelInterface): def __init__( self, - sigma: float = 1.0, - epsilon: float = 1.0, - alpha: float = 2.0, + sigma: float | torch.Tensor = 1.0, + epsilon: float | torch.Tensor = 1.0, + alpha: float | torch.Tensor = 2.0, device: torch.device | None = None, dtype: torch.dtype = torch.float32, *, # Force keyword-only arguments @@ -192,7 +192,7 @@ def __init__( per_atom_energies: bool = False, per_atom_stresses: bool = False, use_neighbor_list: bool = True, - cutoff: float | None = None, + cutoff: float | torch.Tensor | None = None, ) -> None: """Initialize the soft sphere model. @@ -244,10 +244,10 @@ def __init__( self.use_neighbor_list = use_neighbor_list # Convert interaction parameters to tensors with proper dtype/device - self.sigma = torch.tensor(sigma, dtype=dtype, device=self.device) - self.cutoff = torch.tensor(cutoff or sigma, dtype=dtype, device=self.device) - self.epsilon = torch.tensor(epsilon, dtype=dtype, device=self.device) - self.alpha = torch.tensor(alpha, dtype=dtype, device=self.device) + self.sigma = torch.as_tensor(sigma, dtype=dtype, device=self.device) + self.cutoff = torch.as_tensor(cutoff or sigma, dtype=dtype, device=self.device) + self.epsilon = torch.as_tensor(epsilon, dtype=dtype, device=self.device) + self.alpha = torch.as_tensor(alpha, dtype=dtype, device=self.device) def unbatched_forward(self, state: ts.SimState) -> dict[str, torch.Tensor]: """Compute energies and forces for a single unbatched system. diff --git a/torch_sim/optimizers/bfgs.py b/torch_sim/optimizers/bfgs.py index 458f90979..33453daf5 100644 --- a/torch_sim/optimizers/bfgs.py +++ b/torch_sim/optimizers/bfgs.py @@ -83,8 +83,8 @@ def bfgs_init( state: SimState | StateDict, model: "ModelInterface", *, - max_step: float = 0.2, - alpha: float = 70.0, + max_step: float | torch.Tensor = 0.2, + alpha: float | torch.Tensor = 70.0, cell_filter: "CellFilter | CellFilterFuncs | None" = None, **filter_kwargs: Any, ) -> "BFGSState | CellBFGSState": @@ -130,17 +130,23 @@ def bfgs_init( forces = model_output["forces"] # [N, 3] stress = model_output.get("stress") # [S, 3, 3] or None - alpha_t = torch.full((n_systems,), alpha, **tensor_args) # [S] - max_step_t = torch.full((n_systems,), max_step, **tensor_args) # [S] + alpha_t = torch.as_tensor(alpha, **tensor_args) + if alpha_t.ndim == 0: + alpha_t = alpha_t.expand(n_systems) + + max_step_t = torch.as_tensor(max_step, **tensor_args) + if max_step_t.ndim == 0: + max_step_t = max_step_t.expand(n_systems) + n_iter = torch.zeros((n_systems,), device=model.device, dtype=torch.int32) # [S] if cell_filter is not None: # Extended Hessian: (3*global_max_atoms + 9) x (3*global_max_atoms + 9) # The extra 9 DOFs are for cell parameters (3x3 matrix flattened) dim = 3 * global_max_atoms + (3 * 3) # D_ext - hessian = ( - torch.eye(dim, **tensor_args).unsqueeze(0).repeat(n_systems, 1, 1) * alpha - ) # [S, D_ext, D_ext] + hessian = torch.eye(dim, **tensor_args).unsqueeze(0).repeat( + n_systems, 1, 1 + ) * alpha_t.view(n_systems, 1, 1) # [S, D_ext, D_ext] cell_filter_funcs = init_fn, _step_fn = ts.get_cell_filter(cell_filter) @@ -210,9 +216,9 @@ def bfgs_init( # Position-only Hessian: 3*global_max_atoms x 3*global_max_atoms dim = 3 * global_max_atoms # D - hessian = ( - torch.eye(dim, **tensor_args).unsqueeze(0).repeat(n_systems, 1, 1) * alpha - ) # [S, D, D] + hessian = torch.eye(dim, **tensor_args).unsqueeze(0).repeat( + n_systems, 1, 1 + ) * alpha_t.view(n_systems, 1, 1) # [S, D, D] common_args = { "positions": state.positions.clone(), # [N, 3] diff --git a/torch_sim/optimizers/cell_filters.py b/torch_sim/optimizers/cell_filters.py index 18343370a..e59386efa 100644 --- a/torch_sim/optimizers/cell_filters.py +++ b/torch_sim/optimizers/cell_filters.py @@ -32,12 +32,11 @@ def _setup_cell_factor( # Count atoms per system _, counts = torch.unique(state.system_idx, return_counts=True) cell_factor_tensor = counts.to(dtype=dtype) - elif isinstance(cell_factor, (int, float)): - cell_factor_tensor = torch.full( - (n_systems,), cell_factor, device=device, dtype=dtype - ) else: - cell_factor_tensor = torch.tensor(cell_factor, device=device, dtype=dtype) + cell_factor_tensor = torch.as_tensor(cell_factor, device=device, dtype=dtype) + if cell_factor_tensor.ndim == 0: + cell_factor_tensor = cell_factor_tensor.expand(n_systems) + if (n_cft := cell_factor_tensor.numel()) != n_systems: raise ValueError( f"cell_factor tensor must have {n_systems} elements, got {n_cft}" @@ -254,10 +253,9 @@ class CellFilter(StrEnum): # Filter type definitions for convenience def unit_cell_step[T: AnyCellState](state: T, cell_lr: float | torch.Tensor) -> None: """Update cell using unit cell approach.""" - if isinstance(cell_lr, (int, float)): - cell_lr = torch.full( - (state.n_systems,), cell_lr, device=state.device, dtype=state.dtype - ) + cell_lr = torch.as_tensor(cell_lr, device=state.device, dtype=state.dtype) + if cell_lr.ndim == 0: + cell_lr = cell_lr.expand(state.n_systems) # Get current deformation gradient cur_deform_grad = deform_grad(state.reference_cell.mT, state.row_vector_cell) @@ -284,10 +282,9 @@ def unit_cell_step[T: AnyCellState](state: T, cell_lr: float | torch.Tensor) -> def frechet_cell_step[T: AnyCellState](state: T, cell_lr: float | torch.Tensor) -> None: """Update cell using frechet approach.""" - if isinstance(cell_lr, (int, float)): - cell_lr = torch.full( - (state.n_systems,), cell_lr, device=state.device, dtype=state.dtype - ) + cell_lr = torch.as_tensor(cell_lr, device=state.device, dtype=state.dtype) + if cell_lr.ndim == 0: + cell_lr = cell_lr.expand(state.n_systems) cell_wise_lr = cell_lr.view(state.n_systems, 1, 1) # Compute cell step and update cell positions in log space diff --git a/torch_sim/optimizers/fire.py b/torch_sim/optimizers/fire.py index 49ab9f2cf..58134e1be 100644 --- a/torch_sim/optimizers/fire.py +++ b/torch_sim/optimizers/fire.py @@ -27,8 +27,8 @@ def fire_init( state: SimState | StateDict, model: "ModelInterface", *, - dt_start: float = 0.1, - alpha_start: float = 0.1, + dt_start: float | torch.Tensor = 0.1, + alpha_start: float | torch.Tensor = 0.1, fire_flavor: "FireFlavor" = "ase_fire", cell_filter: "CellFilter | CellFilterFuncs | None" = None, **filter_kwargs: Any, @@ -74,14 +74,23 @@ def fire_init( forces = model_output["forces"] stress = model_output.get("stress") + # Setup initial parameters + dt_start_t = torch.as_tensor(dt_start, **tensor_args) + if dt_start_t.ndim == 0: + dt_start_t = dt_start_t.expand(n_systems) + + alpha_start_t = torch.as_tensor(alpha_start, **tensor_args) + if alpha_start_t.ndim == 0: + alpha_start_t = alpha_start_t.expand(n_systems) + # FIRE-specific additional attributes fire_attrs = { "forces": forces, "energy": energy, "stress": stress, "velocities": torch.full(state.positions.shape, torch.nan, **tensor_args), - "dt": torch.full((n_systems,), dt_start, **tensor_args), - "alpha": torch.full((n_systems,), alpha_start, **tensor_args), + "dt": dt_start_t, + "alpha": alpha_start_t, "n_pos": torch.zeros((n_systems,), device=model.device, dtype=torch.int32), } @@ -108,13 +117,13 @@ def fire_step( state: "FireState | CellFireState", model: "ModelInterface", *, - dt_max: float = 1.0, - n_min: int = 5, - f_inc: float = 1.1, - f_dec: float = 0.5, - alpha_start: float = 0.1, - f_alpha: float = 0.99, - max_step: float = 0.2, + dt_max: float | torch.Tensor = 1.0, + n_min: int | torch.Tensor = 5, + f_inc: float | torch.Tensor = 1.1, + f_dec: float | torch.Tensor = 0.5, + alpha_start: float | torch.Tensor = 0.1, + f_alpha: float | torch.Tensor = 0.99, + max_step: float | torch.Tensor = 0.2, fire_flavor: "FireFlavor" = "ase_fire", ) -> "FireState | CellFireState": """Perform one FIRE optimization step. diff --git a/torch_sim/optimizers/gradient_descent.py b/torch_sim/optimizers/gradient_descent.py index 23a51a0ed..abe1ba66b 100644 --- a/torch_sim/optimizers/gradient_descent.py +++ b/torch_sim/optimizers/gradient_descent.py @@ -96,8 +96,9 @@ def gradient_descent_step( device, dtype = model.device, model.dtype # Get per-atom learning rates - if isinstance(pos_lr, (int, float)): - pos_lr = torch.full((state.n_systems,), pos_lr, device=device, dtype=dtype) + pos_lr = torch.as_tensor(pos_lr, device=device, dtype=dtype) + if pos_lr.ndim == 0: + pos_lr = pos_lr.expand(state.n_systems) atom_lr = pos_lr[state.system_idx].unsqueeze(-1) # Update atomic positions diff --git a/torch_sim/optimizers/lbfgs.py b/torch_sim/optimizers/lbfgs.py index 53a7e0bdc..d2dc2fa31 100644 --- a/torch_sim/optimizers/lbfgs.py +++ b/torch_sim/optimizers/lbfgs.py @@ -111,8 +111,8 @@ def lbfgs_init( state: SimState | StateDict, model: "ModelInterface", *, - step_size: float = 0.1, - alpha: float | None = None, + step_size: float | torch.Tensor = 0.1, + alpha: float | torch.Tensor | None = None, cell_filter: "CellFilter | CellFilterFuncs | None" = None, **filter_kwargs: Any, ) -> "LBFGSState | CellLBFGSState": @@ -185,8 +185,13 @@ def lbfgs_init( ) # [S, 0, M, 3] # Alpha tensor: 0.0 means dynamic, >0 means fixed - alpha_val = 0.0 if alpha is None else alpha - alpha_tensor = torch.full((n_systems,), alpha_val, **tensor_args) # [S] + alpha_tensor = torch.as_tensor(alpha or 0.0, **tensor_args) + if alpha_tensor.ndim == 0: + alpha_tensor = alpha_tensor.expand(n_systems) + + step_size_tensor = torch.as_tensor(step_size, **tensor_args) + if step_size_tensor.ndim == 0: + step_size_tensor = step_size_tensor.expand(n_systems) common_args = { # Copy SimState attributes @@ -208,7 +213,7 @@ def lbfgs_init( "prev_positions": state.positions.clone(), # [N, 3] "s_history": s_history, # [S, 0, M, 3] "y_history": y_history, # [S, 0, M, 3] - "step_size": torch.full((n_systems,), step_size, **tensor_args), # [S] + "step_size": step_size_tensor, # [S] "alpha": alpha_tensor, # [S] "n_iter": torch.zeros((n_systems,), device=model.device, dtype=torch.int32), "max_atoms": max_atoms, # [S] atoms per system for padding @@ -275,7 +280,7 @@ def lbfgs_step( # noqa: PLR0915, C901 model: "ModelInterface", *, max_history: int = 20, - max_step: float = 0.2, + max_step: float | torch.Tensor = 0.2, curvature_eps: float = 1e-12, ) -> "LBFGSState | CellLBFGSState": r"""Advance one L-BFGS iteration using the two-loop recursion. diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 8f2772894..1355506ff 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -183,35 +183,18 @@ def _normalize_temperature_tensor( torch.Tensor: Normalized temperature tensor """ # ---- Step 1: Convert to tensor ---- - if isinstance(temperature, (float, int)): - return torch.full( - (n_steps,), - float(temperature), - dtype=initial_state.dtype, - device=initial_state.device, - ) - - # Convert list or tensor input to tensor - if isinstance(temperature, list): - temps = torch.tensor( - temperature, dtype=initial_state.dtype, device=initial_state.device - ) - elif isinstance(temperature, torch.Tensor): - temps = temperature.to(dtype=initial_state.dtype, device=initial_state.device) - else: - raise TypeError( - f"Invalid temperature type: {type(temperature).__name__}. " - "Must be float, int, list, or torch.Tensor." - ) + temps = torch.as_tensor( + temperature, dtype=initial_state.dtype, device=initial_state.device + ) # ---- Step 2: Determine how to broadcast ---- temps = torch.atleast_1d(temps) if temps.ndim > 2: raise ValueError(f"Temperature tensor must be 1D or 2D, got shape {temps.shape}.") - if temps.shape[0] == 1: - # A single value in a 1-element list/tensor - return temps.repeat(n_steps) + if temps.numel() == 1: + # A single value + return temps.expand(n_steps) if initial_state.n_systems == n_steps: warnings.warn( @@ -313,7 +296,7 @@ def integrate[T: SimState]( # noqa: C901 dtype, device = initial_state.dtype, initial_state.device kTs = _normalize_temperature_tensor(temperature, n_steps, initial_state) kTs = kTs * unit_system.temperature - dt = torch.tensor(timestep * unit_system.time, dtype=dtype, device=device) + dt = torch.as_tensor(timestep * unit_system.time, dtype=dtype, device=device) # Handle both string names and direct function tuples if isinstance(integrator, Integrator): From a164f42283f9bddcdfee1253f78b1c2df9eabcac Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Mon, 23 Feb 2026 22:35:28 -0500 Subject: [PATCH 2/5] driveby: rename vv to vv_step --- torch_sim/integrators/__init__.py | 8 +++++++- torch_sim/integrators/md.py | 15 ++++++++++++++- torch_sim/integrators/nvt.py | 6 +++--- 3 files changed, 24 insertions(+), 5 deletions(-) diff --git a/torch_sim/integrators/__init__.py b/torch_sim/integrators/__init__.py index bae4b5cd3..3405b732e 100644 --- a/torch_sim/integrators/__init__.py +++ b/torch_sim/integrators/__init__.py @@ -68,7 +68,13 @@ import torch_sim as ts -from .md import MDState, calculate_momenta, momentum_step, position_step, velocity_verlet +from .md import ( + MDState, + calculate_momenta, + momentum_step, + position_step, + velocity_verlet_step, +) from .npt import ( NPTLangevinState, NPTNoseHooverState, diff --git a/torch_sim/integrators/md.py b/torch_sim/integrators/md.py index 604316b11..7078e2799 100644 --- a/torch_sim/integrators/md.py +++ b/torch_sim/integrators/md.py @@ -1,5 +1,6 @@ """Core molecular dynamics state and operations.""" +import warnings from collections.abc import Callable from dataclasses import dataclass @@ -194,7 +195,7 @@ def position_step[T: MDState](state: T, dt: float | torch.Tensor) -> T: return state -def velocity_verlet[T: MDState]( +def velocity_verlet_step[T: MDState]( state: T, dt: float | torch.Tensor, model: ModelInterface ) -> T: """Perform one complete velocity Verlet integration step. @@ -232,6 +233,18 @@ def velocity_verlet[T: MDState]( return momentum_step(state, dt_2) +def velocity_verlet[T: MDState]( + state: T, dt: float | torch.Tensor, model: ModelInterface +) -> T: + """Deprecated alias for velocity_verlet_step.""" + warnings.warn( + "velocity_verlet is deprecated. Use velocity_verlet_step instead.", + DeprecationWarning, + stacklevel=2, + ) + return velocity_verlet_step(state=state, dt=dt, model=model) + + @dataclass class NoseHooverChain: """State information for a Nose-Hoover chain thermostat. diff --git a/torch_sim/integrators/nvt.py b/torch_sim/integrators/nvt.py index 76264a44b..c35ea5a2b 100644 --- a/torch_sim/integrators/nvt.py +++ b/torch_sim/integrators/nvt.py @@ -15,7 +15,7 @@ construct_nose_hoover_chain, momentum_step, position_step, - velocity_verlet, + velocity_verlet_step, ) from torch_sim.models.interface import ModelInterface from torch_sim.state import SimState @@ -369,7 +369,7 @@ def nvt_nose_hoover_step( state.set_constrained_momenta(momenta) # Full velocity Verlet step - state = velocity_verlet(state=state, dt=dt, model=model) + state = velocity_verlet_step(state=state, dt=dt, model=model) # Update chain kinetic energy per system KE = ts.calc_kinetic_energy( @@ -658,4 +658,4 @@ def nvt_vrescale_step( state = _vrescale_update(state, tau, kT, dt) # Perform velocity Verlet step - return velocity_verlet(state=state, dt=dt, model=model) + return velocity_verlet_step(state=state, dt=dt, model=model) From 03e8a5b4a8be2afc4807e4198f0948e2e547bed7 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Tue, 24 Feb 2026 13:57:58 -0500 Subject: [PATCH 3/5] Update torch_sim/integrators/npt.py Signed-off-by: Rhys Goodall --- torch_sim/integrators/npt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_sim/integrators/npt.py b/torch_sim/integrators/npt.py index a4b859f75..07a06e290 100644 --- a/torch_sim/integrators/npt.py +++ b/torch_sim/integrators/npt.py @@ -2300,7 +2300,7 @@ def npt_crescale_init( kT = torch.as_tensor(kT, device=device, dtype=dtype) # Set default values if not provided - tau_p = torch.as_tensor(tau_p or 5000 * dt, device=device, dtype=dtype) # (eV/A^3)^-1 + tau_p = torch.as_tensor(tau_p or 5000 * dt, device=device, dtype=dtype) # 5ps for dt=1fs isothermal_compressibility = torch.as_tensor( isothermal_compressibility or 1e-1, device=device, From 697da75c043e5fdede6501af2696814d4dc190b3 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Wed, 25 Feb 2026 13:51:50 -0500 Subject: [PATCH 4/5] detach tensors before moving to cpu and numpy. --- examples/scripts/5_workflow.py | 2 +- examples/scripts/7_others.py | 2 +- tests/models/test_fairchem.py | 30 ++++++++++++++++++++++++++++++ tests/models/test_orb.py | 2 +- tests/test_math.py | 32 ++++++++++++++++++++------------ tests/test_quantities.py | 10 +++++----- tests/test_state.py | 4 ++-- tests/workflows/test_a2c.py | 2 +- torch_sim/models/fairchem.py | 8 ++++---- torch_sim/models/mace.py | 2 +- torch_sim/workflows/a2c.py | 11 +++++++---- 11 files changed, 73 insertions(+), 32 deletions(-) diff --git a/examples/scripts/5_workflow.py b/examples/scripts/5_workflow.py index e92c4b424..604f1bf8a 100644 --- a/examples/scripts/5_workflow.py +++ b/examples/scripts/5_workflow.py @@ -213,7 +213,7 @@ # Print results print("\nElastic tensor (GPa):") -elastic_tensor_np = elastic_tensor.cpu().numpy() +elastic_tensor_np = elastic_tensor.detach().cpu().numpy() for row in elastic_tensor_np: print(" " + " ".join(f"{val:10.4f}" for val in row)) diff --git a/examples/scripts/7_others.py b/examples/scripts/7_others.py index ba88ebe0e..c84ff2dda 100644 --- a/examples/scripts/7_others.py +++ b/examples/scripts/7_others.py @@ -187,7 +187,7 @@ time = time_steps * correlation_dt * timestep * 1000 # Convert to fs if vacf_calc.vacf is not None: - vacf_data = vacf_calc.vacf.cpu().numpy() + vacf_data = vacf_calc.vacf.detach().cpu().numpy() print("\nVACF calculation complete:") print(f" Number of windows averaged: {vacf_calc._window_count}") # noqa: SLF001 print(f" VACF at t=0: {vacf_data[0]:.4f}") diff --git a/tests/models/test_fairchem.py b/tests/models/test_fairchem.py index b992f8172..58e87c7b3 100644 --- a/tests/models/test_fairchem.py +++ b/tests/models/test_fairchem.py @@ -288,3 +288,33 @@ def test_fairchem_charge_spin(charge: float, spin: float) -> None: # Verify outputs are finite assert torch.isfinite(result["energy"]).all() assert torch.isfinite(result["forces"]).all() + + +# TODO: we should perhaps put something like this inside `validate_model_outputs` +# the question is how we can do this with creating a circular dependency +@pytest.mark.skipif( + get_token() is None, reason="Requires HuggingFace authentication for UMA model access" +) +def test_fairchem_single_step_relax(rattled_si_sim_state: ts.SimState) -> None: + """Test a single optimization step with FairChemModel. + + This verifies that the model works correctly with optimizers, particularly + that it doesn't have issues with the computational graph (e.g., missing + .detach() calls). + """ + model = FairChemModel(model="uma-s-1", task_name="omat", device=DEVICE) + state = rattled_si_sim_state.to(device=DEVICE, dtype=DTYPE) + + # Initialize FIRE optimizer + opt_state = ts.fire_init(state, model) + initial_positions = opt_state.positions.clone() + _initial_energy = opt_state.energy.item() + + # Run exactly one step + opt_state = ts.fire_step(opt_state, model) + + # Verify positions changed + assert not torch.allclose(opt_state.positions, initial_positions) + # Verify energy is still available and finite + assert torch.isfinite(opt_state.energy).all() + assert isinstance(opt_state.energy.item(), float) diff --git a/tests/models/test_orb.py b/tests/models/test_orb.py index 33415da7e..6f3de1a6f 100644 --- a/tests/models/test_orb.py +++ b/tests/models/test_orb.py @@ -83,7 +83,7 @@ def orbv3_direct_20_omat_calculator() -> ORBCalculator: def test_cell_to_cellpar(ti_sim_state: SimState) -> None: assert np.allclose( ase_c2p(ti_sim_state.row_vector_cell.squeeze()), - cell_to_cellpar(ti_sim_state.row_vector_cell.squeeze(0)).cpu().numpy(), + cell_to_cellpar(ti_sim_state.row_vector_cell.squeeze(0)).detach().cpu().numpy(), ) assert np.allclose( ase_c2p(ti_sim_state.row_vector_cell.squeeze(), radians=True), diff --git a/tests/test_math.py b/tests/test_math.py index 04e16e992..a6dfda571 100644 --- a/tests/test_math.py +++ b/tests/test_math.py @@ -28,8 +28,10 @@ def test_expm_frechet(self): A = torch.from_numpy(A).to(device=device) E = torch.from_numpy(E).to(device=device) observed_expm, observed_frechet = fm.expm_frechet(A, E) - assert_allclose(expected_expm, observed_expm.cpu().numpy(), atol=1e-14) - assert_allclose(expected_frechet, observed_frechet.cpu().numpy(), atol=1e-14) + assert_allclose(expected_expm, observed_expm.detach().cpu().numpy(), atol=1e-14) + assert_allclose( + expected_frechet, observed_frechet.detach().cpu().numpy(), atol=1e-14 + ) def test_small_norm_expm_frechet(self): """Test matrices with small norms.""" @@ -41,8 +43,10 @@ def test_small_norm_expm_frechet(self): A = torch.from_numpy(A).to(device=device, dtype=DTYPE) E = torch.from_numpy(E).to(device=device, dtype=DTYPE) observed_expm, observed_frechet = fm.expm_frechet(A, E) - assert_allclose(expected_expm, observed_expm.cpu().numpy(), atol=1e-14) - assert_allclose(expected_frechet, observed_frechet.cpu().numpy(), atol=1e-14) + assert_allclose(expected_expm, observed_expm.detach().cpu().numpy(), atol=1e-14) + assert_allclose( + expected_frechet, observed_frechet.detach().cpu().numpy(), atol=1e-14 + ) def test_fuzz(self): """Test with a variety of random 3x3 inputs to ensure robustness.""" @@ -61,8 +65,12 @@ def test_fuzz(self): A = torch.from_numpy(A).to(device=device, dtype=DTYPE) E = torch.from_numpy(E).to(device=device, dtype=DTYPE) observed_expm, observed_frechet = fm.expm_frechet(A, E) - assert_allclose(expected_expm, observed_expm.cpu().numpy(), atol=5e-8) - assert_allclose(expected_frechet, observed_frechet.cpu().numpy(), atol=1e-7) + assert_allclose( + expected_expm, observed_expm.detach().cpu().numpy(), atol=5e-8 + ) + assert_allclose( + expected_frechet, observed_frechet.detach().cpu().numpy(), atol=1e-7 + ) def test_problematic_matrix(self): """Test a specific matrix that previously uncovered a bug.""" @@ -79,8 +87,8 @@ def test_problematic_matrix(self): blockEnlarge_expm, blockEnlarge_frechet = fm.expm_frechet( A.unsqueeze(0), E.unsqueeze(0), method="blockEnlarge" ) - assert_allclose(sps_expm, blockEnlarge_expm[0].cpu().numpy()) - assert_allclose(sps_frechet, blockEnlarge_frechet[0].cpu().numpy()) + assert_allclose(sps_expm, blockEnlarge_expm[0].detach().cpu().numpy()) + assert_allclose(sps_frechet, blockEnlarge_frechet[0].detach().cpu().numpy()) def test_medium_matrix(self): """Test with a medium-sized matrix to compare performance between methods.""" @@ -96,8 +104,8 @@ def test_medium_matrix(self): blockEnlarge_expm, blockEnlarge_frechet = fm.expm_frechet( A.unsqueeze(0), E.unsqueeze(0), method="blockEnlarge" ) - assert_allclose(sps_expm, blockEnlarge_expm[0].cpu().numpy()) - assert_allclose(sps_frechet, blockEnlarge_frechet[0].cpu().numpy()) + assert_allclose(sps_expm, blockEnlarge_expm[0].detach().cpu().numpy()) + assert_allclose(sps_frechet, blockEnlarge_frechet[0].detach().cpu().numpy()) class TestExpmFrechetTorch: @@ -341,7 +349,7 @@ def test_random_float(self): n = 3 M = torch.randn(n, n, dtype=DTYPE, device=device) M_logm = fm.matrix_log_33(M) - scipy_logm = scipy.linalg.logm(M.cpu().numpy()) + scipy_logm = scipy.linalg.logm(M.detach().cpu().numpy()) torch.testing.assert_close( M_logm, torch.tensor(scipy_logm, dtype=DTYPE, device=device) ) @@ -360,7 +368,7 @@ def test_nearly_degenerate(self): device=device, ) M_logm = fm._matrix_log_33(M) - scipy_logm = scipy.linalg.logm(M.cpu().numpy()) + scipy_logm = scipy.linalg.logm(M.detach().cpu().numpy()) torch.testing.assert_close( M_logm, torch.tensor(scipy_logm, dtype=DTYPE, device=device) ) diff --git a/tests/test_quantities.py b/tests/test_quantities.py index bb8f6fd69..3eec4ad92 100644 --- a/tests/test_quantities.py +++ b/tests/test_quantities.py @@ -57,7 +57,7 @@ def test_unbatched_total_flux( # Heat flux parts should cancel out expected = torch.zeros(3, device=flux.device) - assert_allclose(flux.cpu().numpy(), expected.cpu().numpy()) + assert_allclose(flux.detach().cpu().numpy(), expected.detach().cpu().numpy()) def test_unbatched_virial_only( self, mock_simple_system: dict[str, torch.Tensor] @@ -73,7 +73,7 @@ def test_unbatched_virial_only( ) expected = -torch.tensor([1.0, 4.0, 9.0], device=virial.device) - assert_allclose(virial.cpu().numpy(), expected.cpu().numpy()) + assert_allclose(virial.detach().cpu().numpy(), expected.detach().cpu().numpy()) def test_batched_calculation(self) -> None: """Test heat flux calculation with batched data.""" @@ -107,7 +107,7 @@ def test_batched_calculation(self) -> None: # Each batch should cancel heat flux parts expected = torch.zeros((2, 3), device=DEVICE) - assert_allclose(flux.cpu().numpy(), expected.cpu().numpy()) + assert_allclose(flux.detach().cpu().numpy(), expected.detach().cpu().numpy()) def test_centroid_stress(self) -> None: """Test heat flux with centroid stress formulation.""" @@ -130,7 +130,7 @@ def test_centroid_stress(self) -> None: # Heatflux should be [-1,-1,-1] expected = torch.full((3,), -1.0, device=DEVICE) - assert_allclose(flux.cpu().numpy(), expected.cpu().numpy()) + assert_allclose(flux.detach().cpu().numpy(), expected.detach().cpu().numpy()) def test_momenta_input(self) -> None: """Test heat flux calculation using momenta instead.""" @@ -149,7 +149,7 @@ def test_momenta_input(self) -> None: # Heat flux terms should cancel out expected = torch.zeros(3, device=DEVICE) - assert_allclose(flux.cpu().numpy(), expected.cpu().numpy()) + assert_allclose(flux.detach().cpu().numpy(), expected.detach().cpu().numpy()) @pytest.fixture diff --git a/tests/test_state.py b/tests/test_state.py index 3b42b0227..1ed638805 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -702,7 +702,7 @@ def test_state_set_cell(ti_sim_state: SimState) -> None: ) ase_atoms = ti_sim_state.to_atoms()[0] ti_sim_state.set_cell(new_cell, scale_atoms=True) - ase_atoms.set_cell(new_cell[0].T.cpu().numpy(), scale_atoms=True) + ase_atoms.set_cell(new_cell[0].T.detach().cpu().numpy(), scale_atoms=True) assert torch.allclose( ti_sim_state.positions.cpu(), torch.from_numpy(ase_atoms.positions) ) @@ -715,7 +715,7 @@ def test_state_set_cell(ti_sim_state: SimState) -> None: new_cell = M @ ti_sim_state.cell ase_atoms = ti_sim_state.to_atoms()[0] ti_sim_state.set_cell(new_cell, scale_atoms=True) - ase_atoms.set_cell(new_cell[0].T.cpu().numpy(), scale_atoms=True) + ase_atoms.set_cell(new_cell[0].T.detach().cpu().numpy(), scale_atoms=True) assert torch.allclose( ti_sim_state.positions.cpu(), torch.from_numpy(ase_atoms.positions) ) diff --git a/tests/workflows/test_a2c.py b/tests/workflows/test_a2c.py index 27dd26d7e..68a4cc840 100644 --- a/tests/workflows/test_a2c.py +++ b/tests/workflows/test_a2c.py @@ -514,7 +514,7 @@ def test_get_subcells_with_composition_restrictions( # Check that all candidates match the requested compositions for ids, _, _ in candidates: - subcell_species = [species[int(i)] for i in ids.cpu().numpy()] + subcell_species = [species[int(i)] for i in ids.detach().cpu().numpy()] comp = Composition("".join(subcell_species)).reduced_formula assert comp in restrict_to_compositions, ( f"Found composition {comp} not in {restrict_to_compositions}" diff --git a/torch_sim/models/fairchem.py b/torch_sim/models/fairchem.py index 602cf7529..a692216d8 100644 --- a/torch_sim/models/fairchem.py +++ b/torch_sim/models/fairchem.py @@ -201,11 +201,11 @@ def forward(self, state: ts.SimState | StateDict) -> dict: zip(n_atoms, torch.cumsum(n_atoms, dim=0), strict=False) ): # Extract system data - positions = sim_state.positions[c - n : c].cpu().numpy() - atomic_nums = sim_state.atomic_numbers[c - n : c].cpu().numpy() - pbc = sim_state.pbc.cpu().numpy() + positions = sim_state.positions[c - n : c].detach().cpu().numpy() + atomic_nums = sim_state.atomic_numbers[c - n : c].detach().cpu().numpy() + pbc = sim_state.pbc.detach().cpu().numpy() cell = ( - sim_state.row_vector_cell[idx].cpu().numpy() + sim_state.row_vector_cell[idx].detach().cpu().numpy() if sim_state.row_vector_cell is not None else None ) diff --git a/torch_sim/models/mace.py b/torch_sim/models/mace.py index 678405ef8..56a4467fd 100644 --- a/torch_sim/models/mace.py +++ b/torch_sim/models/mace.py @@ -230,7 +230,7 @@ def setup_from_system_idx( self.node_attrs = to_one_hot( torch.tensor( atomic_numbers_to_indices( - atomic_numbers.cpu().numpy(), z_table=self.z_table + atomic_numbers.detach().cpu().numpy(), z_table=self.z_table ), dtype=torch.long, device=self.device, diff --git a/torch_sim/workflows/a2c.py b/torch_sim/workflows/a2c.py index f7bac0e73..634ee93ab 100644 --- a/torch_sim/workflows/a2c.py +++ b/torch_sim/workflows/a2c.py @@ -307,7 +307,7 @@ def random_packed_structure( if min_distance(state.positions, cell, distance_tolerance) > diameter * 0.95: break - log.append(state.positions.cpu().numpy()) + log.append(state.positions.detach().cpu().numpy()) state = ts.fire_step(state, model) @@ -395,7 +395,10 @@ def random_packed_structure_multi( # If auto_diameter enabled, calculate species-specific diameter matrix if auto_diameter: diameter_matrix = get_diameter_matrix(composition, device=device, dtype=dtype) - print(f"Using random pack diameter matrix:\n{diameter_matrix.cpu().numpy()}") + print( + f"Using random pack diameter matrix:\n" + f"{diameter_matrix.detach().cpu().numpy()}" + ) # Perform overlap minimization if diameter matrix is specified if diameter_matrix is not None: @@ -611,7 +614,7 @@ def get_subcells_to_crystallize( # Apply composition restrictions if specified if restrict_to_compositions: subcell_comp = Composition( - "".join(species_array[ids.cpu().numpy()]) + "".join(species_array[ids.detach().cpu().numpy()]) ).reduced_formula if subcell_comp not in restrict_to_compositions: continue @@ -657,7 +660,7 @@ def subcells_to_structures( # Get species for these atoms and convert tensor indices to list/numpy array # before indexing species list - subcell_species = [species[int(i)] for i in ids.cpu().numpy()] + subcell_species = [species[int(i)] for i in ids.detach().cpu().numpy()] list_subcells.append((new_frac_pos, new_cell, subcell_species)) From 3cdfc24afee06abd4c4d11881afbeb052edbe19d Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Wed, 25 Feb 2026 13:52:21 -0500 Subject: [PATCH 5/5] lint --- torch_sim/integrators/npt.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torch_sim/integrators/npt.py b/torch_sim/integrators/npt.py index 07a06e290..a7d57313d 100644 --- a/torch_sim/integrators/npt.py +++ b/torch_sim/integrators/npt.py @@ -2300,7 +2300,9 @@ def npt_crescale_init( kT = torch.as_tensor(kT, device=device, dtype=dtype) # Set default values if not provided - tau_p = torch.as_tensor(tau_p or 5000 * dt, device=device, dtype=dtype) # 5ps for dt=1fs + tau_p = torch.as_tensor( + tau_p or 5000 * dt, device=device, dtype=dtype + ) # 5ps for dt=1fs isothermal_compressibility = torch.as_tensor( isothermal_compressibility or 1e-1, device=device,