Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
244 changes: 113 additions & 131 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,114 +1,105 @@
# Qwix: a quantization library for Jax.

Qwix is a Jax quantization library supporting Quantization-Aware Training (QAT)
and Post-Training Quantization (PTQ) for both XLA targets (CPU/GPU/TPU) and ODML
targets (LiteRT).

## Features

* Supported schemas:
* Weight-only quantization.
* Dynamic-range quantization.
* Static-range quantization.
* Supported modes:
* QAT: this mode emulates quantized behavior during serving with fake
quantization.
* PTQ: this mode achieves the best serving performance on XLA devices such
as TPU and GPU.
* ODML: this mode adds proper annotation to the model so that the LiteRT
converter could produce full integer models.
* LoRA/QLoRA: this mode enables LoRA and QLoRA on a model.
* Supported numerics:
* Native: `int4`, `int8`, `fp8`.
* Emulated: `int1` to `int7`, `nf4`.
* Supported array calibration methods:
* `absmax`: symmetric quantization using maximum absolute value.
* `minmax`: asymmetric quantization using minimum and maximum values.
* `rms`: symmetric quantization using root mean square.
* `fixed`: fixed range.
* Supported Jax ops and their quantization granularity:
* XLA:
* `conv_general_dilated`: per-channel.
* `dot_general` and `einsum`: per-channel and sub-channel.
* LiteRT:
* `conv`, `matmul`, and `fully_connected`: per-channel.
* Other ops available in LiteRT: per-tensor.
* Integration with any Flax Linen or NNX models via a single function call.

## Usage

Qwix doesn't provide a PyPI package yet. To use Qwix, you need to install from
GitHub directly.

```sh
pip install git+https://github.com/google/qwix
```
Qwix: A Quantization Library for JAX
Qwix is a JAX quantization library that supports Quantization-Aware Training (QAT) and Post-Training Quantization (PTQ) for both XLA targets (CPU/GPU/TPU) and ODML targets (LiteRT).

Features
Supported Schemas
Weight-only quantization

Dynamic-range quantization

Static-range quantization

Supported Modes
QAT: Emulates quantized inference behavior during training using fake quantization.

PTQ: Provides the best inference performance on XLA devices such as TPU and GPU.

ODML: Adds the required annotations so that the LiteRT converter can generate full-integer models.

LoRA/QLoRA: Enables LoRA and QLoRA on a model.

Supported Numerics
Native: int4, int8, fp8

Emulated: int1 to int7, nf4

Supported Array Calibration Methods
absmax: Symmetric quantization using the maximum absolute value

minmax: Asymmetric quantization using the minimum and maximum values

rms: Symmetric quantization using root mean square

fixed: Fixed quantization range

Supported JAX Ops and Quantization Granularity
XLA
conv_general_dilated: per-channel

dot_general and einsum: per-channel and sub-channel

LiteRT
conv, matmul, and fully_connected: per-channel

Other LiteRT-supported ops: per-tensor

### Model definition
Model Integration
Works with any Flax Linen or NNX model using a single function call.

We're going to use a simple MLP model in the example. Qwix integrates with
models without need to modify their code, so any model can be used below.
Usage
Qwix is not available on PyPI yet. To use it, install directly from GitHub:

pip install git+https://github.com/google/qwix
Model Definition
In this example, we use a simple MLP model. Since Qwix integrates without requiring changes to model code, any model can be used.

```py
import jax
from flax import linen as nn

class MLP(nn.Module):

dhidden: int
dout: int
dhidden: int
dout: int

@nn.compact
def __call__(self, x):
x = nn.Dense(self.dhidden, use_bias=False)(x)
x = nn.relu(x)
x = nn.Dense(self.dout, use_bias=False)(x)
return x
@nn.compact
def **call**(self, x):
x = nn.Dense(self.dhidden, use_bias=False)(x)
x = nn.relu(x)
x = nn.Dense(self.dout, use_bias=False)(x)
return x

model = MLP(64, 16)
model_input = jax.random.uniform(jax.random.key(0), (8, 16))
```
Quantization Config
Qwix uses a regex-based configuration system to define how a JAX model should be quantized. Configurations are specified as a list of QuantizationRule. Each rule contains:

## Quantization config
A key that matches Flax modules

Qwix uses a regex-based configuration system to instruct how to quantize a Jax
model. Configurations are defined as a list of `QuantizationRule`. Each rule
consists of a key that matches Flax modules, and a set of values that control
quantization behavior.
A set of values that control quantization behavior

For example, to quantize the above model in int8 (w8a8), we need to define the
rules as below.
For example, to quantize the model above using int8 (w8a8), define the rules as follows:

```py
import qwix

rules = [
qwix.QuantizationRule(
module_path='.*', # this rule matches all modules.
weight_qtype='int8', # quantizes weights in int8.
act_qtype='int8', # quantizes activations in int8.
)
qwix.QuantizationRule(
module_path='.*', # matches all modules
weight_qtype='int8', # quantizes weights to int8
act_qtype='int8', # quantizes activations to int8
)
]
```

Unlike some other libraries that provides limited number of **quantization
recipes**, Qwix doesn't have a list of presets. Instead, different quantization
schemas are achieved by combinations of quantization configs.

### Post-Training Quantization
Unlike some libraries that provide only a limited set of quantization recipes, Qwix does not use presets. Instead, different quantization schemas are achieved by combining configuration options.

To apply PTQ to the above model, we only need to call `qwix.quantize_model`.
Post-Training Quantization (PTQ)
To apply PTQ to the model above, simply call qwix.quantize_model:

```py
ptq_model = qwix.quantize_model(model, qwix.PtqProvider(rules))
```
The resulting ptq_model contains quantized weights. This can be verified as shown below:

Now the `ptq_model` will contain quantized weights. We could verify that.
> > > jax.eval_shape(ptq_model.init, jax.random.key(0), model_input)['params']
> > > {
> > > 'Dense_0': {

```py
>>> jax.eval_shape(ptq_model.init, jax.random.key(0), model_input)['params']
{
'Dense_0': {
'kernel': WithAux(
array=QArray(
qvalue=ShapeDtypeStruct(shape=(16, 64), dtype=int8),
Expand All @@ -117,67 +108,58 @@ Now the `ptq_model` will contain quantized weights. We could verify that.
),
...
)
},
'Dense_1': {
'kernel': WithAux(
array=QArray(
qvalue=ShapeDtypeStruct(shape=(64, 16), dtype=int8),
scale=ShapeDtypeStruct(shape=(1, 16), dtype=float32),
...
),
...
)
}
}
```

### Weight quantization
},
'Dense_1': {
'kernel': WithAux(
array=QArray(
qvalue=ShapeDtypeStruct(shape=(64, 16), dtype=int8),
scale=ShapeDtypeStruct(shape=(1, 16), dtype=float32),
...
),
...
)
}
}
Weight Quantization
Because Flax Linen modules are pure-functional, weight quantization is handled separately from model quantization. To quantize weights for the ptq_model, use qwix.quantize_params.

Since Flax Linen modules are pure-functional, weights quantization are separate
from model quantization. To quantize weights for the above `ptq_model`, we
need to call `qwix.quantize_params`.
# Floating-point params, typically loaded from checkpoints.

```py
# Floating-point params, usually loaded from checkpoints.
fp_params = ...

# Abstract quantized params, which serve as a template for quantize_params.
# Abstract quantized params used as a template for quantize_params.

abs_ptq_params = jax.eval_shape(ptq_model.init, jax.random.key(0), model_input)['params']

# Weight quantization.

ptq_params = qwix.quantize_params(fp_params, abs_ptq_params)

# ptq_params contains the quantized weights and can be consumed by ptq_model.
# ptq_params now contains quantized weights and can be used with ptq_model.

quantized_model_output = ptq_model.apply({'params': ptq_params}, model_input)
```
Relation with AQT
The design of Qwix was inspired by AQT and borrows many of its ideas. Below is a summary of similarities and differences:

## Relation with AQT
Similarities
Qwix’s QArray is similar to AQT’s QTensor, and both support sub-channel quantization.

The design of Qwix was inspired by [AQT](https://github.com/google/aqt) and
borrowed many great ideas from it. Here's a brief list of the similarities and
the differences.
Differences
AQT supports quantized training (quantized forward and backward passes), while Qwix’s QAT is based on fake quantization and does not improve training performance.

* Qwix's `QArray` is similar to AQT's `QTensor`, both supporting sub-channel
quantization.
* AQT has quantized training support (quantized forwards and quantized
backwards), while Qwix's QAT is based on fake quantization, which doesn't
improve the training performance.
* AQT provides drop-in replacements for `einsum` and `dot_general`, each of
these having to be configured separately. Qwix provides addtional mechanisms
to integrate with a whole model implicitly.
* Applying static-range quantization is easier in Qwix as it has more in-depth
support with Flax.
AQT provides drop-in replacements for einsum and dot_general, which must be configured separately. Qwix provides additional mechanisms to integrate quantization across the entire model implicitly.

## Citing Qwix
Applying static-range quantization is easier in Qwix due to deeper integration with Flax.

To cite Qwix please use the citation:
Citing Qwix
To cite Qwix, please use the following citation:

<!-- disableFinding(SNIPPET_INVALID_LANGUAGE) -->
```bibtex
bibtex
Copy code
@software{Qwix,
title = {Qwix: A Quantization Library for Jax},
author={Dangyi Liu, Jiwon Shin, et al.},
year = {2024},
howpublished = {\url{https://github.com/google/qwix}},
title = {Qwix: A Quantization Library for Jax},
author={Dangyi Liu, Jiwon Shin, et al.},
year = {2024},
howpublished = {\url{https://github.com/google/qwix}},
}
```