To do, add a version of `matrix_ops` that works well with GPU - for example a `jax` or `CuPy` based version.