Skip to content

Introduce PRNG to SimState and add reproducibility docs.#460

Open
CompRhys wants to merge 10 commits intomainfrom
prng-simstate-reproduce
Open

Introduce PRNG to SimState and add reproducibility docs.#460
CompRhys wants to merge 10 commits intomainfrom
prng-simstate-reproduce

Conversation

@CompRhys
Copy link
Member

@CompRhys CompRhys commented Feb 21, 2026

The only messy bit is resumption as for serialization it seems that the only way to do it is with torch.save and I feel asking the user to store the pickle manually is awkward.


AI Overview

Every SimState now carries an optional _rng field (a torch.Generator) that controls all stochastic operations: momentum initialization, Langevin OU noise, V-Rescale Gamma draws, and C-Rescale barostat noise. No integrator init or step function accepts a seed or prng argument anymore — seeding is done exclusively through the state.

The rng property

state.rng = 42          # int → coerced to a seeded Generator
state.rng = gen         # torch.Generator used directly
state.rng = None        # reset; next access creates an unseeded Generator
samples = state.rng     # lazily initialises if _rng is None/int, then returns it
  • Lazy: if _rng is None (the default), accessing state.rng creates a new torch.Generator on the state's device and stores it back. No Generator is allocated until first use.
  • Coercing: if _rng is an int, accessing state.rng converts it to a seeded Generator via coerce_prng() and stores it back, so subsequent accesses return the same (advancing) Generator.
  • Advancing: because a single torch.Generator object is stored, its internal state advances with each draw, giving a proper random stream rather than re-seeding every step.

Cloning

state.clone() deep-copies the Generator via get_state() / set_state(), producing an independent copy with identical initial RNG state. Drawing from one does not affect the other.

Splitting

state.split() copies global attributes (including _rng) to every piece. All resulting single-system states share the same Generator value (copied), not the same object.

Concatenating

concatenate_states([s1, s2, ...]) takes global attributes from the first state. The resulting batch uses s1's Generator; other states' Generators are discarded.

Device movement

state.to(device) moves the Generator to the target device via coerce_prng(), which creates a new Generator on the target device and copies the RNG state if devices differ.

Serialisation

torch.save(state.rng.get_state(), "rng.pt")           # save
gen = torch.Generator(device=state.device)
gen.set_state(torch.load("rng.pt"))
state.rng = gen                                        # restore

What changed

  • _rng moved from MDState to SimState (it's a global attribute, not MD-specific).
  • All seed= / prng= parameters removed from integrator init functions.
  • initialize_momenta takes generator: torch.Generator | None directly.
  • V-Rescale Gamma sampling switched from torch.distributions.Gamma (unseeded) to torch._standard_gamma(..., generator=rng) so it's now fully seedable.
  • _rattle_sim_state in testing.py refactored to use state.rng instead of saving/restoring global RNG state.
  • coerce_prng handles cross-device Generator transfer.
  • _state_to_device handles Generator device movement.

@CompRhys CompRhys added api API design discussions breaking Breaking changes keep-open PRs to be ignored by StaleBot labels Feb 21, 2026
@CompRhys CompRhys requested a review from thomasloux February 21, 2026 20:05
c2 = torch.sqrt(kT * (1 - torch.square(c1))).unsqueeze(-1)

# Generate random noise from normal distribution
noise = torch.randn_like(state.momenta, device=state.device, dtype=state.dtype)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

after pytorch/pytorch#165865 randn_like is in torch 2.10 but I am not sure we want to pin to 2.10 given not all the models people want to use will support.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@CompRhys CompRhys marked this pull request as ready for review February 21, 2026 20:35
weibull = torch.distributions.weibull.Weibull(scale=0.1, concentration=1)
rnd = torch.randn_like(sim_state.positions)
rnd = rnd / torch.norm(rnd, dim=-1, keepdim=True)
shifts = weibull.sample(rnd.shape).to(device=sim_state.positions.device) * rnd
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

avoid weibull.sample() as it cannot be seeded.

# Generate random numbers
r1 = torch.randn(n_systems, device=device, dtype=dtype)
# Sample Gamma((dof - 1)/2, 1/2) = \sum_2^{dof} X_i^2 where X_i ~ N(0,1)
r2 = torch.distributions.Gamma((dof - 1) / 2, torch.ones_like(dof) / 2).sample()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

avoid Gamma.sample() as it cannot be seeded.


@staticmethod
def _clone_attr(value: object) -> object:
"""Clone a single attribute value, handling torch.Generator specially."""
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

forks have identical rng states.



def calculate_momenta(
def initialize_momenta(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

driveby: this was a misleading name.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

api API design discussions breaking Breaking changes keep-open PRs to be ignored by StaleBot

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add a page in docs about reproducibility Add a seed for integrator step function to reproduce results Allow seeds to be set for individual batches

1 participant