v0.5.4
Improvements to JAX frontend.
Added to openequivariance.jax:
- Jacobian Vector Products (JVP) for both
TensorProductandTensorProductConvvia custom primitives, in addition to VJP. - Arbitrary higher-order derivatives. In particular, support for Phonon Fine Tuning in Nequix.
jax.jitsupport.- Initial support for vmap / pmap applied to fused convolution (to-do for unfused tensor products).
Fixed:
- Zero all output buffers in the backwards and double-backwards implementations of convolution before calling kernels.