-
Notifications
You must be signed in to change notification settings - Fork 80
Refactor #681
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Refactor #681
Conversation
- Create TitanTrainer dataclass in src/forge/actors/trainer/titan.py - TitanTrainer uses composition with ForgeEngine (self.engine) - Extract common methods: forward_backward, train_step_sft, setup_metric_logger - Update ForgeSFTRecipe to inherit from TitanTrainer - Move compile config to top-level in YAML (required by ForgeJobConfig) - Add documentation for trainer refactoring
- Move compile from training.compile to compile.enable in llama3_8b.yaml - Move compile from training.compile to compile.enable in qwen3_8b.yaml This aligns the YAML config structure with ForgeJobConfig's expected schema, where compile is a separate top-level dataclass, not nested under training.
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #681 +/- ##
==========================================
+ Coverage 83.05% 83.83% +0.77%
==========================================
Files 31 31
Lines 4036 3946 -90
==========================================
- Hits 3352 3308 -44
+ Misses 684 638 -46 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
| self.engine.checkpointer.load(step=self.step) | ||
| self.engine.optimizers.zero_grad() | ||
|
|
||
| async def setup_metric_logger(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
instead of adding this, can we use record_metric like is done in other places in this file?
| loss.backward() | ||
| return loss | ||
|
|
||
| def train_step_sft(self, batch: dict[str, Tensor]) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we prefer this to live on TitanTrainer vs SFTTrainer?
| profiler: Profiler | ||
| device: torch.device | ||
| step: int | ||
| class ForgeSFTRecipe(TitanTrainer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we rename this SFTTrainer?
|
|
||
| def __repr__(self) -> str: | ||
| return "Trainer" | ||
| return "ForgeSFTRecipe" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need this method?
|
|
||
| import torchtitan.experiments.forge.train_spec as forge_train_spec | ||
| from forge.controller import ForgeActor | ||
| from forge.actors.trainer import TitanTrainer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can we make the bottom of this file emulate grpo/main.py?
Lines 368 to 374 in b4b0e6a
| if __name__ == "__main__": | |
| @parse | |
| def _main(cfg): | |
| asyncio.run(main(cfg)) | |
| _main() # @parse grabs the cfg from CLI |
| # Already initialized in post_init__ | ||
| pass | ||
|
|
||
| def _setup_engine(self): | ||
| """Initialize the ForgeEngine (non-endpoint helper for subclasses).""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why this change? setup is called by ForgeActor here:
torchforge/src/forge/controller/actor.py
Line 222 in b4b0e6a
| await actor.setup.call() |
|
thanks for the PR! Left some comments :) Also, nit but can we add a more descriptive title? |
RFC: TitanTrainer Base Class Extraction and SFT Refactoring
IMPORTANT NOTE: In this PR, I have tried to move training APIs from the SFT recipe. It makes it easier to review this way. In the subsequent PRs, more changes and potential merging between SFT and GRPO can be implemented as I recommended in this review.
Executive Summary
Refactored SFT recipe to extract reusable training components into a new
TitanTrainerbase class, establishing unified training infrastructure for both SFT and RL workloads.Key Changes:
TitanTrainerbase class insrc/forge/actors/trainer/titan.pyForgeSFTRecipeto inherit fromTitanTrainerForgeEngine(self.engine)Impact: -215 lines of duplicated code, clearer architecture, easier extensibility
Motivation
Problems Solved
ForgeSFTRecipeinherited from bothForgeActorandForgeEngineGoals
ForgeEngineArchitecture
Key Design Decisions
engine.loss_fn, RL uses customself.losstrain_step()endpoint for RL external control,train_step_sft()for SFT internal loopTitanTraineronly handles engine config, recipes handle specific fieldsImplementation
Components Extracted to TitanTrainer
_setup_engine()rank_should_record_losssetup_metric_logger()record_batch_metrics()forward_backward()engine.loss_fn)forward_backward_rl()self.loss)train_step_sft()train_step()push_weights()cleanup()Before/After Comparison
Before:
After:
File Changes
src/forge/actors/trainer/titan.pyTitanTrainerbase classapps/sft/main.pyTitanTrainerapps/sft/llama3_8b.yamldocs/TRAINER_REFACTORING.mdNet: +974 additions, -215 deletions
Breaking Changes
❌ Config Structure:
compilemoved fromtraining.compileto top-levelcompile.enableMigration: Update YAML configs to new structure (required by
ForgeJobConfigschema)Future Work
Recommended: Extract Helper Functions (Priority)
Current implementation has code duplication in non-PP forward/backward paths. Recommended refactor:
Benefits:
forward_backwardvsforward_backward_rl)Why not merge completely? Different signatures and use cases make separate methods clearer than one method with boolean flags.
Other Future PRs
PR 2: Enhanced Features
Success Metrics
Code Quality:
Functionality: