Conversation
Adds GPU-optimized execution backend using Triton JIT kernels for Wigner D-matrix transformations (gather_rotate, rotate_back). Key components: - UMASFastGPUBackend: Extends UMASFastPytorchBackend with Triton ops - triton/ops.py: Public API (scatter_wigner, gather_wigner autograd Functions) - triton/constants.py: Wigner transformation constants - triton/_kernels/: Triton JIT kernels for lmax=2 transforms Requirements: lmax==2, mmax==2, sphere_channels % 128 == 0, merge_mole=True Performance: ~33% speedup over umas_fast_pytorch backend
The E parameter was passed to kernels but never used inside kernel bodies - edge_id is obtained from tl.program_id(0) instead. Since E was marked as tl.constexpr, Triton was recompiling the kernel for every unique edge count, causing unnecessary JIT compilation overhead. Removed from 5 kernels: - wigner_lmax2_fwd_kernel - wigner_lmax2_bwd_dx_kernel - wigner_lmax2_bwd_dw_kernel - l_to_m_kernel - m_to_l_kernel
- Update test imports for new _kernels/ subfolder paths - Add MockEdgeDegreeEmbedding to validation tests for activation_checkpoint_chunk_size check
The UMASFastPytorchBackend.validate() now checks both: 1. Model-level: edge_degree_embedding.activation_checkpoint_chunk_size 2. Settings-level: settings.activation_checkpointing Both checks are needed because validate() runs BEFORE the model is rebuilt with inference settings.
Major cleanup of the triton folder reducing codebase by ~3000 lines while maintaining 1.31x speedup (15.14 QPS vs 11.57 QPS baseline). Changes: - Remove 4 dead kernels from wigner_transform.py (wigner_lmax2_fwd_kernel, wigner_lmax2_bwd_dx_kernel, l_to_m_kernel, m_to_l_kernel) - Remove dead fused_wigner_backward_scatter_add from gather_wigner_bwd.py - Remove unused edge_distance parameter from forward methods - Inline helper functions into autograd backward methods: - _wigner_lmax2_bwd_dw -> UMASFastGPUPermuteWignerInvEdgeToNode.backward - _fused_m_to_l_wigner_fwd -> UMASFastGPUPermuteWignerInvEdgeToNode.forward - _fused_wigner_bwd_dx_l_to_m -> UMASFastGPUPermuteWignerInvEdgeToNode.backward - Rename autograd Functions for clarity: - GatherRotateFunction -> UMASFastGPUNodeToEdgeWignerPermute - RotateBackFunction -> UMASFastGPUPermuteWignerInvEdgeToNode - Rename public methods: - gather_rotate -> node_to_edge_wigner_permute - rotate_back -> permute_wigner_inv_edge_to_node - Clean up module docstrings and __all__ exports Benchmark results (2000 atoms, compiled): general: 11.57 QPS (baseline) umas_fast_pytorch: 12.79 QPS (1.10x) umas_fast_gpu: 15.14 QPS (1.31x)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
No description provided.