Skip to content

Jax CUDA errors/warnings #152

@epaillas

Description

@epaillas

A simple import on a CPU node triggers these messages:

In [4]: from acm.hod import CutskyHOD

ERROR:2026-01-30 07:41:46,407:jax._src.xla_bridge:477: Jax plugin configuration error: Exception when calling jax_plugins.xla_cuda12.initialize()
Traceback (most recent call last):
  File "/global/common/software/desi/users/adematti/perlmutter/cosmodesiconda/20251214-1.0.0/conda/lib/python3.12/site-packages/jax/_src/xla_bridge.py", line 475, in discover_pjrt_plugins
    plugin_module.initialize()
  File "/global/common/software/desi/users/adematti/perlmutter/cosmodesiconda/20251214-1.0.0/conda/lib/python3.12/site-packages/jax_plugins/xla_cuda12/__init__.py", line 328, in initialize
    _check_cuda_versions(raise_on_first_error=True)
  File "/global/common/software/desi/users/adematti/perlmutter/cosmodesiconda/20251214-1.0.0/conda/lib/python3.12/site-packages/jax_plugins/xla_cuda12/__init__.py", line 285, in _check_cuda_versions
    local_device_count = cuda_versions.cuda_device_count()
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: jaxlib/cuda/versions_helpers.cc:113: operation cuInit(0) failed: CUDA_ERROR_NO_DEVICE

Looks like the code would crash but it's actually just a warning, so the rest keeps running. Would be good to track down what import inside the HOD class is triggering this and see if there's a way around it.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions