Skip to content

Plugins not being loaded properly in gpu_triton in "community" docker images #339

@owainkenwayucl

Description

@owainkenwayucl

Description

I've been trying to get AlphaFold3 working on an AMD Mi300x box we have on loan from HPE and I've run into some issues caused by the test here:

_hip_triton = import_from_plugin("rocm", "_triton")
failing on the "community" docker images for JAX, which results in triton complaining that we don't have the GPU version of JAX installed.

I've tried:

docker.io/rocm/jax-community     latest                         193ba487b999
docker.io/rocm/jax-community     rocm6.2.4-jax0.4.35-py3.11.10  ef50d5181ba5  
docker.io/rocm/jax-community     rocm6.2.3-jax0.4.34-py3.11.10  b229479e4af8  

As an example, you can test this in a container built from these images:

Python 3.10.16 (main, Feb 17 2025, 01:40:07) [GCC 11.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> from jaxlib import gpu_triton
>>> gpu_triton._hip_triton
>>>

For comparison with the "non community" but older JAX:

docker.io/rocm/jax               latest                         d949265c6ac2
Python 3.10.12 (main, Feb  4 2025, 14:57:36) [GCC 11.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> from jaxlib import gpu_triton
>>> gpu_triton._hip_triton
<module 'jax_rocm60_plugin._triton' from '/opt/venv/lib/python3.10/site-packages/jax_rocm60_plugin/_triton.so'>
>>> 

I'm also finding it difficult to install from wheels as per https://github.com/rocm/jax/tree/main/build/rocm because pip can't find the wheels for the ROCm features.

e.g.

root@ea47cc7f8ef9:/# pip install jax[rocm]==0.4.38
Collecting jax[rocm]==0.4.38
  Downloading jax-0.4.38-py3-none-any.whl (2.2 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.2/2.2 MB 52.2 MB/s eta 0:00:00
WARNING: jax 0.4.38 does not provide the extra 'rocm'

I remember distantly from installing things on Nvidia GPUs, that there is a "secret" repository you have to use sometimes to get JAX Cuda plugins to work (https://storage.googleapis.com/jax-releases/jax_cuda_releases.html) - is there a similar thing for ROCm plugins?

System info (python version, jaxlib version, accelerator, etc.)

docker.io/rocm/jax-community containers, tag rocm6.2.3-jax0.4.34-py3.11.10, rocm6.2.4-jax0.4.35-py3.11.10, latest

Python 3.10.16 (main, Feb 17 2025, 01:40:07) [GCC 11.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax; jax.print_environment_info()
jax:    0.5.0
jaxlib: 0.5.0
numpy:  1.26.4
python: 3.10.16 (main, Feb 17 2025, 01:40:07) [GCC 11.4.0]
device info: AMD Instinct MI300X-8, 8 local devices"
process_count: 1
platform: uname_result(system='Linux', node='5ca490f32332', release='5.14.0-503.34.1.el9_5.x86_64', version='#1 SMP PREEMPT_DYNAMIC Thu Mar 27 06:00:50 EDT 2025', machine='x86_64')

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions