Move multi-step training into TrainingConfig with per-step IS correction#39
Move multi-step training into TrainingConfig with per-step IS correction#39
Conversation
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughThis PR implements multi-step gradient updates within single batches and feedback repetition control. Configuration is restructured to nest training parameters under a Changes
Sequence Diagram(s)sequenceDiagram
participant Trainer as DistillationTrainer
participant PreparedSamples as Prepared Samples
participant Step as Step Loop
participant StudentModel as Student Model
participant Optimizer as Optimizer
Trainer->>PreparedSamples: Validate & accumulate samples<br/>(full_ids, response_ids, logprobs)
PreparedSamples-->>Trainer: PreparedSample list
Trainer->>Step: For each step in steps_per_batch
Step->>StudentModel: Compute student response logprobs<br/>(current adapter state)
StudentModel-->>Step: per_step_logprobs
Step->>Step: Build SDPOLossInput from<br/>prepared samples + new logprobs
Step->>Step: Compute per-step loss<br/>(distill_loss, kl_reg, clip)
Step->>Optimizer: Backward & gradient update<br/>with clipping
Optimizer-->>Step: updated model state
Step->>StudentModel: Recompute behavior_logprobs<br/>for next step
StudentModel-->>Step: updated logprobs
Step-->>Trainer: step metrics & updated state
Trainer->>Trainer: Aggregate per-step metrics<br/>steps_per_batch_applied
Trainer-->>Trainer: Return per-step results<br/>& tokens processed
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Poem
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (5)
claas/training/engine/tinker/engine.py (2)
239-241: Lambda used for averaging — minor style nit.The
avglambda is re-created each loop iteration. Consider extracting it before the loop or using a local function.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@claas/training/engine/tinker/engine.py` around lines 239 - 241, The averaging lambda avg is being recreated each loop iteration; extract it as a local helper function or define it once before the loop to avoid recreating the closure repeatedly. Replace the inline lambda assignment avg = lambda key: ... with a named function (e.g., def avg(key): return sum(m[key] for m in sample_metrics) / n) or move the lambda definition above the loop where sample_metrics and n are available, and update all uses of avg (referenced as avg and sample_metrics in this block) accordingly.
218-261: Multi-step loop with Tinker SDK: correct but note the cost of intermediate weight saves.The flow is sound: build datums → forward/backward → optimizer step → recompute logprobs. The
save_weights_and_get_sampling_client_asynccall at line 257 is required by Tinker's architecture to get a sampling client with updated weights, but it means each intermediate step (all except the last) triggers a full weight save. Forsteps_per_batch > 2, this could be a latency concern.Worth documenting this tradeoff or considering whether Tinker offers a lighter-weight way to get an updated sampling client without a full checkpoint save.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@claas/training/engine/tinker/engine.py` around lines 218 - 261, The loop calls training_client.save_weights_and_get_sampling_client_async inside the step loop (see save_weights_and_get_sampling_client_async, steps_per_batch and training_client) which triggers a full weight save on every intermediate step and can cause latency when steps_per_batch > 2; update the code to either (a) document this tradeoff just above the loop and in the function docstring, or (b) add a configurable behavior (e.g., a flag like save_intermediate_weights) so you only call save_weights_and_get_sampling_client_async for steps where it’s necessary (or avoid it until the final step), and, if the Tinker SDK offers a lighter alternative to get an updated sampling client, switch to that API instead.tests/test_tinker_engine.py (1)
91-122: Consider parameterizing mock save paths for multi-step scenarios.The
mock_training_clientfixture returns a fixedsave_result.path = "tinker://checkpoints/step-1"regardless of the checkpoint name passed tosave_state_async. This works for current tests, but if future tests need to assert the saved path reflects the actual step, the fixture would need to be updated. Not a blocker.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/test_tinker_engine.py` around lines 91 - 122, The fixture mock_training_client currently returns a fixed save_result.path ("tinker://checkpoints/step-1") for save_state_async; change it so save_state_async uses an AsyncMock side_effect that builds and returns a MagicMock whose .path is derived from the checkpoint name/step passed into save_state_async (e.g., include the step id or checkpoint name from the method args), and do the same for save_weights_for_sampler_async/sampler_save.path if needed; update references to save_result and sampler_save in the fixture to be created inside the side_effects so tests that call mock_training_client.save_state_async(...) will receive a result object with a path that reflects the input.claas/training/distillation.py (1)
38-47: PreparedSample name collision with tinker engine.Both
claas/training/distillation.pyandclaas/training/engine/tinker/engine.pydefine aPreparedSampleTypedDict with different fields (torch.Tensor-based vs. list-based). This works fine since they're module-private, but could cause confusion when navigating the codebase or in IDE symbol search.Consider naming one of them more specifically (e.g.,
LocalPreparedSampleorTinkerPreparedSample) to disambiguate.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@claas/training/distillation.py` around lines 38 - 47, The TypedDict PreparedSample in claas/training/distillation.py collides by name with another PreparedSample in claas/training/engine/tinker/engine.py; rename this TypedDict to a more specific name (e.g., DistillationPreparedSample or LocalPreparedSample) and update all local type annotations and imports in claas/training/distillation.py that reference PreparedSample (functions, return types, variables) to use the new name so the module remains unambiguous while preserving the same fields and behavior.claas/eval/types.py (1)
80-91: Risk of default drift betweenEvalTrainingConfigandTrainingConfig.
EvalTrainingConfigmanually duplicates field names and defaults from the PydanticTrainingConfig(inclaas/core/types.py). If a default changes in one but not the other, eval runs will silently use stale values. Consider adding a test or a factory that asserts parity.💡 Example: add a parity test
# tests/test_eval_config.py (or similar) from claas.core.types import TrainingConfig from claas.eval.types import EvalTrainingConfig def test_eval_training_config_defaults_match(): """Ensure EvalTrainingConfig defaults stay in sync with TrainingConfig.""" runtime = TrainingConfig() hydra = EvalTrainingConfig() for f in dataclasses.fields(hydra): assert getattr(hydra, f.name) == getattr(runtime, f.name), ( f"Default mismatch on '{f.name}': " f"EvalTrainingConfig={getattr(hydra, f.name)} vs " f"TrainingConfig={getattr(runtime, f.name)}" )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@claas/eval/types.py` around lines 80 - 91, EvalTrainingConfig duplicates defaults from the Pydantic TrainingConfig which can drift; add a parity test that instantiates TrainingConfig and EvalTrainingConfig and asserts all field defaults match (use dataclasses.fields on EvalTrainingConfig and compare getattr(hydra, name) == getattr(runtime, name)), e.g. add tests/test_eval_config.py to fail CI if any default on EvalTrainingConfig diverges from TrainingConfig; alternatively implement a factory that constructs EvalTrainingConfig from TrainingConfig to guarantee parity and update usages to call that factory instead of hardcoding defaults.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@claas/eval/runner.py`:
- Around line 84-93: The code currently indexes
metadata["steps_per_batch_applied"] in the LocalDistillMetrics return path which
can raise KeyError and drop the entire metrics object; change that access to use
metadata.get("steps_per_batch_applied", 1) so LocalDistillMetrics is constructed
with a default of 1 when the field is absent (keep identical pattern used for
other fields like distill_loss, kl_reg, mean_is_ratio, clip_fraction) to make
the metrics construction resilient.
---
Nitpick comments:
In `@claas/eval/types.py`:
- Around line 80-91: EvalTrainingConfig duplicates defaults from the Pydantic
TrainingConfig which can drift; add a parity test that instantiates
TrainingConfig and EvalTrainingConfig and asserts all field defaults match (use
dataclasses.fields on EvalTrainingConfig and compare getattr(hydra, name) ==
getattr(runtime, name)), e.g. add tests/test_eval_config.py to fail CI if any
default on EvalTrainingConfig diverges from TrainingConfig; alternatively
implement a factory that constructs EvalTrainingConfig from TrainingConfig to
guarantee parity and update usages to call that factory instead of hardcoding
defaults.
In `@claas/training/distillation.py`:
- Around line 38-47: The TypedDict PreparedSample in
claas/training/distillation.py collides by name with another PreparedSample in
claas/training/engine/tinker/engine.py; rename this TypedDict to a more specific
name (e.g., DistillationPreparedSample or LocalPreparedSample) and update all
local type annotations and imports in claas/training/distillation.py that
reference PreparedSample (functions, return types, variables) to use the new
name so the module remains unambiguous while preserving the same fields and
behavior.
In `@claas/training/engine/tinker/engine.py`:
- Around line 239-241: The averaging lambda avg is being recreated each loop
iteration; extract it as a local helper function or define it once before the
loop to avoid recreating the closure repeatedly. Replace the inline lambda
assignment avg = lambda key: ... with a named function (e.g., def avg(key):
return sum(m[key] for m in sample_metrics) / n) or move the lambda definition
above the loop where sample_metrics and n are available, and update all uses of
avg (referenced as avg and sample_metrics in this block) accordingly.
- Around line 218-261: The loop calls
training_client.save_weights_and_get_sampling_client_async inside the step loop
(see save_weights_and_get_sampling_client_async, steps_per_batch and
training_client) which triggers a full weight save on every intermediate step
and can cause latency when steps_per_batch > 2; update the code to either (a)
document this tradeoff just above the loop and in the function docstring, or (b)
add a configurable behavior (e.g., a flag like save_intermediate_weights) so you
only call save_weights_and_get_sampling_client_async for steps where it’s
necessary (or avoid it until the final step), and, if the Tinker SDK offers a
lighter alternative to get an updated sampling client, switch to that API
instead.
In `@tests/test_tinker_engine.py`:
- Around line 91-122: The fixture mock_training_client currently returns a fixed
save_result.path ("tinker://checkpoints/step-1") for save_state_async; change it
so save_state_async uses an AsyncMock side_effect that builds and returns a
MagicMock whose .path is derived from the checkpoint name/step passed into
save_state_async (e.g., include the step id or checkpoint name from the method
args), and do the same for save_weights_for_sampler_async/sampler_save.path if
needed; update references to save_result and sampler_save in the fixture to be
created inside the side_effects so tests that call
mock_training_client.save_state_async(...) will receive a result object with a
path that reflects the input.
ℹ️ Review info
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to data retention organization setting
📒 Files selected for processing (11)
claas/core/types.pyclaas/eval/README.mdclaas/eval/config.pyclaas/eval/configs/base.yamlclaas/eval/runner.pyclaas/eval/types.pyclaas/training/distillation.pyclaas/training/engine/tinker/engine.pytests/test_eval_config.pytests/test_eval_runner.pytests/test_tinker_engine.py
| steps_per_batch_applied=metadata["steps_per_batch_applied"], | ||
| ) | ||
|
|
||
| return LocalDistillMetrics( | ||
| distill_loss=metadata.get("distill_loss"), | ||
| kl_reg=metadata.get("kl_reg"), | ||
| mean_is_ratio=metadata.get("mean_is_ratio"), | ||
| clip_fraction=metadata.get("clip_fraction"), | ||
| steps_per_batch_applied=metadata["steps_per_batch_applied"], | ||
| ) |
There was a problem hiding this comment.
Use .get() with a default for steps_per_batch_applied to match resilience of other fields.
Lines 84 and 92 use hard metadata["steps_per_batch_applied"] access. If an older or third-party engine omits this new field, both branches raise KeyError. While the caller catches KeyError (line 409), that discards the entire metrics object — losing distill_loss, kl_mean, etc. for the step.
This is especially inconsistent in the local branch (lines 88–92), where every other field uses .get().
🛡️ Proposed fix — use `.get()` with default 1
batch_size=metadata["batch_size"],
- steps_per_batch_applied=metadata["steps_per_batch_applied"],
+ steps_per_batch_applied=metadata.get("steps_per_batch_applied", 1),
)
return LocalDistillMetrics(
distill_loss=metadata.get("distill_loss"),
kl_reg=metadata.get("kl_reg"),
mean_is_ratio=metadata.get("mean_is_ratio"),
clip_fraction=metadata.get("clip_fraction"),
- steps_per_batch_applied=metadata["steps_per_batch_applied"],
+ steps_per_batch_applied=metadata.get("steps_per_batch_applied", 1),
)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@claas/eval/runner.py` around lines 84 - 93, The code currently indexes
metadata["steps_per_batch_applied"] in the LocalDistillMetrics return path which
can raise KeyError and drop the entire metrics object; change that access to use
metadata.get("steps_per_batch_applied", 1) so LocalDistillMetrics is constructed
with a default of 1 when the field is absent (keep identical pattern used for
other fields like distill_loss, kl_reg, mean_is_ratio, clip_fraction) to make
the metrics construction resilient.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 762072da79
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
…aining # Conflicts: # claas/core/types.py # claas/eval/README.md # claas/eval/config.py # claas/eval/configs/base.yaml # claas/eval/runner.py # claas/eval/types.py # tests/test_eval_config.py # tests/test_eval_runner.py
|
@codex review |
There was a problem hiding this comment.
♻️ Duplicate comments (1)
claas/eval/runner.py (1)
84-84:⚠️ Potential issue | 🟠 MajorUse safe default access for
steps_per_batch_appliedto avoid dropping metrics.Line 84 and Line 92 still use
metadata["steps_per_batch_applied"]. If omitted by an engine, this raisesKeyError, and the catch path discards the entire SDPO metrics object for that step.Suggested fix
if config.mode == "tinker" and "adv_mean" in metadata: return TinkerDistillMetrics( @@ - steps_per_batch_applied=metadata["steps_per_batch_applied"], + steps_per_batch_applied=metadata.get("steps_per_batch_applied", 1), ) @@ return LocalDistillMetrics( distill_loss=metadata.get("distill_loss"), kl_reg=metadata.get("kl_reg"), mean_is_ratio=metadata.get("mean_is_ratio"), clip_fraction=metadata.get("clip_fraction"), - steps_per_batch_applied=metadata["steps_per_batch_applied"], + steps_per_batch_applied=metadata.get("steps_per_batch_applied", 1), )Also applies to: 92-92
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@claas/eval/runner.py` at line 84, Replace direct indexing of metadata["steps_per_batch_applied"] with a safe lookup that supplies a sensible default (e.g., metadata.get("steps_per_batch_applied", 1)) to avoid raising KeyError and dropping the SDPO metrics object; update both occurrences that reference steps_per_batch_applied in claass.eval.runner (the two places around the current lines using metadata["steps_per_batch_applied"]) so downstream logic receives the fallback value when the engine omits the key.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@claas/eval/runner.py`:
- Line 84: Replace direct indexing of metadata["steps_per_batch_applied"] with a
safe lookup that supplies a sensible default (e.g.,
metadata.get("steps_per_batch_applied", 1)) to avoid raising KeyError and
dropping the SDPO metrics object; update both occurrences that reference
steps_per_batch_applied in claass.eval.runner (the two places around the current
lines using metadata["steps_per_batch_applied"]) so downstream logic receives
the fallback value when the engine omits the key.
ℹ️ Review info
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to data retention organization setting
📒 Files selected for processing (7)
claas/core/types.pyclaas/eval/README.mdclaas/eval/configs/base.yamlclaas/eval/runner.pyclaas/eval/types.pytests/test_eval_config.pytests/test_eval_runner.py
🚧 Files skipped from review as they are similar to previous changes (3)
- claas/core/types.py
- claas/eval/types.py
- tests/test_eval_config.py
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: e2331f72b5
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| max_grad_norm: float = 1.0 | ||
| kl_reg_weight: float = 0.0 | ||
| teacher_top_k: int = 100 | ||
| steps_per_batch: int = 4 |
There was a problem hiding this comment.
Enforce positive
steps_per_batch in TrainingConfig
The newly added steps_per_batch field has no lower-bound validation, but both multi-step trainers now assume at least one iteration and unconditionally read step_metrics[-1] (claas/training/distillation.py and claas/training/engine/tinker/engine.py), so sending training.steps_per_batch=0 is currently accepted and then crashes /v1/feedback with a server error instead of a clean 4xx validation failure; this can break eval runs by turning every feedback update into a failed request.
Useful? React with 👍 / 👎.
Summary
steps_per_batch,feedback_repetitions) from eval-owned settings intoTrainingConfigtrainingconfig throughFeedbackItemin each/v1/feedbackrequeststeps_per_batch_applied, per-step metrics) and wire evalsub_step_countto that metadataKey Implementation Notes
TrainingConfigfields:steps_per_batchfeedback_repetitionsEvalTrainingConfigand convert to runtimeTrainingConfiginbuild_harness_configsave_weights_and_get_sampling_client_asyncValidation
uv run ruff check claas/ tests/ --fixuv run pytest tests/ -q -m "not integration"109 passed, 26 skipped, 5 deselecteduv run ty checktorch,tinker,transformers) are expected in this environmentSummary by CodeRabbit
Release Notes
New Features
steps_per_batchparameterfeedback_repetitionsconfiguration option for enhanced training controlsteps_per_batch_appliedtracks actual steps executed per batchDocumentation
Refactor