-
Notifications
You must be signed in to change notification settings - Fork 0
Support float8 training #13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,7 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import dataclasses | ||
| import enum | ||
| import sys | ||
| from pathlib import Path | ||
| from typing import Literal | ||
|
|
@@ -95,16 +96,23 @@ def create( | |
| """ | ||
|
|
||
|
|
||
| class Dtype(enum.StrEnum): | ||
| float8 = enum.auto() | ||
| bfloat16 = enum.auto() | ||
| float16 = enum.auto() | ||
| float32 = enum.auto() | ||
|
|
||
|
|
||
| @dataclasses.dataclass | ||
| class Train: | ||
| steps: int = 10_000 | ||
|
|
||
| grad_clip: float | None = None | ||
|
|
||
| dtype: Literal["bfloat16", "float32"] = "float32" | ||
| dtype: Dtype = Dtype.float32 | ||
| """Dtype used for model initialization""" | ||
|
|
||
| amp_dtype: Literal["bfloat16", "float16", "float32"] = "bfloat16" | ||
| amp_dtype: Dtype = Dtype.bfloat16 | ||
| """Dtype used for automatic mixed precision training""" | ||
|
Comment on lines
+99
to
116
|
||
|
|
||
| compile: bool = False | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,24 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from sarasa.models import BaseModel | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def to_float8( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| model: BaseModel, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from torchao.float8 import Float8LinearConfig, convert_to_float8_training | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from torchao.float8 import Float8LinearConfig, convert_to_float8_training | |
| try: | |
| from torchao.float8 import Float8LinearConfig, convert_to_float8_training | |
| except ImportError as exc: | |
| raise ImportError( | |
| "torchao is required to use float8 quantization. " | |
| "Please install sarasa with the 'quantize' extra, for example:\n" | |
| " pip install 'sarasa[quantize]'\n" | |
| "or disable float8 quantization in your configuration." | |
| ) from exc |
Copilot
AI
Feb 17, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
module_filter_fn claims to “not convert the last module”, but checking fqn == "1" won’t match this codebase’s output heads (e.g., lm_head in nanochat_gpt or output in llama3). As written, the final projection layer will still be eligible for float8 conversion. Update the filter to exclude the actual head module(s) by FQN (or by a more robust predicate like matching out_features == vocab_size).
Copilot
AI
Feb 17, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
New float8 quantization behavior is introduced but there are no tests covering to_float8() (e.g., that it skips the output head and only converts eligible Linear layers). Add a focused unit test that constructs a small model and asserts which modules were converted.
| from sarasa.models import BaseModel | |
| def to_float8( | |
| model: BaseModel, | |
| ) -> None: | |
| from torchao.float8 import Float8LinearConfig, convert_to_float8_training | |
| # optional: filter modules from being eligible for float8 conversion | |
| def module_filter_fn(mod: torch.nn.Module, fqn: str): | |
| # don't convert the last module | |
| if fqn == "1": | |
| return False | |
| # don't convert linear modules with weight dimensions not divisible by 16 | |
| if isinstance(mod, torch.nn.Linear): | |
| if mod.in_features % 16 != 0 or mod.out_features % 16 != 0: | |
| return False | |
| return True | |
| config = Float8LinearConfig.from_recipe_name("tensorwise") | |
| convert_to_float8_training(model, config=config, module_filter_fn=module_filter_fn) | |
| import unittest | |
| from sarasa.models import BaseModel | |
| def _float8_module_filter_fn(mod: torch.nn.Module, fqn: str) -> bool: | |
| # don't convert the last module | |
| if fqn == "1": | |
| return False | |
| # don't convert linear modules with weight dimensions not divisible by 16 | |
| if isinstance(mod, torch.nn.Linear): | |
| if mod.in_features % 16 != 0 or mod.out_features % 16 != 0: | |
| return False | |
| return True | |
| def to_float8( | |
| model: BaseModel, | |
| ) -> None: | |
| from torchao.float8 import Float8LinearConfig, convert_to_float8_training | |
| config = Float8LinearConfig.from_recipe_name("tensorwise") | |
| convert_to_float8_training( | |
| model, | |
| config=config, | |
| module_filter_fn=_float8_module_filter_fn, | |
| ) | |
| class TestFloat8ModuleFilter(unittest.TestCase): | |
| def test_allows_eligible_linear_modules(self) -> None: | |
| # Linear layer with dimensions divisible by 16 should be eligible | |
| linear = torch.nn.Linear(16, 32) | |
| self.assertTrue(_float8_module_filter_fn(linear, "0")) | |
| def test_skips_last_module(self) -> None: | |
| # Any module with fqn == "1" should be skipped | |
| linear = torch.nn.Linear(16, 16) | |
| self.assertFalse(_float8_module_filter_fn(linear, "1")) | |
| def test_skips_non_divisible_linear_modules(self) -> None: | |
| # Linear layers with dimensions not divisible by 16 should be skipped | |
| linear_in_not_divisible = torch.nn.Linear(15, 16) | |
| linear_out_not_divisible = torch.nn.Linear(16, 15) | |
| self.assertFalse(_float8_module_filter_fn(linear_in_not_divisible, "0")) | |
| self.assertFalse(_float8_module_filter_fn(linear_out_not_divisible, "0")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Train.dtypenow allowsDtype.float8, but elsewheredtypeis converted viagetattr(torch, config.train.dtype)/torch.set_default_dtype(...), andtorchdoesn’t expose atorch.float8dtype. Consider validating inTrain.__post_init__(or config parsing) thattrain.dtypecannot befloat8, or introduce a dedicated mapping helper that handles float8 explicitly.