diff --git a/kazu/quantization/README.md b/kazu/quantization/README.md new file mode 100644 index 00000000..3bb39a48 --- /dev/null +++ b/kazu/quantization/README.md @@ -0,0 +1,77 @@ +# Accelerating Biomedical NER with Quantization + +## Introduction + +Quantization is an approach to represent model weights and/or activations using lower precision, aiming to reduce the computational costs of inference. The KAZU framework is designed for efficient and scalable document processing, without requiring a GPU. However, support for quantization on CPU is limited, as they generally lack native support for low precision data types (e.g. `bfloat16` or `int4`). + +In this project, we explore the use of quantization to accelerate CPU inference for biomedical named entity recognition. Specifically, we apply 8-bit quantization to the weights and activations (`W8A8`). This enables inference speedups on CPUs supporting the `VNNI` ([Vector Neural Network Instructions](https://en.wikichip.org/wiki/x86/avx512_vnni)) instruction set extension. + +## Supported hardware + +The following Linux command can be used to verify if the target CPU supports `VNNI`. This should output either `avx512_vnni` or `avx_vnni` on supported systems. + +```shell +lscpu | grep -o "\S*_vnni" +``` + +## Usage + +> [!IMPORTANT] +> Quantization is currently experimental as it relies on PyTorch prototype features. + +The following instructions apply to the [`TransformersModelForTokenClassificationNerStep`](https://astrazeneca.github.io/KAZU/_autosummary/kazu.steps.ner.hf_token_classification.html#kazu.steps.ner.hf_token_classification.TransformersModelForTokenClassificationNerStep). + +To enable quantization, set the following environment variables. TorchInductor is required to lower the quantized model to optimized instructions. + +```shell +export KAZU_ENABLE_INDUCTOR=1 +export KAZU_ENABLE_QUANTIZATION=1 +``` + +Optionally, TorchInductor [Max-Autotune](https://pytorch.org/tutorials/prototype/max_autotune_on_CPU_tutorial.html) can be enabled to automatically profile and select the best performing operation implementations. + +```shell +export KAZU_ENABLE_MAX_AUTOTUNE=1 +``` + +## Benchmarking + +To benchmark inference performance, we use [`evaluate_script.py`](/kazu/training/evaluate_script.py) with the ([`multilabel_biomedBERT`](/resources/kazu_model_pack_public/multilabel_biomedBERT)) model. We use the dataset from the following guide: [training multilabel NER](https://astrazeneca.github.io/KAZU/training_multilabel_ner.html). To simulate a long workload, we use the entire test set (365 documents), whereas for a short workload, we use the first 10 documents (alphabetically). + +The following benchmark results were collected on an Intel Xeon Gold 6252 CPU (single core) with PyTorch 2.6.0. + +### Short workload (10 documents) + +| Method | Mean F1 | Duration (S) | Speedup | +| :-------------------------------- | ------: | -----------: | ------: | +| Baseline | 0.9697 | 373.39 | 1.00 | +| Baseline (Inductor) | 0.9697 | 357.15 | 1.05 | +| Baseline (Inductor, Max-Autotune) | 0.9697 | 358.99 | 1.04 | +| W8A8 (Inductor) | 0.9656 | 194.84 | 1.92 | +| W8A8 (Inductor, Max-Autotune) | 0.9656 | 195.55 | 1.91 | + +### Long workload (365 documents) + +| Method | Mean F1 | Duration (S) | Speedup | +| :-------------------------------- | ------: | -----------: | ------: | +| Baseline | 0.9560 | 14797.50 | 1.00 | +| Baseline (Inductor) | 0.9560 | 13450.01 | 1.10 | +| Baseline (Inductor, Max-Autotune) | 0.9560 | 13469.89 | 1.10 | +| W8A8 (Inductor) | 0.9519 | 6642.87 | 2.23 | +| W8A8 (Inductor, Max-Autotune) | 0.9519 | 6801.01 | 2.18 | + +## Conclusion + +In our benchmarks, W8A8 quantization via TorchInductor achieved up to a 2× speedup over the baseline (W32A32) model. This incurs only a -0.4 point reduction in mean F1 score. For short workloads, the performance benefits of quantization are slightly reduced. Finally, we did not observe any additional performance benefits from using the TorchInductor Max-Autotune mode. + +## Future work + +- [ ] Load exported quantized models from checkpoints. +- [ ] Support mixed `int8` and `bfloat16` for speedups on newer CPUs. + +## Resources + +- [Tuning Guide for Deep Learning with Intel AVX-512 and Intel Deep Learning Boost on 3rd Generation Intel Xeon Scalable Processors](https://www.intel.com/content/www/us/en/developer/articles/guide/deep-learning-with-avx512-and-dl-boost.html) +- [PyTorch 2 Export Quantization with X86 Backend through Inductor](https://pytorch.org/tutorials/prototype/pt2e_quant_x86_inductor.html) +- [(prototype) PyTorch 2 Export Post Training Quantization](https://pytorch.org/tutorials/prototype/pt2e_quant_ptq.html) +- [Using Max-Autotune Compilation on CPU for Better Performance](https://pytorch.org/tutorials/prototype/max_autotune_on_CPU_tutorial.html) diff --git a/kazu/quantization/__init__.py b/kazu/quantization/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/kazu/quantization/int8_x86_quantizer.py b/kazu/quantization/int8_x86_quantizer.py new file mode 100644 index 00000000..0d01f0e6 --- /dev/null +++ b/kazu/quantization/int8_x86_quantizer.py @@ -0,0 +1,42 @@ +import torch +from torch.ao.quantization import move_exported_model_to_eval +from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e +from torch.ao.quantization.quantizer.x86_inductor_quantizer import ( + X86InductorQuantizer, + get_default_x86_inductor_quantization_config, +) +from torch.export import export_for_training +from transformers import PreTrainedModel, PreTrainedTokenizerBase +from transformers.file_utils import PaddingStrategy + + +class _Int8X86Quantizer: + def __init__(self) -> None: + quantization_config = get_default_x86_inductor_quantization_config(is_dynamic=True) + + quantizer = X86InductorQuantizer() + quantizer.set_global(quantization_config) + self.quantizer = quantizer + + @torch.inference_mode() + def quantize( + self, + model: PreTrainedModel, + tokenizer: PreTrainedTokenizerBase, + max_length: int, + ) -> torch.nn.Module: + example_inputs = tokenizer( + "", + max_length=max_length, + padding=PaddingStrategy.MAX_LENGTH, + return_tensors="pt", + ) + example_inputs = dict(example_inputs.to(model.device)) + + exported_model = export_for_training(model, args=(), kwargs=example_inputs).module() + + exported_model = prepare_pt2e(exported_model, self.quantizer) # type: ignore[arg-type] + exported_model(**example_inputs) + + exported_model = convert_pt2e(exported_model) + return move_exported_model_to_eval(exported_model) # type: ignore[no-any-return] diff --git a/kazu/steps/ner/hf_token_classification.py b/kazu/steps/ner/hf_token_classification.py index 82e784b4..8eadc34b 100644 --- a/kazu/steps/ner/hf_token_classification.py +++ b/kazu/steps/ner/hf_token_classification.py @@ -1,5 +1,7 @@ import logging +import os from collections.abc import Iterator +from enum import Enum from typing import Any, Iterable, Optional, cast import torch @@ -27,6 +29,16 @@ logger = logging.getLogger(__name__) +class _Feature(str, Enum): + INDUCTOR = "KAZU_ENABLE_INDUCTOR" + MAX_AUTOTUNE = "KAZU_ENABLE_MAX_AUTOTUNE" + QUANTIZATION = "KAZU_ENABLE_QUANTIZATION" + + @property + def is_enabled(self) -> bool: + return os.getenv(self.value, "0") == "1" + + class HFDataset(IterableDataset[dict[str, Any]]): def __getitem__(self, index: int) -> dict[str, Any]: return {key: self.encodings.data[key][index] for key in self.keys_to_use} @@ -105,6 +117,27 @@ def __init__( self.tokenized_word_processor = tokenized_word_processor self.model.to(device) + self._optimize_model() + + def _optimize_model(self) -> None: + if _Feature.QUANTIZATION.is_enabled: + from kazu.quantization.int8_x86_quantizer import _Int8X86Quantizer + + quantizer = _Int8X86Quantizer() + self.model = quantizer.quantize( + model=self.model, + tokenizer=self.tokeniser, + max_length=self.max_sequence_length, + ) + + if _Feature.INDUCTOR.is_enabled: + from torch._inductor import config + + config.freezing = True + + mode = "max-autotune" if _Feature.MAX_AUTOTUNE.is_enabled else None + self.model = torch.compile(self.model, mode=mode) + @document_batch_step def __call__(self, docs: list[Document]) -> None: loader, id_section_map = self.get_dataloader(docs)