Skip to content

jaxmd padding fix#25

Merged
teddykoker merged 2 commits intomainfrom
jaxmd-pad-fix
Feb 6, 2026
Merged

jaxmd padding fix#25
teddykoker merged 2 commits intomainfrom
jaxmd-pad-fix

Conversation

@teddykoker
Copy link
Member

Separating from #24 (thanks @abhijeetgangan )

@abhijeetgangan
Copy link
Contributor

abhijeetgangan commented Feb 6, 2026

For repro just run the test test_jax_md.py with and without kernels. The kernel ones give wrong answer for grads. The suspect is ffi'ed functions handling indexing differently. Let me know if you can repro it.

@abhijeetgangan
Copy link
Contributor

Another repro without jax-md

import jax
import jax.numpy as jnp
import openequivariance as oeq

# Distinct values to verify correctness
X = jnp.array([[1., 2.], [3., 4.], [5., 6.]], dtype=jnp.float32)  # 3 nodes
Y = jnp.ones((4, 1), dtype=jnp.float32)  # 4 edges
W = jnp.ones((4, 2), dtype=jnp.float32)  # 4 edges

# Edge 3: sender -> receiver 2
receivers = jnp.array([0, 1, 1, 2], dtype=jnp.int32)
senders_valid = jnp.array([1, 0, 2, 1], dtype=jnp.int32)  # sender=1 for edge 3
senders_oob = jnp.array([1, 0, 2, 3], dtype=jnp.int32)    # sender=3 OOB

def jax_conv(X, Y, W, receivers, senders):
    msgs = X[senders] * Y * W
    return jax.ops.segment_sum(msgs, receivers, num_segments=X.shape[0])

problem = oeq.TPProblem("2x0e", "1x0e", "2x0e", [(0, 0, 0, "uvu", True)],
                        shared_weights=False, internal_weights=False)
tp_conv = oeq.jax.TensorProductConv(problem, deterministic=False)

jax_valid = jax_conv(X, Y, W, receivers, senders_valid)
oeq_valid = tp_conv.forward(X, Y, W, receivers, senders_valid)
print(f"JAX Valid: {jax_valid[2]}")
print(f"OEQ Valid: {oeq_valid[2]}")

# OOB case
jax_oob = jax_conv(X, Y, W, receivers, senders_oob)
oeq_oob = tp_conv.forward(X, Y, W, receivers, senders_oob)
print(f"JAX OOB: {jax_oob[2]}")
print(f"OEQ OOB: {oeq_oob[2]}")

@teddykoker
Copy link
Member Author

Getting the same results! I guess when the edge indices are greater than the number of nodes we can run into undefined behavior. So it is good to add this padding to the node to avoid the index out of bounds. Thanks for clarifying!

@teddykoker teddykoker merged commit 1fb985d into main Feb 6, 2026
2 checks passed
@teddykoker teddykoker deleted the jaxmd-pad-fix branch February 6, 2026 22:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants