Hi Qwix team,
I’m trying to use Qwix quantized weights together with a custom JAX Pallas kernel.
In the Pallas kernel, I pass the quantized weight tensor as an input, but I also need the corresponding scale to correctly dequantize or fuse the computation.
I’m not sure what the intended way is to get and pass these scales from Qwix.
Any guidance or pointers to examples would be very helpful. Thanks!