This repository was archived by the owner on Aug 29, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 111
This repository was archived by the owner on Aug 29, 2025. It is now read-only.
Incompatibility with jax_enable_x64 #82
Copy link
Copy link
Open
Description
Hi !
I seem like evojax use internally explicitly float32 array, which lead to Exception when jax support for float64 is manually set.
Here is a minimal piece of code which reproduce the problematic behavior :
import jax
import numpy as np
from evosax import CMA_ES
jax.config.update("jax_enable_x64", True)
jax.print_environment_info()
p = np.random.default_rng(0).uniform(-1, 1, 100)
def loss(x):
return ((x - p) ** 2).sum()
vloss = jax.jit(jax.vmap(loss, 0, 0))
rng = jax.random.PRNGKey(0)
strategy = CMA_ES(popsize=8, num_dims=100)
es_params = strategy.default_params.replace(init_min=-1, init_max=1)
state = strategy.initialize(rng, es_params)
for t in range(1000):
rng, rng_gen, rng_eval = jax.random.split(rng, 3)
x, state = strategy.ask(rng_gen, state, es_params)
fitness = vloss(x)
state = strategy.tell(x, fitness, state, es_params)Here is the output, I've only removed folder names containing personal informations :
jax: 0.4.35
jaxlib: 0.4.34
numpy: 2.1.2
python: 3.10.14 (main, Aug 20 2024, 11:12:06) [GCC 11.4.0]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node=[...], release='6.8.0-49-generic', version='#49-Ubuntu SMP PREEMPT_DYNAMIC Mon Nov 4 02:06:24 UTC 2024', machine='x86_64')
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/[...]/evolution.py", line 26, in <module>
state = strategy.tell(x, fitness, state, es_params)
File "/home/[...]/venv/lib/python3.10/site-packages/evosax/strategy.py", line 136, in tell
best_member, best_fitness = get_best_fitness_member(
File "/home/[...]/venv/lib/python3.10/site-packages/evosax/utils/helpers.py", line 22, in get_best_fitness_member
best_fitness = jax.lax.select(
TypeError: lax.select requires arguments to have the same dtypes, got float32, float64. (Tip: jnp.where is a similar function that does automatic type promotion on inputs).
Thanks for your time.
Regards
Metadata
Metadata
Assignees
Labels
No labels