Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 10 additions & 20 deletions miles/backends/fsdp_utils/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@

from ...utils import tracking_utils
from ...utils.profile_utils import TrainProfiler
from ...models.peft import LoRAConfig, apply_lora
from . import checkpoint
from .data_packing import pack_sequences, pad_packed_sequence_with_cp, unpack_sequences
from .lora_utils import apply_lora_to_model, is_lora_model
from .lr_scheduler import get_lr_scheduler
from .update_weight_utils import UpdateWeightFromDistributed, UpdateWeightFromTensor

Expand Down Expand Up @@ -95,15 +95,8 @@ def init(self, args: Namespace, role: str, with_ref: bool = False) -> int: # ty
attn_implementation=self.args.attn_implementation,
)

if args.use_lora:
lora_config = LoRAConfig(
lora_rank=args.lora_rank,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
target_modules=args.lora_target_modules,
)
model = apply_lora(model, lora_config)
logger.info(f"[Rank {dist.get_rank()}] Applied LoRA: {lora_config}")
if args.lora_rank > 0 or args.lora_adapter_path:
model = apply_lora_to_model(model, args)

model.train()

Expand All @@ -118,20 +111,21 @@ def init(self, args: Namespace, role: str, with_ref: bool = False) -> int: # ty
self.model = model

if args.gradient_checkpointing:
# Use non-reentrant mode for gradient checkpointing
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
# Use non-reentrant mode for gradient checkpointing (required for PEFT/LoRA)
gc_kwargs = {"use_reentrant": False} if is_lora_model(self.model) else {}
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gc_kwargs)

if args.optimizer == "adam":
trainable_params = [p for p in self.model.parameters() if p.requires_grad]
if args.use_lora:

if is_lora_model(self.model):
total_params = sum(p.numel() for p in self.model.parameters())
trainable_count = sum(p.numel() for p in trainable_params)
logger.info(
f"[Rank {dist.get_rank()}] LoRA: {trainable_count:,} trainable params "
f"out of {total_params:,} total ({100 * trainable_count / total_params:.2f}%)"
)

self.optimizer = torch.optim.AdamW(
trainable_params,
lr=args.lr,
Expand Down Expand Up @@ -344,11 +338,7 @@ def save_model(self, iteration: int) -> None:
if self.args.debug_rollout_only or self.args.save is None:
return

keys_filter = None
if self.args.use_lora:
keys_filter = lambda k: "lora_" in k

checkpoint.save(self, iteration, keys_filter=keys_filter)
checkpoint.save(self, iteration)

def _compute_log_prob(
self,
Expand Down
118 changes: 46 additions & 72 deletions miles/backends/fsdp_utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,67 +6,59 @@
from pathlib import Path
from typing import Any

import safetensors.torch
import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict, get_model_state_dict, StateDictOptions
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from torch.distributed.checkpoint.stateful import Stateful
from miles.models.peft import LoRAConfig

from .lora_utils import is_lora_model

logger = logging.getLogger(__name__)


class ModelState(Stateful):
"""Wrapper for model state only."""

def __init__(self, model, keys_filter=None):
def __init__(self, model, lora_only: bool = False):
self.model = model
self.keys_filter = keys_filter
self.lora_only = lora_only
self._key = "adapter" if lora_only else "model"

def state_dict(self):
model_state_dict, _ = get_state_dict(self.model, optimizers=[])
if self.keys_filter:
model_state_dict = {k: v for k, v in model_state_dict.items() if self.keys_filter(k)}
return {"model": model_state_dict}
if self.lora_only:
model_state_dict = {k: v for k, v in model_state_dict.items() if "lora_" in k}
return {self._key: model_state_dict}

def load_state_dict(self, state_dict):
options = None
if self.keys_filter:
# For filtered loading (e.g., LoRA), use strict=False to allow partial loading
options = StateDictOptions(strict=False)
set_state_dict(
self.model, optimizers=[],
model_state_dict=state_dict["model"],
optim_state_dict=None,
options=options
)
data = state_dict[self._key]

if self.lora_only:
full_state_dict, _ = get_state_dict(self.model, optimizers=[])
full_state_dict.update(data)
set_state_dict(self.model, optimizers=[], model_state_dict=full_state_dict, optim_state_dict=None)
else:
set_state_dict(self.model, optimizers=[], model_state_dict=data, optim_state_dict=None)


class OptimizerState(Stateful):
"""Wrapper for optimizer state only."""

def __init__(self, model, optimizer, keys_filter=None):
def __init__(self, model, optimizer):
self.model = model
self.optimizer = optimizer
self.keys_filter = keys_filter

def state_dict(self):
_, optimizer_state_dict = get_state_dict(self.model, optimizers=self.optimizer)
if self.keys_filter:
optimizer_state_dict = {k: v for k, v in optimizer_state_dict.items() if self.keys_filter(k)}
return {"optim": optimizer_state_dict}

def load_state_dict(self, state_dict):
options = None
if self.keys_filter:
# For filtered loading (e.g., LoRA), use strict=False to allow partial loading
options = StateDictOptions(strict=False)
set_state_dict(
self.model, optimizers=self.optimizer,
model_state_dict=None,
self.model,
optimizers=self.optimizer,
model_state_dict=None,
optim_state_dict=state_dict["optim"],
options=options
)


Expand Down Expand Up @@ -127,31 +119,28 @@ def load(actor: Any) -> dict[str, Any] | None:
model_dir = checkpoint_dir / "model"
optimizer_dir = checkpoint_dir / "optimizer"
lr_scheduler_dir = checkpoint_dir / "lr_scheduler"
lora_dir = checkpoint_dir / "adapter"

if not model_dir.exists():
logger.info(f"[FSDP] Model checkpoint {model_dir} not found; skipping load.")
return None
lora_only = lora_dir.exists() and is_lora_model(actor.model)
load_model_dir = lora_dir if lora_only else model_dir

keys_filter = None
if actor.args.use_lora:
keys_filter = lambda k: "lora_" in k
logger.info("[FSDP] LoRA mode: loading only LoRA weights from checkpoint")
if not load_model_dir.exists():
logger.info(f"[FSDP] No model checkpoint found at {model_dir} or {lora_dir}; skipping load.")
return None

# Load model weights (always)
model_state = ModelState(actor.model, keys_filter=keys_filter)
model_state = ModelState(actor.model, lora_only=lora_only)
state_dict = {"model_state": model_state}

try:
dcp.load(state_dict=state_dict, checkpoint_id=str(model_dir))
logger.info(f"[FSDP] Loaded model from {model_dir}")
dcp.load(state_dict=state_dict, checkpoint_id=str(load_model_dir))
logger.info(f"[FSDP] Loaded {'LoRA adapter' if lora_only else 'model'} from {load_model_dir}")
except Exception as e:
logger.error(f"[FSDP] Failed to load model from {model_dir}: {e}")
logger.error(f"[FSDP] Failed to load {'LoRA adapter' if lora_only else 'model'} from {load_model_dir}: {e}")
return None

# Load optimizer state (optional)
load_optimizer = not getattr(actor.args, "no_load_optim", False) and hasattr(actor, "optimizer")
if load_optimizer and optimizer_dir.exists():
optimizer_state = OptimizerState(actor.model, actor.optimizer, keys_filter=keys_filter)
optimizer_state = OptimizerState(actor.model, actor.optimizer)
optim_state_dict = {"optim_state": optimizer_state}
try:
dcp.load(state_dict=optim_state_dict, checkpoint_id=str(optimizer_dir))
Expand Down Expand Up @@ -216,7 +205,7 @@ def finalize_load(actor: Any, checkpoint_payload: dict[str, Any] | None) -> None
dist.barrier()


def save(actor: Any, iteration: int, keys_filter=None) -> None:
def save(actor: Any, iteration: int) -> None:
"""Save checkpoint to disk.

Saves model weights and optimizer state to separate directories.
Expand All @@ -239,13 +228,23 @@ def save(actor: Any, iteration: int, keys_filter=None) -> None:
dist.barrier()

# Save model weights
model_state = ModelState(actor.model, keys_filter=keys_filter)
lora_only = is_lora_model(actor.model)
if lora_only:
save_dir = checkpoint_dir / "adapter"
if dist.get_rank() == 0:
save_dir.mkdir(parents=True, exist_ok=True)
dist.barrier()
else:
save_dir = model_dir

model_state = ModelState(actor.model, lora_only=lora_only)
state_dict = {"model_state": model_state}
dcp.save(state_dict, checkpoint_id=str(model_dir))
dcp.save(state_dict, checkpoint_id=str(save_dir))
logger.info(f"[FSDP] Saved {'LoRA adapter' if lora_only else 'model'} to {save_dir}")

# Save optimizer state
if hasattr(actor, "optimizer") and actor.optimizer is not None:
optimizer_state = OptimizerState(actor.model, actor.optimizer, keys_filter=keys_filter)
optimizer_state = OptimizerState(actor.model, actor.optimizer)
optim_state_dict = {"optim_state": optimizer_state}
dcp.save(optim_state_dict, checkpoint_id=str(optimizer_dir))

Expand Down Expand Up @@ -275,29 +274,4 @@ def save(actor: Any, iteration: int, keys_filter=None) -> None:
tracker_file.write_text(str(step_id))
logger.info(f"[FSDP] Saved checkpoint to {checkpoint_dir}")

if actor.args.use_lora:
_save_hf_lora(actor, checkpoint_dir)

dist.barrier()


def _save_hf_lora(actor: Any, checkpoint_dir: Path) -> None:
"""Save LoRA adapter in Hugging Face PEFT format."""

options = dcp.state_dict.StateDictOptions(full_state_dict=True, cpu_offload=True)
full_state_dict = get_model_state_dict(actor.model, options=options)

if dist.get_rank() == 0:
lora_config = LoRAConfig(
lora_rank=actor.args.lora_rank,
lora_alpha=actor.args.lora_alpha,
lora_dropout=actor.args.lora_dropout,
target_modules=actor.args.lora_target_modules,
)
peft_config = lora_config.to_hf_peft_config()
with open(checkpoint_dir / "adapter_config.json", "w") as f:
json.dump(peft_config, f, indent=2)

lora_state_dict = {k: v for k, v in full_state_dict.items() if "lora_" in k}
safetensors.torch.save_file(lora_state_dict, checkpoint_dir / "adapter_model.safetensors")
logger.info(f"[FSDP] Saved HF LoRA adapter to {checkpoint_dir}")
113 changes: 113 additions & 0 deletions miles/backends/fsdp_utils/lora_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
"""LoRA utilities for FSDP backend using PEFT library.

This module provides functions for applying, detecting, and saving LoRA adapters
in a way that's compatible with the PEFT library and HuggingFace ecosystem.
"""

import logging
import os
import shutil
from pathlib import Path

import torch.distributed as dist
import torch.nn as nn
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict

try:
from peft import LoraConfig, PeftModel, TaskType, get_peft_model
except ImportError as err:
raise ImportError("peft library required for LoRA. Install with: pip install peft") from err

logger = logging.getLogger(__name__)

LORA_READY_MARKER = ".lora_ready"
LORA_ADAPTER_NAME = "miles_lora"
LORA_SUBDIR = "tmp_lora"


def apply_lora_to_model(model: nn.Module, args) -> nn.Module:
"""Apply LoRA to model using PEFT library.

Args:
model: The base model to apply LoRA to.
args: Arguments containing LoRA configuration (lora_rank, lora_alpha,
target_modules, lora_adapter_path).

Returns:
Model wrapped with LoRA adapters.
"""
if args.lora_adapter_path:
logger.info(f"Loading LoRA adapter from {args.lora_adapter_path}")
model = PeftModel.from_pretrained(model, args.lora_adapter_path, is_trainable=True)
peft_config = model.peft_config["default"]
if isinstance(peft_config.task_type, str):
peft_config.task_type = TaskType.CAUSAL_LM
model.print_trainable_parameters()
return model

lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=args.lora_rank,
lora_alpha=args.lora_alpha,
target_modules=args.target_modules,
bias="none",
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
logger.info(f"Applied LoRA: rank={args.lora_rank}, alpha={args.lora_alpha}")
return model


def is_lora_model(module: nn.Module) -> bool:
"""Check if a module is a PEFT LoRA model.

Args:
module: The module to check.

Returns:
True if the module has PEFT LoRA applied.
"""
unwrapped = getattr(module, "_fsdp_wrapped_module", module)
return hasattr(unwrapped, "peft_config")


def save_lora_to_disk(module: nn.Module, save_dir: str) -> str:
"""Save LoRA adapter to disk in HuggingFace PEFT format.

Args:
module: The PEFT model to save.
save_dir: Directory to save the adapter to.

Returns:
The save directory path.
"""
# Gather full state dict (all-gather LoRA weights)
options = StateDictOptions(full_state_dict=True, cpu_offload=True)
full_state_dict = get_model_state_dict(module, options=options)

lora_state_dict = {name: param for name, param in full_state_dict.items() if "lora_" in name}

if dist.get_rank() == 0:
save_path = Path(save_dir)
save_path.mkdir(parents=True, exist_ok=True)

module.save_pretrained(str(save_path), state_dict=lora_state_dict)

# Sync filesystem
os.sync()

logger.info(f"Saved LoRA adapter to {save_path}")
return save_dir


def delete_lora_from_disk(save_dir: str) -> None:
"""Delete LoRA adapter files from disk.

Args:
save_dir: Directory containing the adapter to delete.
"""
save_path = Path(save_dir)
if save_path.exists():
shutil.rmtree(save_path)
logger.info(f"Deleted LoRA adapter from {save_path}")
3 changes: 0 additions & 3 deletions miles/models/peft/__init__.py

This file was deleted.

Loading