-
Notifications
You must be signed in to change notification settings - Fork 10
Description
Description
I've been trying to get AlphaFold3 working on an AMD Mi300x box we have on loan from HPE and I've run into some issues caused by the test here:
Line 20 in 1f93b4b
| _hip_triton = import_from_plugin("rocm", "_triton") |
I've tried:
docker.io/rocm/jax-community latest 193ba487b999
docker.io/rocm/jax-community rocm6.2.4-jax0.4.35-py3.11.10 ef50d5181ba5
docker.io/rocm/jax-community rocm6.2.3-jax0.4.34-py3.11.10 b229479e4af8
As an example, you can test this in a container built from these images:
Python 3.10.16 (main, Feb 17 2025, 01:40:07) [GCC 11.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> from jaxlib import gpu_triton
>>> gpu_triton._hip_triton
>>>
For comparison with the "non community" but older JAX:
docker.io/rocm/jax latest d949265c6ac2
Python 3.10.12 (main, Feb 4 2025, 14:57:36) [GCC 11.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> from jaxlib import gpu_triton
>>> gpu_triton._hip_triton
<module 'jax_rocm60_plugin._triton' from '/opt/venv/lib/python3.10/site-packages/jax_rocm60_plugin/_triton.so'>
>>>
I'm also finding it difficult to install from wheels as per https://github.com/rocm/jax/tree/main/build/rocm because pip can't find the wheels for the ROCm features.
e.g.
root@ea47cc7f8ef9:/# pip install jax[rocm]==0.4.38
Collecting jax[rocm]==0.4.38
Downloading jax-0.4.38-py3-none-any.whl (2.2 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.2/2.2 MB 52.2 MB/s eta 0:00:00
WARNING: jax 0.4.38 does not provide the extra 'rocm'
I remember distantly from installing things on Nvidia GPUs, that there is a "secret" repository you have to use sometimes to get JAX Cuda plugins to work (https://storage.googleapis.com/jax-releases/jax_cuda_releases.html) - is there a similar thing for ROCm plugins?
System info (python version, jaxlib version, accelerator, etc.)
docker.io/rocm/jax-community containers, tag rocm6.2.3-jax0.4.34-py3.11.10, rocm6.2.4-jax0.4.35-py3.11.10, latest
Python 3.10.16 (main, Feb 17 2025, 01:40:07) [GCC 11.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax; jax.print_environment_info()
jax: 0.5.0
jaxlib: 0.5.0
numpy: 1.26.4
python: 3.10.16 (main, Feb 17 2025, 01:40:07) [GCC 11.4.0]
device info: AMD Instinct MI300X-8, 8 local devices"
process_count: 1
platform: uname_result(system='Linux', node='5ca490f32332', release='5.14.0-503.34.1.el9_5.x86_64', version='#1 SMP PREEMPT_DYNAMIC Thu Mar 27 06:00:50 EDT 2025', machine='x86_64')