Conversation
There was a problem hiding this comment.
Pull request overview
This PR refactors the distributed training configuration system to use a unified Distributed class instead of separate DDP and FSDP classes. The changes also fix an evaluation loss calculation bug where total_tokens was accumulated before the all_reduce operation in distributed settings.
Changes:
- Unified distributed configuration using
dp_replicate_degreeanddp_shard_degreeparameters instead of separate DDP/FSDP classes - Fixed evaluation loss calculation to correctly accumulate
total_tokensafterall_reduceoperation - Updated imports, configuration files, and tests to use the new
Distributedclass
Reviewed changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| sarasa/config.py | Replaced DDP/FSDP classes with unified Distributed class, added dp_replicate_degree and dp_shard_degree parameters, moved FSDP-specific options into Distributed class |
| sarasa/init.py | Updated exports to use Distributed instead of DDP/FSDP |
| sarasa/utils.py | Updated apply_distributed to check dp_replicate_degree/dp_shard_degree and accept explicit dtype parameters |
| sarasa/train.py | Pass dtype parameters explicitly to apply_distributed, updated AMP context condition |
| sarasa/evaluate.py | Fixed evaluation loss calculation to correctly handle distributed token counting |
| configs/llama3-1b.py | Removed explicit FSDP configuration, now uses default Distributed |
| tests/test_config.py | Removed test for old FSDP dtype inheritance, minor formatting change |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if not (self.dp_replicate_degree == 1 and self.dp_shard_degree == -1): | ||
| raise NotImplementedError() |
There was a problem hiding this comment.
The post_init validation logic is incorrect. It only allows the default configuration where dp_replicate_degree=1 and dp_shard_degree=-1 (FSDP mode), but rejects any other configuration including DDP mode. This means users cannot actually use DDP with this new unified class. The validation should allow either DDP (dp_replicate_degree != 1, dp_shard_degree == 1) or FSDP (dp_replicate_degree == 1, dp_shard_degree == -1), not only FSDP.
| if not (self.dp_replicate_degree == 1 and self.dp_shard_degree == -1): | |
| raise NotImplementedError() | |
| # Allow either: | |
| # - FSDP mode: no replication, full sharding | |
| # - DDP mode: replication, no sharding | |
| is_fsdp = self.dp_replicate_degree == 1 and self.dp_shard_degree == -1 | |
| is_ddp = self.dp_replicate_degree != 1 and self.dp_shard_degree == 1 | |
| if not (is_fsdp or is_ddp): | |
| raise NotImplementedError( | |
| f"Unsupported distributed configuration: " | |
| f"dp_replicate_degree={self.dp_replicate_degree}, " | |
| f"dp_shard_degree={self.dp_shard_degree}. " | |
| "Supported modes are: " | |
| "FSDP (dp_replicate_degree=1, dp_shard_degree=-1) or " | |
| "DDP (dp_replicate_degree!=1, dp_shard_degree=1)." | |
| ) |
| @dataclasses.dataclass | ||
| class Distributed: | ||
| dp_replicate_degree: int = 1 | ||
| """Degree of DDP. 1 means no replication""" |
There was a problem hiding this comment.
The docstring is misleading. "Degree of DDP" suggests this controls the degree of data parallelism when using DDP, but a value of 1 means "no replication" which implies no DDP at all. Consider clarifying this to something like "Data parallelism replication factor. Set to 1 to disable DDP replication (use FSDP or single-device training), or > 1 to enable DDP with that replication factor."
| """Degree of DDP. 1 means no replication""" | |
| """Data parallelism replication factor. Set to 1 to disable DDP replication | |
| (use FSDP or single-device training), or > 1 to enable DDP with that | |
| replication factor.""" |
| """Degree of DDP. 1 means no replication""" | ||
|
|
||
| dp_shard_degree: int = -1 | ||
| """Degree of FSDP. -1 means full sharding""" |
There was a problem hiding this comment.
The docstring is misleading. "Degree of FSDP" suggests this controls the degree of sharding when using FSDP, but the value -1 has a special meaning for "full sharding" which is not a degree but a mode indicator. Consider clarifying this to something like "Sharding degree for FSDP. Set to -1 for full sharding, 1 for no sharding (use DDP or single-device training), or > 1 for partial sharding with that degree."
| """Degree of FSDP. -1 means full sharding""" | |
| """Sharding degree for FSDP. Set to -1 for full sharding (the only supported mode).""" |
This pull request refactors the distributed training configuration system to simplify and unify how distributed strategies (DDP and FSDP) are handled. Instead of having separate
DDPandFSDPclasses, a singleDistributedclass now supports both strategies using configuration parameters. The changes also update related logic throughout the codebase to use this new unified approach and improve evaluation loss calculation.Distributed training configuration refactor:
Removed the
DDPandFSDPclasses and replaced them with a unifiedDistributedclass that usesdp_replicate_degreeanddp_shard_degreeto select between DDP and FSDP. Additional FSDP-specific options are now part of theDistributedclass, and a__post_init__check ensures only supported configurations are used. (sarasa/config.py[1] [2] [3] [4];sarasa/__init__.py[5];configs/llama3-1b.py[6] [7]Updated all code that previously referenced
DDPorFSDPto use the newDistributedclass, including configuration creation and CLI loading. (sarasa/config.py[1] [2];configs/llama3-1b.py[3] [4]Refactored distributed application logic: selection between DDP and FSDP is now based on the values of
dp_replicate_degreeanddp_shard_degree, and dtype handling is passed explicitly instead of being set on the config. (sarasa/utils.py[1] [2];sarasa/train.py[3] [4]Evaluation and testing improvements:
Improved loss calculation in evaluation: now accumulates per-batch losses and divides by the total number of valid tokens for better accuracy. (
sarasa/evaluate.py[1] [2]Updated or removed tests to reflect the new configuration approach and removed tests that depended on the old
FSDPclass. (tests/test_config.py[1] [2]