Skip to content

Conversation

@HosseinKaviani-H
Copy link
Contributor

@HosseinKaviani-H HosseinKaviani-H commented Jan 22, 2026

Summary

Adds TrainBatch dataclass that separates model_inputs from loss_inputs, enabling any training paradigm without type changes.

Motivation

The current TextTrainBatch has limitations:

  • Hardcoded fields require changes for each new training mode
  • Text-only naming doesn't support multimodal
  • Every new paradigm (DPO, distillation, etc.) needs type updates

Solution

@dataclass
class TrainBatch:
    model_inputs: dict[str, Any]
    loss_inputs: dict[str, Any]
    meta: dict[str, Any] = field(default_factory=dict)

# Usage:
logits = model(**batch.model_inputs)
loss = loss_fn(logits, **batch.loss_inputs)

Files Changed
File: src/forge/types.py
Change: Added TrainBatch dataclass
────────────────────────────────────────
File: src/forge/rl/collate.py
Change: Updated to return list[TrainBatch] with model_inputs/loss_inputs
────────────────────────────────────────
File: src/forge/actors/trainer/titan.py
Change: Updated train_step() to accept list[TrainBatch] and unpack fields
────────────────────────────────────────
File: apps/grpo/main.py
Change: Updated to pass batch directly: trainer.train_step.call(batch)

Test Plan

  • Core implementation: types.py, collate.py, titan.py, main.py
  • Update test files (tests/sandbox/)
  • Update documentation (docs/)

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jan 22, 2026
Copy link
Contributor

@felipemello1 felipemello1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i dont think that this class should be in trainer.py. Probably in types.py or something like that. Are you also going to add it to collate and test it in this PR?

@joecummings
Copy link
Member

i dont think that this class should be in trainer.py. Probably in types.py or something like that. Are you also going to add it to collate and test it in this PR?

Why wouldn't this be in the trainer.py file under api? It defines the training API of which this is part. I would vote to keep it in the trainer API.

@felipemello1
Copy link
Contributor

Why wouldn't this be in the trainer.py file under api?

this is also used collate_fn. Not sure if it may be used in other places. I think we would be exposed to circular dependencies.

e.g. collate imports from train
train imports from X
X imports from collate

Also, thats what other frameworks do, like tinker: https://github.com/thinking-machines-lab/tinker/blob/ad03d44978096b1dcae662e469293e70f509d5a8/src/tinker/types/datum.py#L25

@joecummings
Copy link
Member

e.g. collate imports from train
train imports from X
X imports from collate

What would X be here? I will not hold up the PR on this point but am curious b/c I have a hard time imagining what that would be.

@felipemello1
Copy link
Contributor

felipemello1 commented Jan 23, 2026

What would X be here?

I will leave that as an exercise for the reader

jk, i guess it cannot happen if collate is its own file and doesnt really import from anywhere. It just makes more sense to me, given the patterns i have seen. But no big deal either way. Worst case we refactor later.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants