Conversation
There was a problem hiding this comment.
Pull request overview
Adds float8 quantization support for training and refactors dtype configuration to a Dtype enum, wiring the new dtype handling through training and distributed setup.
Changes:
- Introduces
sarasa/quantize.pywith ato_float8()helper backed bytorchao. - Replaces string/Literal dtype configuration with a
DtypeStrEnuminsarasa/config.py. - Updates training and distributed mixed-precision logic to use
Dtypeand triggers float8 conversion when configured.
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 7 comments.
Show a summary per file
| File | Description |
|---|---|
sarasa/utils.py |
Updates FSDP dtype casting/usage to Dtype and uses it to build the FSDP mixed-precision policy. |
sarasa/train.py |
Imports Dtype, converts model to float8 when configured, and adjusts autocast gating logic. |
sarasa/quantize.py |
New module implementing float8 training conversion via torchao. |
sarasa/config.py |
Adds Dtype enum and migrates Train/FSDP dtype fields to use it. |
pyproject.toml |
Adds torchao as an optional dependency under a new quantize extra. |
Comments suppressed due to low confidence (1)
sarasa/utils.py:200
MixedPrecisionPolicy(param_dtype=getattr(torch, config.amp_dtype), ...)will crash ifamp_dtypeis set toDtype.float8becausetorchdoesn’t have atorch.float8attribute. Add an explicit mapping fromDtype->torch.dtype(and either disallow float8 here or map to the specific torch float8 dtype(s) you intend), and/or validate that FSDP is not used withamp_dtype=float8.
mp_policy = MixedPrecisionPolicy(
param_dtype=getattr(torch, config.amp_dtype),
reduce_dtype=getattr(torch, config.dtype),
)
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
sarasa/train.py
Outdated
| (config.train.dtype != config.train.amp_dtype) | ||
| and (config.train.amp_dtype != Dtype.float8) | ||
| and (world_size() == 1 or config.distributed.name == "fsdp") |
There was a problem hiding this comment.
The new AMP enabling condition disables autocast for multi-process DDP runs (world_size()>1 and distributed.name=="ddp"), which means train.amp_dtype will be ignored in the common distributed case. If this isn’t intentional, adjust the condition so autocast is enabled for DDP as well (and only excluded where it’s actually incompatible, such as float8).
| (config.train.dtype != config.train.amp_dtype) | |
| and (config.train.amp_dtype != Dtype.float8) | |
| and (world_size() == 1 or config.distributed.name == "fsdp") | |
| config.train.dtype != config.train.amp_dtype | |
| and config.train.amp_dtype != Dtype.float8 |
| dtype: Dtype = Dtype.float32 | ||
| """Dtype used for model initialization""" | ||
|
|
||
| amp_dtype: Literal["bfloat16", "float16", "float32"] = "bfloat16" | ||
| amp_dtype: Dtype = Dtype.bfloat16 | ||
| """Dtype used for automatic mixed precision training""" |
There was a problem hiding this comment.
Train.dtype now allows Dtype.float8, but elsewhere dtype is converted via getattr(torch, config.train.dtype) / torch.set_default_dtype(...), and torch doesn’t expose a torch.float8 dtype. Consider validating in Train.__post_init__ (or config parsing) that train.dtype cannot be float8, or introduce a dedicated mapping helper that handles float8 explicitly.
| # optional: filter modules from being eligible for float8 conversion | ||
| def module_filter_fn(mod: torch.nn.Module, fqn: str): | ||
| # don't convert the last module | ||
| if fqn == "1": | ||
| return False | ||
| # don't convert linear modules with weight dimensions not divisible by 16 | ||
| if isinstance(mod, torch.nn.Linear): | ||
| if mod.in_features % 16 != 0 or mod.out_features % 16 != 0: | ||
| return False | ||
| return True |
There was a problem hiding this comment.
module_filter_fn claims to “not convert the last module”, but checking fqn == "1" won’t match this codebase’s output heads (e.g., lm_head in nanochat_gpt or output in llama3). As written, the final projection layer will still be eligible for float8 conversion. Update the filter to exclude the actual head module(s) by FQN (or by a more robust predicate like matching out_features == vocab_size).
| def to_float8( | ||
| model: BaseModel, | ||
| ) -> None: | ||
| from torchao.float8 import Float8LinearConfig, convert_to_float8_training |
There was a problem hiding this comment.
torchao is an optional dependency (added via the quantize extra), but importing it will currently raise a bare ModuleNotFoundError if the extra isn’t installed. Catch ImportError here and raise a clearer error message instructing users to install with the appropriate extra (or otherwise disable float8).
| from torchao.float8 import Float8LinearConfig, convert_to_float8_training | |
| try: | |
| from torchao.float8 import Float8LinearConfig, convert_to_float8_training | |
| except ImportError as exc: | |
| raise ImportError( | |
| "torchao is required to use float8 quantization. " | |
| "Please install sarasa with the 'quantize' extra, for example:\n" | |
| " pip install 'sarasa[quantize]'\n" | |
| "or disable float8 quantization in your configuration." | |
| ) from exc |
|
|
||
| from sarasa.models import BaseModel | ||
|
|
||
|
|
||
| def to_float8( | ||
| model: BaseModel, | ||
| ) -> None: | ||
| from torchao.float8 import Float8LinearConfig, convert_to_float8_training | ||
|
|
||
| # optional: filter modules from being eligible for float8 conversion | ||
| def module_filter_fn(mod: torch.nn.Module, fqn: str): | ||
| # don't convert the last module | ||
| if fqn == "1": | ||
| return False | ||
| # don't convert linear modules with weight dimensions not divisible by 16 | ||
| if isinstance(mod, torch.nn.Linear): | ||
| if mod.in_features % 16 != 0 or mod.out_features % 16 != 0: | ||
| return False | ||
| return True | ||
|
|
||
| config = Float8LinearConfig.from_recipe_name("tensorwise") | ||
|
|
||
| convert_to_float8_training(model, config=config, module_filter_fn=module_filter_fn) |
There was a problem hiding this comment.
New float8 quantization behavior is introduced but there are no tests covering to_float8() (e.g., that it skips the output head and only converts eligible Linear layers). Add a focused unit test that constructs a small model and asserts which modules were converted.
| from sarasa.models import BaseModel | |
| def to_float8( | |
| model: BaseModel, | |
| ) -> None: | |
| from torchao.float8 import Float8LinearConfig, convert_to_float8_training | |
| # optional: filter modules from being eligible for float8 conversion | |
| def module_filter_fn(mod: torch.nn.Module, fqn: str): | |
| # don't convert the last module | |
| if fqn == "1": | |
| return False | |
| # don't convert linear modules with weight dimensions not divisible by 16 | |
| if isinstance(mod, torch.nn.Linear): | |
| if mod.in_features % 16 != 0 or mod.out_features % 16 != 0: | |
| return False | |
| return True | |
| config = Float8LinearConfig.from_recipe_name("tensorwise") | |
| convert_to_float8_training(model, config=config, module_filter_fn=module_filter_fn) | |
| import unittest | |
| from sarasa.models import BaseModel | |
| def _float8_module_filter_fn(mod: torch.nn.Module, fqn: str) -> bool: | |
| # don't convert the last module | |
| if fqn == "1": | |
| return False | |
| # don't convert linear modules with weight dimensions not divisible by 16 | |
| if isinstance(mod, torch.nn.Linear): | |
| if mod.in_features % 16 != 0 or mod.out_features % 16 != 0: | |
| return False | |
| return True | |
| def to_float8( | |
| model: BaseModel, | |
| ) -> None: | |
| from torchao.float8 import Float8LinearConfig, convert_to_float8_training | |
| config = Float8LinearConfig.from_recipe_name("tensorwise") | |
| convert_to_float8_training( | |
| model, | |
| config=config, | |
| module_filter_fn=_float8_module_filter_fn, | |
| ) | |
| class TestFloat8ModuleFilter(unittest.TestCase): | |
| def test_allows_eligible_linear_modules(self) -> None: | |
| # Linear layer with dimensions divisible by 16 should be eligible | |
| linear = torch.nn.Linear(16, 32) | |
| self.assertTrue(_float8_module_filter_fn(linear, "0")) | |
| def test_skips_last_module(self) -> None: | |
| # Any module with fqn == "1" should be skipped | |
| linear = torch.nn.Linear(16, 16) | |
| self.assertFalse(_float8_module_filter_fn(linear, "1")) | |
| def test_skips_non_divisible_linear_modules(self) -> None: | |
| # Linear layers with dimensions not divisible by 16 should be skipped | |
| linear_in_not_divisible = torch.nn.Linear(15, 16) | |
| linear_out_not_divisible = torch.nn.Linear(16, 15) | |
| self.assertFalse(_float8_module_filter_fn(linear_in_not_divisible, "0")) | |
| self.assertFalse(_float8_module_filter_fn(linear_out_not_divisible, "0")) |
| class Dtype(enum.StrEnum): | ||
| float8 = enum.auto() | ||
| bfloat16 = enum.auto() | ||
| float16 = enum.auto() | ||
| float32 = enum.auto() | ||
|
|
||
|
|
||
| @dataclasses.dataclass | ||
| class Train: | ||
| steps: int = 10_000 | ||
|
|
||
| grad_clip: float | None = None | ||
|
|
||
| dtype: Literal["bfloat16", "float32"] = "float32" | ||
| dtype: Dtype = Dtype.float32 | ||
| """Dtype used for model initialization""" | ||
|
|
||
| amp_dtype: Literal["bfloat16", "float16", "float32"] = "bfloat16" | ||
| amp_dtype: Dtype = Dtype.bfloat16 | ||
| """Dtype used for automatic mixed precision training""" |
There was a problem hiding this comment.
Config parsing/type-safety changed substantially by introducing the Dtype enum (including float8), but tests don’t currently cover CLI/config-file parsing for these new enum fields (valid values, invalid values, and the float8 special-case). Add tests to ensure Config.from_cli correctly parses --train.dtype/--train.amp_dtype into Dtype values and rejects unsupported combinations.
sarasa/utils.py
Outdated
| from torch.distributed.fsdp import MixedPrecisionPolicy, fully_shard | ||
|
|
||
| config = typing.cast(FSDPConfig, config) | ||
| config.amp_dtype = typing.cast(str, config.amp_dtype) | ||
| config.dtype = typing.cast(str, config.dtype) | ||
| config.dtype = typing.cast(Dtype, config.dtype) | ||
| config.amp_dtype = typing.cast(Dtype, config.amp_dtype) | ||
|
|
There was a problem hiding this comment.
apply_distributed() uses typing.cast(FSDPConfig, ...) and typing.cast(Dtype, ...), but FSDPConfig/Dtype are only imported under if typing.TYPE_CHECKING:. At runtime this will raise NameError the first time the FSDP branch is executed. Import these symbols at runtime (top-level or inside the FSDP branch), or change the casts to use string annotations (e.g., typing.cast("FSDPConfig", ...)) so they don’t require runtime names.
This pull request introduces support for float8 quantization in model training, refactors dtype handling to use a new enum for safer and more extensible configuration, and integrates these changes throughout the codebase. The most important changes are grouped below:
Quantization Support
sarasa/quantize.pymodule with ato_float8function that converts models for float8 training usingtorchao. This enables efficient float8 quantization and filtering of eligible modules.pyproject.tomlto includetorchaoas a dependency for quantization features.Dtype Refactoring
Dtypeenum insarasa/config.pyto replace string literals for dtype configuration, improving type safety and extensibility. All relevant configuration fields (dtype,amp_dtype) now use this enum. [1] [2]sarasa/train.pyandsarasa/utils.py, to reference the newDtypeenum. [1] [2] [3]Training Integration
sarasa/train.pyto convert the model to float8 whenamp_dtypeis set toDtype.float8, and adjusted autocast logic to account for the new dtype enum and float8 handling. [1] [2]