Skip to content

[Bug] CUDA Illegal Instruction Error on NVIDIA H20 GPU (Hopper sm_90) #269

@Yuiclein1001

Description

@Yuiclein1001

[Bug] CUDA Illegal Instruction Error on NVIDIA H20 GPU (Hopper sm_90)

🐛 Bug Description

StreamPETR training fails with RuntimeError: CUDA error: an illegal instruction was encountered during backward pass on NVIDIA H20-3e GPU (Hopper architecture, sm_90).

The issue occurs ONLY during backward pass, forward pass works perfectly.

📋 Environment

  • GPU: NVIDIA H20-3e (Hopper architecture, compute capability sm_90)
    • VRAM: 140GB HBM3
    • Driver: CUDA 12.4
  • OS: Linux 5.15.0-60-generic
  • Python: 3.8.20
  • PyTorch: 2.0.1 (compiled with CUDA 11.8)
  • MMCV: 1.7.0 (mmcv-full, compiled with CUDA 11.8)
  • MMDetection3D: 1.0.0rc6
  • NumPy: 1.22.0

🔄 Steps to Reproduce

  1. Setup environment:

    conda create -n streampetr python=3.8
    conda activate streampetr
    pip install torch==2.0.1 torchvision==0.15.2 --index-url https://download.pytorch.org/whl/cu118
  2. Install dependencies:

    pip install mmcv-full==1.7.0
    pip install mmdet==2.28.2
    pip install mmsegmentation==0.30.0
    cd mmdetection3d && pip install -e .
  3. Run training:

    python tools/train.py projects/configs/StreamPETR/stream_petr_r50_flash_704_bs1_semi_supervised_h20_optimized.py \
        --work-dir work_dirs/run_2stage/

❌ Error Output

Traceback (most recent call last):
  File "tools/train.py", line 269, in <module>
    main()
  File "tools/train.py", line 257, in main
    custom_train_model(
  ...
  File "/path/to/torch/autograd/__init__.py", line 200, in backward
    Variable._execution_engine.run_backward(
RuntimeError: CUDA error: an illegal instruction was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

🔍 Investigation & Attempted Solutions

✅ What Works:

  • Forward pass: Model inference works perfectly
  • Single operation backward: Simple tensor operations work fine
  • GPU recognition: CUDA device is correctly detected

❌ What Doesn't Work:

  • Complex model backward pass: Training fails during loss.backward()
  • FP16 training: Same error with mixed precision
  • Gradient checkpointing: Same error with with_cp=True

🛠️ Attempted Fixes:

  1. Disabled Flash-Attention (not compatible with sm_90):

    • Added fallback to torch.nn.functional.scaled_dot_product_attention
    • Modified projects/mmdet3d_plugin/models/utils/attention.py
  2. Disabled Gradient Checkpointing:

    • Set with_cp=False in both img_backbone and transformerlayers
  3. Disabled FP16:

    • Set fp16=None in config
  4. Compiled MMCV with matching CUDA versions:

    • Tried CUDA 11.8 (matching PyTorch)
    • Tried CUDA 12.5 (matching system driver)
  5. Attempted PyTorch 2.1.2 upgrade:

    • Failed due to ABI incompatibility with MMCV 1.7.0
    • Error: undefined symbol: _ZNK3c106SymIntltEl

All attempts still result in the same illegal instruction error.

🤔 Root Cause Analysis

The issue appears to be PyTorch 2.0.1's incomplete support for Hopper (sm_90) architecture:

  • PyTorch 2.0.1 (March 2023) was released shortly after Hopper GPUs
  • According to PyTorch docs, sm_90 support was marked as "early/experimental" in 2.0.x
  • The backward pass likely uses CUDA kernels not fully optimized/tested for sm_90

Evidence:

  1. Forward pass works (simpler kernels)
  2. Backward pass fails (complex autograd kernels)
  3. Error persists across different CUDA compilation versions
  4. Cannot upgrade to PyTorch 2.1+ due to MMCV 1.7.0 ABI incompatibility

✨ Expected Behavior

Training should work on H20 GPU without illegal instruction errors, similar to how it works on V100/A100 GPUs.

🙋 Questions for Maintainers

  1. Has StreamPETR been tested on Hopper (H100/H20) GPUs?
  2. Is there a recommended PyTorch/MMCV version combination for sm_90 GPUs?
  3. Would upgrading to MMCV 2.x + PyTorch 2.3+ solve this? (requires code migration)
  4. Are there any H20-specific branches or forks available?

💡 Suggested Solutions

Short-term (for current users):

  1. Provide Docker image with PyTorch 2.3+ (native sm_90 support)
  2. Document known incompatibilities with specific GPU architectures

Long-term:

  1. Update to MMCV 2.x and PyTorch 2.3+ for better Hopper support
  2. Add GPU architecture detection and warnings in setup

📌 Additional Notes

  • Inference works perfectly - users can still deploy trained models on H20
  • This affects all Hopper GPUs (H100, H20, H200, etc.)
  • Workaround: Train on older GPUs (V100/A100), deploy on H20 for inference

🔗 Related Issues


Any guidance or suggestions would be greatly appreciated! The H20 has 140GB VRAM which would be perfect for large-scale training if we can resolve this compatibility issue.

Thank you! 🙏

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions