Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions .github/workflows/config/.secrets.baseline
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,6 @@
{
"path": "detect_secrets.filters.allowlist.is_line_allowlisted"
},
{
"path": "detect_secrets.filters.common.is_baseline_file",
"filename": ".github/workflows/config/.secrets.baseline"
},
{
"path": "detect_secrets.filters.common.is_ignored_due_to_verification_policies",
"min_level": 2
Expand Down Expand Up @@ -139,10 +135,10 @@
"filename": "examples/automodel/pretrain/cicd/wan21_cicd_nightly_video.yaml",
"hashed_secret": "c70f071570ba65f9c4079d6051e955ff4f802eea",
"is_verified": false,
"line_number": 67,
"line_number": 72,
"is_secret": false
}
]
},
"generated_at": "2026-01-30T18:50:34Z"
"generated_at": "2026-02-12T07:45:24Z"
}
115 changes: 98 additions & 17 deletions dfm/src/automodel/recipes/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from nemo_automodel.components.checkpoint.checkpointing import Checkpointer, CheckpointingConfig
from nemo_automodel.components.loggers.log_utils import setup_logging
from nemo_automodel.components.loggers.wandb_utils import suppress_wandb_log_messages
from nemo_automodel.components.optim.scheduler import OptimizerParamScheduler
from nemo_automodel.components.training.rng import StatefulRNG
from nemo_automodel.components.training.step_scheduler import StepScheduler
from nemo_automodel.recipes.base_recipe import BaseRecipe
Expand Down Expand Up @@ -195,20 +196,93 @@ def build_model_and_optimizer(


def build_lr_scheduler(
cfg,
optimizer: torch.optim.Optimizer,
*,
num_epochs: int,
steps_per_epoch: int,
eta_min: float = 1e-6,
) -> torch.optim.lr_scheduler.CosineAnnealingLR:
"""Build the cosine annealing learning rate scheduler."""

total_steps = max(1, num_epochs * max(1, steps_per_epoch))
logging.info(f"[INFO] Scheduler configured for {total_steps} total steps")
return torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer,
T_max=total_steps,
eta_min=eta_min,
total_steps: int,
) -> Optional[OptimizerParamScheduler]:
"""Build the learning rate scheduler.

Args:
cfg: Configuration for the OptimizerParamScheduler from YAML. If None, no scheduler
is created and constant LR is used. Supports:
- lr_decay_style: constant, linear, cosine, inverse-square-root, WSD
- lr_warmup_steps: Number of warmup steps (or fraction < 1 for percentage)
- min_lr: Minimum LR after decay
- init_lr: Initial LR for warmup (defaults to 10% of max_lr if warmup enabled)
- wd_incr_style: constant, linear, cosine (for weight decay scheduling)
- wsd_decay_steps: WSD-specific decay steps
- lr_wsd_decay_style: WSD-specific decay style (cosine, linear, exponential, minus_sqrt)
optimizer: The optimizer to be scheduled.
total_steps: Total number of optimizer steps for the training run.

Returns:
OptimizerParamScheduler instance, or None if cfg is None.
"""
if cfg is None:
return None

user_cfg = cfg.to_dict() if hasattr(cfg, "to_dict") else dict(cfg)

base_lr = optimizer.param_groups[0]["lr"]
base_wd = optimizer.param_groups[0].get("weight_decay", 0.0)

# Compute defaults from runtime values
default_cfg: Dict[str, Any] = {
"optimizer": optimizer,
"lr_warmup_steps": min(1000, total_steps // 10),
"lr_decay_steps": total_steps,
"lr_decay_style": "cosine",
"init_lr": base_lr * 0.1,
"max_lr": base_lr,
"min_lr": base_lr * 0.01,
"start_wd": base_wd,
"end_wd": base_wd,
"wd_incr_steps": total_steps,
"wd_incr_style": "constant",
}

# Handle warmup as fraction before merging
if "lr_warmup_steps" in user_cfg:
warmup = user_cfg["lr_warmup_steps"]
if isinstance(warmup, float) and 0 < warmup < 1:
user_cfg["lr_warmup_steps"] = int(warmup * total_steps)

# WSD defaults if user specifies WSD style
if user_cfg.get("lr_decay_style") == "WSD":
default_cfg["wsd_decay_steps"] = max(1, total_steps // 10)
default_cfg["lr_wsd_decay_style"] = "cosine"

# User config overrides defaults
default_cfg.update(user_cfg)

# If user disabled warmup, set init_lr = max_lr
if default_cfg["lr_warmup_steps"] == 0:
default_cfg["init_lr"] = default_cfg["max_lr"]

# Ensure warmup < decay steps
if default_cfg["lr_warmup_steps"] >= default_cfg["lr_decay_steps"]:
default_cfg["lr_warmup_steps"] = max(0, default_cfg["lr_decay_steps"] - 1)

logging.info(
f"[INFO] LR Scheduler: style={default_cfg['lr_decay_style']}, "
f"warmup={default_cfg['lr_warmup_steps']}, total={default_cfg['lr_decay_steps']}, "
f"max_lr={default_cfg['max_lr']}, min_lr={default_cfg['min_lr']}"
)

return OptimizerParamScheduler(
optimizer=default_cfg["optimizer"],
init_lr=default_cfg["init_lr"],
max_lr=default_cfg["max_lr"],
min_lr=default_cfg["min_lr"],
lr_warmup_steps=default_cfg["lr_warmup_steps"],
lr_decay_steps=default_cfg["lr_decay_steps"],
lr_decay_style=default_cfg["lr_decay_style"],
start_wd=default_cfg["start_wd"],
end_wd=default_cfg["end_wd"],
wd_incr_steps=default_cfg["wd_incr_steps"],
wd_incr_style=default_cfg["wd_incr_style"],
wsd_decay_steps=default_cfg.get("wsd_decay_steps"),
lr_wsd_decay_style=default_cfg.get("lr_wsd_decay_style"),
)


Expand Down Expand Up @@ -390,11 +464,17 @@ def setup(self):
grad_acc_steps = max(1, self.global_batch_size // max(1, self.local_batch_size * self.dp_size))
self.steps_per_epoch = ceil(self.raw_steps_per_epoch / grad_acc_steps)

self.lr_scheduler = build_lr_scheduler(
# Calculate total optimizer steps for LR scheduler
total_steps = self.num_epochs * self.steps_per_epoch

# Build LR scheduler (returns None if lr_scheduler not in config)
# Wrap in list for compatibility with checkpointing (OptimizerState expects list)
lr_scheduler = build_lr_scheduler(
self.cfg.get("lr_scheduler", None),
self.optimizer,
num_epochs=self.num_epochs,
steps_per_epoch=self.steps_per_epoch,
total_steps,
)
self.lr_scheduler = [lr_scheduler] if lr_scheduler is not None else None

self.global_step = 0
self.start_epoch = 0
Expand Down Expand Up @@ -490,7 +570,8 @@ def run_train_validation_loop(self):
grad_norm = float(grad_norm) if torch.is_tensor(grad_norm) else grad_norm

self.optimizer.step()
self.lr_scheduler.step()
if self.lr_scheduler is not None:
self.lr_scheduler[0].step(1)

group_loss_mean = float(sum(micro_losses) / len(micro_losses))
epoch_loss += group_loss_mean
Expand Down
Loading