diff --git a/.github/workflows/tests_and_lint.yaml b/.github/workflows/tests_and_lint.yaml index a34ba65..5a9b0e7 100644 --- a/.github/workflows/tests_and_lint.yaml +++ b/.github/workflows/tests_and_lint.yaml @@ -28,7 +28,7 @@ jobs: - name: Install dependencies run: | - uv sync --dev --extra cpu + uv sync --dev --extra cpu --extra quantize - name: Pytest run: uv run pytest diff --git a/pyproject.toml b/pyproject.toml index d8b43f4..e343ced 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,6 +52,9 @@ cu130 = [ flash_attn = [ "flash-attn-cute", ] +quantize = [ + "torchao>=0.16.0", +] [tool.uv] conflicts = [ @@ -97,4 +100,4 @@ project-includes = [ ] project-excludes = [ "tests/**", -] \ No newline at end of file +] diff --git a/sarasa/config.py b/sarasa/config.py index 2e5b2e0..c8c1cbb 100644 --- a/sarasa/config.py +++ b/sarasa/config.py @@ -1,6 +1,7 @@ from __future__ import annotations import dataclasses +import enum import sys from pathlib import Path from typing import Literal @@ -95,16 +96,23 @@ def create( """ +class Dtype(enum.StrEnum): + float8 = enum.auto() + bfloat16 = enum.auto() + float16 = enum.auto() + float32 = enum.auto() + + @dataclasses.dataclass class Train: steps: int = 10_000 grad_clip: float | None = None - dtype: Literal["bfloat16", "float32"] = "float32" + dtype: Dtype = Dtype.float32 """Dtype used for model initialization""" - amp_dtype: Literal["bfloat16", "float16", "float32"] = "bfloat16" + amp_dtype: Dtype = Dtype.bfloat16 """Dtype used for automatic mixed precision training""" compile: bool = False diff --git a/sarasa/quantize.py b/sarasa/quantize.py new file mode 100644 index 0000000..ff9bfda --- /dev/null +++ b/sarasa/quantize.py @@ -0,0 +1,24 @@ +import torch + +from sarasa.models import BaseModel + + +def to_float8( + model: BaseModel, +) -> None: + from torchao.float8 import Float8LinearConfig, convert_to_float8_training + + # optional: filter modules from being eligible for float8 conversion + def module_filter_fn(mod: torch.nn.Module, fqn: str): + # don't convert the last module + if fqn == "1": + return False + # don't convert linear modules with weight dimensions not divisible by 16 + if isinstance(mod, torch.nn.Linear): + if mod.in_features % 16 != 0 or mod.out_features % 16 != 0: + return False + return True + + config = Float8LinearConfig.from_recipe_name("tensorwise") + + convert_to_float8_training(model, config=config, module_filter_fn=module_filter_fn) diff --git a/sarasa/train.py b/sarasa/train.py index c9387a7..e9628d5 100644 --- a/sarasa/train.py +++ b/sarasa/train.py @@ -11,7 +11,7 @@ from sarasa.activation_checkpoint import apply_op_sac from sarasa.checkpoint import Checkpointer -from sarasa.config import Config +from sarasa.config import Config, Dtype from sarasa.evaluate import Evaluator from sarasa.metrics import MetricsProcessor from sarasa.utils import ( @@ -78,6 +78,12 @@ def __init__( for i, block in enumerate(self.model.blocks): self.model.blocks[i] = apply_op_sac(block) + if config.train.amp_dtype == Dtype.float8: + from sarasa.quantize import to_float8 + + logger.info("Converting model to float8") + to_float8(self.model) + if config.train.compile: logger.info("Compiling the model") for block in self.model.blocks: @@ -107,7 +113,11 @@ def __init__( logger.info(f"Gradient accumulation step is set to: {self.grad_accum_steps}") self.amp_context: contextlib.AbstractContextManager = contextlib.nullcontext() - if world_size() == 1 or config.distributed.dp_shard_degree != -1: + if ( + (config.train.dtype != config.train.amp_dtype) + and (config.train.amp_dtype != Dtype.float8) + and (world_size() == 1 or config.distributed.dp_shard_degree != -1) + ): self.amp_context = torch.autocast( device_type=self.device.type, dtype=getattr(torch, config.train.amp_dtype),