Skip to content

MapDataset does not respect jax.default_device(jax.devices("cpu")[0]). #975

@james77777778

Description

@james77777778

I would like to use CPU only for the transforms using jax.numpy.

However, I’ve found that MapDataset does not respect jax.default_device(jax.devices("cpu")[0]).

You can see a clear demonstration of this issue in the following Colab notebook:
https://colab.research.google.com/drive/1hRfWBe8a09bPNnXSVjeLvcrYjoqcUA2b?usp=sharing

Is there a reason for this behavior?
Could this be a bug?

Metadata

Metadata

Assignees

No one assigned

    Labels

    type:bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions