Skip to content

Standardize parameter typing to float | torch.Tensor when appropriate#466

Merged
CompRhys merged 7 commits intomainfrom
float-or-tensor
Feb 25, 2026
Merged

Standardize parameter typing to float | torch.Tensor when appropriate#466
CompRhys merged 7 commits intomainfrom
float-or-tensor

Conversation

@CompRhys
Copy link
Member

@CompRhys CompRhys commented Feb 24, 2026

see #325

I am unsure if step functions should allow floats because then we have overhead at the integration level. Init functions are a one off cost not a linear cost.

@CompRhys CompRhys requested a review from orionarcher February 24, 2026 13:49
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the changes here make the code less clean.

Copy link
Collaborator

@orionarcher orionarcher left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@CompRhys
Copy link
Member Author

Will wait on a speed check from @falletta and then merge if it looks fine. I can't imagine it's significant but no real harm in checking.

Signed-off-by: Rhys Goodall <rhys.goodall@outlook.com>
@CompRhys CompRhys marked this pull request as ready for review February 24, 2026 18:58
@CompRhys CompRhys linked an issue Feb 24, 2026 that may be closed by this pull request
@janosh
Copy link
Collaborator

janosh commented Feb 25, 2026

@CompRhys in case this gets merged today, could you apply

diff --git a/torch_sim/models/fairchem.py b/torch_sim/models/fairchem.py
--- a/torch_sim/models/fairchem.py
+++ b/torch_sim/models/fairchem.py
@@ -201,11 +201,11 @@ class FairChemModel(ModelInterface):
             zip(n_atoms, torch.cumsum(n_atoms, dim=0), strict=False)
         ):
             # Extract system data
-            positions = sim_state.positions[c - n : c].cpu().numpy()
-            atomic_nums = sim_state.atomic_numbers[c - n : c].cpu().numpy()
-            pbc = sim_state.pbc.cpu().numpy()
+            positions = sim_state.positions[c - n : c].detach().cpu().numpy()
+            atomic_nums = sim_state.atomic_numbers[c - n : c].detach().cpu().numpy()
+            pbc = sim_state.pbc.detach().cpu().numpy()
             cell = (
-                sim_state.row_vector_cell[idx].cpu().numpy()
+                sim_state.row_vector_cell[idx].detach().cpu().numpy()
                 if sim_state.row_vector_cell is not None
                 else None
             )

to close #469. could add a single-step fairchem relax test but if models get moved to external repos, maybe not worth it

@CompRhys
Copy link
Member Author

CompRhys commented Feb 25, 2026

to close #469. could add a single-step fairchem relax test but if models get moved to external repos, maybe not worth it

The fairchem PR I made to upstream has been sitting for a while without eyes from meta so I will add here.

@CompRhys CompRhys linked an issue Feb 25, 2026 that may be closed by this pull request
@falletta
Copy link
Contributor

I see similar performance when running the 8_benchmarking.py script. Below are the results, rounded to the first digit, obtained both before and after the changes:

=== Static benchmark, n=5000 static_time=7.9 s
=== Relax benchmark, n=500 relax_10_time=34.2 s
=== NVE benchmark, n=500 nve_time=4.5 s
=== NVT benchmark, n=500 nvt_time=4.7 s

I also checked the scalings of various operations with system size using my own profiling scripts, and everything looks good. So green light from my side @CompRhys

@CompRhys CompRhys enabled auto-merge (squash) February 25, 2026 19:37
@CompRhys CompRhys merged commit ad8624a into main Feb 25, 2026
68 checks passed
@CompRhys CompRhys deleted the float-or-tensor branch February 25, 2026 19:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

fairchem model relaxations crash from .cpu().numpy() on grad tensors Standardize parameter typing to float | torch.Tensor when appropriate

4 participants