Skip to content

Umas fast gpu backend #1826

Draft
misko wants to merge 5 commits intomainfrom
umas_fast_gpu_backend_clean
Draft

Umas fast gpu backend #1826
misko wants to merge 5 commits intomainfrom
umas_fast_gpu_backend_clean

Conversation

@misko
Copy link
Contributor

@misko misko commented Feb 26, 2026

No description provided.

Adds GPU-optimized execution backend using Triton JIT kernels for
Wigner D-matrix transformations (gather_rotate, rotate_back).

Key components:
- UMASFastGPUBackend: Extends UMASFastPytorchBackend with Triton ops
- triton/ops.py: Public API (scatter_wigner, gather_wigner autograd Functions)
- triton/constants.py: Wigner transformation constants
- triton/_kernels/: Triton JIT kernels for lmax=2 transforms

Requirements: lmax==2, mmax==2, sphere_channels % 128 == 0, merge_mole=True

Performance: ~33% speedup over umas_fast_pytorch backend
The E parameter was passed to kernels but never used inside kernel bodies -
edge_id is obtained from tl.program_id(0) instead. Since E was marked as
tl.constexpr, Triton was recompiling the kernel for every unique edge count,
causing unnecessary JIT compilation overhead.

Removed from 5 kernels:
- wigner_lmax2_fwd_kernel
- wigner_lmax2_bwd_dx_kernel
- wigner_lmax2_bwd_dw_kernel
- l_to_m_kernel
- m_to_l_kernel
- Update test imports for new _kernels/ subfolder paths
- Add MockEdgeDegreeEmbedding to validation tests for
  activation_checkpoint_chunk_size check
The UMASFastPytorchBackend.validate() now checks both:
1. Model-level: edge_degree_embedding.activation_checkpoint_chunk_size
2. Settings-level: settings.activation_checkpointing

Both checks are needed because validate() runs BEFORE the model is
rebuilt with inference settings.
Major cleanup of the triton folder reducing codebase by ~3000 lines while
maintaining 1.31x speedup (15.14 QPS vs 11.57 QPS baseline).

Changes:
- Remove 4 dead kernels from wigner_transform.py (wigner_lmax2_fwd_kernel,
  wigner_lmax2_bwd_dx_kernel, l_to_m_kernel, m_to_l_kernel)
- Remove dead fused_wigner_backward_scatter_add from gather_wigner_bwd.py
- Remove unused edge_distance parameter from forward methods
- Inline helper functions into autograd backward methods:
  - _wigner_lmax2_bwd_dw -> UMASFastGPUPermuteWignerInvEdgeToNode.backward
  - _fused_m_to_l_wigner_fwd -> UMASFastGPUPermuteWignerInvEdgeToNode.forward
  - _fused_wigner_bwd_dx_l_to_m -> UMASFastGPUPermuteWignerInvEdgeToNode.backward
- Rename autograd Functions for clarity:
  - GatherRotateFunction -> UMASFastGPUNodeToEdgeWignerPermute
  - RotateBackFunction -> UMASFastGPUPermuteWignerInvEdgeToNode
- Rename public methods:
  - gather_rotate -> node_to_edge_wigner_permute
  - rotate_back -> permute_wigner_inv_edge_to_node
- Clean up module docstrings and __all__ exports

Benchmark results (2000 atoms, compiled):
  general:           11.57 QPS (baseline)
  umas_fast_pytorch: 12.79 QPS (1.10x)
  umas_fast_gpu:     15.14 QPS (1.31x)
@meta-cla meta-cla bot added the cla signed label Feb 26, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant