From 9f16aff18870b7283ea1cbcfc2265478078c8d2c Mon Sep 17 00:00:00 2001 From: emergenz Date: Wed, 31 Dec 2025 18:35:19 +0100 Subject: [PATCH] chore: sync lora implementation with upstream --- miles/backends/fsdp_utils/actor.py | 30 ++--- miles/backends/fsdp_utils/checkpoint.py | 118 ++++++++----------- miles/backends/fsdp_utils/lora_utils.py | 113 ++++++++++++++++++ miles/models/peft/__init__.py | 3 - miles/models/peft/arguments.py | 38 ------- miles/models/peft/lora.py | 145 ------------------------ miles/utils/arguments.py | 70 +++++++++++- requirements.txt | 1 + scripts/run-sft-torchrun.sh | 4 +- tests/models/peft/test_lora.py | 64 ----------- train_sft.py | 145 ++++++++---------------- 11 files changed, 288 insertions(+), 443 deletions(-) create mode 100644 miles/backends/fsdp_utils/lora_utils.py delete mode 100644 miles/models/peft/__init__.py delete mode 100644 miles/models/peft/arguments.py delete mode 100644 miles/models/peft/lora.py delete mode 100644 tests/models/peft/test_lora.py diff --git a/miles/backends/fsdp_utils/actor.py b/miles/backends/fsdp_utils/actor.py index 729d9a428..091aa6868 100644 --- a/miles/backends/fsdp_utils/actor.py +++ b/miles/backends/fsdp_utils/actor.py @@ -26,9 +26,9 @@ from ...utils import tracking_utils from ...utils.profile_utils import TrainProfiler -from ...models.peft import LoRAConfig, apply_lora from . import checkpoint from .data_packing import pack_sequences, pad_packed_sequence_with_cp, unpack_sequences +from .lora_utils import apply_lora_to_model, is_lora_model from .lr_scheduler import get_lr_scheduler from .update_weight_utils import UpdateWeightFromDistributed, UpdateWeightFromTensor @@ -95,15 +95,8 @@ def init(self, args: Namespace, role: str, with_ref: bool = False) -> int: # ty attn_implementation=self.args.attn_implementation, ) - if args.use_lora: - lora_config = LoRAConfig( - lora_rank=args.lora_rank, - lora_alpha=args.lora_alpha, - lora_dropout=args.lora_dropout, - target_modules=args.lora_target_modules, - ) - model = apply_lora(model, lora_config) - logger.info(f"[Rank {dist.get_rank()}] Applied LoRA: {lora_config}") + if args.lora_rank > 0 or args.lora_adapter_path: + model = apply_lora_to_model(model, args) model.train() @@ -118,20 +111,21 @@ def init(self, args: Namespace, role: str, with_ref: bool = False) -> int: # ty self.model = model if args.gradient_checkpointing: - # Use non-reentrant mode for gradient checkpointing - self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) + # Use non-reentrant mode for gradient checkpointing (required for PEFT/LoRA) + gc_kwargs = {"use_reentrant": False} if is_lora_model(self.model) else {} + self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gc_kwargs) if args.optimizer == "adam": trainable_params = [p for p in self.model.parameters() if p.requires_grad] - - if args.use_lora: + + if is_lora_model(self.model): total_params = sum(p.numel() for p in self.model.parameters()) trainable_count = sum(p.numel() for p in trainable_params) logger.info( f"[Rank {dist.get_rank()}] LoRA: {trainable_count:,} trainable params " f"out of {total_params:,} total ({100 * trainable_count / total_params:.2f}%)" ) - + self.optimizer = torch.optim.AdamW( trainable_params, lr=args.lr, @@ -344,11 +338,7 @@ def save_model(self, iteration: int) -> None: if self.args.debug_rollout_only or self.args.save is None: return - keys_filter = None - if self.args.use_lora: - keys_filter = lambda k: "lora_" in k - - checkpoint.save(self, iteration, keys_filter=keys_filter) + checkpoint.save(self, iteration) def _compute_log_prob( self, diff --git a/miles/backends/fsdp_utils/checkpoint.py b/miles/backends/fsdp_utils/checkpoint.py index 3846bd98c..5287e047e 100644 --- a/miles/backends/fsdp_utils/checkpoint.py +++ b/miles/backends/fsdp_utils/checkpoint.py @@ -6,13 +6,13 @@ from pathlib import Path from typing import Any -import safetensors.torch import torch import torch.distributed as dist import torch.distributed.checkpoint as dcp -from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict, get_model_state_dict, StateDictOptions +from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict from torch.distributed.checkpoint.stateful import Stateful -from miles.models.peft import LoRAConfig + +from .lora_utils import is_lora_model logger = logging.getLogger(__name__) @@ -20,53 +20,45 @@ class ModelState(Stateful): """Wrapper for model state only.""" - def __init__(self, model, keys_filter=None): + def __init__(self, model, lora_only: bool = False): self.model = model - self.keys_filter = keys_filter + self.lora_only = lora_only + self._key = "adapter" if lora_only else "model" def state_dict(self): model_state_dict, _ = get_state_dict(self.model, optimizers=[]) - if self.keys_filter: - model_state_dict = {k: v for k, v in model_state_dict.items() if self.keys_filter(k)} - return {"model": model_state_dict} + if self.lora_only: + model_state_dict = {k: v for k, v in model_state_dict.items() if "lora_" in k} + return {self._key: model_state_dict} def load_state_dict(self, state_dict): - options = None - if self.keys_filter: - # For filtered loading (e.g., LoRA), use strict=False to allow partial loading - options = StateDictOptions(strict=False) - set_state_dict( - self.model, optimizers=[], - model_state_dict=state_dict["model"], - optim_state_dict=None, - options=options - ) + data = state_dict[self._key] + + if self.lora_only: + full_state_dict, _ = get_state_dict(self.model, optimizers=[]) + full_state_dict.update(data) + set_state_dict(self.model, optimizers=[], model_state_dict=full_state_dict, optim_state_dict=None) + else: + set_state_dict(self.model, optimizers=[], model_state_dict=data, optim_state_dict=None) class OptimizerState(Stateful): """Wrapper for optimizer state only.""" - def __init__(self, model, optimizer, keys_filter=None): + def __init__(self, model, optimizer): self.model = model self.optimizer = optimizer - self.keys_filter = keys_filter def state_dict(self): _, optimizer_state_dict = get_state_dict(self.model, optimizers=self.optimizer) - if self.keys_filter: - optimizer_state_dict = {k: v for k, v in optimizer_state_dict.items() if self.keys_filter(k)} return {"optim": optimizer_state_dict} def load_state_dict(self, state_dict): - options = None - if self.keys_filter: - # For filtered loading (e.g., LoRA), use strict=False to allow partial loading - options = StateDictOptions(strict=False) set_state_dict( - self.model, optimizers=self.optimizer, - model_state_dict=None, + self.model, + optimizers=self.optimizer, + model_state_dict=None, optim_state_dict=state_dict["optim"], - options=options ) @@ -127,31 +119,28 @@ def load(actor: Any) -> dict[str, Any] | None: model_dir = checkpoint_dir / "model" optimizer_dir = checkpoint_dir / "optimizer" lr_scheduler_dir = checkpoint_dir / "lr_scheduler" + lora_dir = checkpoint_dir / "adapter" - if not model_dir.exists(): - logger.info(f"[FSDP] Model checkpoint {model_dir} not found; skipping load.") - return None + lora_only = lora_dir.exists() and is_lora_model(actor.model) + load_model_dir = lora_dir if lora_only else model_dir - keys_filter = None - if actor.args.use_lora: - keys_filter = lambda k: "lora_" in k - logger.info("[FSDP] LoRA mode: loading only LoRA weights from checkpoint") + if not load_model_dir.exists(): + logger.info(f"[FSDP] No model checkpoint found at {model_dir} or {lora_dir}; skipping load.") + return None - # Load model weights (always) - model_state = ModelState(actor.model, keys_filter=keys_filter) + model_state = ModelState(actor.model, lora_only=lora_only) state_dict = {"model_state": model_state} - try: - dcp.load(state_dict=state_dict, checkpoint_id=str(model_dir)) - logger.info(f"[FSDP] Loaded model from {model_dir}") + dcp.load(state_dict=state_dict, checkpoint_id=str(load_model_dir)) + logger.info(f"[FSDP] Loaded {'LoRA adapter' if lora_only else 'model'} from {load_model_dir}") except Exception as e: - logger.error(f"[FSDP] Failed to load model from {model_dir}: {e}") + logger.error(f"[FSDP] Failed to load {'LoRA adapter' if lora_only else 'model'} from {load_model_dir}: {e}") return None # Load optimizer state (optional) load_optimizer = not getattr(actor.args, "no_load_optim", False) and hasattr(actor, "optimizer") if load_optimizer and optimizer_dir.exists(): - optimizer_state = OptimizerState(actor.model, actor.optimizer, keys_filter=keys_filter) + optimizer_state = OptimizerState(actor.model, actor.optimizer) optim_state_dict = {"optim_state": optimizer_state} try: dcp.load(state_dict=optim_state_dict, checkpoint_id=str(optimizer_dir)) @@ -216,7 +205,7 @@ def finalize_load(actor: Any, checkpoint_payload: dict[str, Any] | None) -> None dist.barrier() -def save(actor: Any, iteration: int, keys_filter=None) -> None: +def save(actor: Any, iteration: int) -> None: """Save checkpoint to disk. Saves model weights and optimizer state to separate directories. @@ -239,13 +228,23 @@ def save(actor: Any, iteration: int, keys_filter=None) -> None: dist.barrier() # Save model weights - model_state = ModelState(actor.model, keys_filter=keys_filter) + lora_only = is_lora_model(actor.model) + if lora_only: + save_dir = checkpoint_dir / "adapter" + if dist.get_rank() == 0: + save_dir.mkdir(parents=True, exist_ok=True) + dist.barrier() + else: + save_dir = model_dir + + model_state = ModelState(actor.model, lora_only=lora_only) state_dict = {"model_state": model_state} - dcp.save(state_dict, checkpoint_id=str(model_dir)) + dcp.save(state_dict, checkpoint_id=str(save_dir)) + logger.info(f"[FSDP] Saved {'LoRA adapter' if lora_only else 'model'} to {save_dir}") # Save optimizer state if hasattr(actor, "optimizer") and actor.optimizer is not None: - optimizer_state = OptimizerState(actor.model, actor.optimizer, keys_filter=keys_filter) + optimizer_state = OptimizerState(actor.model, actor.optimizer) optim_state_dict = {"optim_state": optimizer_state} dcp.save(optim_state_dict, checkpoint_id=str(optimizer_dir)) @@ -275,29 +274,4 @@ def save(actor: Any, iteration: int, keys_filter=None) -> None: tracker_file.write_text(str(step_id)) logger.info(f"[FSDP] Saved checkpoint to {checkpoint_dir}") - if actor.args.use_lora: - _save_hf_lora(actor, checkpoint_dir) - dist.barrier() - - -def _save_hf_lora(actor: Any, checkpoint_dir: Path) -> None: - """Save LoRA adapter in Hugging Face PEFT format.""" - - options = dcp.state_dict.StateDictOptions(full_state_dict=True, cpu_offload=True) - full_state_dict = get_model_state_dict(actor.model, options=options) - - if dist.get_rank() == 0: - lora_config = LoRAConfig( - lora_rank=actor.args.lora_rank, - lora_alpha=actor.args.lora_alpha, - lora_dropout=actor.args.lora_dropout, - target_modules=actor.args.lora_target_modules, - ) - peft_config = lora_config.to_hf_peft_config() - with open(checkpoint_dir / "adapter_config.json", "w") as f: - json.dump(peft_config, f, indent=2) - - lora_state_dict = {k: v for k, v in full_state_dict.items() if "lora_" in k} - safetensors.torch.save_file(lora_state_dict, checkpoint_dir / "adapter_model.safetensors") - logger.info(f"[FSDP] Saved HF LoRA adapter to {checkpoint_dir}") diff --git a/miles/backends/fsdp_utils/lora_utils.py b/miles/backends/fsdp_utils/lora_utils.py new file mode 100644 index 000000000..0df435e10 --- /dev/null +++ b/miles/backends/fsdp_utils/lora_utils.py @@ -0,0 +1,113 @@ +"""LoRA utilities for FSDP backend using PEFT library. + +This module provides functions for applying, detecting, and saving LoRA adapters +in a way that's compatible with the PEFT library and HuggingFace ecosystem. +""" + +import logging +import os +import shutil +from pathlib import Path + +import torch.distributed as dist +import torch.nn as nn +from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict + +try: + from peft import LoraConfig, PeftModel, TaskType, get_peft_model +except ImportError as err: + raise ImportError("peft library required for LoRA. Install with: pip install peft") from err + +logger = logging.getLogger(__name__) + +LORA_READY_MARKER = ".lora_ready" +LORA_ADAPTER_NAME = "miles_lora" +LORA_SUBDIR = "tmp_lora" + + +def apply_lora_to_model(model: nn.Module, args) -> nn.Module: + """Apply LoRA to model using PEFT library. + + Args: + model: The base model to apply LoRA to. + args: Arguments containing LoRA configuration (lora_rank, lora_alpha, + target_modules, lora_adapter_path). + + Returns: + Model wrapped with LoRA adapters. + """ + if args.lora_adapter_path: + logger.info(f"Loading LoRA adapter from {args.lora_adapter_path}") + model = PeftModel.from_pretrained(model, args.lora_adapter_path, is_trainable=True) + peft_config = model.peft_config["default"] + if isinstance(peft_config.task_type, str): + peft_config.task_type = TaskType.CAUSAL_LM + model.print_trainable_parameters() + return model + + lora_config = LoraConfig( + task_type=TaskType.CAUSAL_LM, + r=args.lora_rank, + lora_alpha=args.lora_alpha, + target_modules=args.target_modules, + bias="none", + ) + + model = get_peft_model(model, lora_config) + model.print_trainable_parameters() + logger.info(f"Applied LoRA: rank={args.lora_rank}, alpha={args.lora_alpha}") + return model + + +def is_lora_model(module: nn.Module) -> bool: + """Check if a module is a PEFT LoRA model. + + Args: + module: The module to check. + + Returns: + True if the module has PEFT LoRA applied. + """ + unwrapped = getattr(module, "_fsdp_wrapped_module", module) + return hasattr(unwrapped, "peft_config") + + +def save_lora_to_disk(module: nn.Module, save_dir: str) -> str: + """Save LoRA adapter to disk in HuggingFace PEFT format. + + Args: + module: The PEFT model to save. + save_dir: Directory to save the adapter to. + + Returns: + The save directory path. + """ + # Gather full state dict (all-gather LoRA weights) + options = StateDictOptions(full_state_dict=True, cpu_offload=True) + full_state_dict = get_model_state_dict(module, options=options) + + lora_state_dict = {name: param for name, param in full_state_dict.items() if "lora_" in name} + + if dist.get_rank() == 0: + save_path = Path(save_dir) + save_path.mkdir(parents=True, exist_ok=True) + + module.save_pretrained(str(save_path), state_dict=lora_state_dict) + + # Sync filesystem + os.sync() + + logger.info(f"Saved LoRA adapter to {save_path}") + return save_dir + + +def delete_lora_from_disk(save_dir: str) -> None: + """Delete LoRA adapter files from disk. + + Args: + save_dir: Directory containing the adapter to delete. + """ + save_path = Path(save_dir) + if save_path.exists(): + shutil.rmtree(save_path) + logger.info(f"Deleted LoRA adapter from {save_path}") diff --git a/miles/models/peft/__init__.py b/miles/models/peft/__init__.py deleted file mode 100644 index 25e0269ea..000000000 --- a/miles/models/peft/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .lora import LoRAConfig, LoRALinear, apply_lora, get_lora_state_dict, load_lora_state_dict -from .arguments import add_lora_arguments - diff --git a/miles/models/peft/arguments.py b/miles/models/peft/arguments.py deleted file mode 100644 index 3d12b6a72..000000000 --- a/miles/models/peft/arguments.py +++ /dev/null @@ -1,38 +0,0 @@ -import argparse - -def add_lora_arguments(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: - """Add LoRA arguments to the parser.""" - group = parser.add_argument_group(title="LoRA") - - group.add_argument( - "--use-lora", - action="store_true", - help="Whether to use LoRA for training.", - ) - group.add_argument( - "--lora-rank", - type=int, - default=8, - help="LoRA rank.", - ) - group.add_argument( - "--lora-alpha", - type=int, - default=16, - help="LoRA alpha.", - ) - group.add_argument( - "--lora-dropout", - type=float, - default=0.0, - help="LoRA dropout.", - ) - group.add_argument( - "--lora-target-modules", - type=str, - nargs="+", - default=["q_proj", "v_proj"], - help="List of module names to apply LoRA to.", - ) - - return parser diff --git a/miles/models/peft/lora.py b/miles/models/peft/lora.py deleted file mode 100644 index 196645a2a..000000000 --- a/miles/models/peft/lora.py +++ /dev/null @@ -1,145 +0,0 @@ -import math -from dataclasses import dataclass, field -from typing import List, Dict, Any - -import torch -import torch.nn as nn -import torch.nn.functional as F - - -@dataclass -class LoRAConfig: - """Configuration for LoRA.""" - lora_rank: int = 8 - lora_alpha: int = 16 - lora_dropout: float = 0.0 - target_modules: List[str] = field(default_factory=lambda: ["q_proj", "v_proj"]) - bias: str = "none" # "none", "all", or "lora_only" - currently only "none" supported for simplicity - - def to_hf_peft_config(self) -> Dict[str, Any]: - """Convert to Hugging Face PEFT config format.""" - return { - "peft_type": "LORA", - "task_type": "CAUSAL_LM", - "inference_mode": False, - "r": self.lora_rank, - "lora_alpha": self.lora_alpha, - "lora_dropout": self.lora_dropout, - "target_modules": self.target_modules, - "bias": self.bias, - } - - -class LoRALinear(nn.Module): - """ - LoRA linear layer that wraps a base linear layer. - - Args: - base_layer: The existing Linear layer to wrap. - rank: LoRA rank (r). - alpha: LoRA alpha (scaling factor). - dropout: Dropout probability for LoRA input. - """ - def __init__( - self, - base_layer: nn.Linear, - rank: int = 8, - alpha: int = 16, - dropout: float = 0.0 - ): - super().__init__() - self.base_layer = base_layer - self.rank = rank - self.alpha = alpha - self.scaling = alpha / rank - - self.lora_A = nn.Parameter(torch.zeros(rank, base_layer.in_features)) - self.lora_B = nn.Parameter(torch.zeros(base_layer.out_features, rank)) - - self.dropout = nn.Dropout(p=dropout) - - self.reset_parameters() - - def reset_parameters(self): - nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) - nn.init.zeros_(self.lora_B) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - result = self.base_layer(x) - - lora_out = self.dropout(x) - lora_out = F.linear(lora_out, self.lora_A) - lora_out = F.linear(lora_out, self.lora_B) - - return result + lora_out * self.scaling - - def __repr__(self): - return ( - f"LoRALinear(in_features={self.base_layer.in_features}, " - f"out_features={self.base_layer.out_features}, " - f"rank={self.rank}, alpha={self.alpha})" - ) - - -def apply_lora(model: nn.Module, config: LoRAConfig) -> nn.Module: - """ - Apply LoRA to the model by replacing target linear layers with LoRALinear. - - Args: - model: The model to modify. - config: LoRA configuration. - - Returns: - The modified model. - """ - assert config.bias == "none", "Only bias='none' is currently supported" - target_modules = set(config.target_modules) - - # We need to collect replacements first to avoid modifying the dict while iterating - modules_to_replace = [] - - for name, module in model.named_modules(): - # Check if this module name ends with any of the target modules - # e.g. "model.layers.0.self_attn.q_proj" ends with "q_proj" - if any(name.endswith(target) for target in target_modules): - if isinstance(module, nn.Linear): - modules_to_replace.append((name, module)) - - if not modules_to_replace: - raise ValueError(f"No modules found matching {target_modules}") - - for name, module in modules_to_replace: - if '.' in name: - parent_name, child_name = name.rsplit('.', 1) - parent = model.get_submodule(parent_name) - else: - parent_name = "" - child_name = name - parent = model - - lora_layer = LoRALinear( - base_layer=module, - rank=config.lora_rank, - alpha=config.lora_alpha, - dropout=config.lora_dropout - ) - - setattr(parent, child_name, lora_layer) - - # Freeze all non-LoRA parameters - for n, p in model.named_parameters(): - if "lora_" not in n: - p.requires_grad = False - - return model - - -def get_lora_state_dict(model: nn.Module) -> Dict[str, torch.Tensor]: - """Return state dict with only LoRA parameters.""" - return {k: v for k, v in model.state_dict().items() if "lora_" in k} - - -def load_lora_state_dict(model: nn.Module, state_dict: Dict[str, torch.Tensor], strict: bool = False): - """Load LoRA parameters into the model.""" - # We only load keys that exist in the state_dict and match LoRA params - model.load_state_dict(state_dict, strict=strict) diff --git a/miles/utils/arguments.py b/miles/utils/arguments.py index 65066e3e6..c65c8ef1f 100644 --- a/miles/utils/arguments.py +++ b/miles/utils/arguments.py @@ -10,7 +10,6 @@ from miles.backends.sglang_utils.arguments import add_sglang_arguments from miles.backends.sglang_utils.arguments import validate_args as sglang_validate_args -from miles.models.peft import add_lora_arguments from miles.utils.eval_config import EvalDatasetConfig, build_eval_dataset_configs, ensure_dataset_list from miles.utils.logging_utils import configure_logger @@ -1256,6 +1255,47 @@ def add_sglang_tp_size(): if add_custom_arguments is not None: parser = add_custom_arguments(parser) + def add_lora_arguments(parser): + """Add LoRA arguments matching PR 326/377 style.""" + group = parser.add_argument_group(title="LoRA") + + group.add_argument( + "--lora-rank", + type=int, + default=0, + help="LoRA rank. Set to 0 to disable LoRA (default: 0).", + ) + group.add_argument( + "--lora-alpha", + type=int, + default=16, + help="LoRA alpha parameter (default: 16).", + ) + group.add_argument( + "--target-modules", + type=str, + default=None, + help=( + "Target modules for LoRA adaptation. " + "Can be 'all-linear', a single module name, or comma-separated module names. " + "Example: 'q_proj,k_proj,v_proj' (default: None)" + ), + ) + group.add_argument( + "--exclude-modules", + type=str, + default=None, + help="Comma-separated list of modules to exclude from LoRA adaptation (default: None).", + ) + group.add_argument( + "--lora-adapter-path", + type=str, + default=None, + help="Path to load pre-trained LoRA adapter weights (default: None).", + ) + + return parser + parser = add_cluster_arguments(parser) parser = add_train_arguments(parser) parser = add_lora_arguments(parser) @@ -1403,6 +1443,34 @@ def _resolve_eval_datasets(args) -> list[EvalDatasetConfig]: def miles_validate_args(args): args.eval_datasets = _resolve_eval_datasets(args) + # LoRA validation + if args.lora_rank > 0: + if args.train_backend == "megatron": + raise NotImplementedError( + "LoRA is not yet implemented for Megatron backend. " + "Please use FSDP backend (--train-backend fsdp) or disable LoRA (--lora-rank 0)." + ) + if args.target_modules is None: + raise ValueError("'--target-modules' is required when LoRA is enabled (--lora-rank > 0).") + + # Process target_modules into a list + if args.target_modules == "all-linear": + modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] + elif "," in args.target_modules: + modules = [m.strip() for m in args.target_modules.split(",")] + else: + modules = [args.target_modules] + + if args.exclude_modules: + exclude_set = ( + set(m.strip() for m in args.exclude_modules.split(",")) + if "," in args.exclude_modules + else {args.exclude_modules} + ) + modules = [m for m in modules if m not in exclude_set] + + args.target_modules = modules + if args.kl_coef != 0 or args.use_kl_loss: if not os.path.exists(args.ref_load): raise FileNotFoundError(f"ref_load {args.ref_load} does not exist, please check the path.") diff --git a/requirements.txt b/requirements.txt index 2c20195fc..3840f294d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,6 +5,7 @@ httpx[http2] mcp[cli] memray # needed for debugging (but is lightweight), we can put it to dev mode when using pyproject.toml omegaconf +peft pillow pylatexenc pyyaml diff --git a/scripts/run-sft-torchrun.sh b/scripts/run-sft-torchrun.sh index 33d99b3f9..70aa7888d 100644 --- a/scripts/run-sft-torchrun.sh +++ b/scripts/run-sft-torchrun.sh @@ -54,11 +54,9 @@ SFT_ARGS=( ) LORA_ARGS=( - --use-lora --lora-rank 8 --lora-alpha 16 - --lora-dropout 0.0 - --lora-target-modules q_proj v_proj + --target-modules q_proj,v_proj ) OPTIMIZER_ARGS=( diff --git a/tests/models/peft/test_lora.py b/tests/models/peft/test_lora.py deleted file mode 100644 index 81ccc2e26..000000000 --- a/tests/models/peft/test_lora.py +++ /dev/null @@ -1,64 +0,0 @@ -import unittest -import torch -import torch.nn as nn -from miles.models.peft import LoRAConfig, LoRALinear, apply_lora, get_lora_state_dict, load_lora_state_dict - -class TestLoRA(unittest.TestCase): - def setUp(self): - self.input_dim = 10 - self.output_dim = 20 - self.base_layer = nn.Linear(self.input_dim, self.output_dim) - self.config = LoRAConfig(lora_rank=4, lora_alpha=8, lora_dropout=0.0) - - def test_lora_linear_forward(self): - lora_layer = LoRALinear(self.base_layer, rank=4, alpha=8, dropout=0.0) - x = torch.randn(5, self.input_dim) - - # Initial forward should match base layer (since B is zero) - out_lora = lora_layer(x) - out_base = self.base_layer(x) - torch.testing.assert_close(out_lora, out_base) - - # Modify LoRA weights - lora_layer.lora_B.data.fill_(1.0) - out_lora_mod = lora_layer(x) - self.assertFalse(torch.allclose(out_lora_mod, out_base)) - - def test_apply_lora(self): - model = nn.Sequential( - nn.Linear(10, 10), - nn.Linear(10, 10) - ) - # Name modules to match default target "q_proj", "v_proj" won't work here. - # Let's use custom config - config = LoRAConfig(target_modules=["0"], lora_rank=4) - - model = apply_lora(model, config) - - self.assertIsInstance(model[0], LoRALinear) - self.assertIsInstance(model[1], nn.Linear) - - # Check gradients - self.assertTrue(model[0].lora_A.requires_grad) - self.assertFalse(model[0].base_layer.weight.requires_grad) - self.assertFalse(model[1].weight.requires_grad) # Should be frozen by apply_lora - - def test_state_dict(self): - model = nn.Sequential( - nn.Linear(10, 10) - ) - config = LoRAConfig(target_modules=["0"], lora_rank=4) - model = apply_lora(model, config) - - state_dict = get_lora_state_dict(model) - self.assertEqual(len(state_dict), 2) # A and B - self.assertTrue(all("lora_" in k for k in state_dict.keys())) - - # Test loading - new_state = {k: torch.ones_like(v) for k, v in state_dict.items()} - load_lora_state_dict(model, new_state) - - self.assertTrue(torch.allclose(model[0].lora_A, torch.ones_like(model[0].lora_A))) - -if __name__ == "__main__": - unittest.main() diff --git a/train_sft.py b/train_sft.py index 5d18a607f..d23f21830 100644 --- a/train_sft.py +++ b/train_sft.py @@ -25,24 +25,15 @@ import torch import torch.distributed as dist +from ring_flash_attn import substitute_hf_flash_attn, update_ring_flash_attn_params from torch.distributed.device_mesh import init_device_mesh from tqdm import tqdm from transformers import AutoConfig -from ring_flash_attn import substitute_hf_flash_attn, update_ring_flash_attn_params - -from miles.models.peft import LoRAConfig, apply_lora from miles.backends.fsdp_utils import checkpoint -from miles.backends.fsdp_utils.actor import ( - apply_fsdp2, - get_logprob_and_entropy_with_cp, - sum_of_sample_mean, -) -from miles.backends.fsdp_utils.data_packing import ( - pack_sequences, - pad_packed_sequence_with_cp, - unpack_sequences, -) +from miles.backends.fsdp_utils.actor import apply_fsdp2, get_logprob_and_entropy_with_cp, sum_of_sample_mean +from miles.backends.fsdp_utils.data_packing import pack_sequences, pad_packed_sequence_with_cp, unpack_sequences +from miles.backends.fsdp_utils.lora_utils import apply_lora_to_model, is_lora_model from miles.backends.fsdp_utils.lr_scheduler import get_lr_scheduler from miles.rollout.data_source import RolloutDataSource from miles.utils import tracking_utils @@ -158,9 +149,7 @@ def _enable_true_on_policy_optimizations(self): if self.args.true_on_policy_mode: from sglang.srt.batch_invariant_ops import enable_batch_invariant_mode - from miles.backends.fsdp_utils.models.qwen3_moe import ( - apply_true_on_policy_patch_for_qwen3_moe, - ) + from miles.backends.fsdp_utils.models.qwen3_moe import apply_true_on_policy_patch_for_qwen3_moe logger.info("SFTTrainer: enabling batch_invariant_mode for true-on-policy") enable_batch_invariant_mode( @@ -178,17 +167,11 @@ def _load_tokenizer_and_config(self): """Load tokenizer and model config sequentially to avoid race conditions.""" for i in range(dist.get_world_size()): if i == dist.get_rank(): - self.hf_config = AutoConfig.from_pretrained( - self.args.hf_checkpoint, trust_remote_code=True - ) - self.tokenizer = load_tokenizer( - self.args.hf_checkpoint, trust_remote_code=True - ) + self.hf_config = AutoConfig.from_pretrained(self.args.hf_checkpoint, trust_remote_code=True) + self.tokenizer = load_tokenizer(self.args.hf_checkpoint, trust_remote_code=True) self.processor = None if self.args.multimodal_keys: - self.processor = load_processor( - self.args.hf_checkpoint, trust_remote_code=True - ) + self.processor = load_processor(self.args.hf_checkpoint, trust_remote_code=True) dist.barrier(group=get_gloo_group()) # Initialize loss mask generator for SFT @@ -237,10 +220,7 @@ def cpu_init_weights(): def _fsdp2_load_full_state_dict(self, model, full_state, device_mesh, cpu_offload): """Load full state dict into FSDP2 model with broadcast from rank 0.""" - from torch.distributed.checkpoint.state_dict import ( - StateDictOptions, - set_model_state_dict, - ) + from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict if dist.get_rank() == 0: model = model.to(device=torch.cuda.current_device(), non_blocking=True) @@ -248,9 +228,7 @@ def _fsdp2_load_full_state_dict(self, model, full_state, device_mesh, cpu_offloa model = model.to_empty(device=torch.cuda.current_device()) is_cpu_offload = cpu_offload is not None - options = StateDictOptions( - full_state_dict=True, cpu_offload=is_cpu_offload, broadcast_from_rank0=True - ) + options = StateDictOptions(full_state_dict=True, cpu_offload=is_cpu_offload, broadcast_from_rank0=True) set_model_state_dict(model, full_state, options=options) @@ -288,22 +266,13 @@ def _init_model(self): attn_implementation=self.args.attn_implementation, ) - if self.args.use_lora: - lora_config = LoRAConfig( - lora_rank=self.args.lora_rank, - lora_alpha=self.args.lora_alpha, - lora_dropout=self.args.lora_dropout, - target_modules=self.args.lora_target_modules, - ) - model = apply_lora(model, lora_config) - logger.info(f"[Rank {dist.get_rank()}] Applied LoRA: {lora_config}") + if self.args.lora_rank > 0 or self.args.lora_adapter_path: + model = apply_lora_to_model(model, self.args) model.train() full_state = model.state_dict() - model = apply_fsdp2( - model, mesh=self.dp_mesh, cpu_offload=self.fsdp_cpu_offload, args=self.args - ) + model = apply_fsdp2(model, mesh=self.dp_mesh, cpu_offload=self.fsdp_cpu_offload, args=self.args) model = self._fsdp2_load_full_state_dict( model, @@ -315,23 +284,24 @@ def _init_model(self): self.model = model if self.args.gradient_checkpointing: - # Use non-reentrant mode for gradient checkpointing - self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) + # Use non-reentrant mode for gradient checkpointing (required for PEFT/LoRA) + gc_kwargs = {"use_reentrant": False} if is_lora_model(self.model) else {} + self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gc_kwargs) logger.info(f"[Rank {dist.get_rank()}] Model initialized with FSDP") def _init_optimizer(self): """Initialize optimizer and learning rate scheduler.""" trainable_params = [p for p in self.model.parameters() if p.requires_grad] - - if self.args.use_lora: + + if is_lora_model(self.model): total_params = sum(p.numel() for p in self.model.parameters()) trainable_count = sum(p.numel() for p in trainable_params) logger.info( f"[Rank {dist.get_rank()}] LoRA: {trainable_count:,} trainable params " f"out of {total_params:,} total ({100 * trainable_count / total_params:.2f}%)" ) - + if self.args.optimizer == "adam": self.optimizer = torch.optim.AdamW( trainable_params, @@ -351,7 +321,7 @@ def _load_checkpoint(self): """Load checkpoint if available.""" checkpoint_payload = checkpoint.load(self) checkpoint.finalize_load(self, checkpoint_payload) - + if self.args.rollout_global_dataset and self.args.start_rollout_id > 0: self.data_source.load(self.args.start_rollout_id - 1) @@ -372,10 +342,7 @@ def generate_sft_rollout(self, rollout_id: int, data_source: RolloutDataSource) result.append(sample) if i == 0 and rollout_id == 0 and dist.get_rank() == 0: - logger.info( - f"SFT rollout sample: tokens_len={len(token_ids)}, " - f"response_length={response_length}" - ) + logger.info(f"SFT rollout sample: tokens_len={len(token_ids)}, " f"response_length={response_length}") return result @@ -443,24 +410,19 @@ def _packed_data(self, rollout_data: dict) -> tuple[list[dict], list[int]]: max_tokens, ) ) - num_microbatches = torch.tensor( - mbs_size_list, dtype=torch.int, device=torch.cuda.current_device() - ) + num_microbatches = torch.tensor(mbs_size_list, dtype=torch.int, device=torch.cuda.current_device()) dist.all_reduce(num_microbatches, op=dist.ReduceOp.MAX, group=self.dp_group) num_microbatches = num_microbatches.tolist() else: - num_microbatches = [ - self.args.global_batch_size // (self.args.micro_batch_size * self.dp_size) - ] * (len(tokens) // local_batch_size) + num_microbatches = [self.args.global_batch_size // (self.args.micro_batch_size * self.dp_size)] * ( + len(tokens) // local_batch_size + ) start = 0 for mbs_size in num_microbatches: end = start + local_batch_size # Create dummy advantages/returns for SFT (not used but required by pack_sequences) - dummy_advantages = [ - torch.zeros(rollout_data["response_lengths"][i]) - for i in range(start, end) - ] + dummy_advantages = [torch.zeros(rollout_data["response_lengths"][i]) for i in range(start, end)] packed_batches.extend( pack_sequences( rollout_data["tokens"][start:end], @@ -491,12 +453,8 @@ def _get_model_inputs_args(self, packed_sequence: dict) -> dict: cu_seqlens = packed_sequence["cu_seqlens"] update_ring_flash_attn_params(cu_seqlens, self.cp_group) - input_ids = torch.chunk( - packed_sequence["tokens"].unsqueeze(0), self.cp_size, dim=1 - )[self.cp_rank] - position_ids = torch.chunk( - packed_sequence["position_ids"].unsqueeze(0), self.cp_size, dim=1 - )[self.cp_rank] + input_ids = torch.chunk(packed_sequence["tokens"].unsqueeze(0), self.cp_size, dim=1)[self.cp_rank] + position_ids = torch.chunk(packed_sequence["position_ids"].unsqueeze(0), self.cp_size, dim=1)[self.cp_rank] model_args = { "input_ids": input_ids, @@ -511,13 +469,9 @@ def _get_model_inputs_args(self, packed_sequence: dict) -> dict: def _compute_sft_loss(self, unpacked_batches: list[dict], logits: torch.Tensor): """Compute SFT loss (negative log likelihood).""" - loss_masks = [ - batch["loss_masks"].to(device=logits.device) for batch in unpacked_batches - ] + loss_masks = [batch["loss_masks"].to(device=logits.device) for batch in unpacked_batches] response_lengths = [batch["response_lengths"] for batch in unpacked_batches] - log_probs = torch.cat( - [batch["cur_log_probs"] for batch in unpacked_batches], dim=0 - ) + log_probs = torch.cat([batch["cur_log_probs"] for batch in unpacked_batches], dim=0) loss = -sum_of_sample_mean(log_probs, response_lengths, loss_masks) if log_probs.numel() == 0: @@ -634,10 +588,12 @@ def calculate_val_loss(self, rollout_id: int): packed_batches, accum = self._packed_data(rollout_data) if len(accum) == 0: - logger.warning(f"[Rank {dist.get_rank()}] No batches to validate on rollout {rollout_id}, validation step {v_step}") + logger.warning( + f"[Rank {dist.get_rank()}] No batches to validate on rollout {rollout_id}, validation step {v_step}" + ) return - for mbs_id, packed_batch in enumerate(packed_batches): + for _mbs_id, packed_batch in enumerate(packed_batches): reported = self._val_step(packed_batch) for k, v in reported.items(): reported_accum.setdefault(k, []).append(v) @@ -648,12 +604,12 @@ def calculate_val_loss(self, rollout_id: int): dist.all_gather_object(reduced_aggregated, aggregated, group=self.dp_group) aggregated = {} for k in reported_accum.keys(): - aggregated[k] = sum([r[k] for r in reduced_aggregated]) / (self.args.global_batch_size * self.args.val_steps) + aggregated[k] = sum([r[k] for r in reduced_aggregated]) / ( + self.args.global_batch_size * self.args.val_steps + ) reported_accum.clear() if dist.get_rank() == 0: - log_dict = { - f"val/{k}": (val.item() if torch.is_tensor(val) else val) for k, val in aggregated.items() - } + log_dict = {f"val/{k}": (val.item() if torch.is_tensor(val) else val) for k, val in aggregated.items()} logger.info(f"step {self.global_step}: {log_dict}") log_dict["val/step"] = self.global_step tracking_utils.log(self.args, log_dict, step_key="val/step") @@ -681,18 +637,13 @@ def _val_step(self, packed_batch): _, reported = self._compute_sft_loss(unpacked_batches, logits) return reported - def save_model(self, iteration: int): """Save model checkpoint.""" if self.args.save is None: return - - keys_filter = None - if self.args.use_lora: - keys_filter = lambda k: "lora_" in k - - checkpoint.save(self, iteration, keys_filter=keys_filter) - + + checkpoint.save(self, iteration) + if self.args.rollout_global_dataset: self.data_source.save(iteration) @@ -703,8 +654,12 @@ def train(self): f"rollout_id {self.args.start_rollout_id} -> {self.args.num_rollout}" ) if self.args.val_prompt_data: - assert self.args.val_interval > 0, f"val_interval must be greater than 0 when val_prompt_data is provided, got {self.args.val_interval}" - assert self.args.val_steps > 0, f"val_steps must be greater than 0 when val_prompt_data is provided, got {self.args.val_steps}" + assert ( + self.args.val_interval > 0 + ), f"val_interval must be greater than 0 when val_prompt_data is provided, got {self.args.val_interval}" + assert ( + self.args.val_steps > 0 + ), f"val_steps must be greater than 0 when val_prompt_data is provided, got {self.args.val_steps}" # calculate val loss at the beginning of training if self.args.val_prompt_data and self.args.start_rollout_id == 0: @@ -714,9 +669,7 @@ def train(self): self.train_one_rollout(rollout_id) # Save checkpoint periodically - if should_run_periodic_action( - rollout_id, self.args.save_interval, self.num_rollout_per_epoch - ): + if should_run_periodic_action(rollout_id, self.args.save_interval, self.num_rollout_per_epoch): self.save_model(rollout_id) # Calculate val loss periodically @@ -754,5 +707,3 @@ def main(): if __name__ == "__main__": main() - -