From 0af1e8e48ac8aea68556d6f0a2a8b46193661a70 Mon Sep 17 00:00:00 2001 From: HARI PRASAD L S <06hariumaraja@gmail.com> Date: Thu, 5 Feb 2026 13:42:03 +0530 Subject: [PATCH] Just updated the README.md --- README.md | 244 +++++++++++++++++++++++++----------------------------- 1 file changed, 113 insertions(+), 131 deletions(-) diff --git a/README.md b/README.md index 965883f..de2e387 100644 --- a/README.md +++ b/README.md @@ -1,114 +1,105 @@ -# Qwix: a quantization library for Jax. - -Qwix is a Jax quantization library supporting Quantization-Aware Training (QAT) -and Post-Training Quantization (PTQ) for both XLA targets (CPU/GPU/TPU) and ODML -targets (LiteRT). - -## Features - -* Supported schemas: - * Weight-only quantization. - * Dynamic-range quantization. - * Static-range quantization. -* Supported modes: - * QAT: this mode emulates quantized behavior during serving with fake - quantization. - * PTQ: this mode achieves the best serving performance on XLA devices such - as TPU and GPU. - * ODML: this mode adds proper annotation to the model so that the LiteRT - converter could produce full integer models. - * LoRA/QLoRA: this mode enables LoRA and QLoRA on a model. -* Supported numerics: - * Native: `int4`, `int8`, `fp8`. - * Emulated: `int1` to `int7`, `nf4`. -* Supported array calibration methods: - * `absmax`: symmetric quantization using maximum absolute value. - * `minmax`: asymmetric quantization using minimum and maximum values. - * `rms`: symmetric quantization using root mean square. - * `fixed`: fixed range. -* Supported Jax ops and their quantization granularity: - * XLA: - * `conv_general_dilated`: per-channel. - * `dot_general` and `einsum`: per-channel and sub-channel. - * LiteRT: - * `conv`, `matmul`, and `fully_connected`: per-channel. - * Other ops available in LiteRT: per-tensor. -* Integration with any Flax Linen or NNX models via a single function call. - -## Usage - -Qwix doesn't provide a PyPI package yet. To use Qwix, you need to install from -GitHub directly. - -```sh -pip install git+https://github.com/google/qwix -``` +Qwix: A Quantization Library for JAX +Qwix is a JAX quantization library that supports Quantization-Aware Training (QAT) and Post-Training Quantization (PTQ) for both XLA targets (CPU/GPU/TPU) and ODML targets (LiteRT). + +Features +Supported Schemas +Weight-only quantization + +Dynamic-range quantization + +Static-range quantization + +Supported Modes +QAT: Emulates quantized inference behavior during training using fake quantization. + +PTQ: Provides the best inference performance on XLA devices such as TPU and GPU. + +ODML: Adds the required annotations so that the LiteRT converter can generate full-integer models. + +LoRA/QLoRA: Enables LoRA and QLoRA on a model. + +Supported Numerics +Native: int4, int8, fp8 + +Emulated: int1 to int7, nf4 + +Supported Array Calibration Methods +absmax: Symmetric quantization using the maximum absolute value + +minmax: Asymmetric quantization using the minimum and maximum values + +rms: Symmetric quantization using root mean square + +fixed: Fixed quantization range + +Supported JAX Ops and Quantization Granularity +XLA +conv_general_dilated: per-channel + +dot_general and einsum: per-channel and sub-channel + +LiteRT +conv, matmul, and fully_connected: per-channel + +Other LiteRT-supported ops: per-tensor -### Model definition +Model Integration +Works with any Flax Linen or NNX model using a single function call. -We're going to use a simple MLP model in the example. Qwix integrates with -models without need to modify their code, so any model can be used below. +Usage +Qwix is not available on PyPI yet. To use it, install directly from GitHub: + +pip install git+https://github.com/google/qwix +Model Definition +In this example, we use a simple MLP model. Since Qwix integrates without requiring changes to model code, any model can be used. -```py import jax from flax import linen as nn class MLP(nn.Module): - dhidden: int - dout: int +dhidden: int +dout: int - @nn.compact - def __call__(self, x): - x = nn.Dense(self.dhidden, use_bias=False)(x) - x = nn.relu(x) - x = nn.Dense(self.dout, use_bias=False)(x) - return x +@nn.compact +def **call**(self, x): +x = nn.Dense(self.dhidden, use_bias=False)(x) +x = nn.relu(x) +x = nn.Dense(self.dout, use_bias=False)(x) +return x model = MLP(64, 16) model_input = jax.random.uniform(jax.random.key(0), (8, 16)) -``` +Quantization Config +Qwix uses a regex-based configuration system to define how a JAX model should be quantized. Configurations are specified as a list of QuantizationRule. Each rule contains: -## Quantization config +A key that matches Flax modules -Qwix uses a regex-based configuration system to instruct how to quantize a Jax -model. Configurations are defined as a list of `QuantizationRule`. Each rule -consists of a key that matches Flax modules, and a set of values that control -quantization behavior. +A set of values that control quantization behavior -For example, to quantize the above model in int8 (w8a8), we need to define the -rules as below. +For example, to quantize the model above using int8 (w8a8), define the rules as follows: -```py import qwix rules = [ - qwix.QuantizationRule( - module_path='.*', # this rule matches all modules. - weight_qtype='int8', # quantizes weights in int8. - act_qtype='int8', # quantizes activations in int8. - ) +qwix.QuantizationRule( +module_path='.*', # matches all modules +weight_qtype='int8', # quantizes weights to int8 +act_qtype='int8', # quantizes activations to int8 +) ] -``` - -Unlike some other libraries that provides limited number of **quantization -recipes**, Qwix doesn't have a list of presets. Instead, different quantization -schemas are achieved by combinations of quantization configs. - -### Post-Training Quantization +Unlike some libraries that provide only a limited set of quantization recipes, Qwix does not use presets. Instead, different quantization schemas are achieved by combining configuration options. -To apply PTQ to the above model, we only need to call `qwix.quantize_model`. +Post-Training Quantization (PTQ) +To apply PTQ to the model above, simply call qwix.quantize_model: -```py ptq_model = qwix.quantize_model(model, qwix.PtqProvider(rules)) -``` +The resulting ptq_model contains quantized weights. This can be verified as shown below: -Now the `ptq_model` will contain quantized weights. We could verify that. +> > > jax.eval_shape(ptq_model.init, jax.random.key(0), model_input)['params'] +> > > { +> > > 'Dense_0': { -```py ->>> jax.eval_shape(ptq_model.init, jax.random.key(0), model_input)['params'] -{ - 'Dense_0': { 'kernel': WithAux( array=QArray( qvalue=ShapeDtypeStruct(shape=(16, 64), dtype=int8), @@ -117,67 +108,58 @@ Now the `ptq_model` will contain quantized weights. We could verify that. ), ... ) - }, - 'Dense_1': { - 'kernel': WithAux( - array=QArray( - qvalue=ShapeDtypeStruct(shape=(64, 16), dtype=int8), - scale=ShapeDtypeStruct(shape=(1, 16), dtype=float32), - ... - ), - ... - ) - } -} -``` -### Weight quantization +}, +'Dense_1': { +'kernel': WithAux( +array=QArray( +qvalue=ShapeDtypeStruct(shape=(64, 16), dtype=int8), +scale=ShapeDtypeStruct(shape=(1, 16), dtype=float32), +... +), +... +) +} +} +Weight Quantization +Because Flax Linen modules are pure-functional, weight quantization is handled separately from model quantization. To quantize weights for the ptq_model, use qwix.quantize_params. -Since Flax Linen modules are pure-functional, weights quantization are separate -from model quantization. To quantize weights for the above `ptq_model`, we -need to call `qwix.quantize_params`. +# Floating-point params, typically loaded from checkpoints. -```py -# Floating-point params, usually loaded from checkpoints. fp_params = ... -# Abstract quantized params, which serve as a template for quantize_params. +# Abstract quantized params used as a template for quantize_params. + abs_ptq_params = jax.eval_shape(ptq_model.init, jax.random.key(0), model_input)['params'] # Weight quantization. + ptq_params = qwix.quantize_params(fp_params, abs_ptq_params) -# ptq_params contains the quantized weights and can be consumed by ptq_model. +# ptq_params now contains quantized weights and can be used with ptq_model. + quantized_model_output = ptq_model.apply({'params': ptq_params}, model_input) -``` +Relation with AQT +The design of Qwix was inspired by AQT and borrows many of its ideas. Below is a summary of similarities and differences: -## Relation with AQT +Similarities +Qwix’s QArray is similar to AQT’s QTensor, and both support sub-channel quantization. -The design of Qwix was inspired by [AQT](https://github.com/google/aqt) and -borrowed many great ideas from it. Here's a brief list of the similarities and -the differences. +Differences +AQT supports quantized training (quantized forward and backward passes), while Qwix’s QAT is based on fake quantization and does not improve training performance. -* Qwix's `QArray` is similar to AQT's `QTensor`, both supporting sub-channel - quantization. -* AQT has quantized training support (quantized forwards and quantized - backwards), while Qwix's QAT is based on fake quantization, which doesn't - improve the training performance. -* AQT provides drop-in replacements for `einsum` and `dot_general`, each of - these having to be configured separately. Qwix provides addtional mechanisms - to integrate with a whole model implicitly. -* Applying static-range quantization is easier in Qwix as it has more in-depth - support with Flax. +AQT provides drop-in replacements for einsum and dot_general, which must be configured separately. Qwix provides additional mechanisms to integrate quantization across the entire model implicitly. -## Citing Qwix +Applying static-range quantization is easier in Qwix due to deeper integration with Flax. -To cite Qwix please use the citation: +Citing Qwix +To cite Qwix, please use the following citation: - -```bibtex +bibtex +Copy code @software{Qwix, - title = {Qwix: A Quantization Library for Jax}, - author={Dangyi Liu, Jiwon Shin, et al.}, - year = {2024}, - howpublished = {\url{https://github.com/google/qwix}}, +title = {Qwix: A Quantization Library for Jax}, +author={Dangyi Liu, Jiwon Shin, et al.}, +year = {2024}, +howpublished = {\url{https://github.com/google/qwix}}, } -```