Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 29 additions & 3 deletions miles/backends/fsdp_utils/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from ...utils import tracking_utils
from ...utils.profile_utils import TrainProfiler
from ...models.peft import LoRAConfig, apply_lora
from . import checkpoint
from .data_packing import pack_sequences, pad_packed_sequence_with_cp, unpack_sequences
from .lr_scheduler import get_lr_scheduler
Expand Down Expand Up @@ -94,6 +95,16 @@ def init(self, args: Namespace, role: str, with_ref: bool = False) -> int: # ty
attn_implementation=self.args.attn_implementation,
)

if args.use_lora:
lora_config = LoRAConfig(
lora_rank=args.lora_rank,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
target_modules=args.lora_target_modules,
)
model = apply_lora(model, lora_config)
logger.info(f"[Rank {dist.get_rank()}] Applied LoRA: {lora_config}")

model.train()

full_state = model.state_dict()
Expand All @@ -107,11 +118,22 @@ def init(self, args: Namespace, role: str, with_ref: bool = False) -> int: # ty
self.model = model

if args.gradient_checkpointing:
self.model.gradient_checkpointing_enable()
# Use non-reentrant mode for gradient checkpointing
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})

if args.optimizer == "adam":
trainable_params = [p for p in self.model.parameters() if p.requires_grad]

if args.use_lora:
total_params = sum(p.numel() for p in self.model.parameters())
trainable_count = sum(p.numel() for p in trainable_params)
logger.info(
f"[Rank {dist.get_rank()}] LoRA: {trainable_count:,} trainable params "
f"out of {total_params:,} total ({100 * trainable_count / total_params:.2f}%)"
)

self.optimizer = torch.optim.AdamW(
self.model.parameters(),
trainable_params,
lr=args.lr,
betas=(args.adam_beta1, args.adam_beta2),
eps=args.adam_eps,
Expand Down Expand Up @@ -322,7 +344,11 @@ def save_model(self, iteration: int) -> None:
if self.args.debug_rollout_only or self.args.save is None:
return

checkpoint.save(self, iteration)
keys_filter = None
if self.args.use_lora:
keys_filter = lambda k: "lora_" in k

checkpoint.save(self, iteration, keys_filter=keys_filter)

def _compute_log_prob(
self,
Expand Down
74 changes: 64 additions & 10 deletions miles/backends/fsdp_utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,43 +6,67 @@
from pathlib import Path
from typing import Any

import safetensors.torch
import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict, get_model_state_dict, StateDictOptions
from torch.distributed.checkpoint.stateful import Stateful
from miles.models.peft import LoRAConfig

logger = logging.getLogger(__name__)


class ModelState(Stateful):
"""Wrapper for model state only."""

def __init__(self, model):
def __init__(self, model, keys_filter=None):
self.model = model
self.keys_filter = keys_filter

def state_dict(self):
model_state_dict, _ = get_state_dict(self.model, optimizers=[])
if self.keys_filter:
model_state_dict = {k: v for k, v in model_state_dict.items() if self.keys_filter(k)}
return {"model": model_state_dict}

def load_state_dict(self, state_dict):
set_state_dict(self.model, optimizers=[], model_state_dict=state_dict["model"], optim_state_dict=None)
options = None
if self.keys_filter:
# For filtered loading (e.g., LoRA), use strict=False to allow partial loading
options = StateDictOptions(strict=False)
set_state_dict(
self.model, optimizers=[],
model_state_dict=state_dict["model"],
optim_state_dict=None,
options=options
)


class OptimizerState(Stateful):
"""Wrapper for optimizer state only."""

def __init__(self, model, optimizer):
def __init__(self, model, optimizer, keys_filter=None):
self.model = model
self.optimizer = optimizer
self.keys_filter = keys_filter

def state_dict(self):
_, optimizer_state_dict = get_state_dict(self.model, optimizers=self.optimizer)
if self.keys_filter:
optimizer_state_dict = {k: v for k, v in optimizer_state_dict.items() if self.keys_filter(k)}
return {"optim": optimizer_state_dict}

def load_state_dict(self, state_dict):
options = None
if self.keys_filter:
# For filtered loading (e.g., LoRA), use strict=False to allow partial loading
options = StateDictOptions(strict=False)
set_state_dict(
self.model, optimizers=self.optimizer, model_state_dict=None, optim_state_dict=state_dict["optim"]
self.model, optimizers=self.optimizer,
model_state_dict=None,
optim_state_dict=state_dict["optim"],
options=options
)


Expand Down Expand Up @@ -108,8 +132,13 @@ def load(actor: Any) -> dict[str, Any] | None:
logger.info(f"[FSDP] Model checkpoint {model_dir} not found; skipping load.")
return None

keys_filter = None
if actor.args.use_lora:
keys_filter = lambda k: "lora_" in k
logger.info("[FSDP] LoRA mode: loading only LoRA weights from checkpoint")

# Load model weights (always)
model_state = ModelState(actor.model)
model_state = ModelState(actor.model, keys_filter=keys_filter)
state_dict = {"model_state": model_state}

try:
Expand All @@ -122,7 +151,7 @@ def load(actor: Any) -> dict[str, Any] | None:
# Load optimizer state (optional)
load_optimizer = not getattr(actor.args, "no_load_optim", False) and hasattr(actor, "optimizer")
if load_optimizer and optimizer_dir.exists():
optimizer_state = OptimizerState(actor.model, actor.optimizer)
optimizer_state = OptimizerState(actor.model, actor.optimizer, keys_filter=keys_filter)
optim_state_dict = {"optim_state": optimizer_state}
try:
dcp.load(state_dict=optim_state_dict, checkpoint_id=str(optimizer_dir))
Expand Down Expand Up @@ -187,7 +216,7 @@ def finalize_load(actor: Any, checkpoint_payload: dict[str, Any] | None) -> None
dist.barrier()


def save(actor: Any, iteration: int) -> None:
def save(actor: Any, iteration: int, keys_filter=None) -> None:
"""Save checkpoint to disk.

Saves model weights and optimizer state to separate directories.
Expand All @@ -210,13 +239,13 @@ def save(actor: Any, iteration: int) -> None:
dist.barrier()

# Save model weights
model_state = ModelState(actor.model)
model_state = ModelState(actor.model, keys_filter=keys_filter)
state_dict = {"model_state": model_state}
dcp.save(state_dict, checkpoint_id=str(model_dir))

# Save optimizer state
if hasattr(actor, "optimizer") and actor.optimizer is not None:
optimizer_state = OptimizerState(actor.model, actor.optimizer)
optimizer_state = OptimizerState(actor.model, actor.optimizer, keys_filter=keys_filter)
optim_state_dict = {"optim_state": optimizer_state}
dcp.save(optim_state_dict, checkpoint_id=str(optimizer_dir))

Expand Down Expand Up @@ -246,4 +275,29 @@ def save(actor: Any, iteration: int) -> None:
tracker_file.write_text(str(step_id))
logger.info(f"[FSDP] Saved checkpoint to {checkpoint_dir}")

if actor.args.use_lora:
_save_hf_lora(actor, checkpoint_dir)

dist.barrier()


def _save_hf_lora(actor: Any, checkpoint_dir: Path) -> None:
"""Save LoRA adapter in Hugging Face PEFT format."""

options = dcp.state_dict.StateDictOptions(full_state_dict=True, cpu_offload=True)
full_state_dict = get_model_state_dict(actor.model, options=options)

if dist.get_rank() == 0:
lora_config = LoRAConfig(
lora_rank=actor.args.lora_rank,
lora_alpha=actor.args.lora_alpha,
lora_dropout=actor.args.lora_dropout,
target_modules=actor.args.lora_target_modules,
)
peft_config = lora_config.to_hf_peft_config()
with open(checkpoint_dir / "adapter_config.json", "w") as f:
json.dump(peft_config, f, indent=2)

lora_state_dict = {k: v for k, v in full_state_dict.items() if "lora_" in k}
safetensors.torch.save_file(lora_state_dict, checkpoint_dir / "adapter_model.safetensors")
logger.info(f"[FSDP] Saved HF LoRA adapter to {checkpoint_dir}")
3 changes: 3 additions & 0 deletions miles/models/peft/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .lora import LoRAConfig, LoRALinear, apply_lora, get_lora_state_dict, load_lora_state_dict
from .arguments import add_lora_arguments

38 changes: 38 additions & 0 deletions miles/models/peft/arguments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import argparse

def add_lora_arguments(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
"""Add LoRA arguments to the parser."""
group = parser.add_argument_group(title="LoRA")

group.add_argument(
"--use-lora",
action="store_true",
help="Whether to use LoRA for training.",
)
group.add_argument(
"--lora-rank",
type=int,
default=8,
help="LoRA rank.",
)
group.add_argument(
"--lora-alpha",
type=int,
default=16,
help="LoRA alpha.",
)
group.add_argument(
"--lora-dropout",
type=float,
default=0.0,
help="LoRA dropout.",
)
group.add_argument(
"--lora-target-modules",
type=str,
nargs="+",
default=["q_proj", "v_proj"],
help="List of module names to apply LoRA to.",
)

return parser
Loading
Loading