I don't know if you are aware of these internal function in equinox.
https://github.com/patrick-kidger/equinox/blob/main/equinox/internal/_loop/loop.py
They are quite powerful alternatives to jax while loop and scan loops and might help improve the code.
Also, this might help with linear algebra,
https://cola.readthedocs.io/en/latest/