Skip to content

[BUG] Large floating point errors #52

@adzcai

Description

@adzcai

Describe the bug

The sum tree implementation is sensitive to floating point errors. I noticed during my RL training runs that, at a certain point, sampling from the buffer would return experiences of all zeros. Setting JAX_ENABLE_X64=true fixed the issue.

To Reproduce

Here's a code snippet that illustrates the error.

import flashbax as fbx
import jax.numpy as jnp
import jax

batch_size = 32
length = 64
buffer = fbx.make_prioritised_trajectory_buffer(
    add_batch_size=batch_size,
    sample_batch_size=batch_size,
    sample_sequence_length=length,
    period=1,
    min_length_time_axis=length,
    max_length_time_axis=length * 4,
)

# initialize the state (just add once)
state = buffer.init(0.0)
state = buffer.add(state, jnp.zeros((batch_size, length * 2)))
assert buffer.can_sample(state)
key = jax.random.key(42)
key, key_ = jax.random.split(key)
sample = buffer.sample(state, key_)

# update the priorities multiple times
for i in range(1000):
    key, key_ = jax.random.split(key)
    priorities = jax.random.uniform(key_, batch_size, minval=0.1, maxval=5)
    state = buffer.set_priorities(
        state,
        sample.indices,
        priorities,
    )
    nodes = state.priority_state.nodes
    if i % 100 == 0:
        print(nodes[0] - nodes[nodes.size // 2 :].sum())

Expected behavior

I would hope for the printed values to all be zero (or at least on the order of 1e-6). However, I get the values

-0.00024414062
0.0014648438
0.008056641
0.0075683594
0.0024414062
0.007080078
0.014404297
0.02368164
0.020996094
0.024414062

This makes it possible for an out-of-bounds index to be sampled.

Context (Environment)

I'm running macOS 15.3.1. I just ran pip install flashbax in a fresh environment:

absl-py==2.1.0
chex==0.1.88
etils==1.12.0
flashbax==0.1.2
flax==0.10.3
fsspec==2025.2.0
humanize==4.12.0
importlib-resources==6.5.2
jax==0.5.0
jaxlib==0.5.0
markdown-it-py==3.0.0
mdurl==0.1.2
ml-dtypes==0.5.1
msgpack==1.1.0
nest-asyncio==1.6.0
numpy==2.2.3
opt-einsum==3.4.0
optax==0.2.4
orbax-checkpoint==0.11.5
protobuf==5.29.3
pygments==2.19.1
pyyaml==6.0.2
rich==13.9.4
scipy==1.15.2
simplejson==3.20.1
tensorstore==0.1.71
toolz==1.0.0
treescope==0.1.8
typing-extensions==4.12.2
zipp==3.21.0

Additional context

No other context.

Possible Solution

It might be a good idea to just recompute the entire tree occasionally if the difference starts getting large.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions