Skip to content

feat: AMD ROCm GPU support — device enumeration, compile mode, SDPA fix#206

Merged
nikopueringer merged 7 commits intonikopueringer:mainfrom
JamesNyeVRGuy:feat/amd-rocm-support
Mar 27, 2026
Merged

feat: AMD ROCm GPU support — device enumeration, compile mode, SDPA fix#206
nikopueringer merged 7 commits intonikopueringer:mainfrom
JamesNyeVRGuy:feat/amd-rocm-support

Conversation

@JamesNyeVRGuy
Copy link
Contributor

AMD ROCm GPU Support (RDNA3 / RDNA4)

Adds support for AMD Radeon GPUs via PyTorch's ROCm/HIP backend. Most inference code works unchanged because HIP intercepts all torch.cuda.* calls transparently — the changes are in GPU enumeration, torch.compile configuration, and environment variable setup.

What changed

device_utils.py

  • Added GPUInfo dataclass and enumerate_gpus() function
  • GPU detection chain: nvidia-smi → amd-smi (ROCm 6.0+) → rocm-smi (legacy) → torch.cuda fallback
  • Works on NVIDIA, AMD, or mixed systems

corridorkey_cli.py

  • ROCm environment variables set at CLI entry point before any torch import — covers GVM, BiRefNet, VideoMaMa (not just inference engine)
  • Auto-detects and patches pytorch-rocm-gtt if installed

CorridorKeyModule/inference_engine.py

  • Same ROCm environment variables set before import torch (for direct engine usage):
    • TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 — enables flash attention on RDNA3 (without this, SDPA silently falls back to O(n²) math kernel)
    • MIOPEN_FIND_MODE=2 — fast convolution kernel selection (reduces warmup from 5-8 min to seconds)
    • MIOPEN_LOG_LEVEL=0 — silences harmless workspace warnings
  • torch.compile uses mode="default" on ROCm (max-autotune OOM-kills 16GB cards during kernel benchmarking)
  • torch.compile skipped entirely on Windows ROCm (Triton compilation hangs indefinitely)
  • Auto-detects and uses pytorch-rocm-gtt if installed (GTT system RAM overflow for 16GB cards on Linux)
  • Dummy input freed + VRAM cache cleared after compile warmup
  • TORCHINDUCTOR_CACHE_DIR persisted to ~/.cache/corridorkey/inductor/ (default /tmp gets wiped, causing 10-20 min re-autotune on every reboot)
  • CORRIDORKEY_SKIP_COMPILE=1 env var to force eager mode for testing

pyproject.toml

  • Added rocm optional extra with install instructions in comments

README.md

  • Added supported AMD GPUs list with VRAM requirements
  • Added setup instructions for Linux native, WSL2, and Windows
  • Documented automatic ROCm behavior, first-run autotune, 16GB card workarounds, WSL2 limitations

Supported GPUs

GPU VRAM 2048x2048 inference
RX 7900 XTX 24GB Full speed
RX 7900 XT 20GB Full speed
RX 7800 XT 16GB Windows: yes (shared mem). Linux: needs pytorch-rocm-gtt
RX 9070 XT 16GB Same as 7800 XT

Tested on

  • RX 7800 XT (gfx1101) — Windows native: torch.cuda.is_available()=True, SDPA float16 OK, BiRefNet 37 frames in 45s, CorridorKey inference ~4s/frame in eager mode
  • RX 7800 XT (gfx1101) — WSL2: rocminfo detects GPU, torch sees it, torch.compile works with Triton GPU kernels (not C++ fallback), but OOMs at 2048x2048 due to WSL2 VRAM hard limit (no shared memory)

Platform comparison

Platform torch.compile Flash attention VRAM overflow Speed
Linux native Yes (default mode) Yes (AOTriton) GTT if pytorch-rocm-gtt installed Best
Windows native Skipped (hangs) Yes (AOTriton) Automatic (WDDM shared memory) Good (~4s/frame eager)
WSL2 Yes Yes None (hard VRAM limit) OOM on 16GB cards

Not changed

  • Zero modifications to model architecture, inference loop, or output pipeline
  • NVIDIA path completely unaffected (all ROCm code gated behind torch.version.hip or /opt/rocm checks)
  • MPS/CPU paths unaffected

Checklist

  • uv run pytest passes
  • uv run ruff check passes
  • uv run ruff format --check passes

@JamesNyeVRGuy JamesNyeVRGuy force-pushed the feat/amd-rocm-support branch from 2c7421d to 7fcd2ae Compare March 26, 2026 07:43
@JamesNyeVRGuy
Copy link
Contributor Author

AMD Inference:
AMDInference

NVidia Inference:
NvidiaInference

Despill strength (0–10, 10 = max despill) (5):
Enable auto-despeckle (removes tracking dots)? [y/n] (y):
Despeckle size (min pixels for a spot) (400):
Refiner strength multiplier (experimental) (1.0):
[00:31:51] INFO     Not Apple Silicon — using torch backend                                          
Generate composition previews [y/n] (y):
Use GPU accelerated post-processing (experimental) [y/n] (n):
image

@Raiden129
Copy link

Hey, thanks for the work on this. This is a solid PR overall, and the runtime changes look genuinely useful. I did want to flag two issues that I think should be fixed before merge.

  1. The rocm extra does not actually install ROCm PyTorch

Right now the rocm extra looks like this:

rocm = [
    "torch==2.8.0",
    "torchvision==0.23.0",
]

But unlike the cuda extra, there is no corresponding [[tool.uv.index]] or [tool.uv.sources] entry that points to a ROCm wheel source.

The cuda extra is properly wired. When a user runs uv sync --extra cuda, they actually get CUDA wheels from the CUDA index.

The rocm extra has no equivalent wiring. Running uv sync --extra rocm just resolves torch and torchvision from the default source, which is not a ROCm build. There is nothing in the metadata that connects rocm to https://download.pytorch.org/whl/rocm6.3 or any other ROCm wheel index.

I know this is partially mentioned in the README and in the comment in pyproject.toml, where the instructions tell users to manually install the ROCm wheels first. But that is exactly why the extra is misleading. If someone sees --extra rocm, they will reasonably assume it is a supported install path in the same way --extra cuda is.

Right now it is not, and that is going to create confusion for users.

I think this should go one of two ways:

  • add a working ROCm index and source mapping similar to the CUDA setup, or
  • remove the rocm extra entirely and keep ROCm as a manual install path documented in the README

If the extra is not actually functional, I think it is better not to advertise it in the metadata.

  1. The Windows ROCm pre-import bootstrap probably does not activate on a normal Windows install

The pre-import ROCm detection in inference_engine.py and corridorkey_cli.py uses this check:

_is_rocm_system = os.environ.get("HIP_VISIBLE_DEVICES") is not None or os.path.exists("/opt/rocm")

On Linux this is reasonable, because /opt/rocm is the default ROCm install path and will usually exist on a ROCm system.

On native Windows, neither of these conditions is likely to be true by default.

/opt/rocm is a Linux path and does not exist on Windows. The ROCm docs show that the Windows side uses HIP_PATH with a default like C:/hip, not /opt/rocm.

ROCm docs:
https://rocm.docs.amd.com/en/develop/reference/env-variables.html

Also, HIP_VISIBLE_DEVICES is a GPU visibility or isolation variable, not an installation indicator. ROCm’s docs explicitly say it is unset by default.

That means on a normal native Windows ROCm install, this block probably does not run:

os.environ.setdefault("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL", "1")
os.environ.setdefault("MIOPEN_FIND_MODE", "2")
os.environ.setdefault("MIOPEN_LOG_LEVEL", "0")

This does not necessarily mean Windows ROCm is completely broken. The later post-import check:

is_rocm = hasattr(torch.version, "hip") and torch.version.hip

is much better, and things like skipping torch.compile on Windows ROCm should still work.

The problem is that the pre-import ROCm env vars that the PR and README describe as important probably will not get applied on a normal Windows install. In particular:

  • TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1
  • MIOPEN_FIND_MODE=2

So I think the Windows claim should either be narrowed, or the detection logic should be made platform-aware. For example, check Windows ROCm paths as well, or require an explicit opt-in env var for experimental Windows ROCm support.

The runtime work, compile handling, cache persistence, and documentation all look useful. I just think these two issues should be cleaned up so the install surface and platform behavior match the claims.

Add support for AMD GPUs (RDNA3/RDNA4) via ROCm's HIP translation layer.
Most PyTorch code works unchanged since HIP maps torch.cuda.* APIs, but
three areas needed explicit AMD handling:

device_utils.py:
- Add GPUInfo dataclass and enumerate_gpus() function
- Tries nvidia-smi, then amd-smi (ROCm 6.0+), then rocm-smi (legacy),
  then torch.cuda fallback — works on NVIDIA, AMD, or either

inference_engine.py:
- torch.compile uses "max-autotune-no-cudagraphs" on ROCm to avoid a
  known HIP graph segfault on large graphs (pytorch/pytorch#155720)
- Auto-sets TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 on ROCm so SDPA
  dispatches to AOTriton flash attention on RDNA3 (without this it
  silently falls back to O(n^2) math backend)
- Skips torch.compile entirely on Windows ROCm where Triton kernel
  compilation hangs indefinitely (eager fallback works fine)

pyproject.toml:
- Add "rocm" optional extra with install instructions

README.md:
- Add AMD ROCm to Hardware Requirements
- Add ROCm install instructions for Linux and Windows
- Add dedicated AMD ROCm Setup section with supported GPUs, automatic
  behavior, and known limitations

Tested on RX 7800 XT (gfx1101) Windows with AMD ROCm 7.2:
- torch.cuda.is_available() = True
- SDPA with float16 = OK
- Raw inference = OK
- torch.compile = hangs on Windows (skipped), works on Linux
@JamesNyeVRGuy JamesNyeVRGuy force-pushed the feat/amd-rocm-support branch from 7fcd2ae to 5acb900 Compare March 26, 2026 17:54
- Extract is_rocm_system() and setup_rocm_env() into device_utils.py as
  shared functions. Both corridorkey_cli.py and inference_engine.py now
  call setup_rocm_env() instead of duplicating the detection + env var
  logic.

- device_utils.py restructured so setup_rocm_env() is defined before
  `import torch`, allowing callers to set env vars before torch init.

- Broaden pytorch_rocm_gtt exception handler from ImportError to Exception
  (catches patch() runtime errors too, not just missing package).

- Add per-entry error handling in _enumerate_amd() — malformed amd-smi
  JSON entries are skipped instead of crashing the whole enumeration.
  Add json.JSONDecodeError to the outer exception handler.

- Use top-level `import json` instead of local `import json as _json`.

- Change MIOPEN_LOG_LEVEL from 0 (silent) to 4 (suppress info/debug,
  keep warnings and errors visible).

- Compile message is now conditional: "10-20 minutes on first run (ROCm)"
  vs just "Compiling model..." on NVIDIA where it takes ~30 seconds.

- README VRAM claim clarified: ~10GB on NVIDIA, ~18GB on AMD due to HIP
  allocator overhead. Prevents scaring off 12GB NVIDIA card owners.
@nikopueringer
Copy link
Owner

Big moves happening here! Let me know when you feel like this is ready to be merged.

…e error handling

1. Remove dead code: GPUInfo, enumerate_gpus(), _enumerate_nvidia(),
   _enumerate_amd() — 158 lines with zero callers in upstream. These
   are cloud-platform code that doesn't belong in this PR.

2. Remove unused imports (json, subprocess, dataclass) that were only
   needed by the dead GPU enumeration code.

3. Fix misleading "deferred so setup_rocm_env() can run first" comment.
   The env vars are read at operation time by PyTorch/MIOpen C libraries,
   not at torch import time. The code works correctly either way.

4. Remove all noqa: E402 comments — imports are now in normal order
   since the pre-import ordering was unnecessary.

5. Better exception handling in setup_rocm_env(): ImportError caught
   separately (expected, silent) vs other exceptions (logged as warning
   with traceback for debugging broken pytorch-rocm-gtt installs).

6. Cache is_rocm as self._is_rocm instance attribute instead of
   recomputing hasattr(torch.version, "hip") in both __init__ and
   _compile.
@Raiden129
Copy link

Hi, thanks for the follow-up changes. This is in much better shape now, but i see 2 minor issues worth fixing before merge.

1. setup_rocm_env() does not actually run before torch import

Right now inference_engine.py does:

from device_utils import setup_rocm_env as _setup_rocm_env
_setup_rocm_env()

but device_utils.py itself has a module-level:

import torch

So importing setup_rocm_env already imports the whole module, which imports torch before _setup_rocm_env() is ever called.

That means this docstring is not true in the current structure:

"""Must be called before importing torch."""

and this comment in inference_engine.py is also incorrect:

# noqa: E402 — no torch import

That import absolutely does trigger a torch import through device_utils.

To be clear, this still works because the env vars you are setting are only read later at SDPA / MIOpen dispatch time rather than at import torch, but that is an upstream behavior detail, not something the code guarantees.

If you want the pre-import contract to be real, setup_rocm_env() needs to live in a torch-free module, or the torch import in device_utils.py needs to be moved into the functions that actually need it.

2. The HIP_PATH detection does not do what the docstring says

is_rocm_system() says this in the docstring:

Checks: /opt/rocm (Linux), HIP_PATH (Windows, default C:\hip)

but the implementation only does:

os.environ.get("HIP_PATH") is not None

That is not the same thing as checking the default Windows install path.

So right now this is not actually checking for a Windows ROCm install at C:\hip or any real path at all. It is only checking whether the HIP_PATH environment variable happens to be present.

That means:

  • a default Windows ROCm install can false-negative if HIP_PATH is not exported
  • a stale or unrelated HIP_PATH env var can false-positive

Copy link
Contributor

@MarcelLieb MarcelLieb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would be my suggestion for how a rocm extra could work (Not tested)

pyproject.toml Outdated
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch = [
  { index = "pytorch-cu128", extra = "cuda" },
  { index = "pytorch-rocm", extra = "rocm" },
]
torchvision = [
  { index = "pytorch-cu128", extra = "cuda" },
  { index = "pytorch-rocm", extra = "rocm" },
]

url = "https://download.pytorch.org/whl/cu128" # CUDA 12.6 doesn't support RTX 5000 Series
explicit = true
extra = "cuda"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[[tool.uv.index]]
name = "pytorch-cu128"
url = "https://download.pytorch.org/whl/cu128" # CUDA 12.6 doesn't support RTX 5000 Series
explicit = true

[[tool.uv.index]]
name = "pytorch-rocm"
url = "https://download.pytorch.org/whl/rocm6.3"
explicit = true

conflicts = [
[
{ extra = "cuda" },
{ extra = "mlx" },
Copy link
Contributor

@MarcelLieb MarcelLieb Mar 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

{ extra = "rocm" }

@@ -50,6 +50,10 @@
mlx = [
"corridorkey-mlx ; python_version >= '3.11'",
]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rocm = [
    "torch==2.8.0",
    "torchvision==0.23.0",
]

@JamesNyeVRGuy JamesNyeVRGuy force-pushed the feat/amd-rocm-support branch from a7a5478 to 6025c57 Compare March 26, 2026 20:57
Use uv's explicit index + extra system to route torch/torchvision to
the ROCm wheel index, matching the existing CUDA pattern. Adds
pytorch-triton-rocm as an explicit ROCm dependency since the ROCm
index is marked explicit=true.

Linux native setup is now just `uv sync --extra rocm`. WSL2 and
Windows native still require manual install (different wheel sources).
@JamesNyeVRGuy JamesNyeVRGuy force-pushed the feat/amd-rocm-support branch from 6025c57 to f71edff Compare March 26, 2026 20:58
@JamesNyeVRGuy
Copy link
Contributor Author

Thanks Marcel! Your changes DID work, I am addressing the lint failures and going to be fixing the things Raiden pointed out as well~

…tring

Move torch import in device_utils.py from module level into the
functions that need it, so importing setup_rocm_env no longer triggers
a torch import. Restructure corridorkey_cli.py to call setup_rocm_env()
before any torch-importing modules. Remove redundant call from
inference_engine.py.

Fix is_rocm_system() docstring — remove false claim about checking
C:\hip path; it only checks the HIP_PATH env var.

Update tests for list-style torch sources and lazy torch import.
@JamesNyeVRGuy
Copy link
Contributor Author

Both fixed, thanks for catching these.

1. setup_rocm_env() now genuinely runs before torch import

device_utils.py no longer has a module-level import torch — it's moved into the three functions that actually need it (detect_best_device, resolve_device, clear_device_cache). In corridorkey_cli.py, setup_rocm_env() is now called before any torch-importing modules are loaded:

# Set ROCm env vars before any module imports torch.
from device_utils import setup_rocm_env
setup_rocm_env()

from clip_manager import ...  # noqa: E402
from CorridorKeyModule.backend import resolve_backend  # noqa: E402

The redundant call in inference_engine.py is removed — the CLI entry point handles it.

2. is_rocm_system() docstring fixed

Removed the false claim about checking C:\hip. The docstring now accurately reflects what the code does: checks HIP_PATH env var (not a filesystem path), and documents CORRIDORKEY_ROCM=1 as the explicit fallback for pip-installed ROCm on Windows where auto-detection can't work.

Also in this push: the rocm extra now works via uv sync --extra rocm (Marclie's suggestion). Tested on WSL2 with an AMD GPU — confirmed it pulls torch 2.8.0+rocm6.3 with zero nvidia-* packages.

Copy link
Contributor

@MarcelLieb MarcelLieb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does uv sync --extra rocm not work with WSL?

# 3. Install AMD's WSL torch wheels (Python 3.12)
pip3 install \
https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/torch-2.9.1%2Brocm7.2.0.lw.git7e1940d4-cp312-cp312-linux_x86_64.whl \

This comment was marked as resolved.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right that the WSL2 section in the README uses different wheels -> torch 2.9.1 from repo.radeon.com which are AMD's WSL-specific builds. Those were originally there because the standard ROCm wheels weren't tested on WSL2 at the time.

Now that uv sync --extra rocm works, the WSL2 section could potentially just point to that instead. The separate AMD WSL wheels + manual ROCm driver install may still be needed for some setups though (older drivers, specific GPU quirks), so I'm inclined to keep both paths documented.

@nikopueringer nikopueringer merged commit 37687cf into nikopueringer:main Mar 27, 2026
3 checks passed
JamesNyeVRGuy added a commit to JamesNyeVRGuy/CorridorKey-Cloud that referenced this pull request Mar 27, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants