Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions kazu/quantization/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Accelerating Biomedical NER with Quantization

## Introduction

Quantization is an approach to represent model weights and/or activations using lower precision, aiming to reduce the computational costs of inference. The KAZU framework is designed for efficient and scalable document processing, without requiring a GPU. However, support for quantization on CPU is limited, as they generally lack native support for low precision data types (e.g. `bfloat16` or `int4`).

In this project, we explore the use of quantization to accelerate CPU inference for biomedical named entity recognition. Specifically, we apply 8-bit quantization to the weights and activations (`W8A8`). This enables inference speedups on CPUs supporting the `VNNI` ([Vector Neural Network Instructions](https://en.wikichip.org/wiki/x86/avx512_vnni)) instruction set extension.

## Supported hardware

The following Linux command can be used to verify if the target CPU supports `VNNI`. This should output either `avx512_vnni` or `avx_vnni` on supported systems.

```shell
lscpu | grep -o "\S*_vnni"
```

## Usage

> [!IMPORTANT]
> Quantization is currently experimental as it relies on PyTorch prototype features.

The following instructions apply to the [`TransformersModelForTokenClassificationNerStep`](https://astrazeneca.github.io/KAZU/_autosummary/kazu.steps.ner.hf_token_classification.html#kazu.steps.ner.hf_token_classification.TransformersModelForTokenClassificationNerStep).

To enable quantization, set the following environment variables. TorchInductor is required to lower the quantized model to optimized instructions.

```shell
export KAZU_ENABLE_INDUCTOR=1
export KAZU_ENABLE_QUANTIZATION=1
```

Optionally, TorchInductor [Max-Autotune](https://pytorch.org/tutorials/prototype/max_autotune_on_CPU_tutorial.html) can be enabled to automatically profile and select the best performing operation implementations.

```shell
export KAZU_ENABLE_MAX_AUTOTUNE=1
```

## Benchmarking

To benchmark inference performance, we use [`evaluate_script.py`](/kazu/training/evaluate_script.py) with the ([`multilabel_biomedBERT`](/resources/kazu_model_pack_public/multilabel_biomedBERT)) model. We use the dataset from the following guide: [training multilabel NER](https://astrazeneca.github.io/KAZU/training_multilabel_ner.html). To simulate a long workload, we use the entire test set (365 documents), whereas for a short workload, we use the first 10 documents (alphabetically).

The following benchmark results were collected on an Intel Xeon Gold 6252 CPU (single core) with PyTorch 2.6.0.

### Short workload (10 documents)

| Method | Mean F1 | Duration (S) | Speedup |
| :-------------------------------- | ------: | -----------: | ------: |
| Baseline | 0.9697 | 373.39 | 1.00 |
| Baseline (Inductor) | 0.9697 | 357.15 | 1.05 |
| Baseline (Inductor, Max-Autotune) | 0.9697 | 358.99 | 1.04 |
| W8A8 (Inductor) | 0.9656 | 194.84 | 1.92 |
| W8A8 (Inductor, Max-Autotune) | 0.9656 | 195.55 | 1.91 |

### Long workload (365 documents)

| Method | Mean F1 | Duration (S) | Speedup |
| :-------------------------------- | ------: | -----------: | ------: |
| Baseline | 0.9560 | 14797.50 | 1.00 |
| Baseline (Inductor) | 0.9560 | 13450.01 | 1.10 |
| Baseline (Inductor, Max-Autotune) | 0.9560 | 13469.89 | 1.10 |
| W8A8 (Inductor) | 0.9519 | 6642.87 | 2.23 |
| W8A8 (Inductor, Max-Autotune) | 0.9519 | 6801.01 | 2.18 |

## Conclusion

In our benchmarks, W8A8 quantization via TorchInductor achieved up to a 2× speedup over the baseline (W32A32) model. This incurs only a -0.4 point reduction in mean F1 score. For short workloads, the performance benefits of quantization are slightly reduced. Finally, we did not observe any additional performance benefits from using the TorchInductor Max-Autotune mode.

## Future work

- [ ] Load exported quantized models from checkpoints.
- [ ] Support mixed `int8` and `bfloat16` for speedups on newer CPUs.

## Resources

- [Tuning Guide for Deep Learning with Intel AVX-512 and Intel Deep Learning Boost on 3rd Generation Intel Xeon Scalable Processors](https://www.intel.com/content/www/us/en/developer/articles/guide/deep-learning-with-avx512-and-dl-boost.html)
- [PyTorch 2 Export Quantization with X86 Backend through Inductor](https://pytorch.org/tutorials/prototype/pt2e_quant_x86_inductor.html)
- [(prototype) PyTorch 2 Export Post Training Quantization](https://pytorch.org/tutorials/prototype/pt2e_quant_ptq.html)
- [Using Max-Autotune Compilation on CPU for Better Performance](https://pytorch.org/tutorials/prototype/max_autotune_on_CPU_tutorial.html)
Empty file added kazu/quantization/__init__.py
Empty file.
42 changes: 42 additions & 0 deletions kazu/quantization/int8_x86_quantizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import torch
from torch.ao.quantization import move_exported_model_to_eval
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from torch.ao.quantization.quantizer.x86_inductor_quantizer import (
X86InductorQuantizer,
get_default_x86_inductor_quantization_config,
)
from torch.export import export_for_training
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from transformers.file_utils import PaddingStrategy


class _Int8X86Quantizer:
def __init__(self) -> None:
quantization_config = get_default_x86_inductor_quantization_config(is_dynamic=True)

quantizer = X86InductorQuantizer()
quantizer.set_global(quantization_config)
self.quantizer = quantizer

@torch.inference_mode()
def quantize(
self,
model: PreTrainedModel,
tokenizer: PreTrainedTokenizerBase,
max_length: int,
) -> torch.nn.Module:
example_inputs = tokenizer(
"",
max_length=max_length,
padding=PaddingStrategy.MAX_LENGTH,
return_tensors="pt",
)
example_inputs = dict(example_inputs.to(model.device))

exported_model = export_for_training(model, args=(), kwargs=example_inputs).module()

exported_model = prepare_pt2e(exported_model, self.quantizer) # type: ignore[arg-type]
exported_model(**example_inputs)

exported_model = convert_pt2e(exported_model)
return move_exported_model_to_eval(exported_model) # type: ignore[no-any-return]
33 changes: 33 additions & 0 deletions kazu/steps/ner/hf_token_classification.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import logging
import os
from collections.abc import Iterator
from enum import Enum
from typing import Any, Iterable, Optional, cast

import torch
Expand Down Expand Up @@ -27,6 +29,16 @@
logger = logging.getLogger(__name__)


class _Feature(str, Enum):
INDUCTOR = "KAZU_ENABLE_INDUCTOR"
MAX_AUTOTUNE = "KAZU_ENABLE_MAX_AUTOTUNE"
QUANTIZATION = "KAZU_ENABLE_QUANTIZATION"

@property
def is_enabled(self) -> bool:
return os.getenv(self.value, "0") == "1"


class HFDataset(IterableDataset[dict[str, Any]]):
def __getitem__(self, index: int) -> dict[str, Any]:
return {key: self.encodings.data[key][index] for key in self.keys_to_use}
Expand Down Expand Up @@ -105,6 +117,27 @@ def __init__(
self.tokenized_word_processor = tokenized_word_processor
self.model.to(device)

self._optimize_model()

def _optimize_model(self) -> None:
if _Feature.QUANTIZATION.is_enabled:
from kazu.quantization.int8_x86_quantizer import _Int8X86Quantizer

quantizer = _Int8X86Quantizer()
self.model = quantizer.quantize(
model=self.model,
tokenizer=self.tokeniser,
max_length=self.max_sequence_length,
)

if _Feature.INDUCTOR.is_enabled:
from torch._inductor import config

config.freezing = True

mode = "max-autotune" if _Feature.MAX_AUTOTUNE.is_enabled else None
self.model = torch.compile(self.model, mode=mode)

@document_batch_step
def __call__(self, docs: list[Document]) -> None:
loader, id_section_map = self.get_dataloader(docs)
Expand Down