-
Notifications
You must be signed in to change notification settings - Fork 2
Open
Description
JAX numpy seems to be slower than numpy unless you also utilize JIT or offload to an accelerator. When JAX is in use, the code base will convert a numpy array to a JAX numpy array prior to calling a JIT function, but doesn't convert the array back once the JIT function returns. An investigation should be done to see if a conversion back introduces a speed improvement over the current process.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels