This is an implementation of matrix one norm estimation in jax as specified by http://eprints.maths.manchester.ac.uk/321/1/35608.pdf
The implementation passes the scipy test suite with some minor relaxations, i.e. number of column resamples. Relaxed tests are documented in ./test_onenormest.py
Some basic benchmarks using a GPU on the Google Colab free tier see ~8x improvement from the scipy CPU implementation for 4096x4096 matrices.
There are existing implementations in scipy and octave
The algorithm as specified is imperative and control flow heavy. Additionally, a few variables have non-constant dimensions. This implementation has a few quirks to get jax to jit compile.
The main loop has many conditional early breaks. We handle this by manual continuation passing into a branch of jax.lax.cond.
ind_hist and ind must have fixed dimensions.
In the scipy implementation and Higham, ind_hist is a growable array that stores indices of the used unit vectors. In the octave implementation, ind_hist is a fixed sized array that writes 1 into index j when e_j is used. We use the octave implementation to keep the array a fixed size.
ind is shape (n,) in Higham but only the first t values are read out of it. The first t values are read for writing to ind_hist and it is read out of with column indices of Y which is shape (n, t). Because we only test elementary vectors a single time, it is not guaranteed we'll have t elementary vectors to test on each loop. We handle this by filling non used elements of ind with a sentinel value n. n will be used to fill columns in X with the zero vector instead of elementary vectors. These zero vectors will cause norm estimations of 0 which are always correct underestimations of the one norm. Note that because ind can have the additional sentinel value of n, ind_hist must be extended to length n + 1. Noting in ind_hist that the sentinel value has been used has no effect.
