Skip to content

Support float8 training#13

Open
moskomule wants to merge 4 commits intomainfrom
feat/float8
Open

Support float8 training#13
moskomule wants to merge 4 commits intomainfrom
feat/float8

Conversation

@moskomule
Copy link
Owner

This pull request introduces support for float8 quantization in model training, refactors dtype handling to use a new enum for safer and more extensible configuration, and integrates these changes throughout the codebase. The most important changes are grouped below:

Quantization Support

  • Added a new sarasa/quantize.py module with a to_float8 function that converts models for float8 training using torchao. This enables efficient float8 quantization and filtering of eligible modules.
  • Updated pyproject.toml to include torchao as a dependency for quantization features.

Dtype Refactoring

  • Introduced a Dtype enum in sarasa/config.py to replace string literals for dtype configuration, improving type safety and extensibility. All relevant configuration fields (dtype, amp_dtype) now use this enum. [1] [2]
  • Updated dtype imports and usage throughout the codebase, including sarasa/train.py and sarasa/utils.py, to reference the new Dtype enum. [1] [2] [3]

Training Integration

  • Modified sarasa/train.py to convert the model to float8 when amp_dtype is set to Dtype.float8, and adjusted autocast logic to account for the new dtype enum and float8 handling. [1] [2]

Copilot AI review requested due to automatic review settings February 17, 2026 12:53
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds float8 quantization support for training and refactors dtype configuration to a Dtype enum, wiring the new dtype handling through training and distributed setup.

Changes:

  • Introduces sarasa/quantize.py with a to_float8() helper backed by torchao.
  • Replaces string/Literal dtype configuration with a Dtype StrEnum in sarasa/config.py.
  • Updates training and distributed mixed-precision logic to use Dtype and triggers float8 conversion when configured.

Reviewed changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated 7 comments.

Show a summary per file
File Description
sarasa/utils.py Updates FSDP dtype casting/usage to Dtype and uses it to build the FSDP mixed-precision policy.
sarasa/train.py Imports Dtype, converts model to float8 when configured, and adjusts autocast gating logic.
sarasa/quantize.py New module implementing float8 training conversion via torchao.
sarasa/config.py Adds Dtype enum and migrates Train/FSDP dtype fields to use it.
pyproject.toml Adds torchao as an optional dependency under a new quantize extra.
Comments suppressed due to low confidence (1)

sarasa/utils.py:200

  • MixedPrecisionPolicy(param_dtype=getattr(torch, config.amp_dtype), ...) will crash if amp_dtype is set to Dtype.float8 because torch doesn’t have a torch.float8 attribute. Add an explicit mapping from Dtype -> torch.dtype (and either disallow float8 here or map to the specific torch float8 dtype(s) you intend), and/or validate that FSDP is not used with amp_dtype=float8.
        mp_policy = MixedPrecisionPolicy(
            param_dtype=getattr(torch, config.amp_dtype),
            reduce_dtype=getattr(torch, config.dtype),
        )

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

sarasa/train.py Outdated
Comment on lines +115 to +117
(config.train.dtype != config.train.amp_dtype)
and (config.train.amp_dtype != Dtype.float8)
and (world_size() == 1 or config.distributed.name == "fsdp")
Copy link

Copilot AI Feb 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new AMP enabling condition disables autocast for multi-process DDP runs (world_size()>1 and distributed.name=="ddp"), which means train.amp_dtype will be ignored in the common distributed case. If this isn’t intentional, adjust the condition so autocast is enabled for DDP as well (and only excluded where it’s actually incompatible, such as float8).

Suggested change
(config.train.dtype != config.train.amp_dtype)
and (config.train.amp_dtype != Dtype.float8)
and (world_size() == 1 or config.distributed.name == "fsdp")
config.train.dtype != config.train.amp_dtype
and config.train.amp_dtype != Dtype.float8

Copilot uses AI. Check for mistakes.
Comment on lines +112 to 116
dtype: Dtype = Dtype.float32
"""Dtype used for model initialization"""

amp_dtype: Literal["bfloat16", "float16", "float32"] = "bfloat16"
amp_dtype: Dtype = Dtype.bfloat16
"""Dtype used for automatic mixed precision training"""
Copy link

Copilot AI Feb 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Train.dtype now allows Dtype.float8, but elsewhere dtype is converted via getattr(torch, config.train.dtype) / torch.set_default_dtype(...), and torch doesn’t expose a torch.float8 dtype. Consider validating in Train.__post_init__ (or config parsing) that train.dtype cannot be float8, or introduce a dedicated mapping helper that handles float8 explicitly.

Copilot uses AI. Check for mistakes.
Comment on lines +11 to +20
# optional: filter modules from being eligible for float8 conversion
def module_filter_fn(mod: torch.nn.Module, fqn: str):
# don't convert the last module
if fqn == "1":
return False
# don't convert linear modules with weight dimensions not divisible by 16
if isinstance(mod, torch.nn.Linear):
if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
return False
return True
Copy link

Copilot AI Feb 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

module_filter_fn claims to “not convert the last module”, but checking fqn == "1" won’t match this codebase’s output heads (e.g., lm_head in nanochat_gpt or output in llama3). As written, the final projection layer will still be eligible for float8 conversion. Update the filter to exclude the actual head module(s) by FQN (or by a more robust predicate like matching out_features == vocab_size).

Copilot uses AI. Check for mistakes.
def to_float8(
model: BaseModel,
) -> None:
from torchao.float8 import Float8LinearConfig, convert_to_float8_training
Copy link

Copilot AI Feb 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torchao is an optional dependency (added via the quantize extra), but importing it will currently raise a bare ModuleNotFoundError if the extra isn’t installed. Catch ImportError here and raise a clearer error message instructing users to install with the appropriate extra (or otherwise disable float8).

Suggested change
from torchao.float8 import Float8LinearConfig, convert_to_float8_training
try:
from torchao.float8 import Float8LinearConfig, convert_to_float8_training
except ImportError as exc:
raise ImportError(
"torchao is required to use float8 quantization. "
"Please install sarasa with the 'quantize' extra, for example:\n"
" pip install 'sarasa[quantize]'\n"
"or disable float8 quantization in your configuration."
) from exc

Copilot uses AI. Check for mistakes.
Comment on lines +2 to +24

from sarasa.models import BaseModel


def to_float8(
model: BaseModel,
) -> None:
from torchao.float8 import Float8LinearConfig, convert_to_float8_training

# optional: filter modules from being eligible for float8 conversion
def module_filter_fn(mod: torch.nn.Module, fqn: str):
# don't convert the last module
if fqn == "1":
return False
# don't convert linear modules with weight dimensions not divisible by 16
if isinstance(mod, torch.nn.Linear):
if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
return False
return True

config = Float8LinearConfig.from_recipe_name("tensorwise")

convert_to_float8_training(model, config=config, module_filter_fn=module_filter_fn)
Copy link

Copilot AI Feb 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New float8 quantization behavior is introduced but there are no tests covering to_float8() (e.g., that it skips the output head and only converts eligible Linear layers). Add a focused unit test that constructs a small model and asserts which modules were converted.

Suggested change
from sarasa.models import BaseModel
def to_float8(
model: BaseModel,
) -> None:
from torchao.float8 import Float8LinearConfig, convert_to_float8_training
# optional: filter modules from being eligible for float8 conversion
def module_filter_fn(mod: torch.nn.Module, fqn: str):
# don't convert the last module
if fqn == "1":
return False
# don't convert linear modules with weight dimensions not divisible by 16
if isinstance(mod, torch.nn.Linear):
if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
return False
return True
config = Float8LinearConfig.from_recipe_name("tensorwise")
convert_to_float8_training(model, config=config, module_filter_fn=module_filter_fn)
import unittest
from sarasa.models import BaseModel
def _float8_module_filter_fn(mod: torch.nn.Module, fqn: str) -> bool:
# don't convert the last module
if fqn == "1":
return False
# don't convert linear modules with weight dimensions not divisible by 16
if isinstance(mod, torch.nn.Linear):
if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
return False
return True
def to_float8(
model: BaseModel,
) -> None:
from torchao.float8 import Float8LinearConfig, convert_to_float8_training
config = Float8LinearConfig.from_recipe_name("tensorwise")
convert_to_float8_training(
model,
config=config,
module_filter_fn=_float8_module_filter_fn,
)
class TestFloat8ModuleFilter(unittest.TestCase):
def test_allows_eligible_linear_modules(self) -> None:
# Linear layer with dimensions divisible by 16 should be eligible
linear = torch.nn.Linear(16, 32)
self.assertTrue(_float8_module_filter_fn(linear, "0"))
def test_skips_last_module(self) -> None:
# Any module with fqn == "1" should be skipped
linear = torch.nn.Linear(16, 16)
self.assertFalse(_float8_module_filter_fn(linear, "1"))
def test_skips_non_divisible_linear_modules(self) -> None:
# Linear layers with dimensions not divisible by 16 should be skipped
linear_in_not_divisible = torch.nn.Linear(15, 16)
linear_out_not_divisible = torch.nn.Linear(16, 15)
self.assertFalse(_float8_module_filter_fn(linear_in_not_divisible, "0"))
self.assertFalse(_float8_module_filter_fn(linear_out_not_divisible, "0"))

Copilot uses AI. Check for mistakes.
Comment on lines +99 to 116
class Dtype(enum.StrEnum):
float8 = enum.auto()
bfloat16 = enum.auto()
float16 = enum.auto()
float32 = enum.auto()


@dataclasses.dataclass
class Train:
steps: int = 10_000

grad_clip: float | None = None

dtype: Literal["bfloat16", "float32"] = "float32"
dtype: Dtype = Dtype.float32
"""Dtype used for model initialization"""

amp_dtype: Literal["bfloat16", "float16", "float32"] = "bfloat16"
amp_dtype: Dtype = Dtype.bfloat16
"""Dtype used for automatic mixed precision training"""
Copy link

Copilot AI Feb 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Config parsing/type-safety changed substantially by introducing the Dtype enum (including float8), but tests don’t currently cover CLI/config-file parsing for these new enum fields (valid values, invalid values, and the float8 special-case). Add tests to ensure Config.from_cli correctly parses --train.dtype/--train.amp_dtype into Dtype values and rejects unsupported combinations.

Copilot uses AI. Check for mistakes.
sarasa/utils.py Outdated
Comment on lines 191 to 196
from torch.distributed.fsdp import MixedPrecisionPolicy, fully_shard

config = typing.cast(FSDPConfig, config)
config.amp_dtype = typing.cast(str, config.amp_dtype)
config.dtype = typing.cast(str, config.dtype)
config.dtype = typing.cast(Dtype, config.dtype)
config.amp_dtype = typing.cast(Dtype, config.amp_dtype)

Copy link

Copilot AI Feb 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

apply_distributed() uses typing.cast(FSDPConfig, ...) and typing.cast(Dtype, ...), but FSDPConfig/Dtype are only imported under if typing.TYPE_CHECKING:. At runtime this will raise NameError the first time the FSDP branch is executed. Import these symbols at runtime (top-level or inside the FSDP branch), or change the casts to use string annotations (e.g., typing.cast("FSDPConfig", ...)) so they don’t require runtime names.

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants