Based off annotated S4, with some optimizations here and there. I'm also updating everything to the new flax nnx api instead of using the flax linen api.
- Install uv if you haven't already
- Run
uv syncin the project root - Run
uv pip install --force-reinstall "put your jax accelerator-specific install here"- i.e.
uv pip install "jax[cuda13]"for Nvidia GPUs - You can skip this step if you plan on using your CPU
- Use the official docs to see what exactly you need to install for your specific hardware
- i.e.
- Run
uv run -m pytestto make sure everything works as intended
- The Recurrent Representation of an SSM
- The Naive Convolution Representation of an SSM
- FFTs for SSM Convolution
- Training SSMs
- S4 Kernel Generation
- S4 nnx Module