Introduce PRNG to SimState and add reproducibility docs.#460
Open
Introduce PRNG to SimState and add reproducibility docs.#460
Conversation
This was
linked to
issues
Feb 21, 2026
CompRhys
commented
Feb 21, 2026
| c2 = torch.sqrt(kT * (1 - torch.square(c1))).unsqueeze(-1) | ||
|
|
||
| # Generate random noise from normal distribution | ||
| noise = torch.randn_like(state.momenta, device=state.device, dtype=state.dtype) |
Member
Author
There was a problem hiding this comment.
after pytorch/pytorch#165865 randn_like is in torch 2.10 but I am not sure we want to pin to 2.10 given not all the models people want to use will support.
Member
Author
There was a problem hiding this comment.
CompRhys
commented
Feb 22, 2026
| weibull = torch.distributions.weibull.Weibull(scale=0.1, concentration=1) | ||
| rnd = torch.randn_like(sim_state.positions) | ||
| rnd = rnd / torch.norm(rnd, dim=-1, keepdim=True) | ||
| shifts = weibull.sample(rnd.shape).to(device=sim_state.positions.device) * rnd |
Member
Author
There was a problem hiding this comment.
avoid weibull.sample() as it cannot be seeded.
CompRhys
commented
Feb 22, 2026
| # Generate random numbers | ||
| r1 = torch.randn(n_systems, device=device, dtype=dtype) | ||
| # Sample Gamma((dof - 1)/2, 1/2) = \sum_2^{dof} X_i^2 where X_i ~ N(0,1) | ||
| r2 = torch.distributions.Gamma((dof - 1) / 2, torch.ones_like(dof) / 2).sample() |
Member
Author
There was a problem hiding this comment.
avoid Gamma.sample() as it cannot be seeded.
CompRhys
commented
Feb 22, 2026
|
|
||
| @staticmethod | ||
| def _clone_attr(value: object) -> object: | ||
| """Clone a single attribute value, handling torch.Generator specially.""" |
Member
Author
There was a problem hiding this comment.
forks have identical rng states.
CompRhys
commented
Feb 22, 2026
|
|
||
|
|
||
| def calculate_momenta( | ||
| def initialize_momenta( |
Member
Author
There was a problem hiding this comment.
driveby: this was a misleading name.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
The only messy bit is resumption as for serialization it seems that the only way to do it is with torch.save and I feel asking the user to store the pickle manually is awkward.
AI Overview
Every
SimStatenow carries an optional_rngfield (atorch.Generator) that controls all stochastic operations: momentum initialization, Langevin OU noise, V-Rescale Gamma draws, and C-Rescale barostat noise. No integrator init or step function accepts aseedorprngargument anymore — seeding is done exclusively through the state.The
rngproperty_rngisNone(the default), accessingstate.rngcreates a newtorch.Generatoron the state's device and stores it back. No Generator is allocated until first use._rngis anint, accessingstate.rngconverts it to a seeded Generator viacoerce_prng()and stores it back, so subsequent accesses return the same (advancing) Generator.torch.Generatorobject is stored, its internal state advances with each draw, giving a proper random stream rather than re-seeding every step.Cloning
state.clone()deep-copies the Generator viaget_state()/set_state(), producing an independent copy with identical initial RNG state. Drawing from one does not affect the other.Splitting
state.split()copies global attributes (including_rng) to every piece. All resulting single-system states share the same Generator value (copied), not the same object.Concatenating
concatenate_states([s1, s2, ...])takes global attributes from the first state. The resulting batch usess1's Generator; other states' Generators are discarded.Device movement
state.to(device)moves the Generator to the target device viacoerce_prng(), which creates a new Generator on the target device and copies the RNG state if devices differ.Serialisation
What changed
_rngmoved fromMDStatetoSimState(it's a global attribute, not MD-specific).seed=/prng=parameters removed from integrator init functions.initialize_momentatakesgenerator: torch.Generator | Nonedirectly.torch.distributions.Gamma(unseeded) totorch._standard_gamma(..., generator=rng)so it's now fully seedable._rattle_sim_stateintesting.pyrefactored to usestate.rnginstead of saving/restoring global RNG state.coerce_prnghandles cross-device Generator transfer._state_to_devicehandles Generator device movement.