Skip to content

Add openequivariance support for JAX#24

Merged
teddykoker merged 27 commits intoatomicarchitects:mainfrom
abhijeetgangan:ag/oeq_integration
Feb 10, 2026
Merged

Add openequivariance support for JAX#24
teddykoker merged 27 commits intoatomicarchitects:mainfrom
abhijeetgangan:ag/oeq_integration

Conversation

@abhijeetgangan
Copy link
Contributor

Also, fixes a minor bug with jax-md. Will delete the NaCl example later.

Copy link
Member

@teddykoker teddykoker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Just a couple small suggestions.

@teddykoker
Copy link
Member

I was proposing putting the whole

        for i, (mul, ir_in1) in enumerate(input_irreps):
            for j, (_, ir_in2) in enumerate(sh_irreps):
                for ir_out in ir_in1 * ir_in2:
                    if ir_out in output_irreps or ir_out == e3nn.Irrep("0e"):
                        k = len(irreps_out_tp)

section in the if kernels: part, but I guess this would result in compute the tp_irreps differently (although it is essentially the same if you look in e3nn.tensor_product(). But perhaps for conciseness it just makes sense to keep this as you had originally to avoid too much nesting.

@teddykoker
Copy link
Member

A couple more comments above otherwise LGTM! Will probably want to remove the NaCl example too.

@abhijeetgangan
Copy link
Contributor Author

Done. I can push the changes to move the above section in if kernels: if you feel like it.

@teddykoker
Copy link
Member

Tests passing on CI and my GPU machine; going to sanity check training/PFT tomorrow but otherwise GTM!

@teddykoker
Copy link
Member

Current training speedup

Task Kernels Time Speedup
MPtrj training No 1.62 step/sec
MPtrj training Yes 7.41 step/sec 4.6x
PFT No 0.65 step/sec
PFT Yes 2.62 4.0x

These are on a single A100 GPU. MPtrj uses a batch size of 64 with energy/force/stress loss. PFT uses a batch size of 16 (although has much bigger supercell inputs) with a energy/force/stress/hvp loss.

For multi gpu training, we use jax.pmap and are currently getting an error:

NotImplementedError: Batching rule for 'conv_fwd' not implemented

Added this to ongoing discussion

@teddykoker teddykoker mentioned this pull request Feb 5, 2026
@teddykoker
Copy link
Member

pmap now working with PASSIONLab/OpenEquivariance#182

@teddykoker
Copy link
Member

Task GPUs Kernels Time Speedup
MPtrj training 1 x A100 No 1.62 step/sec
MPtrj training 1 x A100 Yes 7.41 step/sec 4.6x
MPtrj training 4 x A100 No 1.55 step/sec
MPtrj training 4 x A100 Yes 6.78 step/sec 4.37x

There is a bit more overhead to sync gradients between the GPUs (about 10-20ms in this case), which explains why the speedup is slightly lower.

@teddykoker teddykoker merged commit 66126cb into atomicarchitects:main Feb 10, 2026
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants