Skip to content

Conversation

@isayev
Copy link
Contributor

@isayev isayev commented Dec 16, 2025

Summary

This PR adds torch.compile support with CUDA graphs for significant speedups on GPU molecular dynamics simulations.

Based on community PR #20 from Acellera, but reworked with improvements:

  • Generalized model loading (not hardcoded to one model)
  • Backward-compatible (original cosine_cutoff signature unchanged)
  • Comprehensive test coverage
  • Code style compliance (passes pre-commit)
  • Based on current main branch

Changes

  • Add compile_mode=True parameter to AIMNet2Calculator and AIMNet2ASE
  • Add compile_nb_mode parameter throughout to avoid data-dependent control flow that breaks CUDA graph capture
  • Add get_model_definition_path() for mapping model names to YAML definitions
  • Add cosine_cutoff_tensor() for CUDA graphs compatibility
  • Add enable_compile_mode() to AIMNet2Base to propagate compile settings
  • Add calc_masks_fixed_nb_mode() for compile-time mask calculation

Performance

Based on benchmarks from community PR #20:

  • ~5x speedup on small molecule MD (76s → 15s for 10k steps on caffeine)
  • First call has compilation overhead (~10-30s)
  • Subsequent calls use cached CUDA graphs

Limitations

  • Only nb_mode=0 (single molecule, dense) supported
  • Requires CUDA
  • No PBC support in compile mode

Usage

from aimnet.calculators import AIMNet2Calculator
calc = AIMNet2Calculator("aimnet2", compile_mode=True)

Or with ASE:

from aimnet.calculators import AIMNet2ASE
calc = AIMNet2ASE("aimnet2", compile_mode=True)
atoms.calc = calc

Test plan

  • Existing tests pass
  • Pre-commit hooks pass
  • GPU tests with pytest -m gpu tests/test_compile.py (requires CUDA)
  • Manual benchmark with python examples/ase_md_compiled.py --compile

Closes #20 (supersedes with improvements)

This PR adds torch.compile support with CUDA graphs for significant
speedups on GPU molecular dynamics simulations. Based on community
PR #20 from Acellera, but reworked with improvements:

**Changes:**
- Add `compile_mode=True` parameter to `AIMNet2Calculator` and `AIMNet2ASE`
- Add `compile_nb_mode` parameter throughout to avoid data-dependent control
  flow that breaks CUDA graph capture
- Add `get_model_definition_path()` for mapping model names to YAML definitions
- Add `cosine_cutoff_tensor()` for CUDA graphs compatibility
- Add `enable_compile_mode()` to AIMNet2Base to propagate compile settings
- Add `calc_masks_fixed_nb_mode()` for compile-time mask calculation

**Improvements over original PR #20:**
- Generalized model loading (not hardcoded to one model)
- Backward-compatible (original `cosine_cutoff` signature unchanged)
- Comprehensive test coverage
- Code style compliance (passes pre-commit)
- Based on current main branch

**Limitations:**
- Only `nb_mode=0` (single molecule, dense) supported
- Requires CUDA
- No PBC support in compile mode
- First call has compilation overhead

**Usage:**
```python
from aimnet.calculators import AIMNet2Calculator
calc = AIMNet2Calculator("aimnet2", compile_mode=True)
```
@claude
Copy link

claude bot commented Dec 16, 2025

Pull Request Review: torch.compile with CUDA Graphs Support

Summary

This PR adds torch.compile support with CUDA graphs for ~5x speedup on GPU molecular dynamics simulations. The implementation is well-structured and thoughtfully designed with good backward compatibility.

✅ Strengths

Code Quality

  • Excellent backward compatibility: The original API remains unchanged with compile_mode as an optional parameter
  • Clean separation of concerns: Compile mode logic is isolated with the _compile_mode and _compile_nb_mode attributes
  • Comprehensive docstrings: New functions are well-documented with clear explanations
  • Good code organization: The compile-specific logic is appropriately separated from the dynamic path

Test Coverage

  • 212 lines of new tests covering multiple scenarios:
    • Basic functionality tests
    • Energy and forces consistency between compile/normal modes
    • Multiple model support
    • CUDA graph reuse validation
    • PBC rejection (as expected)
    • ASE calculator integration
  • Appropriate pytest markers: @pytest.mark.gpu and @pytest.mark.ase are properly used
  • Good test structure: Tests check both correctness (consistency) and basic functionality

Performance & Documentation

  • Clear performance documentation: ~5x speedup is well-documented with benchmark script
  • Good example code: ase_md_compiled.py provides a complete working example
  • Clear limitations: PBC and nb_mode restrictions are well-documented

🔍 Issues & Concerns

1. Critical: Potential Device Mismatch in calc_masks_fixed_nb_mode()

Location: aimnet/nbops.py:42-44

data["_input_padded"] = torch.tensor(False)
data["_natom"] = torch.tensor(data["numbers"].shape[1], device=data["numbers"].device)
data["mol_sizes"] = torch.tensor(data["numbers"].shape[1], device=data["numbers"].device)

Issue: _input_padded is created without specifying a device, which may cause it to be on CPU while other tensors are on GPU. This could break CUDA graph capture.

Fix: Add device=data["numbers"].device to all tensor creations:

data["_input_padded"] = torch.tensor(False, device=data["numbers"].device)

2. Potential Issue: cosine_cutoff_tensor() Behavior

Location: aimnet/ops.py:54-68

def cosine_cutoff_tensor(d_ij: Tensor, rc: Tensor) -> Tensor:
    return torch.where(d_ij < rc, 0.5 * (torch.cos(d_ij * (torch.pi / rc)) + 1.0), torch.zeros_like(d_ij))

Issue: Unlike the original cosine_cutoff() which clamps d_ij values, this doesn't clamp. The comparison d_ij < rc handles the cutoff, but the cosine computation for very small values near 0 might differ slightly from the clamped version.

Recommendation: Consider whether the behavioral difference (no clamp(min=1e-6)) is intentional or if numerical stability near zero should be preserved.

3. Code Clarity: Magic Number in Calculation

Location: aimnet/calculators/model_registry.py:70-75

if "nse" in model_name.lower():
    yaml_file = "aimnet2.yaml"
elif "d3" in model_name.lower() or "pd" in model_name.lower():
    yaml_file = "aimnet2_dftd3_wb97m.yaml"
else:
    yaml_file = "aimnet2_dftd3_wb97m.yaml"

Issue: The model-to-YAML mapping logic uses string matching which could be fragile if new model names are added that don't follow this pattern.

Recommendation: Consider using a more explicit mapping dictionary or documenting this convention in the model registry YAML.

4. Potential Bug: Incorrect Tensor Shape in calc_masks_fixed_nb_mode()

Location: aimnet/nbops.py:44

data["mol_sizes"] = torch.tensor(data["numbers"].shape[1], device=data["numbers"].device)

Issue: mol_sizes should likely be a 0-D or 1-D tensor, but the context suggests it might need to match the batch dimension. The dynamic version at line 72 shows:

data["mol_sizes"] = torch.tensor(data["numbers"].shape[1], device=data["numbers"].device)

Both create scalar tensors, but in the dynamic path with padding (line 67), it's computed as (~data["mask_i"]).sum(-1) which would be 1-D for batched input.

Recommendation: Verify that mol_sizes has the correct shape for batched operations in compile mode. Should it be torch.tensor([data["numbers"].shape[1]], ...) (1-D) instead?

5. Code Style: Trailing Newline Removals

The PR removes trailing newlines from several files (.devcontainer/devcontainer.json, .github/dependabot.yml, etc.). While this follows some style guides, it's unrelated to the feature and should ideally be in a separate commit or omitted.

6. Type Safety: Missing Type Hints

Location: Throughout new code

Several new parameters lack type hints:

  • compile_nb_mode parameters default to -1 but aren't typed
  • Return type hints are missing in some places

Recommendation: Add type hints for consistency with the existing codebase.

🔒 Security Concerns

No significant security issues identified. The implementation:

  • Validates compile_mode requirements appropriately (CUDA availability)
  • Raises clear errors for unsupported configurations
  • Doesn't introduce any obvious injection vectors or unsafe operations

⚡ Performance Considerations

Positive:

  • CUDA graphs properly utilized via triton.cudagraphs option
  • Appropriate use of fullgraph=True for maximum optimization
  • Fixed control flow avoids graph breaks

Potential Concerns:

  • First-call compilation overhead: Well-documented (~10-30s)
  • Memory usage: CUDA graphs can increase GPU memory consumption - not documented
  • Graph cache invalidation: No explicit handling for different input shapes, but the fixed nb_mode=0 restriction mitigates this

Recommendation: Document expected GPU memory overhead with CUDA graphs.

📋 Test Coverage Assessment

What's Covered:

✅ Basic functionality
✅ Energy/forces consistency
✅ Multiple model support
✅ CUDA graph reuse
✅ PBC rejection
✅ ASE integration

What's Missing:

❌ Edge cases (empty molecules, single atoms)
❌ Error handling paths
❌ Memory usage tests
❌ Stress tests with varying molecule sizes
❌ Tests for the get_model_definition_path() function
❌ Tests for different YAML file selections

🎯 Recommendations

Must Fix:

  1. Fix device mismatch in calc_masks_fixed_nb_mode() (Issue Help Aboutt Using Your Model #1)
  2. Verify mol_sizes tensor shape (Issue Cleanup #4)

Should Consider:

  1. Review cosine_cutoff_tensor() numerical stability (Issue Small data sets throw an error on batching #2)
  2. Add type hints to new parameters
  3. Add unit tests for get_model_definition_path()
  4. Document GPU memory overhead expectations

Nice to Have:

  1. More robust model-to-YAML mapping (Issue Support for AMD gpu's #3)
  2. Separate style fixes (trailing newlines) from feature commits
  3. Add stress tests for various molecule sizes
  4. Consider adding a benchmark comparison in CI (if GPU runners are available)

📝 Final Assessment

This is a high-quality PR that delivers significant performance improvements while maintaining backward compatibility. The code is well-structured, thoroughly tested, and properly documented. The critical issues are minor and easily addressable. With the device mismatch fix, this would be ready to merge.

Recommendation: Approve pending fixes for Issues #1 and #4.

Great work on bringing the community contribution up to project standards! 🚀

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants