ESM2 implementation in JAX/Equinox
- The code is keeping the exact same structure as the original model to make the translation easier to understand.
- The orginal pytorch implementation is https://github.com/facebookresearch/esm/blob/main/esm/model/esm2.py
- Equinox is a very nice jax library for building ML models: https://github.com/patrick-kidger/equinox