From bfbf90d606e875c802f26513e4b0bfe577a9ae90 Mon Sep 17 00:00:00 2001 From: Ryuichiro Hataya Date: Mon, 16 Feb 2026 22:46:24 +0000 Subject: [PATCH 1/2] add --- pyproject.toml | 5 ++++- sarasa/config.py | 16 ++++++++++++---- sarasa/quantize.py | 24 ++++++++++++++++++++++++ sarasa/train.py | 14 ++++++++++++-- sarasa/utils.py | 6 +++--- 5 files changed, 55 insertions(+), 10 deletions(-) create mode 100644 sarasa/quantize.py diff --git a/pyproject.toml b/pyproject.toml index e97885f..256a0d1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,9 @@ cu130 = [ flash_attn = [ "flash-attn-cute", ] +quantize = [ + "torchao>=0.16.0", +] [tool.uv] conflicts = [ @@ -86,4 +89,4 @@ project-includes = [ ] project-excludes = [ "tests/**", -] \ No newline at end of file +] diff --git a/sarasa/config.py b/sarasa/config.py index f4f9db8..774b8b8 100644 --- a/sarasa/config.py +++ b/sarasa/config.py @@ -1,6 +1,7 @@ from __future__ import annotations import dataclasses +import enum import sys from pathlib import Path from typing import Literal @@ -95,16 +96,23 @@ def create( """ +class Dtype(enum.StrEnum): + float8 = enum.auto() + bfloat16 = enum.auto() + float16 = enum.auto() + float32 = enum.auto() + + @dataclasses.dataclass class Train: steps: int = 10_000 grad_clip: float | None = None - dtype: Literal["bfloat16", "float32"] = "float32" + dtype: Dtype = Dtype.float32 """Dtype used for model initialization""" - amp_dtype: Literal["bfloat16", "float16", "float32"] = "bfloat16" + amp_dtype: Dtype = Dtype.bfloat16 """Dtype used for automatic mixed precision training""" compile: bool = False @@ -193,10 +201,10 @@ class FSDP(Distributed): reshard_after_forward: bool = False """Whether to reshard model parameters after each forward pass (FSDP only).""" - dtype: str | None = None + dtype: Dtype | None = None """Dtype for FSDP reduce operations. If None, uses train.dtype.""" - amp_dtype: str | None = None + amp_dtype: Dtype | None = None """Dtype for FSDP parameter storage. If None, uses train.amp_dtype.""" diff --git a/sarasa/quantize.py b/sarasa/quantize.py new file mode 100644 index 0000000..ff9bfda --- /dev/null +++ b/sarasa/quantize.py @@ -0,0 +1,24 @@ +import torch + +from sarasa.models import BaseModel + + +def to_float8( + model: BaseModel, +) -> None: + from torchao.float8 import Float8LinearConfig, convert_to_float8_training + + # optional: filter modules from being eligible for float8 conversion + def module_filter_fn(mod: torch.nn.Module, fqn: str): + # don't convert the last module + if fqn == "1": + return False + # don't convert linear modules with weight dimensions not divisible by 16 + if isinstance(mod, torch.nn.Linear): + if mod.in_features % 16 != 0 or mod.out_features % 16 != 0: + return False + return True + + config = Float8LinearConfig.from_recipe_name("tensorwise") + + convert_to_float8_training(model, config=config, module_filter_fn=module_filter_fn) diff --git a/sarasa/train.py b/sarasa/train.py index df25b69..9f2ade4 100644 --- a/sarasa/train.py +++ b/sarasa/train.py @@ -11,7 +11,7 @@ from sarasa.activation_checkpoint import apply_op_sac from sarasa.checkpoint import Checkpointer -from sarasa.config import Config +from sarasa.config import Config, Dtype from sarasa.evaluate import Evaluator from sarasa.metrics import MetricsProcessor from sarasa.utils import ( @@ -78,6 +78,12 @@ def __init__( for i, block in enumerate(self.model.blocks): self.model.blocks[i] = apply_op_sac(block) + if config.train.amp_dtype == Dtype.float8: + from sarasa.quantize import to_float8 + + logger.info("Converting model to float8") + to_float8(self.model) + if config.train.compile: logger.info("Compiling the model") for block in self.model.blocks: @@ -105,7 +111,11 @@ def __init__( logger.info(f"Gradient accumulation step is set to: {self.grad_accum_steps}") self.amp_context: contextlib.AbstractContextManager = contextlib.nullcontext() - if config.distributed.name != "fsdp": + if ( + (config.train.dtype != config.train.amp_dtype) + and (config.train.amp_dtype != Dtype.float8) + and (world_size() == 1 or config.distributed.name == "fsdp") + ): self.amp_context = torch.autocast( device_type=self.device.type, dtype=getattr(torch, config.train.amp_dtype), diff --git a/sarasa/utils.py b/sarasa/utils.py index e449597..b4aef30 100644 --- a/sarasa/utils.py +++ b/sarasa/utils.py @@ -16,7 +16,7 @@ if typing.TYPE_CHECKING: from sarasa.config import FSDP as FSDPConfig - from sarasa.config import Config, Distributed, Profile + from sarasa.config import Config, Distributed, Dtype, Profile from sarasa.models import BaseModel @@ -191,8 +191,8 @@ def apply_distributed( from torch.distributed.fsdp import MixedPrecisionPolicy, fully_shard config = typing.cast(FSDPConfig, config) - config.amp_dtype = typing.cast(str, config.amp_dtype) - config.dtype = typing.cast(str, config.dtype) + config.dtype = typing.cast(Dtype, config.dtype) + config.amp_dtype = typing.cast(Dtype, config.amp_dtype) mp_policy = MixedPrecisionPolicy( param_dtype=getattr(torch, config.amp_dtype), From 990b2ac6ae22659fbfea75ab2d41c4ecf3ad4f15 Mon Sep 17 00:00:00 2001 From: Ryuichiro Hataya Date: Tue, 17 Feb 2026 12:56:29 +0000 Subject: [PATCH 2/2] fix --- .github/workflows/tests_and_lint.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests_and_lint.yaml b/.github/workflows/tests_and_lint.yaml index a34ba65..5a9b0e7 100644 --- a/.github/workflows/tests_and_lint.yaml +++ b/.github/workflows/tests_and_lint.yaml @@ -28,7 +28,7 @@ jobs: - name: Install dependencies run: | - uv sync --dev --extra cpu + uv sync --dev --extra cpu --extra quantize - name: Pytest run: uv run pytest