Add openequivariance support for JAX#24
Conversation
teddykoker
left a comment
There was a problem hiding this comment.
Looks great! Just a couple small suggestions.
|
I was proposing putting the whole for i, (mul, ir_in1) in enumerate(input_irreps):
for j, (_, ir_in2) in enumerate(sh_irreps):
for ir_out in ir_in1 * ir_in2:
if ir_out in output_irreps or ir_out == e3nn.Irrep("0e"):
k = len(irreps_out_tp)section in the |
|
A couple more comments above otherwise LGTM! Will probably want to remove the NaCl example too. |
|
Done. I can push the changes to move the above section in |
|
Tests passing on CI and my GPU machine; going to sanity check training/PFT tomorrow but otherwise GTM! |
|
Current training speedup
These are on a single A100 GPU. MPtrj uses a batch size of 64 with energy/force/stress loss. PFT uses a batch size of 16 (although has much bigger supercell inputs) with a energy/force/stress/hvp loss. For multi gpu training, we use Added this to ongoing discussion |
|
pmap now working with PASSIONLab/OpenEquivariance#182 |
There is a bit more overhead to sync gradients between the GPUs (about 10-20ms in this case), which explains why the speedup is slightly lower. |
…eeding when running pytest if torch is not actually installed.
…th backends work if both are installed)
Also, fixes a minor bug with
jax-md. Will delete the NaCl example later.