diff --git a/docs/index.md b/docs/index.md index 32fc04f4a..efe298aad 100644 --- a/docs/index.md +++ b/docs/index.md @@ -4,6 +4,7 @@ :hidden: user/introduction user/overview +user/reproducibility user/models tutorials/index ``` diff --git a/docs/user/overview.md b/docs/user/overview.md index 56fdf136d..9d58bd687 100644 --- a/docs/user/overview.md +++ b/docs/user/overview.md @@ -25,6 +25,13 @@ Efficiently tracking trajectory information is a core feature of simulation engi Learn more in [Understanding Reporting](../tutorials/reporting_tutorial.ipynb) +## Reproducibility + +MD trajectories can vary run-to-run because of stochastic integrators and +non-deterministic GPU operations. For guidance on deterministic settings in +PyTorch and integrator choices in TorchSim, see +[Reproducibility](reproducibility.md). + ## High-level vs Low-Level Under the hood, TorchSim takes a modular functional approach to atomistic simulation. Each integrator or optimizer has associated `init` and `update` functions that initialize and update a unique `State.` The state inherits from `SimState` and tracks the fixed and fluctuating parameters of the simulation, such as the `momenta` for NVT or the timestep for FIRE. The runner functions take this basic structure and wrap it in a convenient interface with autobatching and reporting. diff --git a/docs/user/reproducibility.md b/docs/user/reproducibility.md new file mode 100644 index 000000000..8b348e8f2 --- /dev/null +++ b/docs/user/reproducibility.md @@ -0,0 +1,95 @@ +# Reproducibility + +Molecular dynamics trajectories are often not exactly reproducible across runs, even when starting from the same initial structure and parameters. + +Two common sources are: + +- **Non-deterministic GPU operations**, where floating-point reductions may execute + in different orders +- **Stochastic integrators** such as Langevin, which add random forces + +For many MD tasks this is acceptable because sampling and ensemble statistics matter more than matching a step-by-step trajectory. If you need repeatable trajectories, use deterministic settings. + +## Global Deterministic setup in PyTorch + +Enable deterministic algorithms and seed random number generators: + +```python +import os +import random + +import numpy as np +import torch + +# Required by CUDA/cuBLAS for some deterministic GEMM paths. +# Set this before any CUDA operations. +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" + +seed = 42 +random.seed(seed) +np.random.seed(seed) +torch.manual_seed(seed) +if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + +torch.use_deterministic_algorithms(True) +``` + +If deterministic mode raises a CuBLAS error, ensure `CUBLAS_WORKSPACE_CONFIG` is set before running your script. + +One of the main reasons for seeding the global states is to control the determinism of any `ModelInterface` being used to make predictions. It should be noted that models may implement their own PRNG solutions that do not draw from global seeds. As such, setting the global seeds doesn't necessarily ensure deterministic behavior. Please check each model being used on a case-by-case basis if determinism is critical to your workflow. + +## Seeding TorchSim SimStates + +In PyTorch, `torch.manual_seed()` only seeds the default global generator, i.e. the generator used in `torch.rand(..., generator=None)` if the generator isn't specified. An explicit `torch.Generator()` is independent and is initialized at the time of calling from the OS entropy source. Since TorchSim internally uses a `torch.Generator()` in the form of `SimState.rng` for all algorithmic randomness in core code setting the global seed is insufficient to ensure determinism within TorchSim. **You must explicitly seed your SimStates to ensure determinism**. + +```python +sim_state = ts.initialize_state(atoms, device, dtype) +sim_state.rng = 42 # required for reproducibility — torch.manual_seed() has no effect here +``` + +### Deterministic vs stochastic integrators in TorchSim + +- `ts.Integrator.nvt_langevin` and `ts.Integrator.npt_langevin` include stochastic + terms by design. When seeded via `state.rng`, they produce identical trajectories. + The `rng` generator controls **both** the initial momenta sampling **and** all per-step stochastic noise (Langevin OU noise, V-Rescale draws, C-Rescale barostat noise, etc.). It is stored on the state and automatically advances on every step, so running the same seed twice produces identical trajectories. +- `ts.Integrator.nvt_nose_hoover` and `ts.Integrator.nve` are deterministic at the + algorithmic level and require no seeding. + +For the simplest path to reproducibility, use a deterministic integrator such as Nosé-Hoover: + +```python +import torch_sim as ts + +state = ts.integrate( + system=atoms, + model=model, + n_steps=500, + timestep=0.001, + temperature=300, + integrator=ts.Integrator.nvt_nose_hoover, +) +``` + +In practice, exact reproducibility also depends on hardware, driver/library versions, and precision choices. + +### Batching and reproducibility + +Because TorchSim runs batched simulations, all systems in a batch share a single `torch.Generator`. Random numbers are drawn in a fixed order each step, so **identical batch composition** is required for exact reproducibility. Changing which systems are in a batch (or their order) will consume random numbers differently and cause trajectories to diverge. + +If strict reproducibility is required, keep your batching setup fixed. + +### Serialising the RNG state + +If you wish to be able to resume a session and ensure determinism you need to persist and reload the `torch.Generator` state. This can be done using `torch.save()` and `torch.Generator().set_state()`: + +```python +# save +rng_state = state.rng.get_state() +torch.save(rng_state, "rng_state.pt") + +# restore +gen = torch.Generator(device=state.device) +gen.set_state(torch.load("rng_state.pt")) +state.rng = gen +``` diff --git a/tests/test_constraints.py b/tests/test_constraints.py index b7072c0db..f2ffc2074 100644 --- a/tests/test_constraints.py +++ b/tests/test_constraints.py @@ -30,7 +30,6 @@ def test_fix_com(ar_supercell_sim_state: ts.SimState, lj_model: LennardJonesMode state=ar_supercell_sim_state, model=lj_model, kT=torch.tensor(10.0, dtype=DTYPE), - seed=42, ) ar_supercell_md_state.set_constrained_momenta( torch.randn_like(ar_supercell_md_state.momenta) * 0.1 @@ -70,7 +69,6 @@ def test_fix_atoms(ar_supercell_sim_state: ts.SimState, lj_model: LennardJonesMo state=ar_supercell_sim_state, model=lj_model, kT=torch.tensor(10.0, dtype=DTYPE), - seed=42, ) ar_supercell_md_state.set_constrained_momenta( torch.randn_like(ar_supercell_md_state.momenta) * 0.1 @@ -94,7 +92,7 @@ def test_fix_com_nvt_langevin(cu_sim_state: ts.SimState, lj_model: LennardJonesM cu_sim_state.get_number_of_degrees_of_freedom(), dofs_before - 3 ) - state = ts.nvt_langevin_init(state=cu_sim_state, model=lj_model, kT=kT, seed=42) + state = ts.nvt_langevin_init(state=cu_sim_state, model=lj_model, kT=kT) positions = [] system_masses = torch.zeros((state.n_systems, 1), dtype=DTYPE).scatter_add_( 0, @@ -138,7 +136,7 @@ def test_fix_atoms_nvt_langevin(cu_sim_state: ts.SimState, lj_model: LennardJone assert torch.allclose( cu_sim_state.get_number_of_degrees_of_freedom(), dofs_before - torch.tensor([6]) ) - state = ts.nvt_langevin_init(state=cu_sim_state, model=lj_model, kT=kT, seed=42) + state = ts.nvt_langevin_init(state=cu_sim_state, model=lj_model, kT=kT) positions = [] temperatures = [] for _step in range(n_steps): @@ -582,7 +580,7 @@ def test_integrators_with_constraints( # Run integration if integrator == "nve": - state = ts.nve_init(cu_sim_state, lj_model, kT=kT, seed=42) + state = ts.nve_init(cu_sim_state, lj_model, kT=kT) for _ in range(n_steps): state = ts.nve_step(state, lj_model, dt=dt) elif integrator == "nvt_nose_hoover": @@ -590,7 +588,7 @@ def test_integrators_with_constraints( for _ in range(n_steps): state = ts.nvt_nose_hoover_step(state, lj_model, dt=dt, kT=kT) elif integrator == "npt_langevin": - state = ts.npt_langevin_init(cu_sim_state, lj_model, kT=kT, seed=42, dt=dt) + state = ts.npt_langevin_init(cu_sim_state, lj_model, kT=kT, dt=dt) for _ in range(n_steps): state = ts.npt_langevin_step( state, @@ -644,7 +642,6 @@ def test_multiple_constraints_and_dof( cu_sim_state, lj_model, kT=torch.tensor(300.0, dtype=DTYPE) * MetalUnits.temperature, - seed=42, ) for _ in range(200): state = ts.nvt_langevin_step( @@ -757,7 +754,7 @@ def test_constraints_with_non_pbc(lj_model: LennardJonesModel) -> None: initial = get_centers_of_mass( state.positions, state.masses, state.system_idx, state.n_systems ) - md_state = ts.nve_init(state, lj_model, kT=torch.tensor(100.0, dtype=DTYPE), seed=42) + md_state = ts.nve_init(state, lj_model, kT=torch.tensor(100.0, dtype=DTYPE)) for _ in range(100): md_state = ts.nve_step(md_state, lj_model, dt=torch.tensor(0.001, dtype=DTYPE)) final = get_centers_of_mass( @@ -815,7 +812,6 @@ def test_temperature_with_constrained_dof( cu_sim_state, lj_model, kT=torch.tensor(target, dtype=DTYPE) * MetalUnits.temperature, - seed=42, ) temps = [] for _ in range(4000): diff --git a/tests/test_integrators.py b/tests/test_integrators.py index 6ddafa064..f99208e8f 100644 --- a/tests/test_integrators.py +++ b/tests/test_integrators.py @@ -3,14 +3,15 @@ import torch_sim as ts from tests.conftest import DEVICE, DTYPE -from torch_sim.integrators import calculate_momenta +from torch_sim.integrators import initialize_momenta from torch_sim.integrators.npt import _compute_cell_force from torch_sim.models.lennard_jones import LennardJonesModel +from torch_sim.state import coerce_prng from torch_sim.units import MetalUnits -def test_calculate_momenta_basic(): - """Test basic functionality of calculate_momenta.""" +def test_initialize_momenta_basic(): + """Test basic functionality of initialize_momenta.""" seed = 42 # Create test inputs for 3 systems with 2 atoms each @@ -23,7 +24,8 @@ def test_calculate_momenta_basic(): kT = torch.tensor([0.1, 0.2, 0.3, 0.4], dtype=DTYPE, device=DEVICE) # Run the function - momenta = calculate_momenta(positions, masses, system_idx, kT, seed=seed) + gen = coerce_prng(seed, device=DEVICE) + momenta = initialize_momenta(positions, masses, system_idx, kT, generator=gen) # Basic checks assert momenta.shape == positions.shape @@ -40,8 +42,8 @@ def test_calculate_momenta_basic(): ) -def test_calculate_momenta_single_atoms(): - """Test that calculate_momenta preserves momentum for systems with single atoms.""" +def test_initialize_momenta_single_atoms(): + """Test that initialize_momenta preserves momentum for systems with single atoms.""" seed = 42 # Create test inputs with some systems having single atoms @@ -59,7 +61,8 @@ def test_calculate_momenta_single_atoms(): ) * torch.sqrt(masses * kT[system_idx]).unsqueeze(-1) # Run the function - momenta = calculate_momenta(positions, masses, system_idx, kT, seed=seed) + gen = coerce_prng(seed, device=DEVICE) + momenta = initialize_momenta(positions, masses, system_idx, kT, generator=gen) # Check that single-atom systems have unchanged momenta for sys_idx in (0, 2, 3): # Single atom systems @@ -82,13 +85,14 @@ def test_npt_langevin( ) -> None: n_steps = 200 dt = torch.tensor(0.001, dtype=DTYPE) - kT = torch.tensor(100.0, dtype=DTYPE) * MetalUnits.temperature - external_pressure = torch.tensor(0.0, dtype=DTYPE) * MetalUnits.pressure + kT = torch.tensor(300.0, dtype=DTYPE) * MetalUnits.temperature + external_pressure = torch.tensor(10.0, dtype=DTYPE) * MetalUnits.pressure alpha = 40 * dt cell_alpha = alpha b_tau = 1 / (1000 * dt) # Initialize integrator using new direct API + ar_double_sim_state.rng = 42 state = ts.npt_langevin_init( state=ar_double_sim_state, model=lj_model, @@ -97,7 +101,6 @@ def test_npt_langevin( alpha=alpha, cell_alpha=cell_alpha, b_tau=b_tau, - seed=42, ) # Run dynamics for several steps @@ -164,6 +167,7 @@ def test_npt_langevin_multi_kt( b_tau = 1 / (1000 * dt) # Initialize integrator using new direct API + ar_double_sim_state.rng = 42 state = ts.npt_langevin_init( state=ar_double_sim_state, model=lj_model, @@ -172,7 +176,6 @@ def test_npt_langevin_multi_kt( alpha=alpha, cell_alpha=cell_alpha, b_tau=b_tau, - seed=42, ) # Run dynamics for several steps @@ -216,9 +219,8 @@ def test_nvt_langevin(ar_double_sim_state: ts.SimState, lj_model: LennardJonesMo kT = torch.tensor(300, dtype=DTYPE) * MetalUnits.temperature # Initialize integrator - state = ts.nvt_langevin_init( - state=ar_double_sim_state, model=lj_model, kT=kT, seed=42 - ) + ar_double_sim_state.rng = 42 + state = ts.nvt_langevin_init(state=ar_double_sim_state, model=lj_model, kT=kT) energies = [] temperatures = [] for _step in range(n_steps): @@ -272,9 +274,8 @@ def test_nvt_langevin_multi_kt( kT = torch.tensor([300, 10_000], dtype=DTYPE) * MetalUnits.temperature # Initialize integrator - state = ts.nvt_langevin_init( - state=ar_double_sim_state, model=lj_model, kT=kT, seed=42 - ) + ar_double_sim_state.rng = 42 + state = ts.nvt_langevin_init(state=ar_double_sim_state, model=lj_model, kT=kT) energies = [] temperatures = [] for _step in range(n_steps): @@ -310,8 +311,9 @@ def test_nvt_nose_hoover(ar_double_sim_state: ts.SimState, lj_model: LennardJone kT = torch.tensor(300, dtype=dtype) * MetalUnits.temperature # Run dynamics for several steps + ar_double_sim_state.rng = 42 state = ts.nvt_nose_hoover_init( - state=ar_double_sim_state, model=lj_model, dt=dt, kT=kT, seed=42 + state=ar_double_sim_state, model=lj_model, dt=dt, kT=kT ) energies = [] temperatures = [] @@ -394,9 +396,9 @@ def test_nvt_nose_hoover_multi_equivalent_to_single( initial_momenta = [] # Run dynamics for several steps for i in range(mixed_double_sim_state.n_systems): - state = ts.nvt_nose_hoover_init( - state=mixed_double_sim_state[i], model=lj_model, dt=dt, kT=kT, seed=42 - ) + sub_state = mixed_double_sim_state[i] + sub_state.rng = 42 + state = ts.nvt_nose_hoover_init(state=sub_state, model=lj_model, dt=dt, kT=kT) initial_momenta.append(state.momenta.clone()) for _step in range(n_steps): state = ts.nvt_nose_hoover_step( @@ -414,12 +416,12 @@ def test_nvt_nose_hoover_multi_equivalent_to_single( initial_momenta_tensor = torch.concat(initial_momenta) final_temperatures = torch.concat(final_temperatures) + mixed_double_sim_state.rng = 42 state = ts.nvt_nose_hoover_init( state=mixed_double_sim_state, model=lj_model, dt=dt, kT=kT, - seed=42, momenta=initial_momenta_tensor, ) for _step in range(n_steps): @@ -442,8 +444,9 @@ def test_nvt_nose_hoover_multi_kt( kT = torch.tensor([300, 10_000], dtype=dtype) * MetalUnits.temperature # Run dynamics for several steps + ar_double_sim_state.rng = 42 state = ts.nvt_nose_hoover_init( - state=ar_double_sim_state, model=lj_model, dt=dt, kT=kT, seed=42 + state=ar_double_sim_state, model=lj_model, dt=dt, kT=kT ) energies = [] temperatures = [] @@ -491,9 +494,8 @@ def test_nvt_vrescale(ar_double_sim_state: ts.SimState, lj_model: LennardJonesMo kT = torch.tensor(300, dtype=DTYPE) * MetalUnits.temperature # Initialize integrator - state = ts.nvt_vrescale_init( - state=ar_double_sim_state, model=lj_model, kT=kT, seed=42 - ) + ar_double_sim_state.rng = 42 + state = ts.nvt_vrescale_init(state=ar_double_sim_state, model=lj_model, kT=kT) energies = [] temperatures = [] for _step in range(n_steps): @@ -544,12 +546,13 @@ def test_npt_anisotropic_crescale( ) -> None: n_steps = 200 dt = torch.tensor(0.001, dtype=DTYPE) - kT = torch.tensor(100.0, dtype=DTYPE) * MetalUnits.temperature - external_pressure = torch.tensor(0.0, dtype=DTYPE) * MetalUnits.pressure + kT = torch.tensor(300.0, dtype=DTYPE) * MetalUnits.temperature + external_pressure = torch.tensor(10.0, dtype=DTYPE) * MetalUnits.pressure tau_p = torch.tensor(0.1, dtype=DTYPE) isothermal_compressibility = torch.tensor(1e-4, dtype=DTYPE) # Initialize integrator using new direct API + ar_double_sim_state.rng = 42 state = ts.npt_crescale_init( state=ar_double_sim_state, model=lj_model, @@ -557,7 +560,6 @@ def test_npt_anisotropic_crescale( kT=kT, tau_p=tau_p, isothermal_compressibility=isothermal_compressibility, - seed=42, ) # Run dynamics for several steps @@ -617,12 +619,13 @@ def test_npt_isotropic_crescale( ) -> None: n_steps = 200 dt = torch.tensor(0.001, dtype=DTYPE) - kT = torch.tensor(100.0, dtype=DTYPE) * MetalUnits.temperature - external_pressure = torch.tensor(0.0, dtype=DTYPE) * MetalUnits.pressure + kT = torch.tensor(300.0, dtype=DTYPE) * MetalUnits.temperature + external_pressure = torch.tensor(10.0, dtype=DTYPE) * MetalUnits.pressure tau_p = torch.tensor(0.1, dtype=DTYPE) isothermal_compressibility = torch.tensor(1e-4, dtype=DTYPE) # Initialize integrator using new direct API + ar_double_sim_state.rng = 42 state = ts.npt_crescale_init( state=ar_double_sim_state, model=lj_model, @@ -630,7 +633,6 @@ def test_npt_isotropic_crescale( kT=kT, tau_p=tau_p, isothermal_compressibility=isothermal_compressibility, - seed=42, ) # Run dynamics for several steps @@ -693,13 +695,13 @@ def test_npt_nose_hoover(ar_double_sim_state: ts.SimState, lj_model: LennardJone external_pressure = torch.tensor(0.0, dtype=dtype) * MetalUnits.pressure # Run dynamics for several steps + ar_double_sim_state.rng = 42 state = ts.npt_nose_hoover_init( state=ar_double_sim_state, model=lj_model, dt=dt, kT=kT, external_pressure=external_pressure, - seed=42, ) energies = [] temperatures = [] @@ -784,13 +786,14 @@ def test_npt_nose_hoover_multi_equivalent_to_single( initial_momenta = [] # Run dynamics for several steps for i in range(mixed_double_sim_state.n_systems): + sub_state = mixed_double_sim_state[i] + sub_state.rng = 42 state = ts.npt_nose_hoover_init( - state=mixed_double_sim_state[i], + state=sub_state, model=lj_model, dt=dt, kT=kT, external_pressure=external_pressure, - seed=42, ) initial_momenta.append(state.momenta.clone()) for _step in range(n_steps): @@ -810,13 +813,13 @@ def test_npt_nose_hoover_multi_equivalent_to_single( initial_momenta_tensor = torch.concat(initial_momenta) final_temperatures = torch.concat(final_temperatures) + mixed_double_sim_state.rng = 42 state = ts.npt_nose_hoover_init( state=mixed_double_sim_state, model=lj_model, dt=dt, kT=kT, external_pressure=external_pressure, - seed=42, momenta=initial_momenta_tensor, ) for _step in range(n_steps): @@ -846,13 +849,13 @@ def test_npt_nose_hoover_multi_kt( external_pressure = torch.tensor(0.0, dtype=dtype) * MetalUnits.pressure # Run dynamics for several steps + ar_double_sim_state.rng = 42 state = ts.npt_nose_hoover_init( state=ar_double_sim_state, model=lj_model, dt=dt, kT=kT, external_pressure=external_pressure, - seed=42, ) energies = [] temperatures = [] @@ -903,10 +906,11 @@ def test_npt_nose_hoover_multi_kt( def test_nve(ar_double_sim_state: ts.SimState, lj_model: LennardJonesModel): n_steps = 100 dt = torch.tensor(0.001, dtype=DTYPE) - kT = torch.tensor(100.0, dtype=DTYPE) * MetalUnits.temperature + kT = torch.tensor(300.0, dtype=DTYPE) * MetalUnits.temperature # Initialize integrator - state = ts.nve_init(state=ar_double_sim_state, model=lj_model, kT=kT, seed=42) + ar_double_sim_state.rng = 42 + state = ts.nve_init(state=ar_double_sim_state, model=lj_model, kT=kT) # Run dynamics for several steps energies = [] @@ -951,7 +955,7 @@ def test_compare_single_vs_batched_integrators( # Initialize momenta (even if zero) and get forces state = ts.nve_init( - state=state, model=lj_model, kT=kT, seed=42 + state=state, model=lj_model, kT=kT ) # kT is ignored if momenta are set below # Ensure momenta start at zero AFTER init which might randomize them based on kT state.momenta = torch.zeros_like(state.momenta) # Start from rest @@ -1012,3 +1016,84 @@ def test_compute_cell_force_atoms_per_system(): # Force ratio should match atom ratio (8:1) with the fix assert abs(force_ratio - 8.0) / 8.0 < 0.1 + + +# --------------------------------------------------------------------------- +# Reproducibility tests +# --------------------------------------------------------------------------- + + +def test_nvt_langevin_reproducibility( + ar_double_sim_state: ts.SimState, lj_model: LennardJonesModel +): + """Two runs with the same prng seed must produce identical trajectories.""" + n_steps = 10 + dt = torch.tensor(0.001, dtype=DTYPE) + kT = torch.tensor(300, dtype=DTYPE) * MetalUnits.temperature + + def _run(seed: int) -> tuple[torch.Tensor, torch.Tensor]: + ar_double_sim_state.rng = seed + state = ts.nvt_langevin_init(state=ar_double_sim_state, model=lj_model, kT=kT) + for _ in range(n_steps): + state = ts.nvt_langevin_step(state=state, model=lj_model, dt=dt, kT=kT) + return state.positions.clone(), state.momenta.clone() + + pos_a, mom_a = _run(123) + pos_b, mom_b = _run(123) + + torch.testing.assert_close(pos_a, pos_b) + torch.testing.assert_close(mom_a, mom_b) + + # Different seeds should diverge + pos_c, mom_c = _run(456) + assert not torch.allclose(pos_a, pos_c) + assert not torch.allclose(mom_a, mom_c) + + +def test_npt_langevin_reproducibility( + ar_double_sim_state: ts.SimState, lj_model: LennardJonesModel +): + """Two runs with the same seed must produce identical NPT Langevin trajectories.""" + n_steps = 20 + dt = torch.tensor(0.001, dtype=DTYPE) + kT = torch.tensor(300.0, dtype=DTYPE) * MetalUnits.temperature + external_pressure = torch.tensor(10, dtype=DTYPE) * MetalUnits.pressure + alpha = 40 * dt + cell_alpha = alpha + b_tau = dt # make this very small to ensure the barostat is active + + def _run(seed: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + ar_double_sim_state.rng = seed + # NOTE: this init function clones the state so we can use the same fixture + # for all the runs without concern. + state = ts.npt_langevin_init( + state=ar_double_sim_state, + model=lj_model, + dt=dt, + kT=kT, + alpha=alpha, + cell_alpha=cell_alpha, + b_tau=b_tau, + ) + for _ in range(n_steps): + state = ts.npt_langevin_step( + state=state, + model=lj_model, + dt=dt, + kT=kT, + external_pressure=external_pressure, + ) + return state.positions.clone(), state.momenta.clone(), state.cell.clone() + + pos_a, mom_a, cell_a = _run(123) + pos_b, mom_b, cell_b = _run(123) + + torch.testing.assert_close(pos_a, pos_b) + torch.testing.assert_close(mom_a, mom_b) + torch.testing.assert_close(cell_a, cell_b) + + # Different seeds should diverge + pos_c, mom_c, cell_c = _run(456) + assert not torch.allclose(pos_a, pos_c) + assert not torch.allclose(mom_a, mom_c) + assert not torch.allclose(cell_a, cell_c) diff --git a/tests/test_runners.py b/tests/test_runners.py index b32de4f4b..8a817c1f6 100644 --- a/tests/test_runners.py +++ b/tests/test_runners.py @@ -100,7 +100,7 @@ def test_integrate_double_nvt( n_steps=10, temperature=100.0, # K timestep=0.001, # ps - init_kwargs=dict(seed=481516), + init_kwargs={}, ) assert isinstance(final_state, SimState) @@ -120,7 +120,7 @@ def test_integrate_double_nvt_multiple_temperatures( n_steps=n_steps, temperature=[100.0, 200.0], # K timestep=0.001, # ps - init_kwargs=dict(seed=481516), + init_kwargs={}, ) batcher = ts.autobatching.BinningAutoBatcher( @@ -136,7 +136,7 @@ def test_integrate_double_nvt_multiple_temperatures( temperature=[100.0, 200.0], # K timestep=0.001, # ps autobatcher=batcher, - init_kwargs=dict(seed=481516), + init_kwargs={}, ) # Temperature tensor with correct shape (n_steps, n_systems) @@ -148,7 +148,7 @@ def test_integrate_double_nvt_multiple_temperatures( temperature=torch.tensor([100.0, 200.0])[None, :].repeat(n_steps, 1), timestep=0.001, # ps autobatcher=batcher, - init_kwargs=dict(seed=481516), + init_kwargs={}, ) # Temperature tensor with incorrect shape (n_systems, n_steps) @@ -161,7 +161,7 @@ def test_integrate_double_nvt_multiple_temperatures( temperature=torch.tensor([100.0, 200.0])[None, :].repeat(n_steps, 1).T, # K timestep=0.001, # ps autobatcher=batcher, - init_kwargs=dict(seed=481516), + init_kwargs={}, ) diff --git a/tests/test_state.py b/tests/test_state.py index 1ed638805..a77c370c5 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -13,6 +13,7 @@ _normalize_system_indices, _pop_states, _slice_state, + coerce_prng, get_attrs_for_scope, ) @@ -30,7 +31,7 @@ def test_get_attrs_for_scope(si_sim_state: SimState) -> None: per_system_attrs = dict(get_attrs_for_scope(si_sim_state, "per-system")) assert set(per_system_attrs) == {"cell", "charge", "spin"} global_attrs = dict(get_attrs_for_scope(si_sim_state, "global")) - assert set(global_attrs) == {"pbc"} + assert set(global_attrs) == {"pbc", "_rng"} def test_all_attributes_must_be_specified_in_scopes() -> None: @@ -772,3 +773,201 @@ def test_wrap_positions_batched(si_double_sim_state: SimState) -> None: state.positions[mask] = state.positions[mask] + lattice_shift wrapped = state.wrap_positions assert torch.allclose(wrapped, original_positions, atol=1e-5) + + +# ── rng property tests ────────────────────────────────────────────────────── + + +def test_rng_lazy_init(si_sim_state: SimState) -> None: + """rng property creates a Generator on first access when _rng is None.""" + state = si_sim_state.clone() + state.rng = None + assert state._rng is None # noqa: SLF001 + gen = state.rng + assert isinstance(gen, torch.Generator) + assert gen.device == state.device + + +def test_rng_int_seed_via_constructor() -> None: + """Passing an int _rng to SimState is lazily coerced on first .rng access.""" + state = SimState( + positions=torch.randn(2, 3), + masses=torch.ones(2), + cell=torch.eye(3).unsqueeze(0), + pbc=True, + atomic_numbers=torch.ones(2, dtype=torch.int), + _rng=42, + ) + assert isinstance(state.rng, torch.Generator) + + +def test_rng_int_seed_via_property(si_sim_state: SimState) -> None: + """Setting rng to an int via the property coerces on next access.""" + state = si_sim_state.clone() + state.rng = 123 + gen = state.rng + assert isinstance(gen, torch.Generator) + + +def test_rng_generator_passthrough(si_sim_state: SimState) -> None: + """Setting rng to a Generator stores it directly.""" + state = si_sim_state.clone() + gen = torch.Generator(device=state.device) + gen.manual_seed(7) + state.rng = gen + assert state.rng is gen + + +def test_rng_none_resets(si_sim_state: SimState) -> None: + """Setting rng = None resets; next access lazily re-initialises.""" + state = si_sim_state.clone() + state.rng = 42 + first_gen = state.rng + state.rng = None + second_gen = state.rng + assert isinstance(second_gen, torch.Generator) + assert second_gen is not first_gen + + +def test_rng_int_seed_reproducible() -> None: + """Same int seed produces the same random sequence.""" + + def _make_state(seed: int) -> SimState: + return SimState( + positions=torch.randn(4, 3), + masses=torch.ones(4), + cell=torch.eye(3).unsqueeze(0), + pbc=True, + atomic_numbers=torch.ones(4, dtype=torch.int), + _rng=seed, + ) + + s1 = _make_state(99) + s2 = _make_state(99) + r1 = torch.randn(5, generator=s1.rng) + r2 = torch.randn(5, generator=s2.rng) + assert torch.equal(r1, r2) + + +def test_rng_clone_independent(si_sim_state: SimState) -> None: + """Cloned state has an independent Generator with the same state.""" + state = si_sim_state.clone() + state.rng = 42 + clone = state.clone() + + assert torch.equal(state.rng.get_state(), clone.rng.get_state()) + assert clone.rng is not state.rng + + # Drawing from one does not affect the other + torch.randn(3, generator=state.rng) + assert not torch.equal(state.rng.get_state(), clone.rng.get_state()) + + +def test_rng_clone_none_preserved(si_sim_state: SimState) -> None: + """Cloning a state with _rng=None keeps it None on the clone.""" + state = si_sim_state.clone() + state.rng = None + clone = state.clone() + assert clone._rng is None # noqa: SLF001 + + +def test_rng_from_state_mdstate(si_sim_state: SimState) -> None: + """MDState.from_state copies the rng from the source SimState.""" + state = si_sim_state.clone() + state.rng = 42 + original_rng_state = state.rng.get_state().clone() + + md = MDState.from_state( + state, + momenta=torch.zeros_like(state.positions), + energy=torch.zeros(state.n_systems, device=state.device), + forces=torch.zeros_like(state.positions), + ) + assert isinstance(md.rng, torch.Generator) + assert torch.equal(md.rng.get_state(), original_rng_state) + assert md.rng is not state.rng + + +def test_rng_mdstate_inherits_lazy_init() -> None: + """MDState without explicit _rng still lazily initialises via the property.""" + md = MDState( + positions=torch.randn(2, 3), + masses=torch.ones(2), + cell=torch.eye(3).unsqueeze(0), + pbc=True, + atomic_numbers=torch.ones(2, dtype=torch.int), + momenta=torch.zeros(2, 3), + energy=torch.zeros(1), + forces=torch.zeros(2, 3), + ) + assert md._rng is None # noqa: SLF001 + gen = md.rng + assert isinstance(gen, torch.Generator) + + +def test_rng_concat_takes_first(si_sim_state: SimState) -> None: + """concatenate_states takes rng from the first state.""" + s1 = si_sim_state.clone() + s2 = si_sim_state.clone() + s1.rng = 11 + s2.rng = 22 + combined = ts.concatenate_states([s1, s2]) + assert torch.equal(combined.rng.get_state(), s1.rng.get_state()) + + +def test_rng_split_preserves(si_sim_state: SimState) -> None: + """Splitting a batched state shares the same rng value to each piece.""" + batched = ts.concatenate_states([si_sim_state, si_sim_state]) + batched.rng = 77 + parts = batched.split() + assert len(parts) == 2 + for part in parts: + assert isinstance(part.rng, torch.Generator) + + +def test_coerce_prng_none() -> None: + """None seed creates an unseeded Generator.""" + gen = coerce_prng(None, device=DEVICE) + assert isinstance(gen, torch.Generator) + + +def test_coerce_prng_int_seed() -> None: + """Int seed creates a deterministically-seeded Generator.""" + g1 = coerce_prng(42, device=DEVICE) + g2 = coerce_prng(42, device=DEVICE) + r1 = torch.randn(5, generator=g1) + r2 = torch.randn(5, generator=g2) + assert torch.equal(r1, r2) + + +def test_coerce_prng_different_seeds_diverge() -> None: + """Different int seeds produce different random streams.""" + g1 = coerce_prng(1, device=DEVICE) + g2 = coerce_prng(2, device=DEVICE) + r1 = torch.randn(5, generator=g1) + r2 = torch.randn(5, generator=g2) + assert not torch.equal(r1, r2) + + +def test_coerce_prng_generator_passthrough() -> None: + """Passing a Generator returns the exact same object.""" + gen = torch.Generator() + gen.manual_seed(7) + result = coerce_prng(gen, device=DEVICE) + assert result is gen + + +def test_coerce_prng_default_no_arg() -> None: + """Calling with no argument (default None) returns a Generator.""" + gen = coerce_prng(None, device=DEVICE) + assert isinstance(gen, torch.Generator) + + +def test_rng_setter_int_advances_state(si_sim_state: SimState) -> None: + """Setting rng to an int must store a Generator so its state advances.""" + state = si_sim_state.clone() + state.rng = 99 + # Two consecutive draws should differ because the Generator state advances + r1 = torch.randn(5, generator=state.rng) + r2 = torch.randn(5, generator=state.rng) + assert not torch.equal(r1, r2) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index fe9e8c3f2..d1d148172 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -1311,9 +1311,7 @@ def test_unwrap_positions(ar_double_sim_state: ts.SimState, lj_model: LennardJon kT = torch.tensor(300, dtype=DTYPE) * MetalUnits.temperature # Same cell - state = ts.nvt_langevin_init( - state=ar_double_sim_state, model=lj_model, kT=kT, seed=42 - ) + state = ts.nvt_langevin_init(state=ar_double_sim_state, model=lj_model, kT=kT) state.positions = ft.pbc_wrap_batched(state.positions, state.cell, state.system_idx) positions = [state.positions.detach().clone()] for _step in range(n_steps): @@ -1335,9 +1333,7 @@ def test_unwrap_positions(ar_double_sim_state: ts.SimState, lj_model: LennardJon assert torch.allclose(unwrapped_positions, positions, atol=1e-4) # Different cell - state = ts.npt_langevin_init( - state=ar_double_sim_state, model=lj_model, kT=kT, seed=42, dt=dt - ) + state = ts.npt_langevin_init(state=ar_double_sim_state, model=lj_model, kT=kT, dt=dt) state.positions = ft.pbc_wrap_batched(state.positions, state.cell, state.system_idx) positions = [state.positions.detach().clone()] cells = [state.cell.detach().clone()] diff --git a/torch_sim/integrators/__init__.py b/torch_sim/integrators/__init__.py index 3405b732e..d3a32f54d 100644 --- a/torch_sim/integrators/__init__.py +++ b/torch_sim/integrators/__init__.py @@ -68,13 +68,7 @@ import torch_sim as ts -from .md import ( - MDState, - calculate_momenta, - momentum_step, - position_step, - velocity_verlet_step, -) +from .md import MDState, initialize_momenta, momentum_step, position_step, velocity_verlet from .npt import ( NPTLangevinState, NPTNoseHooverState, diff --git a/torch_sim/integrators/md.py b/torch_sim/integrators/md.py index 7078e2799..c7fdfb920 100644 --- a/torch_sim/integrators/md.py +++ b/torch_sim/integrators/md.py @@ -32,6 +32,10 @@ class MDState(SimState): momenta (torch.Tensor): Particle momenta [n_particles, n_dim] energy (torch.Tensor): Potential energy of the system [n_systems] forces (torch.Tensor): Forces on particles [n_particles, n_dim] + rng (torch.Generator): RNG used by stochastic integrators (lazily + initialised via the ``rng`` property on ``SimState``). Stored on + the state so that the random stream advances consistently across + steps and can be serialised for reproducibility. Properties: velocities (torch.Tensor): Particle velocities [n_particles, n_dim] @@ -91,12 +95,12 @@ def calc_kT(self) -> torch.Tensor: # noqa: N802 ) -def calculate_momenta( +def initialize_momenta( positions: torch.Tensor, masses: torch.Tensor, system_idx: torch.Tensor, kT: float | torch.Tensor, - seed: int | None = None, + generator: torch.Generator | None = None, ) -> torch.Tensor: """Initialize particle momenta based on temperature. @@ -109,7 +113,7 @@ def calculate_momenta( masses (torch.Tensor): Particle masses [n_particles] system_idx (torch.Tensor): System indices [n_particles] kT (torch.Tensor): Temperature in energy units [n_systems] - seed (int, optional): Random seed for reproducibility. Defaults to None. + generator: Optional ``torch.Generator`` for reproducibility. Returns: torch.Tensor: Initialized momenta [n_particles, n_dim] @@ -117,13 +121,8 @@ def calculate_momenta( device = positions.device dtype = positions.dtype - generator = torch.Generator(device=device) - if seed is not None: - generator.manual_seed(seed) - kT = torch.as_tensor(kT, device=device, dtype=dtype) if kT.ndim > 0: - # kT is a tensor with shape (n_systems,) kT = kT[system_idx] # Generate random momenta from normal distribution diff --git a/torch_sim/integrators/npt.py b/torch_sim/integrators/npt.py index a7d57313d..7efa64982 100644 --- a/torch_sim/integrators/npt.py +++ b/torch_sim/integrators/npt.py @@ -13,8 +13,8 @@ MDState, NoseHooverChain, NoseHooverChainFns, - calculate_momenta, construct_nose_hoover_chain, + initialize_momenta, momentum_step, ) from torch_sim.integrators.nvt import _vrescale_update @@ -106,7 +106,10 @@ def _npt_langevin_beta( torch.Tensor: Random noise term for force calculation [n_particles, n_dim] """ # Generate system-specific noise with correct shape - noise = torch.randn_like(state.momenta) + rng = state.rng + noise = torch.randn( + state.momenta.shape, device=state.device, dtype=state.dtype, generator=rng + ) # Calculate the thermal noise amplitude by system batch_kT = kT @@ -149,7 +152,10 @@ def _npt_langevin_cell_beta( [n_systems, n_dimensions, n_dimensions] """ # Generate standard normal distribution (zero mean, unit variance) - noise = torch.randn_like(state.cell_positions, device=state.device, dtype=state.dtype) + rng = state.rng + noise = torch.randn( + state.cell_positions.shape, device=state.device, dtype=state.dtype, generator=rng + ) if kT.ndim == 0: kT = kT.expand(state.n_systems) @@ -275,7 +281,13 @@ def _npt_langevin_cell_velocity_step( c_2 = dt_expanded * ((a * F_p_n) + pressure_force) / (2 * cell_masses_expanded) # Generate system-specific cell noise with correct shape (n_systems, 3, 3) - cell_noise = torch.randn_like(state.cell_velocities) + rng = state.rng + cell_noise = torch.randn( + state.cell_velocities.shape, + device=state.cell_velocities.device, + dtype=state.cell_velocities.dtype, + generator=rng, + ) # Calculate thermal noise amplitude noise_prefactor = torch.sqrt( @@ -350,7 +362,10 @@ def _npt_langevin_position_step( c_2 = (2 * L_n_new_atoms / (L_n_new_atoms + L_n_atoms)) * b * dt_atoms # Generate atom-specific noise - noise = torch.randn_like(state.momenta) + rng = state.rng + noise = torch.randn( + state.momenta.shape, device=state.device, dtype=state.dtype, generator=rng + ) batch_kT = kT if kT.ndim == 0: batch_kT = kT.expand(state.n_systems) @@ -415,7 +430,10 @@ def _npt_langevin_velocity_step( c_2 = dt_atoms.unsqueeze(-1) * ((a * forces) + state.forces) / M_2.unsqueeze(-1) # Generate atom-specific noise - noise = torch.randn_like(state.momenta) + rng = state.rng + noise = torch.randn( + state.momenta.shape, device=state.device, dtype=state.dtype, generator=rng + ) batch_kT = kT if kT.ndim == 0: batch_kT = kT.expand(state.n_systems) @@ -508,7 +526,6 @@ def npt_langevin_init( alpha: float | torch.Tensor | None = None, cell_alpha: float | torch.Tensor | None = None, b_tau: float | torch.Tensor | None = None, - seed: int | None = None, **_kwargs: Any, ) -> NPTLangevinState: """Initialize an NPT Langevin state from input data. @@ -518,6 +535,8 @@ def npt_langevin_init( cell parameters, and barostat variables. It computes initial forces and stress using the provided model. + To seed the RNG set ``state.rng = seed`` before calling. + Args: model (ModelInterface): Neural network model that computes energies, forces, and stress. Must return a dict with 'energy', 'forces', and 'stress' keys. @@ -533,7 +552,6 @@ def npt_langevin_init( b_tau (torch.Tensor, optional): Barostat time constant controlling how quickly the system responds to pressure differences, either scalar or shape [n_systems]. Defaults to 1/(1000*dt). - seed (int, optional): Random seed for reproducibility. Defaults to None. Returns: NPTLangevinState: Initialized state for NPT Langevin integration containing @@ -576,7 +594,9 @@ def npt_langevin_init( momenta = getattr( state, "momenta", - calculate_momenta(state.positions, state.masses, state.system_idx, kT, seed), + initialize_momenta( + state.positions, state.masses, state.system_idx, kT, state.rng + ), ) # Initialize cell parameters @@ -1274,7 +1294,6 @@ def npt_nose_hoover_init( sy_steps: int = 3, t_tau: float | torch.Tensor | None = None, b_tau: float | torch.Tensor | None = None, - seed: int | None = None, **kwargs: Any, ) -> NPTNoseHooverState: """Initialize the NPT Nose-Hoover state. @@ -1284,6 +1303,8 @@ def npt_nose_hoover_init( system with appropriate initial conditions including particle positions, momenta, cell variables, and thermostat chains. + To seed the RNG set ``state.rng = seed`` before calling. + Args: model (ModelInterface): Model to compute forces and energies state: Initial system state as MDState or dict containing positions, masses, @@ -1298,7 +1319,6 @@ def npt_nose_hoover_init( equilibrates. Defaults to 100*dt b_tau: Barostat relaxation time. Controls how quickly pressure equilibrates. Defaults to 1000*dt - seed: Random seed for momenta initialization. Used for reproducible runs **kwargs: Additional state variables like atomic_numbers or pre-initialized momenta @@ -1368,7 +1388,9 @@ def npt_nose_hoover_init( # Initialize momenta momenta = kwargs.get( "momenta", - calculate_momenta(state.positions, state.masses, state.system_idx, kT, seed), + initialize_momenta( + state.positions, state.masses, state.system_idx, kT, state.rng + ), ) # Compute total DOF for thermostat initialization and a zero KE placeholder @@ -1700,7 +1722,12 @@ def _crescale_anisotropic_barostat_step( prefactor = state.isothermal_compressibility * sqrt_vol / (2 * state.tau_p) change_sqrt_vol = -prefactor * ( external_pressure - trace_P_int / 3 - kT / (2 * volume) - ) * dt / 2 + prefactor_random * torch.randn_like(sqrt_vol) + ) * dt / 2 + prefactor_random * torch.randn( + sqrt_vol.shape, + device=sqrt_vol.device, + dtype=sqrt_vol.dtype, + generator=state.rng, + ) new_sqrt_volume = sqrt_vol + change_sqrt_vol ## Step 2: compute deformation matrix prefactor_random_matrix = ( @@ -1721,6 +1748,7 @@ def _crescale_anisotropic_barostat_step( 3, device=state.positions.device, dtype=state.positions.dtype, + generator=state.rng, ) random_matrix_tilde = random_matrix - torch.einsum("bii->b", random_matrix)[ :, None, None @@ -1735,7 +1763,12 @@ def _crescale_anisotropic_barostat_step( ## Step 3: propagate sqrt(volume) for dt/2 new_sqrt_volume += -prefactor * ( external_pressure - trace_P_int / 3 - kT / (2 * volume) - ) * dt / 2 + prefactor_random * torch.randn_like(sqrt_vol) + ) * dt / 2 + prefactor_random * torch.randn( + sqrt_vol.shape, + device=sqrt_vol.device, + dtype=sqrt_vol.dtype, + generator=state.rng, + ) rscaling = deformation_matrix * torch.pow((new_sqrt_volume / sqrt_vol), 2 / 3).view( -1, 1, 1 ) @@ -1772,9 +1805,15 @@ def _crescale_independent_lengths_barostat_step( kT * state.isothermal_compressibility * dt / (4 * state.tau_p) ) prefactor = state.isothermal_compressibility * sqrt_vol / (2 * state.tau_p) + rng = state.rng change_sqrt_vol = -prefactor * ( external_pressure - trace_P_int / 3 - kT / (2 * volume) - ) * dt / 2 + prefactor_random * torch.randn_like(sqrt_vol) + ) * dt / 2 + prefactor_random * torch.randn( + sqrt_vol.shape, + device=sqrt_vol.device, + dtype=sqrt_vol.dtype, + generator=rng, + ) new_sqrt_volume = sqrt_vol + change_sqrt_vol ## Step 2: compute deformation matrix prefactor_random_matrix = ( @@ -1792,6 +1831,7 @@ def _crescale_independent_lengths_barostat_step( 3, device=state.positions.device, dtype=state.positions.dtype, + generator=rng, ) random_matrix_tilde = random_matrix - torch.mean(random_matrix, dim=1, keepdim=True) deformation_matrix = torch.exp( @@ -1801,7 +1841,12 @@ def _crescale_independent_lengths_barostat_step( ## Step 3: propagate sqrt(volume) for dt/2 new_sqrt_volume += -prefactor * ( external_pressure - trace_P_int / 3 - kT / (2 * volume) - ) * dt / 2 + prefactor_random * torch.randn_like(sqrt_vol) + ) * dt / 2 + prefactor_random * torch.randn( + sqrt_vol.shape, + device=sqrt_vol.device, + dtype=sqrt_vol.dtype, + generator=rng, + ) rscaling = deformation_matrix * torch.pow( (new_sqrt_volume / sqrt_vol), 2 / 3 ).unsqueeze(-1) @@ -1865,9 +1910,15 @@ def _crescale_average_anisotropic_barostat_step( kT * state.isothermal_compressibility * dt / (4 * state.tau_p) ) prefactor = state.isothermal_compressibility * sqrt_vol / (2 * state.tau_p) + rng = state.rng change_sqrt_vol = -prefactor * ( external_pressure - trace_P_int / 3 - kT / (2 * volume) - ) * dt / 2 + prefactor_random * torch.randn_like(sqrt_vol) + ) * dt / 2 + prefactor_random * torch.randn( + sqrt_vol.shape, + device=sqrt_vol.device, + dtype=sqrt_vol.dtype, + generator=rng, + ) new_sqrt_volume = sqrt_vol + change_sqrt_vol ## Step 2: compute deformation matrix prefactor_random_matrix = ( @@ -1888,6 +1939,7 @@ def _crescale_average_anisotropic_barostat_step( 3, device=state.positions.device, dtype=state.positions.dtype, + generator=rng, ) random_matrix_tilde = random_matrix - torch.einsum("bii->b", random_matrix)[ :, None, None @@ -1902,7 +1954,12 @@ def _crescale_average_anisotropic_barostat_step( ## Step 3: propagate sqrt(volume) for dt/2 new_sqrt_volume += -prefactor * ( external_pressure - trace_P_int / 3 - kT / (2 * volume) - ) * dt / 2 + prefactor_random * torch.randn_like(sqrt_vol) + ) * dt / 2 + prefactor_random * torch.randn( + sqrt_vol.shape, + device=sqrt_vol.device, + dtype=sqrt_vol.dtype, + generator=rng, + ) rscaling = deformation_matrix * torch.pow((new_sqrt_volume / sqrt_vol), 2 / 3).view( -1, 1, 1 ) @@ -1943,9 +2000,15 @@ def _crescale_isotropic_barostat_step( kT * state.isothermal_compressibility * dt / (4 * state.tau_p) ) prefactor = state.isothermal_compressibility * sqrt_vol / (2 * state.tau_p) + rng = state.rng change_sqrt_vol = -prefactor * ( external_pressure - trace_P_int / 3 - kT / (2 * volume) - ) * dt + prefactor_random * torch.randn_like(sqrt_vol) + ) * dt + prefactor_random * torch.randn( + sqrt_vol.shape, + device=sqrt_vol.device, + dtype=sqrt_vol.dtype, + generator=rng, + ) new_sqrt_volume = sqrt_vol + change_sqrt_vol # Update positions and momenta (barostat + half momentum step) @@ -2272,7 +2335,6 @@ def npt_crescale_init( dt: float | torch.Tensor, tau_p: float | torch.Tensor | None = None, isothermal_compressibility: float | torch.Tensor | None = None, - seed: int | None = None, ) -> NPTCRescaleState: """Initialize the NPT cell rescaling state. @@ -2283,6 +2345,8 @@ def npt_crescale_init( Only allow isotropic external stress, but can run both isotropic and anisotropic cell rescaling. + To seed the RNG set ``state.rng = seed`` before calling. + Args: state: Initial system state as MDState or dict containing positions, masses, cell, and PBC information @@ -2291,7 +2355,6 @@ def npt_crescale_init( dt: Integration timestep tau_p: Barostat relaxation time. Controls how quickly pressure equilibrates. isothermal_compressibility: Isothermal compressibility of the system. - seed: Random seed for momenta initialization. """ device, dtype = model.device, model.dtype @@ -2324,7 +2387,9 @@ def npt_crescale_init( momenta = getattr( state, "momenta", - calculate_momenta(state.positions, state.masses, state.system_idx, kT, seed), + initialize_momenta( + state.positions, state.masses, state.system_idx, kT, state.rng + ), ) # Create the initial state diff --git a/torch_sim/integrators/nve.py b/torch_sim/integrators/nve.py index f49de51d4..fed46c207 100644 --- a/torch_sim/integrators/nve.py +++ b/torch_sim/integrators/nve.py @@ -6,7 +6,7 @@ from torch_sim.integrators.md import ( MDState, - calculate_momenta, + initialize_momenta, momentum_step, position_step, ) @@ -20,7 +20,6 @@ def nve_init( model: ModelInterface, *, kT: float | torch.Tensor, - seed: int | None = None, **_kwargs: Any, ) -> MDState: """Initialize an NVE state from input data. @@ -29,6 +28,8 @@ def nve_init( energies and forces, and sampling momenta from a Maxwell-Boltzmann distribution at the specified temperature. + To seed the RNG set ``state.rng = seed`` before calling. + Args: model: Neural network model that computes energies and forces. Must return a dict with 'energy' and 'forces' keys. @@ -36,7 +37,6 @@ def nve_init( masses, cell, pbc, and other required state variables kT: Temperature in energy units for initializing momenta, scalar or with shape [n_systems] - seed: Random seed for reproducibility Returns: MDState: Initialized state for NVE integration containing positions, @@ -54,7 +54,9 @@ def nve_init( momenta = getattr( state, "momenta", - calculate_momenta(state.positions, state.masses, state.system_idx, kT, seed), + initialize_momenta( + state.positions, state.masses, state.system_idx, kT, state.rng + ), ) return MDState.from_state( diff --git a/torch_sim/integrators/nvt.py b/torch_sim/integrators/nvt.py index c35ea5a2b..5840763c7 100644 --- a/torch_sim/integrators/nvt.py +++ b/torch_sim/integrators/nvt.py @@ -11,8 +11,8 @@ MDState, NoseHooverChain, NoseHooverChainFns, - calculate_momenta, construct_nose_hoover_chain, + initialize_momenta, momentum_step, position_step, velocity_verlet_step, @@ -68,7 +68,13 @@ def _ou_step( c2 = torch.sqrt(kT * (1 - torch.square(c1))).unsqueeze(-1) # Generate random noise from normal distribution - noise = torch.randn_like(state.momenta, device=state.device, dtype=state.dtype) + rng = state.rng + noise = torch.randn( + state.momenta.shape, + device=state.device, + dtype=state.dtype, + generator=rng, + ) new_momenta = ( c1.unsqueeze(-1) * state.momenta + c2 * torch.sqrt(state.masses).unsqueeze(-1) * noise @@ -82,7 +88,6 @@ def nvt_langevin_init( model: ModelInterface, *, kT: float | torch.Tensor, - seed: int | None = None, **_kwargs: Any, ) -> MDState: """Initialize an NVT state from input data for Langevin dynamics. @@ -91,6 +96,8 @@ def nvt_langevin_init( energies and forces, and sampling momenta from a Maxwell-Boltzmann distribution at the specified temperature. + To seed the RNG set ``state.rng = seed`` before calling. + Args: model: Neural network model that computes energies and forces. Must return a dict with 'energy' and 'forces' keys. @@ -98,7 +105,6 @@ def nvt_langevin_init( masses, cell, pbc, and other required state vars kT: Temperature in energy units for initializing momenta, either scalar or with shape [n_systems] - seed: Random seed for reproducibility Returns: MDState: Initialized state for NVT integration containing positions, @@ -117,7 +123,9 @@ def nvt_langevin_init( momenta = getattr( state, "momenta", - calculate_momenta(state.positions, state.masses, state.system_idx, kT, seed), + initialize_momenta( + state.positions, state.masses, state.system_idx, kT, state.rng + ), ) return MDState.from_state( state, @@ -247,7 +255,6 @@ def nvt_nose_hoover_init( chain_length: int = 3, chain_steps: int = 3, sy_steps: int = 3, - seed: int | None = None, **kwargs: Any, ) -> NVTNoseHooverState: """Initialize the NVT Nose-Hoover state. @@ -257,6 +264,8 @@ def nvt_nose_hoover_init( coupling the system to a chain of thermostats. The integration scheme is time-reversible and conserves an extended energy quantity. + To seed the RNG set ``state.rng = seed`` before calling. + Args: state: Initial system state as SimState or dict model: Neural network model that computes energies and forces @@ -266,7 +275,6 @@ def nvt_nose_hoover_init( chain_length: Number of thermostats in Nose-Hoover chain (default: 3) chain_steps: Number of chain integration substeps (default: 3) sy_steps: Number of Suzuki-Yoshida steps - must be 1, 3, 5, or 7 (default: 3) - seed: Random seed for momenta initialization **kwargs: Additional state variables Returns: @@ -294,7 +302,9 @@ def nvt_nose_hoover_init( model_output = model(state) momenta = kwargs.get( "momenta", - calculate_momenta(state.positions, state.masses, state.system_idx, kT, seed), + initialize_momenta( + state.positions, state.masses, state.system_idx, kT, state.rng + ), ) # Calculate initial kinetic energy per system @@ -535,9 +545,10 @@ def _vrescale_update( KE_new = dof * kT_tensor / 2 # Generate random numbers - r1 = torch.randn(n_systems, device=device, dtype=dtype) - # Sample Gamma((dof - 1)/2, 1/2) = \sum_2^{dof} X_i^2 where X_i ~ N(0,1) - r2 = torch.distributions.Gamma((dof - 1) / 2, torch.ones_like(dof) / 2).sample() + rng = state.rng + r1 = torch.randn(n_systems, device=device, dtype=dtype, generator=rng) + # Sample Gamma((dof - 1)/2, 1/2) via _standard_gamma so we can seed it + r2 = torch._standard_gamma((dof - 1) / 2, generator=rng) * 2 # noqa: SLF001 # Calculate scaling coefficients c1 = torch.exp(-dt_tensor / tau_tensor) @@ -557,7 +568,6 @@ def nvt_vrescale_init( model: ModelInterface, *, kT: float | torch.Tensor, - seed: int | None = None, **_kwargs: Any, ) -> NVTVRescaleState: """Initialize an NVT state from input data for velocity rescaling dynamics. @@ -567,6 +577,8 @@ def nvt_vrescale_init( samples from the canonical ensemble by rescaling velocities with an appropriately chosen random factor. + To seed the RNG set ``state.rng = seed`` before calling. + Args: model: Neural network model that computes energies and forces. Must return a dict with 'energy' and 'forces' keys. @@ -574,7 +586,6 @@ def nvt_vrescale_init( masses, cell, pbc, and other required state vars kT: Temperature in energy units for initializing momenta, either scalar or with shape [n_systems] - seed: Random seed for reproducibility Returns: MDState: Initialized state for NVT integration containing positions, @@ -593,7 +604,9 @@ def nvt_vrescale_init( momenta = getattr( state, "momenta", - calculate_momenta(state.positions, state.masses, state.system_idx, kT, seed), + initialize_momenta( + state.positions, state.masses, state.system_idx, kT, state.rng + ), ) return NVTVRescaleState.from_state( @@ -629,7 +642,6 @@ def nvt_vrescale_step( with shape [n_systems] tau: Thermostat relaxation time controlling the coupling strength, either scalar or with shape [n_systems]. Defaults to 100*dt. - seed: Random seed for reproducibility Returns: MDState: Updated state after one complete V-Rescale step with new positions, diff --git a/torch_sim/state.py b/torch_sim/state.py index bd0dd1f2a..cd06255da 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -15,7 +15,7 @@ import torch import torch_sim as ts -from torch_sim.typing import StateLike +from torch_sim.typing import PRNGLike, StateLike if TYPE_CHECKING: @@ -26,6 +26,35 @@ from torch_sim.constraints import Constraint, merge_constraints, validate_constraints +def coerce_prng(rng: PRNGLike, device: torch.device) -> torch.Generator: + """Coerce an int seed or existing Generator into a ``torch.Generator``. + + Args: + rng: An integer seed, an existing ``torch.Generator``, or ``None`` + (uses a non-deterministic seed). + device: Device for the generator. + + Returns: + A ``torch.Generator`` ready for use. + """ + if isinstance(rng, torch.Generator): + if rng.device == device: + return rng + new = torch.Generator(device=device) + new.set_state(rng.get_state()) + return new + + if isinstance(rng, int): + gen = torch.Generator(device=device) + gen.manual_seed(rng) + return gen + + if rng is None: + return torch.Generator(device=device) + + raise ValueError(f"Invalid rng type: {type(rng)}") + + @dataclass class SimState: """State representation for atomistic systems with batched operations support. @@ -104,6 +133,8 @@ def charge(self) -> torch.Tensor: ... # noqa: D102 @property def spin(self) -> torch.Tensor: ... # noqa: D102 + _rng: int | torch.Generator | None = field(default=None, repr=False) + _atom_attributes: ClassVar[set[str]] = { "positions", "masses", @@ -111,7 +142,23 @@ def spin(self) -> torch.Tensor: ... # noqa: D102 "system_idx", } _system_attributes: ClassVar[set[str]] = {"cell", "charge", "spin"} - _global_attributes: ClassVar[set[str]] = {"pbc"} + _global_attributes: ClassVar[set[str]] = {"pbc", "_rng"} + + @property + def rng(self) -> torch.Generator: + """Lazily-initialised per-batch PRNG. + + On first access, if no generator has been assigned, a new + ``torch.Generator`` is created on the same device as the state. + Assign ``None`` to reset, or assign a seeded ``torch.Generator`` + or an ``int`` seed for reproducibility. + """ + self._rng = coerce_prng(self._rng, self.device) + return self._rng + + @rng.setter + def rng(self, value: int | torch.Generator | None) -> None: + self._rng = value def __post_init__(self) -> None: # noqa: C901 """Initialize the SimState and validate the arguments.""" @@ -371,6 +418,17 @@ def get_number_of_degrees_of_freedom(self) -> torch.Tensor: raise ValueError("Degrees of freedom cannot be zero or negative") return dof_per_system + @staticmethod + def _clone_attr(value: object) -> object: + """Clone a single attribute value, handling torch.Generator specially.""" + if isinstance(value, torch.Tensor): + return value.clone() + if isinstance(value, torch.Generator): + new = torch.Generator(device=value.device) + new.set_state(value.get_state()) + return new + return copy.deepcopy(value) + def clone(self) -> Self: """Create a deep copy of the SimState. @@ -380,13 +438,7 @@ def clone(self) -> Self: Returns: SimState: A new SimState object with the same properties as the original """ - attrs = {} - for attr_name, attr_value in self.attributes.items(): - if isinstance(attr_value, torch.Tensor): - attrs[attr_name] = attr_value.clone() - else: - attrs[attr_name] = copy.deepcopy(attr_value) - + attrs = {name: self._clone_attr(val) for name, val in self.attributes.items()} return type(self)(**attrs) @classmethod @@ -417,10 +469,7 @@ def from_state(cls, state: "SimState", **additional_attrs: Any) -> Self: attrs = {} for attr_name, attr_value in state.attributes.items(): if attr_name in cls._get_all_attributes(): - if isinstance(attr_value, torch.Tensor): - attrs[attr_name] = attr_value.clone() - else: - attrs[attr_name] = copy.deepcopy(attr_value) + attrs[attr_name] = cls._clone_attr(attr_value) # Add/override with additional attributes attrs.update(additional_attrs) @@ -730,6 +779,8 @@ def _state_to_device[T: SimState]( for attr_name, attr_value in attrs.items(): if isinstance(attr_value, torch.Tensor): attrs[attr_name] = attr_value.to(device=device) + elif isinstance(attr_value, torch.Generator): + attrs[attr_name] = coerce_prng(attr_value, device) if dtype is not None: attrs["positions"] = attrs["positions"].to(dtype=dtype) diff --git a/torch_sim/testing.py b/torch_sim/testing.py index 8ba6589e8..772fde134 100644 --- a/torch_sim/testing.py +++ b/torch_sim/testing.py @@ -160,19 +160,40 @@ def make_sio2_sim_state( return ts.io.atoms_to_state(atoms, device, dtype) -def _rattle_sim_state(sim_state: ts.SimState, seed: int = 3) -> ts.SimState: - """Apply Weibull-distributed random displacements to positions.""" +def _rattle_sim_state( + sim_state: ts.SimState, + seed: int | None = None, + scale: float = 0.1, + concentration: float = 1.0, +) -> ts.SimState: + """Apply Weibull-distributed random displacements to positions. + + Uses the state's ``rng`` (seeded with *seed*) so no global RNG state is touched. + """ sim_state = sim_state.clone() - rng_state = torch.random.get_rng_state() - try: - torch.manual_seed(seed) - weibull = torch.distributions.weibull.Weibull(scale=0.1, concentration=1) - rnd = torch.randn_like(sim_state.positions) - rnd = rnd / torch.norm(rnd, dim=-1, keepdim=True) - shifts = weibull.sample(rnd.shape).to(device=sim_state.positions.device) * rnd - sim_state.positions = sim_state.positions + shifts - finally: - torch.random.set_rng_state(rng_state) + if seed is not None: + sim_state.rng = seed + rng = sim_state.rng + + # Sample Directions on the unit sphere to displace atoms + rnd = torch.randn( + sim_state.positions.shape, + device=sim_state.device, + dtype=sim_state.dtype, + generator=rng, + ) + rnd = rnd / torch.norm(rnd, dim=-1, keepdim=True) + + # Sample magnitudes from Weibull distribution so large displacements are less likely + # Weibull via inverse CDF: X = scale * (-ln(U))^(1/concentration) + u = torch.rand( + rnd.shape[0], + device=sim_state.device, + dtype=sim_state.dtype, + generator=rng, + ) + weibull_samples = scale * (-torch.log(u)) ** (1.0 / concentration) + sim_state.positions = sim_state.positions + weibull_samples.unsqueeze(-1) * rnd return sim_state diff --git a/torch_sim/typing.py b/torch_sim/typing.py index 94c6dab2d..09b015361 100644 --- a/torch_sim/typing.py +++ b/torch_sim/typing.py @@ -48,3 +48,6 @@ class BravaisType(StrEnum): list["PhonopyAtoms"], "SimState", ] + +# Type alias accepted by coerce_prng +PRNGLike = int | torch.Generator | None diff --git a/torch_sim/workflows/a2c.py b/torch_sim/workflows/a2c.py index 634ee93ab..bb006bf52 100644 --- a/torch_sim/workflows/a2c.py +++ b/torch_sim/workflows/a2c.py @@ -258,10 +258,7 @@ def random_packed_structure( # Extract number of atoms for each element from composition element_counts = [int(i) for i in composition.as_dict().values()] - # Set up reproducible random number generator - generator = torch.Generator(device=device) - if seed is not None: - generator.manual_seed(seed) + generator = ts.state.coerce_prng(seed, device) log = [] # Generate initial random positions in fractional coordinates @@ -298,6 +295,7 @@ def random_packed_structure( atomic_numbers=atomic_numbers, cell=cell, pbc=True, + _rng=generator, ) state = ts.fire_init(state, model) print(f"Initial energy: {state.energy.item():.4f}") @@ -385,9 +383,7 @@ def random_packed_structure_multi( print(f"Creating structure with {N_atoms} atoms: {element_dict}") # Set up random number generator with optional seed for reproducibility - generator = torch.Generator(device=device) - if seed is not None: - generator.manual_seed(seed) + generator = ts.state.coerce_prng(seed, device) # Generate initial random positions in fractional coordinates [0,1] positions = torch.rand((N_atoms, 3), device=device, dtype=dtype, generator=generator) @@ -425,6 +421,7 @@ def random_packed_structure_multi( atomic_numbers=atomic_numbers, cell=cell, pbc=True, + _rng=generator, ) # Set up FIRE optimizer with unit masses for all atoms state = ts.fire_init(state_dict, model)