Not a bug, but I thought I would make you aware of this pull request from the main JAX devs,
[JAX] Implement importing external dlpack-aware Python arrays, which allows for creating jax.Arrays from external GPU arrays
asynchronously.
jax-ml/jax#17238