Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests_and_lint.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:

- name: Install dependencies
run: |
uv sync --dev --extra cpu
uv sync --dev --extra cpu --extra quantize

- name: Pytest
run: uv run pytest
Expand Down
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ cu130 = [
flash_attn = [
"flash-attn-cute",
]
quantize = [
"torchao>=0.16.0",
]

[tool.uv]
conflicts = [
Expand Down Expand Up @@ -97,4 +100,4 @@ project-includes = [
]
project-excludes = [
"tests/**",
]
]
12 changes: 10 additions & 2 deletions sarasa/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import dataclasses
import enum
import sys
from pathlib import Path
from typing import Literal
Expand Down Expand Up @@ -95,16 +96,23 @@ def create(
"""


class Dtype(enum.StrEnum):
float8 = enum.auto()
bfloat16 = enum.auto()
float16 = enum.auto()
float32 = enum.auto()


@dataclasses.dataclass
class Train:
steps: int = 10_000

grad_clip: float | None = None

dtype: Literal["bfloat16", "float32"] = "float32"
dtype: Dtype = Dtype.float32
"""Dtype used for model initialization"""

amp_dtype: Literal["bfloat16", "float16", "float32"] = "bfloat16"
amp_dtype: Dtype = Dtype.bfloat16
"""Dtype used for automatic mixed precision training"""
Comment on lines +112 to 116
Copy link

Copilot AI Feb 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Train.dtype now allows Dtype.float8, but elsewhere dtype is converted via getattr(torch, config.train.dtype) / torch.set_default_dtype(...), and torch doesn’t expose a torch.float8 dtype. Consider validating in Train.__post_init__ (or config parsing) that train.dtype cannot be float8, or introduce a dedicated mapping helper that handles float8 explicitly.

Copilot uses AI. Check for mistakes.
Comment on lines +99 to 116
Copy link

Copilot AI Feb 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Config parsing/type-safety changed substantially by introducing the Dtype enum (including float8), but tests don’t currently cover CLI/config-file parsing for these new enum fields (valid values, invalid values, and the float8 special-case). Add tests to ensure Config.from_cli correctly parses --train.dtype/--train.amp_dtype into Dtype values and rejects unsupported combinations.

Copilot uses AI. Check for mistakes.

compile: bool = False
Expand Down
24 changes: 24 additions & 0 deletions sarasa/quantize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import torch

from sarasa.models import BaseModel


def to_float8(
model: BaseModel,
) -> None:
from torchao.float8 import Float8LinearConfig, convert_to_float8_training
Copy link

Copilot AI Feb 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torchao is an optional dependency (added via the quantize extra), but importing it will currently raise a bare ModuleNotFoundError if the extra isn’t installed. Catch ImportError here and raise a clearer error message instructing users to install with the appropriate extra (or otherwise disable float8).

Suggested change
from torchao.float8 import Float8LinearConfig, convert_to_float8_training
try:
from torchao.float8 import Float8LinearConfig, convert_to_float8_training
except ImportError as exc:
raise ImportError(
"torchao is required to use float8 quantization. "
"Please install sarasa with the 'quantize' extra, for example:\n"
" pip install 'sarasa[quantize]'\n"
"or disable float8 quantization in your configuration."
) from exc

Copilot uses AI. Check for mistakes.

# optional: filter modules from being eligible for float8 conversion
def module_filter_fn(mod: torch.nn.Module, fqn: str):
# don't convert the last module
if fqn == "1":
return False
# don't convert linear modules with weight dimensions not divisible by 16
if isinstance(mod, torch.nn.Linear):
if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
return False
return True
Comment on lines +11 to +20
Copy link

Copilot AI Feb 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

module_filter_fn claims to “not convert the last module”, but checking fqn == "1" won’t match this codebase’s output heads (e.g., lm_head in nanochat_gpt or output in llama3). As written, the final projection layer will still be eligible for float8 conversion. Update the filter to exclude the actual head module(s) by FQN (or by a more robust predicate like matching out_features == vocab_size).

Copilot uses AI. Check for mistakes.

config = Float8LinearConfig.from_recipe_name("tensorwise")

convert_to_float8_training(model, config=config, module_filter_fn=module_filter_fn)
Comment on lines +2 to +24
Copy link

Copilot AI Feb 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New float8 quantization behavior is introduced but there are no tests covering to_float8() (e.g., that it skips the output head and only converts eligible Linear layers). Add a focused unit test that constructs a small model and asserts which modules were converted.

Suggested change
from sarasa.models import BaseModel
def to_float8(
model: BaseModel,
) -> None:
from torchao.float8 import Float8LinearConfig, convert_to_float8_training
# optional: filter modules from being eligible for float8 conversion
def module_filter_fn(mod: torch.nn.Module, fqn: str):
# don't convert the last module
if fqn == "1":
return False
# don't convert linear modules with weight dimensions not divisible by 16
if isinstance(mod, torch.nn.Linear):
if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
return False
return True
config = Float8LinearConfig.from_recipe_name("tensorwise")
convert_to_float8_training(model, config=config, module_filter_fn=module_filter_fn)
import unittest
from sarasa.models import BaseModel
def _float8_module_filter_fn(mod: torch.nn.Module, fqn: str) -> bool:
# don't convert the last module
if fqn == "1":
return False
# don't convert linear modules with weight dimensions not divisible by 16
if isinstance(mod, torch.nn.Linear):
if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
return False
return True
def to_float8(
model: BaseModel,
) -> None:
from torchao.float8 import Float8LinearConfig, convert_to_float8_training
config = Float8LinearConfig.from_recipe_name("tensorwise")
convert_to_float8_training(
model,
config=config,
module_filter_fn=_float8_module_filter_fn,
)
class TestFloat8ModuleFilter(unittest.TestCase):
def test_allows_eligible_linear_modules(self) -> None:
# Linear layer with dimensions divisible by 16 should be eligible
linear = torch.nn.Linear(16, 32)
self.assertTrue(_float8_module_filter_fn(linear, "0"))
def test_skips_last_module(self) -> None:
# Any module with fqn == "1" should be skipped
linear = torch.nn.Linear(16, 16)
self.assertFalse(_float8_module_filter_fn(linear, "1"))
def test_skips_non_divisible_linear_modules(self) -> None:
# Linear layers with dimensions not divisible by 16 should be skipped
linear_in_not_divisible = torch.nn.Linear(15, 16)
linear_out_not_divisible = torch.nn.Linear(16, 15)
self.assertFalse(_float8_module_filter_fn(linear_in_not_divisible, "0"))
self.assertFalse(_float8_module_filter_fn(linear_out_not_divisible, "0"))

Copilot uses AI. Check for mistakes.
14 changes: 12 additions & 2 deletions sarasa/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from sarasa.activation_checkpoint import apply_op_sac
from sarasa.checkpoint import Checkpointer
from sarasa.config import Config
from sarasa.config import Config, Dtype
from sarasa.evaluate import Evaluator
from sarasa.metrics import MetricsProcessor
from sarasa.utils import (
Expand Down Expand Up @@ -78,6 +78,12 @@ def __init__(
for i, block in enumerate(self.model.blocks):
self.model.blocks[i] = apply_op_sac(block)

if config.train.amp_dtype == Dtype.float8:
from sarasa.quantize import to_float8

logger.info("Converting model to float8")
to_float8(self.model)

if config.train.compile:
logger.info("Compiling the model")
for block in self.model.blocks:
Expand Down Expand Up @@ -107,7 +113,11 @@ def __init__(
logger.info(f"Gradient accumulation step is set to: {self.grad_accum_steps}")

self.amp_context: contextlib.AbstractContextManager = contextlib.nullcontext()
if world_size() == 1 or config.distributed.dp_shard_degree != -1:
if (
(config.train.dtype != config.train.amp_dtype)
and (config.train.amp_dtype != Dtype.float8)
and (world_size() == 1 or config.distributed.dp_shard_degree != -1)
):
self.amp_context = torch.autocast(
device_type=self.device.type,
dtype=getattr(torch, config.train.amp_dtype),
Expand Down