-
Notifications
You must be signed in to change notification settings - Fork 5
Description
Steps to Reproduce
- Open Windows Command Prompt in administrator mode, install WSL with following command:
$ wsl --update
$ wsl --install -d Ubuntu-24.04 - Install docker:
$ apt update
$ apt install docker.io -y - Create and enter a jax docker container:
$ docker run -it
-v /usr/lib/wsl/lib/libdxcore.so:/usr/lib/libdxcore.so
--device=/dev/dxg
--cap-add=SYS_PTRACE
--security-opt seccomp=unconfined
--ipc=host
--shm-size 16G
ghcr.io/rocm/jax-ubu24.rocm711:ea9a0213a8e8a72417fb87270ee186ee9d05c811
/bin/bash - Execute the command:
$ python3 -c "import jax; print(jax.devices())"
Description
JAX ROCm plugin GPU discovery currently relies solely on KFD topology, determining GPU count from /sys/class/kfd/kfd/topology/nodes.
In ROCm-on-WSL environments, this KFD topology path is not exposed, while GPU access is available via /dev/dxg.
As a result, GPU count is detected as 0, triggering the "No AMD GPUs were found" early-exit path, causing ROCm plugin initialization to fail and JAX to fall back to CPU.
Expected
[RocmDevice(id=0)]
Actual
ERROR:jax._src.xla_bridge:475: Jax plugin configuration error: Exception when calling jax_plugins.xla_rocm7.initialize()
Traceback (most recent call last):
File "/usr/local/lib/python3.12/dist-packages/jax/_src/xla_bridge.py", line 473, in discover_pjrt_plugins
plugin_module.initialize()
File "/usr/local/lib/python3.12/dist-packages/jax_plugins/xla_rocm7/init.py", line 245, in initialize
raise ValueError("No AMD GPUs were found, skipping ROCm plugin initialization")
ValueError: No AMD GPUs were found, skipping ROCm plugin initialization
WARNING:2026-03-05 02:27:58,524:jax._src.xla_bridge:852: An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
[CpuDevice(id=0)]
Metadata
Metadata
Assignees
Labels
Type
Projects
Status