I would like to use CPU only for the transforms using jax.numpy.
However, I’ve found that MapDataset does not respect jax.default_device(jax.devices("cpu")[0]).
You can see a clear demonstration of this issue in the following Colab notebook:
https://colab.research.google.com/drive/1hRfWBe8a09bPNnXSVjeLvcrYjoqcUA2b?usp=sharing
Is there a reason for this behavior?
Could this be a bug?