Skip to content
This repository was archived by the owner on Aug 29, 2025. It is now read-only.
This repository was archived by the owner on Aug 29, 2025. It is now read-only.

Incompatibility with jax_enable_x64 #82

@zazbone

Description

@zazbone

Hi !

I seem like evojax use internally explicitly float32 array, which lead to Exception when jax support for float64 is manually set.
Here is a minimal piece of code which reproduce the problematic behavior :

import jax
import numpy as np
from evosax import CMA_ES

jax.config.update("jax_enable_x64", True)
jax.print_environment_info()
p = np.random.default_rng(0).uniform(-1, 1, 100)

def loss(x):
    return ((x - p) ** 2).sum()
vloss = jax.jit(jax.vmap(loss, 0, 0))

rng = jax.random.PRNGKey(0)
strategy = CMA_ES(popsize=8, num_dims=100)
es_params = strategy.default_params.replace(init_min=-1, init_max=1)
state = strategy.initialize(rng, es_params)
for t in range(1000):
    rng, rng_gen, rng_eval = jax.random.split(rng, 3)
    x, state = strategy.ask(rng_gen, state, es_params)
    fitness = vloss(x)
    state = strategy.tell(x, fitness, state, es_params)

Here is the output, I've only removed folder names containing personal informations :

jax:    0.4.35
jaxlib: 0.4.34
numpy:  2.1.2
python: 3.10.14 (main, Aug 20 2024, 11:12:06) [GCC 11.4.0]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node=[...], release='6.8.0-49-generic', version='#49-Ubuntu SMP PREEMPT_DYNAMIC Mon Nov  4 02:06:24 UTC 2024', machine='x86_64')

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/[...]/evolution.py", line 26, in <module>
    state = strategy.tell(x, fitness, state, es_params)
  File "/home/[...]/venv/lib/python3.10/site-packages/evosax/strategy.py", line 136, in tell
    best_member, best_fitness = get_best_fitness_member(
  File "/home/[...]/venv/lib/python3.10/site-packages/evosax/utils/helpers.py", line 22, in get_best_fitness_member
    best_fitness = jax.lax.select(
TypeError: lax.select requires arguments to have the same dtypes, got float32, float64. (Tip: jnp.where is a similar function that does automatic type promotion on inputs).

Thanks for your time.
Regards

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions