Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion .github/workflows/pr-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ jobs:

- name: Execute
shell: bash
env:
ENABLE_LORA: ${{ matrix.info.enable_lora || '0' }}
run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }}

- name: Post-test cleanup
Expand Down Expand Up @@ -148,6 +150,8 @@ jobs:

- name: Execute
shell: bash
env:
ENABLE_LORA: ${{ matrix.info.enable_lora || '0' }}
run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }}

- name: Post-test cleanup
Expand Down Expand Up @@ -215,6 +219,8 @@ jobs:

- name: Execute
shell: bash
env:
ENABLE_LORA: ${{ matrix.info.enable_lora || '0' }}
run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }}

- name: Post-test cleanup
Expand Down Expand Up @@ -282,6 +288,8 @@ jobs:

- name: Execute
shell: bash
env:
ENABLE_LORA: ${{ matrix.info.enable_lora || '0' }}
run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }}

- name: Post-test cleanup
Expand Down Expand Up @@ -349,6 +357,8 @@ jobs:

- name: Execute
shell: bash
env:
ENABLE_LORA: ${{ matrix.info.enable_lora || '0' }}
run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }}

- name: Post-test cleanup
Expand Down Expand Up @@ -416,6 +426,8 @@ jobs:

- name: Execute
shell: bash
env:
ENABLE_LORA: ${{ matrix.info.enable_lora || '0' }}
run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }}

- name: Post-test cleanup
Expand Down Expand Up @@ -449,7 +461,7 @@ jobs:
strategy:
fail-fast: false
matrix:
info: [{"num_gpus": 4, "test_file": "test_qwen2.5_0.5B_gsm8k_async_short.py"}, {"num_gpus": 4, "test_file": "test_qwen2.5_0.5B_gsm8k_short.py"}, {"num_gpus": 2, "test_file": "test_qwen3_0.6B_fsdp_colocated_2xGPU.py"}, {"num_gpus": 2, "test_file": "test_qwen3_4B_fsdp_true_on_policy.py"}, {"num_gpus": 8, "test_file": "test_qwen3_vl_4B_fsdp.py"}, {"num_gpus": 2, "test_file": "test_qwen3_0.6B_fsdp_distributed.py"}, {"num_gpus": 8, "test_file": "test_quick_start_glm4_9B.py"}, {"num_gpus": 8, "test_file": "test_qwen3_30B_A3B.py"}, {"num_gpus": 8, "test_file": "test_qwen3_4B_ppo.py"}, {"num_gpus": 8, "test_file": "test_moonlight_16B_A3B.py"}, {"num_gpus": 8, "test_file": "test_mimo_7B_mtp_only_grad.py"}, {"num_gpus": 8, "test_file": "test_qwen3_0.6B_parallel_check.py"}, {"num_gpus": 4, "test_file": "test_qwen3_0.6B_megatron_fsdp_align.py"}, {"num_gpus": 8, "test_file": "test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_file": "test_qwen3_4B_ckpt.py --async-save"}, {"num_gpus": 2, "test_file": "test_qwen2.5_0.5B_gsm8k.py"}, {"num_gpus": 2, "test_file": "test_qwen2.5_0.5B_gsm8k_async.py"}]
info: [{"num_gpus": 4, "test_file": "test_qwen2.5_0.5B_gsm8k_async_short.py"}, {"num_gpus": 4, "test_file": "test_qwen2.5_0.5B_gsm8k_short.py"}, {"num_gpus": 2, "test_file": "test_qwen3_0.6B_fsdp_colocated_2xGPU.py"}, {"num_gpus": 2, "test_file": "test_qwen3_4B_fsdp_true_on_policy.py"}, {"num_gpus": 8, "test_file": "test_qwen3_vl_4B_fsdp.py"}, {"num_gpus": 2, "test_file": "test_qwen3_0.6B_fsdp_distributed.py"}, {"num_gpus": 8, "test_file": "test_quick_start_glm4_9B.py"}, {"num_gpus": 8, "test_file": "test_qwen3_30B_A3B.py"}, {"num_gpus": 8, "test_file": "test_qwen3_4B_ppo.py"}, {"num_gpus": 8, "test_file": "test_moonlight_16B_A3B.py"}, {"num_gpus": 8, "test_file": "test_mimo_7B_mtp_only_grad.py"}, {"num_gpus": 8, "test_file": "test_qwen3_0.6B_parallel_check.py"}, {"num_gpus": 4, "test_file": "test_qwen3_0.6B_megatron_fsdp_align.py"}, {"num_gpus": 8, "test_file": "test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_file": "test_qwen3_4B_ckpt.py --async-save"}, {"num_gpus": 2, "test_file": "test_qwen2.5_0.5B_gsm8k.py"}, {"num_gpus": 2, "test_file": "test_qwen2.5_0.5B_gsm8k_async.py"}, {"enable_lora": "1", "num_gpus": 2, "test_file": "test_qwen3_0.6B_fsdp_colocated_2xGPU.py"}, {"enable_lora": "1", "num_gpus": 2, "test_file": "test_qwen3_0.6B_fsdp_distributed.py"}]
defaults:
run:
working-directory: ${{ github.workspace }}
Expand Down Expand Up @@ -483,6 +495,8 @@ jobs:

- name: Execute
shell: bash
env:
ENABLE_LORA: ${{ matrix.info.enable_lora || '0' }}
run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }}

- name: Post-test cleanup
Expand Down
4 changes: 4 additions & 0 deletions .github/workflows/pr-test.yml.j2
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@
{'test_file': 'test_qwen3_4B_ckpt.py --async-save', 'num_gpus': 8},
{'test_file': 'test_qwen2.5_0.5B_gsm8k.py', 'num_gpus': 2},
{'test_file': 'test_qwen2.5_0.5B_gsm8k_async.py', 'num_gpus': 2},
{'test_file': 'test_qwen3_0.6B_fsdp_colocated_2xGPU.py', 'num_gpus': 2, 'enable_lora': '1'},
{'test_file': 'test_qwen3_0.6B_fsdp_distributed.py', 'num_gpus': 2, 'enable_lora': '1'},
],
},
} %>
Expand Down Expand Up @@ -153,6 +155,8 @@ jobs:

- name: Execute
shell: bash
env:
ENABLE_LORA: ${{ matrix.info.enable_lora || '0' }}
run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }}

- name: Post-test cleanup
Expand Down
11 changes: 9 additions & 2 deletions miles/backends/fsdp_utils/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
)
from ..training_utils.loss import compute_advantages_and_returns, get_log_probs_and_entropy, loss_function
from . import checkpoint
from .lora_utils import apply_lora_to_model, is_lora_model
from .lr_scheduler import get_lr_scheduler
from .parallel import create_fsdp_parallel_state
from .update_weight_utils import UpdateWeightFromDistributed, UpdateWeightFromTensor
Expand Down Expand Up @@ -99,6 +100,9 @@ def init(self, args: Namespace, role: str, with_ref: bool = False) -> int: # ty
attn_implementation=self.args.attn_implementation,
)

if self.args.lora_rank > 0 or self.args.lora_adapter_path:
model = apply_lora_to_model(model, self.args)

model.train()

full_state = model.state_dict()
Expand All @@ -112,11 +116,14 @@ def init(self, args: Namespace, role: str, with_ref: bool = False) -> int: # ty
self.model = model

if args.gradient_checkpointing:
self.model.gradient_checkpointing_enable()
# Avoid "does not require grad" error
gc_kwargs = {"use_reentrant": False} if is_lora_model(self.model) else {}
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gc_kwargs)

if args.optimizer == "adam":
trainable_params = [p for p in self.model.parameters() if p.requires_grad]
self.optimizer = torch.optim.AdamW(
self.model.parameters(),
trainable_params,
lr=args.lr,
betas=(args.adam_beta1, args.adam_beta2),
eps=args.adam_eps,
Expand Down
47 changes: 36 additions & 11 deletions miles/backends/fsdp_utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,34 @@
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from torch.distributed.checkpoint.stateful import Stateful

from miles.backends.fsdp_utils.lora_utils import is_lora_model

logger = logging.getLogger(__name__)


class ModelState(Stateful):
"""Wrapper for model state only."""

def __init__(self, model):
def __init__(self, model, lora_only: bool = False):
self.model = model
self.lora_only = lora_only
self._key = "adapter" if lora_only else "model"

def state_dict(self):
model_state_dict, _ = get_state_dict(self.model, optimizers=[])
return {"model": model_state_dict}
if self.lora_only:
model_state_dict = {k: v for k, v in model_state_dict.items() if "lora_" in k}
return {self._key: model_state_dict}

def load_state_dict(self, state_dict):
set_state_dict(self.model, optimizers=[], model_state_dict=state_dict["model"], optim_state_dict=None)
data = state_dict[self._key]

if self.lora_only:
full_state_dict, _ = get_state_dict(self.model, optimizers=[])
full_state_dict.update(data)
set_state_dict(self.model, optimizers=[], model_state_dict=full_state_dict, optim_state_dict=None)
else:
set_state_dict(self.model, optimizers=[], model_state_dict=data, optim_state_dict=None)


class OptimizerState(Stateful):
Expand Down Expand Up @@ -103,20 +116,22 @@ def load(actor: Any) -> dict[str, Any] | None:
model_dir = checkpoint_dir / "model"
optimizer_dir = checkpoint_dir / "optimizer"
lr_scheduler_dir = checkpoint_dir / "lr_scheduler"
lora_dir = checkpoint_dir / "adapter"

lora_only = lora_dir.exists() and is_lora_model(actor.model)
model_dir = lora_dir if lora_only else model_dir

if not model_dir.exists():
logger.info(f"[FSDP] Model checkpoint {model_dir} not found; skipping load.")
logger.info(f"[FSDP] No model checkpoint found at {model_dir} or {lora_dir}; skipping load.")
return None

# Load model weights (always)
model_state = ModelState(actor.model)
model_state = ModelState(actor.model, lora_only=lora_only)
state_dict = {"model_state": model_state}

try:
dcp.load(state_dict=state_dict, checkpoint_id=str(model_dir))
logger.info(f"[FSDP] Loaded model from {model_dir}")
logger.info(f"[FSDP] Loaded {'LoRA adapter' if lora_only else 'model'} from {model_dir}")
except Exception as e:
logger.error(f"[FSDP] Failed to load model from {model_dir}: {e}")
logger.error(f"[FSDP] Failed to load {'LoRA adapter' if lora_only else 'model'} from {model_dir}: {e}")
return None

# Load optimizer state (optional)
Expand Down Expand Up @@ -210,9 +225,19 @@ def save(actor: Any, iteration: int) -> None:
dist.barrier()

# Save model weights
model_state = ModelState(actor.model)
lora_only = is_lora_model(actor.model)
if lora_only:
save_dir = checkpoint_dir / "adapter"
if dist.get_rank() == 0:
save_dir.mkdir(parents=True, exist_ok=True)
dist.barrier()
else:
save_dir = model_dir

model_state = ModelState(actor.model, lora_only=lora_only)
state_dict = {"model_state": model_state}
dcp.save(state_dict, checkpoint_id=str(model_dir))
dcp.save(state_dict, checkpoint_id=str(save_dir))
logger.info(f"[FSDP] Saved {'LoRA adapter' if lora_only else 'model'} to {save_dir}")

# Save optimizer state (skip if --no-save-optim is set)
save_optimizer_state = not getattr(actor.args, "no_save_optim", False)
Expand Down
55 changes: 55 additions & 0 deletions miles/backends/fsdp_utils/lora_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import logging

import torch.nn as nn

logger = logging.getLogger(__name__)

LORA_ADAPTER_NAME = "miles_lora"
LORA_SUBDIR = "tmp_lora"


def apply_lora_to_model(model: nn.Module, args) -> nn.Module:
try:
from peft import LoraConfig, PeftModel, TaskType, get_peft_model
except ImportError as err:
raise ImportError("peft library required for LoRA. Install with: pip install peft") from err

if args.lora_adapter_path:
logger.info(f"Loading LoRA adapter from {args.lora_adapter_path}")
model = PeftModel.from_pretrained(model, args.lora_adapter_path, is_trainable=True)
peft_config = model.peft_config["default"]
if isinstance(peft_config.task_type, str):
peft_config.task_type = TaskType.CAUSAL_LM
model.print_trainable_parameters()
return model

lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=args.lora_rank,
lora_alpha=args.lora_alpha,
target_modules=args.target_modules,
bias="none",
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
logger.info(f"Applied LoRA: rank={args.lora_rank}, alpha={args.lora_alpha}")
return model


def is_lora_model(module: nn.Module) -> bool:
unwrapped = getattr(module, "_fsdp_wrapped_module", module)
return hasattr(unwrapped, "peft_config")


def get_lora_config(module: nn.Module) -> dict[str, any]:
"""Extract LoRA config from PEFT model."""
peft_config = module.peft_config["default"]
config_dict = {
"peft_type": "LORA",
"r": peft_config.r,
"lora_alpha": peft_config.lora_alpha,
"target_modules": list(peft_config.target_modules),
"bias": peft_config.bias,
}
return config_dict
Loading
Loading