Skip to content

v0.5.4

Latest

Choose a tag to compare

@vbharadwaj-bk vbharadwaj-bk released this 10 Feb 03:48
· 1 commit to main since this release
72122ce

v0.5.4

Improvements to JAX frontend.

Added to openequivariance.jax:

  • Jacobian Vector Products (JVP) for both TensorProduct and TensorProductConv via custom primitives, in addition to VJP.
  • Arbitrary higher-order derivatives. In particular, support for Phonon Fine Tuning in Nequix.
  • jax.jit support.
  • Initial support for vmap / pmap applied to fused convolution (to-do for unfused tensor products).

Fixed:

  • Zero all output buffers in the backwards and double-backwards implementations of convolution before calling kernels.