JAX-powered Cosmological Particle-Mesh N-body Solver
The new JaxPM v0.1.xx supports multi-GPU model distribution while remaining compatible with previous releases. These significant changes are still under development and testing, so please report any issues you encounter. For the older but more stable version, install:
pip install jaxpm==0.0.2
Basic installation can be done using pip:
pip install jaxpmFor more advanced installation for optimized distribution on gpu clusters, please install jaxDecomp first. See instructions here.
Provide a modern infrastructure to support differentiable PM N-body simulations using JAX:
- Keep implementation simple and readable, in pure NumPy API
- Any order forward and backward automatic differentiation
- Support automated batching using
vmap - Compatibility with external optimizer libraries like
optax - Now fully distributable on multi-GPU and multi-node systems using jaxDecomp working with
JAX v0.4.35
Thanks goes to these wonderful people (emoji key):
Francois Lanusse 🤔 |
Denise Lanzieri 💻 |
Wassim KABALAN 💻 🚇 👀 |
Hugo Simon-Onfroy 💻 |
Alexandre Boucaud 👀 |
This project follows the all-contributors specification. Contributions of any kind welcome!