Skip to content

Conversation

@HosseinKaviani-H
Copy link
Contributor

RFC: TitanTrainer Base Class Extraction and SFT Refactoring

IMPORTANT NOTE: In this PR, I have tried to move training APIs from the SFT recipe. It makes it easier to review this way. In the subsequent PRs, more changes and potential merging between SFT and GRPO can be implemented as I recommended in this review.

Executive Summary

Refactored SFT recipe to extract reusable training components into a new TitanTrainer base class, establishing unified training infrastructure for both SFT and RL workloads.

Key Changes:

  • Created TitanTrainer base class in src/forge/actors/trainer/titan.py
  • Refactored ForgeSFTRecipe to inherit from TitanTrainer
  • Adopted composition pattern for ForgeEngine (self.engine)
  • Extracted common training utilities (forward/backward, metrics, checkpointing)
  • Fixed config handling to separate SFT-specific and engine-level configs

Impact: -215 lines of duplicated code, clearer architecture, easier extensibility


Motivation

Problems Solved

  1. Fragile Multiple Inheritance: ForgeSFTRecipe inherited from both ForgeActor and ForgeEngine
  2. Code Duplication: Training logic duplicated across SFT and RL recipes
  3. Poor Separation: SFT-specific logic mixed with generic training logic
  4. Limited Reusability: No clear path to share components between workloads

Goals

  • Create reusable foundation for SFT and RL
  • Replace inheritance with composition for ForgeEngine
  • Enable code sharing across recipes
  • Maintain all existing functionality

Architecture

ForgeActor (base)
    │
    └── TitanTrainer (generic trainer, uses self.engine composition)
            │
            └── ForgeSFTRecipe (SFT-specific: data, eval, training loop)

ForgeEngine (standalone, accessed via self.engine)

Key Design Decisions

Decision Rationale
Composition for ForgeEngine Avoids diamond inheritance, explicit dependencies, easier testing
Inheritance for Recipes Recipes "are-a" trainer, automatic access to base utilities
Two forward_backward variants SFT uses engine.loss_fn, RL uses custom self.loss
Endpoint vs non-endpoint train_step() endpoint for RL external control, train_step_sft() for SFT internal loop
Config filtering TitanTrainer only handles engine config, recipes handle specific fields

Implementation

Components Extracted to TitanTrainer

Component Purpose Used By
_setup_engine() Initialize ForgeEngine SFT, RL
rank_should_record_loss PP stage-aware loss recording SFT, RL
setup_metric_logger() Metric logging setup SFT, RL
record_batch_metrics() Batch metric recording SFT
forward_backward() Forward/backward with PP+CP (uses engine.loss_fn) SFT
forward_backward_rl() Forward/backward with custom loss (uses self.loss) RL
train_step_sft() SFT training step SFT
train_step() RL training step (endpoint) RL
push_weights() Push to torchstore RL
cleanup() Resource cleanup SFT, RL

Before/After Comparison

Before:

class ForgeSFTRecipe(ForgeActor, ForgeEngine):  # Multiple inheritance
    def forward_backward(...):
        # 60+ lines of PP/CP handling
        loss = self.loss_fn(pred, labels)

    def train_step(self, batch):
        self.optimizers.step()  # Direct access

After:

class ForgeSFTRecipe(TitanTrainer):  # Single inheritance
    # forward_backward() inherited
    # train_step_sft() inherited

    def setup_data(...):
        dp_mesh = self.engine.parallel_dims.world_mesh.get_group("dp")  # Composition

File Changes

File Changes Description
src/forge/actors/trainer/titan.py +149 New TitanTrainer base class
apps/sft/main.py +81/-215 Refactored to use TitanTrainer
apps/sft/llama3_8b.yaml +2/-6 Config structure (compile → top-level)
docs/TRAINER_REFACTORING.md +322 Design docs

Net: +974 additions, -215 deletions


Breaking Changes

Config Structure: compile moved from training.compile to top-level compile.enable

# OLD
training:
  compile: false

# NEW
compile:
  enable: false

Migration: Update YAML configs to new structure (required by ForgeJobConfig schema)


Future Work

Recommended: Extract Helper Functions (Priority)

Current implementation has code duplication in non-PP forward/backward paths. Recommended refactor:

# Extract shared logic
def _compute_loss_non_pp(self, model, inputs, targets, use_custom_loss, cp_ctx):
    """Shared non-PP forward/backward logic."""
    with self.engine.train_context(cp_ctx):
        with self.engine.maybe_enable_amp:
            if use_custom_loss:
                logits = model(**inputs)
                loss = self.loss(logits, **targets)
            else:
                pred = model(inputs)
                loss = self.engine.loss_fn(pred, targets)
    return loss

def forward_backward(self, input_dict, labels, skip_backward=False):
    """SFT version - clear API."""
    inputs = input_dict["tokens"]
    cp_ctx = self._setup_cp(inputs, labels) if self.engine.parallel_dims.cp_enabled else None

    if self.engine.parallel_dims.pp_enabled:
        loss = self._forward_backward_pp(inputs, labels, cp_ctx)
    else:
        loss = self._compute_loss_non_pp(model, inputs, labels, use_custom_loss=False, cp_ctx)
        if not skip_backward:
            loss.backward()
    return loss

def forward_backward_rl(self, inputs, targets):
    """RL version - clear API."""
    if self.engine.parallel_dims.pp_enabled:
        raise NotImplementedError("PP not yet supported for RL")
    loss = self._compute_loss_non_pp(model, inputs, targets, use_custom_loss=True, cp_ctx=None)
    loss.backward()
    return loss

Benefits:

  • Reduces non-PP path duplication (~12 lines)
  • Keeps clear public API (forward_backward vs forward_backward_rl)
  • Easier to add RL PP support later
  • Better testability

Why not merge completely? Different signatures and use cases make separate methods clearer than one method with boolean flags.

Other Future PRs

PR 2: Enhanced Features

  • Gradient clipping utilities
  • Bringing Titan APIs as they are for less duplicated code

Success Metrics

Code Quality:

  • ✅ -215 lines of duplicated code
  • ✅ Composition pattern: 100% for engine access
  • ✅ Clearer inheritance hierarchy

Functionality:

  • ✅ Full feature parity with pre-refactoring
  • ✅ All distributed training modes working (PP, CP, DP, FSDP)

Hossein Kavianihamedani and others added 5 commits December 18, 2025 16:50
- Create TitanTrainer dataclass in src/forge/actors/trainer/titan.py
- TitanTrainer uses composition with ForgeEngine (self.engine)
- Extract common methods: forward_backward, train_step_sft, setup_metric_logger
- Update ForgeSFTRecipe to inherit from TitanTrainer
- Move compile config to top-level in YAML (required by ForgeJobConfig)
- Add documentation for trainer refactoring
- Move compile from training.compile to compile.enable in llama3_8b.yaml
- Move compile from training.compile to compile.enable in qwen3_8b.yaml

This aligns the YAML config structure with ForgeJobConfig's expected schema,
where compile is a separate top-level dataclass, not nested under training.
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Dec 26, 2025
@codecov-commenter
Copy link

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 83.83%. Comparing base (7b8580a) to head (0d52553).
⚠️ Report is 12 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #681      +/-   ##
==========================================
+ Coverage   83.05%   83.83%   +0.77%     
==========================================
  Files          31       31              
  Lines        4036     3946      -90     
==========================================
- Hits         3352     3308      -44     
+ Misses        684      638      -46     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

self.engine.checkpointer.load(step=self.step)
self.engine.optimizers.zero_grad()

async def setup_metric_logger(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of adding this, can we use record_metric like is done in other places in this file?

loss.backward()
return loss

def train_step_sft(self, batch: dict[str, Tensor]) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we prefer this to live on TitanTrainer vs SFTTrainer?

profiler: Profiler
device: torch.device
step: int
class ForgeSFTRecipe(TitanTrainer):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we rename this SFTTrainer?


def __repr__(self) -> str:
return "Trainer"
return "ForgeSFTRecipe"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need this method?


import torchtitan.experiments.forge.train_spec as forge_train_spec
from forge.controller import ForgeActor
from forge.actors.trainer import TitanTrainer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we make the bottom of this file emulate grpo/main.py?

if __name__ == "__main__":
@parse
def _main(cfg):
asyncio.run(main(cfg))
_main() # @parse grabs the cfg from CLI

Comment on lines +131 to +135
# Already initialized in post_init__
pass

def _setup_engine(self):
"""Initialize the ForgeEngine (non-endpoint helper for subclasses)."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this change? setup is called by ForgeActor here:

await actor.setup.call()

@daniellepintz
Copy link
Contributor

thanks for the PR! Left some comments :) Also, nit but can we add a more descriptive title?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants