-
Notifications
You must be signed in to change notification settings - Fork 21
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
The sum tree implementation is sensitive to floating point errors. I noticed during my RL training runs that, at a certain point, sampling from the buffer would return experiences of all zeros. Setting JAX_ENABLE_X64=true fixed the issue.
To Reproduce
Here's a code snippet that illustrates the error.
import flashbax as fbx
import jax.numpy as jnp
import jax
batch_size = 32
length = 64
buffer = fbx.make_prioritised_trajectory_buffer(
add_batch_size=batch_size,
sample_batch_size=batch_size,
sample_sequence_length=length,
period=1,
min_length_time_axis=length,
max_length_time_axis=length * 4,
)
# initialize the state (just add once)
state = buffer.init(0.0)
state = buffer.add(state, jnp.zeros((batch_size, length * 2)))
assert buffer.can_sample(state)
key = jax.random.key(42)
key, key_ = jax.random.split(key)
sample = buffer.sample(state, key_)
# update the priorities multiple times
for i in range(1000):
key, key_ = jax.random.split(key)
priorities = jax.random.uniform(key_, batch_size, minval=0.1, maxval=5)
state = buffer.set_priorities(
state,
sample.indices,
priorities,
)
nodes = state.priority_state.nodes
if i % 100 == 0:
print(nodes[0] - nodes[nodes.size // 2 :].sum())Expected behavior
I would hope for the printed values to all be zero (or at least on the order of 1e-6). However, I get the values
-0.00024414062
0.0014648438
0.008056641
0.0075683594
0.0024414062
0.007080078
0.014404297
0.02368164
0.020996094
0.024414062
This makes it possible for an out-of-bounds index to be sampled.
Context (Environment)
I'm running macOS 15.3.1. I just ran pip install flashbax in a fresh environment:
absl-py==2.1.0
chex==0.1.88
etils==1.12.0
flashbax==0.1.2
flax==0.10.3
fsspec==2025.2.0
humanize==4.12.0
importlib-resources==6.5.2
jax==0.5.0
jaxlib==0.5.0
markdown-it-py==3.0.0
mdurl==0.1.2
ml-dtypes==0.5.1
msgpack==1.1.0
nest-asyncio==1.6.0
numpy==2.2.3
opt-einsum==3.4.0
optax==0.2.4
orbax-checkpoint==0.11.5
protobuf==5.29.3
pygments==2.19.1
pyyaml==6.0.2
rich==13.9.4
scipy==1.15.2
simplejson==3.20.1
tensorstore==0.1.71
toolz==1.0.0
treescope==0.1.8
typing-extensions==4.12.2
zipp==3.21.0
Additional context
No other context.
Possible Solution
It might be a good idea to just recompute the entire tree occasionally if the difference starts getting large.
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working