From 0baae25f6267d08270c6a9cc46b220806e549e11 Mon Sep 17 00:00:00 2001 From: Guanxing Lu <747398423@qq.com> Date: Sat, 20 Dec 2025 01:09:12 -0600 Subject: [PATCH 1/8] Add LoRA for FSDP backend. (#307) Co-authored-by: PopSoda2002 --- miles/backends/fsdp_utils/actor.py | 11 +- miles/backends/fsdp_utils/arguments.py | 7 ++ miles/backends/fsdp_utils/checkpoint.py | 47 ++++++-- miles/backends/fsdp_utils/lora_utils.py | 77 ++++++++++++ .../fsdp_utils/update_weight_utils.py | 112 +++++++++++++----- miles/backends/sglang_utils/sglang_engine.py | 26 +++- miles/ray/placement_group.py | 1 + miles/ray/rollout.py | 4 +- miles/rollout/sglang_rollout.py | 5 + miles/utils/arguments.py | 30 +++++ requirements.txt | 1 + train.py | 13 +- 12 files changed, 284 insertions(+), 50 deletions(-) create mode 100644 miles/backends/fsdp_utils/lora_utils.py diff --git a/miles/backends/fsdp_utils/actor.py b/miles/backends/fsdp_utils/actor.py index 1e3e5b3ae..2cfbc4339 100644 --- a/miles/backends/fsdp_utils/actor.py +++ b/miles/backends/fsdp_utils/actor.py @@ -28,6 +28,7 @@ from ...utils.profile_utils import TrainProfiler from . import checkpoint from .data_packing import pack_sequences, pad_packed_sequence_with_cp, unpack_sequences +from .lora_utils import apply_lora_to_model, is_lora_model from .lr_scheduler import get_lr_scheduler from .update_weight_utils import UpdateWeightFromDistributed, UpdateWeightFromTensor @@ -94,6 +95,9 @@ def init(self, args: Namespace, role: str, with_ref: bool = False) -> int: # ty attn_implementation=self.args.attn_implementation, ) + if self.args.lora_rank > 0 or self.args.lora_adapter_path: + model = apply_lora_to_model(model, self.args) + model.train() full_state = model.state_dict() @@ -107,11 +111,14 @@ def init(self, args: Namespace, role: str, with_ref: bool = False) -> int: # ty self.model = model if args.gradient_checkpointing: - self.model.gradient_checkpointing_enable() + # Avoid "does not require grad" error + gc_kwargs = {"use_reentrant": False} if is_lora_model(self.model) else {} + self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gc_kwargs) if args.optimizer == "adam": + trainable_params = [p for p in self.model.parameters() if p.requires_grad] self.optimizer = torch.optim.AdamW( - self.model.parameters(), + trainable_params, lr=args.lr, betas=(args.adam_beta1, args.adam_beta2), eps=args.adam_eps, diff --git a/miles/backends/fsdp_utils/arguments.py b/miles/backends/fsdp_utils/arguments.py index a319fe6e5..7cd10867b 100644 --- a/miles/backends/fsdp_utils/arguments.py +++ b/miles/backends/fsdp_utils/arguments.py @@ -60,6 +60,13 @@ class FSDPArgs: # YAML bookkeeping config: str | None = None + # LoRA configuration + lora_rank: int = 0 + lora_alpha: int = 16 + target_modules: str = "all-linear" + exclude_modules: str | None = None + lora_adapter_path: str | None = None + def parse_fsdp_cli(extra_args_provider=None): parser = argparse.ArgumentParser("FSDP SFT Training (miles)") diff --git a/miles/backends/fsdp_utils/checkpoint.py b/miles/backends/fsdp_utils/checkpoint.py index 3c49a10f8..8508fba2b 100644 --- a/miles/backends/fsdp_utils/checkpoint.py +++ b/miles/backends/fsdp_utils/checkpoint.py @@ -12,21 +12,34 @@ from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict from torch.distributed.checkpoint.stateful import Stateful +from miles.backends.fsdp_utils.lora_utils import is_lora_model + logger = logging.getLogger(__name__) class ModelState(Stateful): """Wrapper for model state only.""" - def __init__(self, model): + def __init__(self, model, lora_only: bool = False): self.model = model + self.lora_only = lora_only + self._key = "adapter" if lora_only else "model" def state_dict(self): model_state_dict, _ = get_state_dict(self.model, optimizers=[]) - return {"model": model_state_dict} + if self.lora_only: + model_state_dict = {k: v for k, v in model_state_dict.items() if "lora_" in k} + return {self._key: model_state_dict} def load_state_dict(self, state_dict): - set_state_dict(self.model, optimizers=[], model_state_dict=state_dict["model"], optim_state_dict=None) + data = state_dict[self._key] + + if self.lora_only: + full_state_dict, _ = get_state_dict(self.model, optimizers=[]) + full_state_dict.update(data) + set_state_dict(self.model, optimizers=[], model_state_dict=full_state_dict, optim_state_dict=None) + else: + set_state_dict(self.model, optimizers=[], model_state_dict=data, optim_state_dict=None) class OptimizerState(Stateful): @@ -103,20 +116,22 @@ def load(actor: Any) -> dict[str, Any] | None: model_dir = checkpoint_dir / "model" optimizer_dir = checkpoint_dir / "optimizer" lr_scheduler_dir = checkpoint_dir / "lr_scheduler" + lora_dir = checkpoint_dir / "adapter" + + lora_only = lora_dir.exists() and is_lora_model(actor.model) + model_dir = lora_dir if lora_only else model_dir if not model_dir.exists(): - logger.info(f"[FSDP] Model checkpoint {model_dir} not found; skipping load.") + logger.info(f"[FSDP] No model checkpoint found at {model_dir} or {lora_dir}; skipping load.") return None - # Load model weights (always) - model_state = ModelState(actor.model) + model_state = ModelState(actor.model, lora_only=lora_only) state_dict = {"model_state": model_state} - try: dcp.load(state_dict=state_dict, checkpoint_id=str(model_dir)) - logger.info(f"[FSDP] Loaded model from {model_dir}") + logger.info(f"[FSDP] Loaded {'LoRA adapter' if lora_only else 'model'} from {model_dir}") except Exception as e: - logger.error(f"[FSDP] Failed to load model from {model_dir}: {e}") + logger.error(f"[FSDP] Failed to load {'LoRA adapter' if lora_only else 'model'} from {model_dir}: {e}") return None # Load optimizer state (optional) @@ -210,9 +225,19 @@ def save(actor: Any, iteration: int) -> None: dist.barrier() # Save model weights - model_state = ModelState(actor.model) + lora_only = is_lora_model(actor.model) + if lora_only: + save_dir = checkpoint_dir / "adapter" + if dist.get_rank() == 0: + save_dir.mkdir(parents=True, exist_ok=True) + dist.barrier() + else: + save_dir = model_dir + + model_state = ModelState(actor.model, lora_only=lora_only) state_dict = {"model_state": model_state} - dcp.save(state_dict, checkpoint_id=str(model_dir)) + dcp.save(state_dict, checkpoint_id=str(save_dir)) + logger.info(f"[FSDP] Saved {'LoRA adapter' if lora_only else 'model'} to {save_dir}") # Save optimizer state if hasattr(actor, "optimizer") and actor.optimizer is not None: diff --git a/miles/backends/fsdp_utils/lora_utils.py b/miles/backends/fsdp_utils/lora_utils.py new file mode 100644 index 000000000..d6483b372 --- /dev/null +++ b/miles/backends/fsdp_utils/lora_utils.py @@ -0,0 +1,77 @@ +import logging +import os +import shutil +from pathlib import Path + +import torch.distributed as dist +import torch.nn as nn +from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict + +try: + from peft import LoraConfig, PeftModel, TaskType, get_peft_model +except ImportError as err: + raise ImportError("peft library required for LoRA. Install with: pip install peft") from err + +logger = logging.getLogger(__name__) + +LORA_READY_MARKER = ".lora_ready" +LORA_ADAPTER_NAME = "miles_lora" +LORA_SUBDIR = "tmp_lora" + + +def apply_lora_to_model(model: nn.Module, args) -> nn.Module: + if args.lora_adapter_path: + logger.info(f"Loading LoRA adapter from {args.lora_adapter_path}") + model = PeftModel.from_pretrained(model, args.lora_adapter_path, is_trainable=True) + peft_config = model.peft_config["default"] + if isinstance(peft_config.task_type, str): + peft_config.task_type = TaskType.CAUSAL_LM + model.print_trainable_parameters() + return model + + lora_config = LoraConfig( + task_type=TaskType.CAUSAL_LM, + r=args.lora_rank, + lora_alpha=args.lora_alpha, + target_modules=args.target_modules, + bias="none", + ) + + model = get_peft_model(model, lora_config) # autocast_adapter_dtype=False) + model.print_trainable_parameters() + logger.info(f"Applied LoRA: rank={args.lora_rank}, alpha={args.lora_alpha}") + return model + + +def is_lora_model(module: nn.Module) -> bool: + unwrapped = getattr(module, "_fsdp_wrapped_module", module) + return hasattr(unwrapped, "peft_config") + + +def save_lora_to_disk(module: nn.Module, save_dir: str) -> str: + """Save LoRA adapter to disk with file lock mechanism.""" + # TODO: All gather lora layers not full layers + options = StateDictOptions(full_state_dict=True, cpu_offload=True) + full_state_dict = get_model_state_dict(module, options=options) + + lora_state_dict = {name: param for name, param in full_state_dict.items() if "lora_" in name} + + if dist.get_rank() == 0: + save_path = Path(save_dir) + save_path.mkdir(parents=True, exist_ok=True) + + module.save_pretrained(str(save_path), state_dict=lora_state_dict) + + # TODO: check if file lock is needed or better way to do it + os.sync() + + logger.info(f"Saved LoRA adapter to {save_path}") + return save_dir + + +def delete_lora_from_disk(save_dir: str) -> None: + """Delete LoRA adapter files from disk.""" + save_path = Path(save_dir) + if save_path.exists(): + shutil.rmtree(save_path) + logger.info(f"Deleted LoRA adapter from {save_path}") diff --git a/miles/backends/fsdp_utils/update_weight_utils.py b/miles/backends/fsdp_utils/update_weight_utils.py index c8dcbd810..6e4ee73a5 100644 --- a/miles/backends/fsdp_utils/update_weight_utils.py +++ b/miles/backends/fsdp_utils/update_weight_utils.py @@ -1,5 +1,6 @@ import abc import logging +import os import socket from argparse import Namespace from collections.abc import Sequence @@ -25,6 +26,7 @@ except ImportError: from sglang.srt.model_executor.model_runner import FlattenedTensorBucket # type: ignore[import] +from .lora_utils import LORA_ADAPTER_NAME, LORA_SUBDIR, delete_lora_from_disk, is_lora_model, save_lora_to_disk logger = logging.getLogger(__name__) @@ -33,6 +35,9 @@ class UpdateWeight(abc.ABC): def __init__(self, args: Namespace, model: torch.nn.Module) -> None: self.args = args self.model = model + self.weight_version = 0 + self._lora_loaded = False + self._base_synced = False @abc.abstractmethod def connect_rollout_engines( @@ -43,38 +48,85 @@ def connect_rollout_engines( pass def update_weights(self) -> None: - bucket = [] - bucket_size = 0 - for name, param in self.model.state_dict().items(): - param_size = param.numel() * param.element_size() - if bucket and bucket_size + param_size >= self.args.update_weight_buffer_size: - self.wait_and_update_bucket_weights(bucket) - del bucket - bucket = [] - bucket_size = 0 - - param = param.cuda() - if isinstance(param, DTensor): - # async version of param.full_tensor - param = param.redistribute( - placements=[Replicate()] * param.device_mesh.ndim, - async_op=True, - ).to_local() - bucket.append((name, param)) - bucket_size += param_size - - if bucket: - self.wait_and_update_bucket_weights(bucket) - del bucket + self.weight_version += 1 + + # Update base model if needed + # Level 1: only sync base once for LoRA models, then just LoRA + # Level 2: always sync base + LoRA + if not (is_lora_model(self.model) and self._base_synced and self.args.offload_rollout_level == 1): bucket = [] bucket_size = 0 + for name, param in self.model.state_dict().items(): + if any(x in name for x in ["_flat_param", "lora_"]): + continue + name = name.replace("base_model.model.", "").replace(".base_layer", "") + param_size = param.numel() * param.element_size() + if bucket and bucket_size + param_size >= self.args.update_weight_buffer_size: + self.wait_and_update_bucket_weights(bucket) + del bucket + bucket = [] + bucket_size = 0 + + param = param.cuda() + if isinstance(param, DTensor): + # async version of param.full_tensor + param = param.redistribute( + placements=[Replicate()] * param.device_mesh.ndim, + async_op=True, + ).to_local() + bucket.append((name, param)) + bucket_size += param_size + + if bucket: + self.wait_and_update_bucket_weights(bucket) + del bucket + + self._base_synced = True + + # Update lora weights if needed + if is_lora_model(self.model): + self._update_lora_via_file() + + def _update_lora_via_file(self) -> None: + """Push LoRA weights to rollout engines using disk files.""" + self._lora_save_dir = os.path.join(self.args.save, LORA_SUBDIR) + if dist.get_rank() == 0: + if os.path.exists(self._lora_save_dir): + delete_lora_from_disk(self._lora_save_dir) + + dist.barrier() + + save_lora_to_disk(self.model, self._lora_save_dir) + + dist.barrier() + + if dist.get_rank() == 0: + if self._lora_loaded: + refs = [engine.unload_lora_adapter.remote(LORA_ADAPTER_NAME) for engine in self.rollout_engines] + ray.get(refs) + + refs = [engine.flush_cache.remote() for engine in self.rollout_engines] + ray.get(refs) + + refs = [ + engine.load_lora_adapter.remote(LORA_ADAPTER_NAME, self._lora_save_dir) + for engine in self.rollout_engines + ] + ray.get(refs) + + refs = [engine.flush_cache.remote() for engine in self.rollout_engines] + ray.get(refs) + + self._lora_loaded = True + + dist.barrier() def wait_and_update_bucket_weights(self, bucket): bucket = [(name, param.wait()) if hasattr(param, "wait") else (name, param) for name, param in bucket] - self.update_bucket_weights(bucket) + self.update_bucket_weights(bucket, weight_version=self.weight_version) @abc.abstractmethod - def update_bucket_weights(self, named_tensors) -> None: + def update_bucket_weights(self, named_tensors, weight_version=None) -> None: pass @@ -114,7 +166,7 @@ def connect_rollout_engines( # Calculate TP rank within this SGLang engine group self.tp_rank = dist.get_rank() - start_rank - def update_bucket_weights(self, named_tensors) -> None: + def update_bucket_weights(self, named_tensors, weight_version=None) -> None: monkey_patch_torch_reductions() # Use flattened bucket approach similar to Megatron logger.info("Using flattened tensor bucket") @@ -162,6 +214,7 @@ def update_bucket_weights(self, named_tensors) -> None: "serialized_named_tensors": [tensors[i] for tensors in gathered_serialized_batches], "load_format": "flattened_bucket", "flush_cache": False, + "weight_version": str(weight_version), } ref = self._ipc_engine.update_weights_from_tensor.remote(**kwargs) ray.get(ref) @@ -174,10 +227,6 @@ def update_bucket_weights(self, named_tensors) -> None: class UpdateWeightFromDistributed(UpdateWeight): """Broadcast weights via a temporary NCCL group to rollout engines.""" - def __init__(self, args: Namespace, model: torch.nn.Module) -> None: - self.args = args - self.model = model - def connect_rollout_engines( self, rollout_engines: Sequence[ActorHandle], @@ -220,7 +269,7 @@ def connect_rollout_engines( ) ray.get(refs) - def update_bucket_weights(self, named_tensors) -> None: + def update_bucket_weights(self, named_tensors, weight_version=None) -> None: """Send names/dtypes/shapes metadata to engines, then broadcast tensors. Ensures tensors are contiguous; when `world_size == 1`, converts DTensors @@ -235,6 +284,7 @@ def update_bucket_weights(self, named_tensors) -> None: dtypes=[param.dtype for _, param in named_tensors], shapes=[param.shape for _, param in named_tensors], group_name=self._group_name, + weight_version=str(weight_version), ) for engine in self.rollout_engines ] diff --git a/miles/backends/sglang_utils/sglang_engine.py b/miles/backends/sglang_utils/sglang_engine.py index 2e1afe625..c9a774b32 100644 --- a/miles/backends/sglang_utils/sglang_engine.py +++ b/miles/backends/sglang_utils/sglang_engine.py @@ -278,9 +278,15 @@ def get_weight_version(self): response.raise_for_status() return response.json()["weight_version"] - def release_memory_occupation(self): + def release_memory_occupation(self, tags: list[str] = None): + """ + Available tags for multi-stage resume: weights, kv_cache + """ self.flush_cache() - return self._make_request("release_memory_occupation") + return self._make_request( + "release_memory_occupation", + {"tags": tags}, + ) def resume_memory_occupation(self, tags: list[str] = None): """ @@ -336,6 +342,18 @@ def update_weights_from_distributed( payload, ) + def load_lora_adapter(self, lora_name: str, lora_path: str): + return self._make_request( + "load_lora_adapter", + {"lora_name": lora_name, "lora_path": lora_path}, + ) + + def unload_lora_adapter(self, lora_name: str): + return self._make_request( + "unload_lora_adapter", + {"lora_name": lora_name}, + ) + def pause_generation(self): response = requests.post(f"http://{self.server_host}:{self.server_port}/pause_generation", json={}) response.raise_for_status() @@ -419,6 +437,10 @@ def _compute_server_args(args, rank, dist_init_addr, nccl_port, host, port, work kwargs["enable_return_routed_experts"] = True if args.fp16: kwargs["dtype"] = "float16" + if args.lora_rank > 0 or args.lora_adapter_path is not None: + kwargs["enable_lora"] = True + kwargs["max_lora_rank"] = args.lora_rank + kwargs["lora_target_modules"] = args.target_modules external_engine_need_check_fields = [k for k in kwargs.keys() if k not in _EXTERNAL_ENGINE_SKIP_CHECK_FIELDS] diff --git a/miles/ray/placement_group.py b/miles/ray/placement_group.py index b6fb7a20b..7bd842960 100644 --- a/miles/ray/placement_group.py +++ b/miles/ray/placement_group.py @@ -177,6 +177,7 @@ def create_rollout_manager(args, pg): ray.get(rollout_manager.check_weights.remote(action="reset_tensors")) if args.offload_rollout: + # TODO: Optimization in the future: offload model weights to cpu to make more space for training? ray.get(rollout_manager.offload.remote()) return rollout_manager, num_rollout_per_epoch diff --git a/miles/ray/rollout.py b/miles/ray/rollout.py index 9ee0fbb8a..83c3a5519 100644 --- a/miles/ray/rollout.py +++ b/miles/ray/rollout.py @@ -127,8 +127,8 @@ def save(self, rollout_id): def load(self, rollout_id=None): self.data_source.load(rollout_id) - def offload(self): - return ray.get([engine.release_memory_occupation.remote() for engine in self.rollout_engines]) + def offload(self, tags: list[str] = None): + return ray.get([engine.release_memory_occupation.remote(tags=tags) for engine in self.rollout_engines]) def onload(self, tags: list[str] = None): return ray.get([engine.resume_memory_occupation.remote(tags=tags) for engine in self.rollout_engines]) diff --git a/miles/rollout/sglang_rollout.py b/miles/rollout/sglang_rollout.py index 2e33542a5..36f6e7ce0 100644 --- a/miles/rollout/sglang_rollout.py +++ b/miles/rollout/sglang_rollout.py @@ -11,6 +11,7 @@ from packaging.version import parse from tqdm import tqdm +from miles.backends.fsdp_utils.lora_utils import LORA_ADAPTER_NAME from miles.rollout.base_types import RolloutFnEvalOutput, RolloutFnTrainOutput from miles.rollout.filter_hub.base_types import DynamicFilterOutput from miles.utils.async_utils import run @@ -124,6 +125,10 @@ async def generate(args: Namespace, sample: Sample, sampling_params: dict[str, A "return_logprob": True, } + # Use LoRA adapter when LoRA is enabled + if args.lora_rank > 0 or args.lora_adapter_path is not None: + payload["lora_path"] = LORA_ADAPTER_NAME + if args.use_rollout_routing_replay: payload["return_routed_experts"] = True diff --git a/miles/utils/arguments.py b/miles/utils/arguments.py index ce6e47161..b1c425550 100644 --- a/miles/utils/arguments.py +++ b/miles/utils/arguments.py @@ -104,6 +104,15 @@ def add_cluster_arguments(parser): "This will always be true when --colocate is set." ), ) + parser.add_argument( + "--offload-rollout-level", + type=int, + default=2, + help=( + "The offload level for rollout when offload-rollout is set. " + "1 means only offload kv cache, 2 means offload kv cache and weights." + ), + ) reset_arg(parser, "--distributed-backend", type=str, default="nccl") reset_arg(parser, "--distributed-timeout-minutes", type=int, default=10) @@ -1415,6 +1424,27 @@ def miles_validate_args(args): if args.save_interval is not None: assert args.save is not None, "'--save' is required when save_interval is set." + if args.lora_rank > 0: + # assert args.save is not None, "'--save' is required when LoRA is enabled." + assert args.target_modules is not None, "'--target-modules' is required when LoRA is enabled." + + if args.target_modules == "all-linear": + modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] + elif "," in args.target_modules: + modules = [m.strip() for m in args.target_modules.split(",")] + else: + modules = [args.target_modules] + + if args.exclude_modules: + exclude_set = ( + set(m.strip() for m in args.exclude_modules.split(",")) + if "," in args.exclude_modules + else {args.exclude_modules} + ) + modules = [m for m in modules if m not in exclude_set] + + args.target_modules = modules + assert not (args.kl_coef != 0 and args.kl_loss_coef != 0), "Only one of kl_coef and kl_loss_coef can be set" if args.advantage_estimator in ["reinforce_plus_plus", "reinforce_plus_plus_baseline"]: diff --git a/requirements.txt b/requirements.txt index 2c20195fc..3840f294d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,6 +5,7 @@ httpx[http2] mcp[cli] memray # needed for debugging (but is lightweight), we can put it to dev mode when using pyproject.toml omegaconf +peft pillow pylatexenc pyyaml diff --git a/train.py b/train.py index 9fb480eda..a43aa61db 100644 --- a/train.py +++ b/train.py @@ -56,7 +56,7 @@ def offload_train(): actor_model.clear_memory() def onload_rollout(): - if args.offload_rollout: + if args.offload_rollout and args.offload_rollout_level == 2: ray.get(rollout_manager.onload.remote(tags=[GPU_MEMORY_TYPE_WEIGHTS])) # train loop. @@ -68,7 +68,16 @@ def onload_rollout(): rollout_data_ref = ray.get(rollout_manager.generate.remote(rollout_id)) if args.offload_rollout: - ray.get(rollout_manager.offload.remote()) + # level 1: offload kv cache only, level 2: offload weights + kv cache + ray.get( + rollout_manager.offload.remote( + tags=( + [GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_CUDA_GRAPH] + if args.offload_rollout_level == 1 + else [GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_CUDA_GRAPH] + ) + ) + ) if args.use_critic: critic_train_handle = critic_model.async_train(rollout_id, rollout_data_ref) From 09304c358e7d20c62ef3b52aba0b52fb6dbbd2d0 Mon Sep 17 00:00:00 2001 From: Guanxing Lu <747398423@qq.com> Date: Thu, 25 Dec 2025 13:29:07 -0600 Subject: [PATCH 2/8] Add CI/CD tests for LoRA FSDP. (#351) Co-authored-by: PopSoda2002 --- .github/workflows/pr-test.yml | 6 ++- .github/workflows/pr-test.yml.j2 | 4 ++ miles/backends/fsdp_utils/arguments.py | 7 --- miles/backends/fsdp_utils/lora_utils.py | 1 - .../fsdp_utils/update_weight_utils.py | 4 +- miles/ray/rollout.py | 2 +- miles/utils/arguments.py | 54 +++++++++++++++++-- tests/test_external_rollout.py | 2 +- tests/test_qwen2.5_0.5B_gsm8k.py | 2 +- tests/test_qwen2.5_0.5B_gsm8k_async.py | 2 +- tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py | 25 +++++++-- tests/test_qwen3_0.6B_fsdp_distributed.py | 20 +++++-- train.py | 18 +++---- 13 files changed, 108 insertions(+), 39 deletions(-) diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index e649da717..d662dbaf1 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -63,6 +63,8 @@ jobs: - name: Execute shell: bash + env: + ENABLE_LORA: ${{ matrix.info.enable_lora || '0' }} run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }} e2e-test-long: @@ -84,7 +86,7 @@ jobs: strategy: fail-fast: false matrix: - info: [{"num_gpus": 2, "test_file": "test_qwen2.5_0.5B_gsm8k.py"}, {"num_gpus": 2, "test_file": "test_qwen2.5_0.5B_gsm8k_async.py"}, {"num_gpus": 2, "test_file": "test_qwen3_0.6B_fsdp_colocated_2xGPU.py"}, {"num_gpus": 2, "test_file": "test_qwen3_0.6B_fsdp_distributed.py"}] + info: [{"num_gpus": 2, "test_file": "test_qwen2.5_0.5B_gsm8k.py"}, {"num_gpus": 2, "test_file": "test_qwen2.5_0.5B_gsm8k_async.py"}, {"num_gpus": 2, "test_file": "test_qwen3_0.6B_fsdp_colocated_2xGPU.py"}, {"num_gpus": 2, "test_file": "test_qwen3_0.6B_fsdp_distributed.py"}, {"enable_lora": "1", "num_gpus": 2, "test_file": "test_qwen3_0.6B_fsdp_colocated_2xGPU.py"}, {"enable_lora": "1", "num_gpus": 2, "test_file": "test_qwen3_0.6B_fsdp_distributed.py"}] defaults: run: working-directory: ${{ github.workspace }} @@ -103,4 +105,6 @@ jobs: - name: Execute shell: bash + env: + ENABLE_LORA: ${{ matrix.info.enable_lora || '0' }} run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }} diff --git a/.github/workflows/pr-test.yml.j2 b/.github/workflows/pr-test.yml.j2 index 06d6ed570..56e24b23f 100644 --- a/.github/workflows/pr-test.yml.j2 +++ b/.github/workflows/pr-test.yml.j2 @@ -13,6 +13,8 @@ {'test_file': 'test_qwen2.5_0.5B_gsm8k_async.py', 'num_gpus': 2}, {'test_file': 'test_qwen3_0.6B_fsdp_colocated_2xGPU.py', 'num_gpus': 2}, {'test_file': 'test_qwen3_0.6B_fsdp_distributed.py', 'num_gpus': 2}, + {'test_file': 'test_qwen3_0.6B_fsdp_colocated_2xGPU.py', 'num_gpus': 2, 'enable_lora': '1'}, + {'test_file': 'test_qwen3_0.6B_fsdp_distributed.py', 'num_gpus': 2, 'enable_lora': '1'}, ], }, } %> @@ -77,5 +79,7 @@ jobs: - name: Execute shell: bash + env: + ENABLE_LORA: ${{ matrix.info.enable_lora || '0' }} run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }} <% endfor %> \ No newline at end of file diff --git a/miles/backends/fsdp_utils/arguments.py b/miles/backends/fsdp_utils/arguments.py index 7cd10867b..a319fe6e5 100644 --- a/miles/backends/fsdp_utils/arguments.py +++ b/miles/backends/fsdp_utils/arguments.py @@ -60,13 +60,6 @@ class FSDPArgs: # YAML bookkeeping config: str | None = None - # LoRA configuration - lora_rank: int = 0 - lora_alpha: int = 16 - target_modules: str = "all-linear" - exclude_modules: str | None = None - lora_adapter_path: str | None = None - def parse_fsdp_cli(extra_args_provider=None): parser = argparse.ArgumentParser("FSDP SFT Training (miles)") diff --git a/miles/backends/fsdp_utils/lora_utils.py b/miles/backends/fsdp_utils/lora_utils.py index d6483b372..f7f85d84b 100644 --- a/miles/backends/fsdp_utils/lora_utils.py +++ b/miles/backends/fsdp_utils/lora_utils.py @@ -14,7 +14,6 @@ logger = logging.getLogger(__name__) -LORA_READY_MARKER = ".lora_ready" LORA_ADAPTER_NAME = "miles_lora" LORA_SUBDIR = "tmp_lora" diff --git a/miles/backends/fsdp_utils/update_weight_utils.py b/miles/backends/fsdp_utils/update_weight_utils.py index 6e4ee73a5..f48cca7f5 100644 --- a/miles/backends/fsdp_utils/update_weight_utils.py +++ b/miles/backends/fsdp_utils/update_weight_utils.py @@ -51,9 +51,7 @@ def update_weights(self) -> None: self.weight_version += 1 # Update base model if needed - # Level 1: only sync base once for LoRA models, then just LoRA - # Level 2: always sync base + LoRA - if not (is_lora_model(self.model) and self._base_synced and self.args.offload_rollout_level == 1): + if not (is_lora_model(self.model) and self._base_synced and "weight" not in self.args.offload_rollout_level): bucket = [] bucket_size = 0 for name, param in self.model.state_dict().items(): diff --git a/miles/ray/rollout.py b/miles/ray/rollout.py index 83c3a5519..cb82c8ce8 100644 --- a/miles/ray/rollout.py +++ b/miles/ray/rollout.py @@ -412,7 +412,7 @@ def init_rollout_engines(args, pg, all_rollout_engines): num_new_engines = len(rollout_engines) if num_new_engines == 0: - return num_new_engines, None + return num_new_engines if args.rollout_external: addr_and_ports = _allocate_rollout_engine_addr_and_ports_external(args=args, rollout_engines=rollout_engines) diff --git a/miles/utils/arguments.py b/miles/utils/arguments.py index b1c425550..22172a35c 100644 --- a/miles/utils/arguments.py +++ b/miles/utils/arguments.py @@ -106,11 +106,13 @@ def add_cluster_arguments(parser): ) parser.add_argument( "--offload-rollout-level", - type=int, - default=2, + type=str, + nargs="+", + default=["kv_cache", "weight"], help=( - "The offload level for rollout when offload-rollout is set. " - "1 means only offload kv cache, 2 means offload kv cache and weights." + "Specifies what to offload during rollout when offload-rollout is set. " + "Possible values: 'kv_cache', 'weight'. Default: both 'kv_cache' and 'weight'. " + "Example: --offload-rollout-level kv_cache weight" ), ) @@ -850,6 +852,41 @@ def add_algo_arguments(parser): default=1e-4, help="The threshold for Off-Policy Sequence Masking (OPSM).", ) + parser.add_argument( + "--lora-rank", + type=int, + default=0, + help="LoRA rank. Set to 0 to disable LoRA (default: 0)", + ) + parser.add_argument( + "--lora-alpha", + type=int, + default=16, + help="LoRA alpha parameter (default: 16)", + ) + parser.add_argument( + "--target-modules", + type=str, + default=None, + help=( + "Target modules for LoRA adaptation. " + "Can be 'all-linear', a single module name, or comma-separated module names. " + "Example: 'q_proj,k_proj,v_proj' (default: None)" + ), + ) + parser.add_argument( + "--exclude-modules", + type=str, + default=None, + help="Comma-separated list of modules to exclude from LoRA adaptation (default: None)", + ) + parser.add_argument( + "--lora-adapter-path", + type=str, + default=None, + help="Path to load pre-trained LoRA adapter weights (default: None)", + ) + return parser def add_router_arguments(parser): @@ -1394,6 +1431,13 @@ def _resolve_eval_datasets(args) -> list[EvalDatasetConfig]: def miles_validate_args(args): args.eval_datasets = _resolve_eval_datasets(args) + # Check if LoRA is enabled with Megatron backend (not yet implemented) + if args.lora_rank > 0 and args.train_backend == "megatron": + raise NotImplementedError( + "LoRA is not yet implemented for Megatron backend. " + "Please use FSDP backend (--train-backend fsdp) or disable LoRA (--lora-rank 0)." + ) + if args.kl_coef != 0 or args.use_kl_loss: if not os.path.exists(args.ref_load): raise FileNotFoundError(f"ref_load {args.ref_load} does not exist, please check the path.") @@ -1425,7 +1469,7 @@ def miles_validate_args(args): assert args.save is not None, "'--save' is required when save_interval is set." if args.lora_rank > 0: - # assert args.save is not None, "'--save' is required when LoRA is enabled." + assert args.save is not None, "'--save' is required when LoRA is enabled." assert args.target_modules is not None, "'--target-modules' is required when LoRA is enabled." if args.target_modules == "all-linear": diff --git a/tests/test_external_rollout.py b/tests/test_external_rollout.py index c5c0838c5..f12837d88 100644 --- a/tests/test_external_rollout.py +++ b/tests/test_external_rollout.py @@ -14,7 +14,7 @@ def prepare(): U.exec_command("mkdir -p /root/models /root/datasets") - U.exec_command(f"huggingface-cli download Qwen/Qwen2.5-0.5B-Instruct --local-dir /root/models/{MODEL_NAME}") + U.exec_command(f"hf download Qwen/Qwen2.5-0.5B-Instruct --local-dir /root/models/{MODEL_NAME}") U.hf_download_dataset("zhuzilin/gsm8k") U.convert_checkpoint(model_name=MODEL_NAME, megatron_model_type=MODEL_TYPE, num_gpus_per_node=NUM_GPUS) diff --git a/tests/test_qwen2.5_0.5B_gsm8k.py b/tests/test_qwen2.5_0.5B_gsm8k.py index 6302aadb6..26d3e3197 100644 --- a/tests/test_qwen2.5_0.5B_gsm8k.py +++ b/tests/test_qwen2.5_0.5B_gsm8k.py @@ -11,7 +11,7 @@ def prepare(): U.exec_command("mkdir -p /root/models /root/datasets") - U.exec_command(f"huggingface-cli download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") + U.exec_command(f"hf download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") U.hf_download_dataset("zhuzilin/gsm8k") diff --git a/tests/test_qwen2.5_0.5B_gsm8k_async.py b/tests/test_qwen2.5_0.5B_gsm8k_async.py index 1c55ccb20..878d68b1c 100644 --- a/tests/test_qwen2.5_0.5B_gsm8k_async.py +++ b/tests/test_qwen2.5_0.5B_gsm8k_async.py @@ -10,7 +10,7 @@ def prepare(): U.exec_command("mkdir -p /root/models /root/datasets") - U.exec_command(f"huggingface-cli download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") + U.exec_command(f"hf download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") U.hf_download_dataset("zhuzilin/gsm8k") diff --git a/tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py b/tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py index 6967f9145..460f8a5a5 100644 --- a/tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py +++ b/tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py @@ -1,17 +1,29 @@ import miles.utils.external_utils.command_utils as U MODEL_NAME = "Qwen3-0.6B" +ENABLE_LORA = U.get_bool_env_var("ENABLE_LORA", "0") def prepare(): U.exec_command("mkdir -p /root/models /root/datasets") - U.exec_command(f"huggingface-cli download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") + U.exec_command(f"hf download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") U.hf_download_dataset("zhuzilin/gsm8k") def execute(): ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME} " + lora_args = ( + ( + "--lora-rank 32 " + "--lora-alpha 32 " + "--target-modules all-linear " + f"--save /root/models/{MODEL_NAME}-lora-ckpt " + ) + if ENABLE_LORA + else "" + ) + rollout_args = ( "--prompt-data /root/datasets/gsm8k/train.parquet " "--input-key messages " @@ -50,7 +62,7 @@ def execute(): optimizer_args = ( "--optimizer adam " - "--lr 1e-6 " + f"--lr {'2e-5' if ENABLE_LORA else '1e-6'} " "--lr-decay-style constant " "--weight-decay 0.1 " "--adam-beta1 0.9 " @@ -73,10 +85,17 @@ def execute(): "--ci-metric-checker-threshold 0.71 " # loose threshold at 60 step ) - misc_args = "--actor-num-nodes 1 " "--actor-num-gpus-per-node 2 " "--colocate " "--train-backend fsdp " + misc_args = ( + "--actor-num-nodes 1 " + "--actor-num-gpus-per-node 2 " + "--colocate " + "--offload-rollout-level kv_cache weight " + "--train-backend fsdp " + ) train_args = ( f"{ckpt_args} " + f"{lora_args} " f"{rollout_args} " f"{optimizer_args} " f"{grpo_args} " diff --git a/tests/test_qwen3_0.6B_fsdp_distributed.py b/tests/test_qwen3_0.6B_fsdp_distributed.py index b3eb416b3..c4592ffdf 100644 --- a/tests/test_qwen3_0.6B_fsdp_distributed.py +++ b/tests/test_qwen3_0.6B_fsdp_distributed.py @@ -1,20 +1,30 @@ import miles.utils.external_utils.command_utils as U MODEL_NAME = "Qwen3-0.6B" - - +ENABLE_LORA = U.get_bool_env_var("ENABLE_LORA", "0") FEW_GPU = U.get_bool_env_var("MILES_TEST_FEW_GPU", "1") def prepare(): U.exec_command("mkdir -p /root/models /root/datasets") - U.exec_command(f"huggingface-cli download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") + U.exec_command(f"hf download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") U.hf_download_dataset("zhuzilin/gsm8k") def execute(): ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME} " + lora_args = ( + ( + "--lora-rank 32 " + "--lora-alpha 32 " + "--target-modules all-linear " + f"--save /root/models/{MODEL_NAME}-lora-ckpt " + ) + if ENABLE_LORA + else "" + ) + rollout_args = ( "--prompt-data /root/datasets/gsm8k/train.parquet " "--input-key messages " @@ -54,7 +64,7 @@ def execute(): optimizer_args = ( "--optimizer adam " - "--lr 1e-6 " + f"--lr {'2e-5' if ENABLE_LORA else '1e-6'} " "--lr-decay-style constant " "--weight-decay 0.1 " "--adam-beta1 0.9 " @@ -67,6 +77,7 @@ def execute(): "--actor-num-nodes 1 " f"--actor-num-gpus-per-node {1 if FEW_GPU else 2} " f"--rollout-num-gpus {1 if FEW_GPU else 2} " + "--offload-rollout-level kv_cache weight " "--train-backend fsdp " ) @@ -79,6 +90,7 @@ def execute(): train_args = ( f"{ckpt_args} " + f"{lora_args} " f"{rollout_args} " f"{optimizer_args} " f"{grpo_args} " diff --git a/train.py b/train.py index a43aa61db..212361fc0 100644 --- a/train.py +++ b/train.py @@ -56,7 +56,7 @@ def offload_train(): actor_model.clear_memory() def onload_rollout(): - if args.offload_rollout and args.offload_rollout_level == 2: + if args.offload_rollout and "weight" in args.offload_rollout_level: ray.get(rollout_manager.onload.remote(tags=[GPU_MEMORY_TYPE_WEIGHTS])) # train loop. @@ -68,16 +68,12 @@ def onload_rollout(): rollout_data_ref = ray.get(rollout_manager.generate.remote(rollout_id)) if args.offload_rollout: - # level 1: offload kv cache only, level 2: offload weights + kv cache - ray.get( - rollout_manager.offload.remote( - tags=( - [GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_CUDA_GRAPH] - if args.offload_rollout_level == 1 - else [GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_CUDA_GRAPH] - ) - ) - ) + offload_tags = [GPU_MEMORY_TYPE_CUDA_GRAPH] + if "kv_cache" in args.offload_rollout_level: + offload_tags.append(GPU_MEMORY_TYPE_KV_CACHE) + if "weight" in args.offload_rollout_level: + offload_tags.append(GPU_MEMORY_TYPE_WEIGHTS) + ray.get(rollout_manager.offload.remote(tags=offload_tags)) if args.use_critic: critic_train_handle = critic_model.async_train(rollout_id, rollout_data_ref) From 62bb5bace80dba9cfdecb2357c52386199fca939 Mon Sep 17 00:00:00 2001 From: Guanxing Lu <747398423@qq.com> Date: Thu, 25 Dec 2025 13:29:07 -0600 Subject: [PATCH 3/8] Add CI/CD tests for LoRA FSDP. (#351) Co-authored-by: PopSoda2002 --- .github/workflows/pr-test.yml | 8 +- .github/workflows/pr-test.yml.j2 | 4 + miles/backends/fsdp_utils/lora_utils.py | 76 +++++++++++++++++++ .../fsdp_utils/update_weight_utils.py | 50 +++++++----- miles/utils/arguments.py | 74 ++++++++++++++++++ tests/test_external_rollout.py | 2 +- tests/test_qwen2.5_0.5B_gsm8k.py | 2 +- tests/test_qwen2.5_0.5B_gsm8k_async.py | 2 +- tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py | 25 +++++- tests/test_qwen3_0.6B_fsdp_distributed.py | 20 ++++- train.py | 12 ++- 11 files changed, 242 insertions(+), 33 deletions(-) create mode 100644 miles/backends/fsdp_utils/lora_utils.py diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index 8ea939739..2702192fc 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -243,6 +243,8 @@ jobs: - name: Execute shell: bash + env: + ENABLE_LORA: ${{ matrix.info.enable_lora || '0' }} run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }} e2e-test-long: @@ -268,7 +270,7 @@ jobs: strategy: fail-fast: false matrix: - info: [{"num_gpus": 2, "test_file": "test_qwen2.5_0.5B_gsm8k.py"}, {"num_gpus": 2, "test_file": "test_qwen2.5_0.5B_gsm8k_async.py"}] + info: [{"num_gpus": 2, "test_file": "test_qwen2.5_0.5B_gsm8k.py"}, {"num_gpus": 2, "test_file": "test_qwen2.5_0.5B_gsm8k_async.py"}, {"num_gpus": 2, "test_file": "test_qwen3_0.6B_fsdp_colocated_2xGPU.py"}, {"num_gpus": 2, "test_file": "test_qwen3_0.6B_fsdp_distributed.py"}, {"enable_lora": "1", "num_gpus": 2, "test_file": "test_qwen3_0.6B_fsdp_colocated_2xGPU.py"}, {"enable_lora": "1", "num_gpus": 2, "test_file": "test_qwen3_0.6B_fsdp_distributed.py"}] defaults: run: working-directory: ${{ github.workspace }} @@ -287,6 +289,8 @@ jobs: - name: Execute shell: bash + env: + ENABLE_LORA: ${{ matrix.info.enable_lora || '0' }} run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }} e2e-test-image: @@ -331,4 +335,6 @@ jobs: - name: Execute shell: bash + env: + ENABLE_LORA: ${{ matrix.info.enable_lora || '0' }} run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }} diff --git a/.github/workflows/pr-test.yml.j2 b/.github/workflows/pr-test.yml.j2 index 84cac9114..70ca765db 100644 --- a/.github/workflows/pr-test.yml.j2 +++ b/.github/workflows/pr-test.yml.j2 @@ -65,6 +65,8 @@ {'test_file': 'test_qwen3_4B_ckpt.py --async-save', 'num_gpus': 8}, {'test_file': 'test_qwen2.5_0.5B_gsm8k.py', 'num_gpus': 2}, {'test_file': 'test_qwen2.5_0.5B_gsm8k_async.py', 'num_gpus': 2}, + {'test_file': 'test_qwen3_0.6B_fsdp_colocated_2xGPU.py', 'num_gpus': 2, 'enable_lora': '1'}, + {'test_file': 'test_qwen3_0.6B_fsdp_distributed.py', 'num_gpus': 2, 'enable_lora': '1'}, ], }, } %> @@ -133,5 +135,7 @@ jobs: - name: Execute shell: bash + env: + ENABLE_LORA: ${{ matrix.info.enable_lora || '0' }} run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }} <% endfor %> \ No newline at end of file diff --git a/miles/backends/fsdp_utils/lora_utils.py b/miles/backends/fsdp_utils/lora_utils.py new file mode 100644 index 000000000..f7f85d84b --- /dev/null +++ b/miles/backends/fsdp_utils/lora_utils.py @@ -0,0 +1,76 @@ +import logging +import os +import shutil +from pathlib import Path + +import torch.distributed as dist +import torch.nn as nn +from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict + +try: + from peft import LoraConfig, PeftModel, TaskType, get_peft_model +except ImportError as err: + raise ImportError("peft library required for LoRA. Install with: pip install peft") from err + +logger = logging.getLogger(__name__) + +LORA_ADAPTER_NAME = "miles_lora" +LORA_SUBDIR = "tmp_lora" + + +def apply_lora_to_model(model: nn.Module, args) -> nn.Module: + if args.lora_adapter_path: + logger.info(f"Loading LoRA adapter from {args.lora_adapter_path}") + model = PeftModel.from_pretrained(model, args.lora_adapter_path, is_trainable=True) + peft_config = model.peft_config["default"] + if isinstance(peft_config.task_type, str): + peft_config.task_type = TaskType.CAUSAL_LM + model.print_trainable_parameters() + return model + + lora_config = LoraConfig( + task_type=TaskType.CAUSAL_LM, + r=args.lora_rank, + lora_alpha=args.lora_alpha, + target_modules=args.target_modules, + bias="none", + ) + + model = get_peft_model(model, lora_config) # autocast_adapter_dtype=False) + model.print_trainable_parameters() + logger.info(f"Applied LoRA: rank={args.lora_rank}, alpha={args.lora_alpha}") + return model + + +def is_lora_model(module: nn.Module) -> bool: + unwrapped = getattr(module, "_fsdp_wrapped_module", module) + return hasattr(unwrapped, "peft_config") + + +def save_lora_to_disk(module: nn.Module, save_dir: str) -> str: + """Save LoRA adapter to disk with file lock mechanism.""" + # TODO: All gather lora layers not full layers + options = StateDictOptions(full_state_dict=True, cpu_offload=True) + full_state_dict = get_model_state_dict(module, options=options) + + lora_state_dict = {name: param for name, param in full_state_dict.items() if "lora_" in name} + + if dist.get_rank() == 0: + save_path = Path(save_dir) + save_path.mkdir(parents=True, exist_ok=True) + + module.save_pretrained(str(save_path), state_dict=lora_state_dict) + + # TODO: check if file lock is needed or better way to do it + os.sync() + + logger.info(f"Saved LoRA adapter to {save_path}") + return save_dir + + +def delete_lora_from_disk(save_dir: str) -> None: + """Delete LoRA adapter files from disk.""" + save_path = Path(save_dir) + if save_path.exists(): + shutil.rmtree(save_path) + logger.info(f"Deleted LoRA adapter from {save_path}") diff --git a/miles/backends/fsdp_utils/update_weight_utils.py b/miles/backends/fsdp_utils/update_weight_utils.py index d0f2360ab..347f774c3 100644 --- a/miles/backends/fsdp_utils/update_weight_utils.py +++ b/miles/backends/fsdp_utils/update_weight_utils.py @@ -17,6 +17,7 @@ from sglang.srt.utils import MultiprocessingSerializer +from miles.backends.fsdp_utils.lora_utils import is_lora_model from miles.utils.distributed_utils import init_process_group @@ -34,6 +35,7 @@ def __init__(self, args: Namespace, model: torch.nn.Module) -> None: self.args = args self.model = model self.weight_version = 0 + self._base_synced = False @abc.abstractmethod def connect_rollout_engines( @@ -45,32 +47,38 @@ def connect_rollout_engines( def update_weights(self) -> None: self.weight_version += 1 - bucket = [] - bucket_size = 0 - for name, param in self.model.state_dict().items(): - param_size = param.numel() * param.element_size() - if bucket and bucket_size + param_size >= self.args.update_weight_buffer_size: + + # Update base model if needed + if not (is_lora_model(self.model) and self._base_synced and "weight" not in self.args.offload_rollout_level): + bucket = [] + bucket_size = 0 + for name, param in self.model.state_dict().items(): + if any(x in name for x in ["_flat_param", "lora_"]): + continue + name = name.replace("base_model.model.", "").replace(".base_layer", "") + param_size = param.numel() * param.element_size() + if bucket and bucket_size + param_size >= self.args.update_weight_buffer_size: + self.wait_and_update_bucket_weights(bucket) + del bucket + bucket = [] + bucket_size = 0 + + param = param.cuda() + if isinstance(param, DTensor): + # async version of param.full_tensor + param = param.redistribute( + placements=[Replicate()] * param.device_mesh.ndim, + async_op=True, + ).to_local() + bucket.append((name, param)) + bucket_size += param_size + + if bucket: self.wait_and_update_bucket_weights(bucket) del bucket bucket = [] bucket_size = 0 - param = param.cuda() - if isinstance(param, DTensor): - # async version of param.full_tensor - param = param.redistribute( - placements=[Replicate()] * param.device_mesh.ndim, - async_op=True, - ).to_local() - bucket.append((name, param)) - bucket_size += param_size - - if bucket: - self.wait_and_update_bucket_weights(bucket) - del bucket - bucket = [] - bucket_size = 0 - def wait_and_update_bucket_weights(self, bucket): bucket = [(name, param.wait()) if hasattr(param, "wait") else (name, param) for name, param in bucket] self.update_bucket_weights(bucket, weight_version=self.weight_version) diff --git a/miles/utils/arguments.py b/miles/utils/arguments.py index 51b3d970b..08a60cae1 100644 --- a/miles/utils/arguments.py +++ b/miles/utils/arguments.py @@ -103,6 +103,17 @@ def add_cluster_arguments(parser): "This will always be true when --colocate is set." ), ) + parser.add_argument( + "--offload-rollout-level", + type=str, + nargs="+", + default=["kv_cache", "weight"], + help=( + "Specifies what to offload during rollout when offload-rollout is set. " + "Possible values: 'kv_cache', 'weight'. Default: both 'kv_cache' and 'weight'. " + "Example: --offload-rollout-level kv_cache weight" + ), + ) reset_arg(parser, "--distributed-backend", type=str, default="nccl") reset_arg(parser, "--distributed-timeout-minutes", type=int, default=10) @@ -920,6 +931,41 @@ def add_algo_arguments(parser): default=1e-4, help="The threshold for Off-Policy Sequence Masking (OPSM).", ) + parser.add_argument( + "--lora-rank", + type=int, + default=0, + help="LoRA rank. Set to 0 to disable LoRA (default: 0)", + ) + parser.add_argument( + "--lora-alpha", + type=int, + default=16, + help="LoRA alpha parameter (default: 16)", + ) + parser.add_argument( + "--target-modules", + type=str, + default=None, + help=( + "Target modules for LoRA adaptation. " + "Can be 'all-linear', a single module name, or comma-separated module names. " + "Example: 'q_proj,k_proj,v_proj' (default: None)" + ), + ) + parser.add_argument( + "--exclude-modules", + type=str, + default=None, + help="Comma-separated list of modules to exclude from LoRA adaptation (default: None)", + ) + parser.add_argument( + "--lora-adapter-path", + type=str, + default=None, + help="Path to load pre-trained LoRA adapter weights (default: None)", + ) + return parser def add_router_arguments(parser): @@ -1491,6 +1537,13 @@ def _resolve_eval_datasets(args) -> list[EvalDatasetConfig]: def miles_validate_args(args): args.eval_datasets = _resolve_eval_datasets(args) + # Check if LoRA is enabled with Megatron backend (not yet implemented) + if args.lora_rank > 0 and args.train_backend == "megatron": + raise NotImplementedError( + "LoRA is not yet implemented for Megatron backend. " + "Please use FSDP backend (--train-backend fsdp) or disable LoRA (--lora-rank 0)." + ) + if args.kl_coef != 0 or args.use_kl_loss: if not os.path.exists(args.ref_load): raise FileNotFoundError(f"ref_load {args.ref_load} does not exist, please check the path.") @@ -1526,6 +1579,27 @@ def miles_validate_args(args): if args.save_interval is not None: assert args.save is not None, "'--save' is required when save_interval is set." + if args.lora_rank > 0: + assert args.save is not None, "'--save' is required when LoRA is enabled." + assert args.target_modules is not None, "'--target-modules' is required when LoRA is enabled." + + if args.target_modules == "all-linear": + modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] + elif "," in args.target_modules: + modules = [m.strip() for m in args.target_modules.split(",")] + else: + modules = [args.target_modules] + + if args.exclude_modules: + exclude_set = ( + set(m.strip() for m in args.exclude_modules.split(",")) + if "," in args.exclude_modules + else {args.exclude_modules} + ) + modules = [m for m in modules if m not in exclude_set] + + args.target_modules = modules + assert not (args.kl_coef != 0 and args.kl_loss_coef != 0), "Only one of kl_coef and kl_loss_coef can be set" if args.advantage_estimator in ["reinforce_plus_plus", "reinforce_plus_plus_baseline"]: diff --git a/tests/test_external_rollout.py b/tests/test_external_rollout.py index c5c0838c5..f12837d88 100644 --- a/tests/test_external_rollout.py +++ b/tests/test_external_rollout.py @@ -14,7 +14,7 @@ def prepare(): U.exec_command("mkdir -p /root/models /root/datasets") - U.exec_command(f"huggingface-cli download Qwen/Qwen2.5-0.5B-Instruct --local-dir /root/models/{MODEL_NAME}") + U.exec_command(f"hf download Qwen/Qwen2.5-0.5B-Instruct --local-dir /root/models/{MODEL_NAME}") U.hf_download_dataset("zhuzilin/gsm8k") U.convert_checkpoint(model_name=MODEL_NAME, megatron_model_type=MODEL_TYPE, num_gpus_per_node=NUM_GPUS) diff --git a/tests/test_qwen2.5_0.5B_gsm8k.py b/tests/test_qwen2.5_0.5B_gsm8k.py index dcdbd5834..e33698af3 100644 --- a/tests/test_qwen2.5_0.5B_gsm8k.py +++ b/tests/test_qwen2.5_0.5B_gsm8k.py @@ -12,7 +12,7 @@ def prepare(): U.exec_command("mkdir -p /root/models /root/datasets") - U.exec_command(f"huggingface-cli download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") + U.exec_command(f"hf download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") U.hf_download_dataset("zhuzilin/gsm8k") diff --git a/tests/test_qwen2.5_0.5B_gsm8k_async.py b/tests/test_qwen2.5_0.5B_gsm8k_async.py index dcaaf5e1f..f52c80afe 100644 --- a/tests/test_qwen2.5_0.5B_gsm8k_async.py +++ b/tests/test_qwen2.5_0.5B_gsm8k_async.py @@ -11,7 +11,7 @@ def prepare(): U.exec_command("mkdir -p /root/models /root/datasets") - U.exec_command(f"huggingface-cli download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") + U.exec_command(f"hf download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") U.hf_download_dataset("zhuzilin/gsm8k") diff --git a/tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py b/tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py index 3d19b48ce..94c5157c7 100644 --- a/tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py +++ b/tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py @@ -2,17 +2,29 @@ import miles.utils.external_utils.command_utils as U MODEL_NAME = "Qwen3-0.6B" +ENABLE_LORA = U.get_bool_env_var("ENABLE_LORA", "0") def prepare(): U.exec_command("mkdir -p /root/models /root/datasets") - U.exec_command(f"huggingface-cli download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") + U.exec_command(f"hf download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") U.hf_download_dataset("zhuzilin/gsm8k") def execute(): ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME} " + lora_args = ( + ( + "--lora-rank 32 " + "--lora-alpha 32 " + "--target-modules all-linear " + f"--save /root/models/{MODEL_NAME}-lora-ckpt " + ) + if ENABLE_LORA + else "" + ) + rollout_args = ( "--prompt-data /root/datasets/gsm8k/train.parquet " "--input-key messages " @@ -51,7 +63,7 @@ def execute(): optimizer_args = ( "--optimizer adam " - "--lr 1e-6 " + f"--lr {'2e-5' if ENABLE_LORA else '1e-6'} " "--lr-decay-style constant " "--weight-decay 0.1 " "--adam-beta1 0.9 " @@ -74,10 +86,17 @@ def execute(): "--ci-metric-checker-threshold 0.71 " # loose threshold at 60 step ) - misc_args = "--actor-num-nodes 1 " "--actor-num-gpus-per-node 2 " "--colocate " "--train-backend fsdp " + misc_args = ( + "--actor-num-nodes 1 " + "--actor-num-gpus-per-node 2 " + "--colocate " + "--offload-rollout-level kv_cache weight " + "--train-backend fsdp " + ) train_args = ( f"{ckpt_args} " + f"{lora_args} " f"{rollout_args} " f"{optimizer_args} " f"{grpo_args} " diff --git a/tests/test_qwen3_0.6B_fsdp_distributed.py b/tests/test_qwen3_0.6B_fsdp_distributed.py index 3d70f3e4c..9b11f9221 100644 --- a/tests/test_qwen3_0.6B_fsdp_distributed.py +++ b/tests/test_qwen3_0.6B_fsdp_distributed.py @@ -2,20 +2,30 @@ import miles.utils.external_utils.command_utils as U MODEL_NAME = "Qwen3-0.6B" - - +ENABLE_LORA = U.get_bool_env_var("ENABLE_LORA", "0") FEW_GPU = U.get_bool_env_var("MILES_TEST_FEW_GPU", "1") def prepare(): U.exec_command("mkdir -p /root/models /root/datasets") - U.exec_command(f"huggingface-cli download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") + U.exec_command(f"hf download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") U.hf_download_dataset("zhuzilin/gsm8k") def execute(): ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME} " + lora_args = ( + ( + "--lora-rank 32 " + "--lora-alpha 32 " + "--target-modules all-linear " + f"--save /root/models/{MODEL_NAME}-lora-ckpt " + ) + if ENABLE_LORA + else "" + ) + rollout_args = ( "--prompt-data /root/datasets/gsm8k/train.parquet " "--input-key messages " @@ -55,7 +65,7 @@ def execute(): optimizer_args = ( "--optimizer adam " - "--lr 1e-6 " + f"--lr {'2e-5' if ENABLE_LORA else '1e-6'} " "--lr-decay-style constant " "--weight-decay 0.1 " "--adam-beta1 0.9 " @@ -68,6 +78,7 @@ def execute(): "--actor-num-nodes 1 " f"--actor-num-gpus-per-node {1 if FEW_GPU else 2} " f"--rollout-num-gpus {1 if FEW_GPU else 2} " + "--offload-rollout-level kv_cache weight " "--train-backend fsdp " ) @@ -80,6 +91,7 @@ def execute(): train_args = ( f"{ckpt_args} " + f"{lora_args} " f"{rollout_args} " f"{optimizer_args} " f"{grpo_args} " diff --git a/train.py b/train.py index 745dcbed6..e780811f5 100644 --- a/train.py +++ b/train.py @@ -1,4 +1,5 @@ import ray +from sglang.srt.constants import GPU_MEMORY_TYPE_CUDA_GRAPH, GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS from miles.ray.placement_group import create_placement_groups, create_rollout_manager, create_training_models from miles.utils.arguments import parse_args @@ -61,6 +62,10 @@ def save(rollout_id): if args.rollout_global_dataset: ray.get(rollout_manager.save.remote(rollout_id)) + def onload_rollout(): + if args.offload_rollout and "weight" in args.offload_rollout_level: + ray.get(rollout_manager.onload.remote(tags=[GPU_MEMORY_TYPE_WEIGHTS])) + # train loop. # note that for async training, one can change the position of the sync operation(ray.get). for rollout_id in range(args.start_rollout_id, args.num_rollout): @@ -70,7 +75,12 @@ def save(rollout_id): rollout_data_ref = ray.get(rollout_manager.generate.remote(rollout_id)) if args.offload_rollout: - ray.get(rollout_manager.offload.remote()) + offload_tags = [GPU_MEMORY_TYPE_CUDA_GRAPH] + if "kv_cache" in args.offload_rollout_level: + offload_tags.append(GPU_MEMORY_TYPE_KV_CACHE) + if "weight" in args.offload_rollout_level: + offload_tags.append(GPU_MEMORY_TYPE_WEIGHTS) + ray.get(rollout_manager.offload.remote(tags=offload_tags)) if args.use_critic: critic_train_handle = critic_model.async_train(rollout_id, rollout_data_ref) From 3d6484cd9768d46a59cc803c50b5bd72746ee1dd Mon Sep 17 00:00:00 2001 From: Guanxing Lu <747398423@qq.com> Date: Tue, 30 Dec 2025 16:41:50 +0000 Subject: [PATCH 4/8] Update LoRA Weights to Rollout Engine via Tensor Co-authored-by: PopSoda2002 --- .github/workflows/pr-test.yml | 12 ++- miles/backends/fsdp_utils/lora_utils.py | 27 ++++++ .../fsdp_utils/update_weight_utils.py | 82 ++++++++++++++++++- miles/backends/sglang_utils/sglang_engine.py | 34 +++++++- miles/ray/rollout.py | 8 +- miles/utils/arguments.py | 6 ++ tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py | 6 ++ tests/test_qwen3_0.6B_fsdp_distributed.py | 6 ++ 8 files changed, 171 insertions(+), 10 deletions(-) diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index 2702192fc..a0cd4c291 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -67,6 +67,8 @@ jobs: - name: Execute shell: bash + env: + ENABLE_LORA: ${{ matrix.info.enable_lora || '0' }} run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }} e2e-test-fsdp: @@ -111,6 +113,8 @@ jobs: - name: Execute shell: bash + env: + ENABLE_LORA: ${{ matrix.info.enable_lora || '0' }} run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }} e2e-test-megatron: @@ -155,6 +159,8 @@ jobs: - name: Execute shell: bash + env: + ENABLE_LORA: ${{ matrix.info.enable_lora || '0' }} run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }} e2e-test-precision: @@ -199,6 +205,8 @@ jobs: - name: Execute shell: bash + env: + ENABLE_LORA: ${{ matrix.info.enable_lora || '0' }} run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }} e2e-test-ckpt: @@ -270,7 +278,7 @@ jobs: strategy: fail-fast: false matrix: - info: [{"num_gpus": 2, "test_file": "test_qwen2.5_0.5B_gsm8k.py"}, {"num_gpus": 2, "test_file": "test_qwen2.5_0.5B_gsm8k_async.py"}, {"num_gpus": 2, "test_file": "test_qwen3_0.6B_fsdp_colocated_2xGPU.py"}, {"num_gpus": 2, "test_file": "test_qwen3_0.6B_fsdp_distributed.py"}, {"enable_lora": "1", "num_gpus": 2, "test_file": "test_qwen3_0.6B_fsdp_colocated_2xGPU.py"}, {"enable_lora": "1", "num_gpus": 2, "test_file": "test_qwen3_0.6B_fsdp_distributed.py"}] + info: [{"num_gpus": 2, "test_file": "test_qwen2.5_0.5B_gsm8k.py"}, {"num_gpus": 2, "test_file": "test_qwen2.5_0.5B_gsm8k_async.py"}] defaults: run: working-directory: ${{ github.workspace }} @@ -316,7 +324,7 @@ jobs: strategy: fail-fast: false matrix: - info: [{"num_gpus": 4, "test_file": "test_qwen2.5_0.5B_gsm8k_async_short.py"}, {"num_gpus": 4, "test_file": "test_qwen2.5_0.5B_gsm8k_short.py"}, {"num_gpus": 2, "test_file": "test_qwen3_0.6B_fsdp_colocated_2xGPU.py"}, {"num_gpus": 2, "test_file": "test_qwen3_4B_fsdp_true_on_policy.py"}, {"num_gpus": 8, "test_file": "test_qwen3_vl_4B_fsdp.py"}, {"num_gpus": 2, "test_file": "test_qwen3_0.6B_fsdp_distributed.py"}, {"num_gpus": 8, "test_file": "test_quick_start_glm4_9B.py"}, {"num_gpus": 8, "test_file": "test_qwen3_30B_A3B.py"}, {"num_gpus": 8, "test_file": "test_qwen3_4B_ppo.py"}, {"num_gpus": 8, "test_file": "test_moonlight_16B_A3B.py"}, {"num_gpus": 8, "test_file": "test_mimo_7B_mtp_only_grad.py"}, {"num_gpus": 8, "test_file": "test_qwen3_0.6B_parallel_check.py"}, {"num_gpus": 8, "test_file": "test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_file": "test_qwen3_4B_ckpt.py --async-save"}, {"num_gpus": 2, "test_file": "test_qwen2.5_0.5B_gsm8k.py"}, {"num_gpus": 2, "test_file": "test_qwen2.5_0.5B_gsm8k_async.py"}] + info: [{"num_gpus": 4, "test_file": "test_qwen2.5_0.5B_gsm8k_async_short.py"}, {"num_gpus": 4, "test_file": "test_qwen2.5_0.5B_gsm8k_short.py"}, {"num_gpus": 2, "test_file": "test_qwen3_0.6B_fsdp_colocated_2xGPU.py"}, {"num_gpus": 2, "test_file": "test_qwen3_4B_fsdp_true_on_policy.py"}, {"num_gpus": 8, "test_file": "test_qwen3_vl_4B_fsdp.py"}, {"num_gpus": 2, "test_file": "test_qwen3_0.6B_fsdp_distributed.py"}, {"num_gpus": 8, "test_file": "test_quick_start_glm4_9B.py"}, {"num_gpus": 8, "test_file": "test_qwen3_30B_A3B.py"}, {"num_gpus": 8, "test_file": "test_qwen3_4B_ppo.py"}, {"num_gpus": 8, "test_file": "test_moonlight_16B_A3B.py"}, {"num_gpus": 8, "test_file": "test_mimo_7B_mtp_only_grad.py"}, {"num_gpus": 8, "test_file": "test_qwen3_0.6B_parallel_check.py"}, {"num_gpus": 8, "test_file": "test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_file": "test_qwen3_4B_ckpt.py --async-save"}, {"num_gpus": 2, "test_file": "test_qwen2.5_0.5B_gsm8k.py"}, {"num_gpus": 2, "test_file": "test_qwen2.5_0.5B_gsm8k_async.py"}, {"enable_lora": "1", "num_gpus": 2, "test_file": "test_qwen3_0.6B_fsdp_colocated_2xGPU.py"}, {"enable_lora": "1", "num_gpus": 2, "test_file": "test_qwen3_0.6B_fsdp_distributed.py"}] defaults: run: working-directory: ${{ github.workspace }} diff --git a/miles/backends/fsdp_utils/lora_utils.py b/miles/backends/fsdp_utils/lora_utils.py index f7f85d84b..1872fe59c 100644 --- a/miles/backends/fsdp_utils/lora_utils.py +++ b/miles/backends/fsdp_utils/lora_utils.py @@ -74,3 +74,30 @@ def delete_lora_from_disk(save_dir: str) -> None: if save_path.exists(): shutil.rmtree(save_path) logger.info(f"Deleted LoRA adapter from {save_path}") + + +def get_lora_weights_and_config(module: nn.Module) -> tuple[dict[str, any], dict[str, any]]: + """Extract LoRA weights and config from PEFT model for tensor-based sync.""" + # TODO: only gather lora weights, or gather lora weights in bucket logic i.e., layered summon + # options = StateDictOptions(full_state_dict=True, cpu_offload=True) + options = StateDictOptions(full_state_dict=True, cpu_offload=False) + full_state_dict = get_model_state_dict(module, options=options) + + state_dict = {name: param for name, param in full_state_dict.items() if "lora_" in name} + if dist.get_rank() == 0: + logger.info(f"Extracted {len(state_dict)} LoRA weight tensors") + + for name in list(state_dict.keys()): + key = name.replace(".default.weight", ".weight") # .replace("base_model.model.", "") + state_dict[key] = state_dict.pop(name) + + peft_config = module.peft_config["default"] + config_dict = { + "peft_type": "LORA", + "r": peft_config.r, + "lora_alpha": peft_config.lora_alpha, + "target_modules": list(peft_config.target_modules), + "bias": peft_config.bias, + } + + return state_dict, config_dict diff --git a/miles/backends/fsdp_utils/update_weight_utils.py b/miles/backends/fsdp_utils/update_weight_utils.py index 347f774c3..7b7674337 100644 --- a/miles/backends/fsdp_utils/update_weight_utils.py +++ b/miles/backends/fsdp_utils/update_weight_utils.py @@ -1,5 +1,6 @@ import abc import logging +import os import socket from argparse import Namespace from collections.abc import Sequence @@ -17,15 +18,21 @@ from sglang.srt.utils import MultiprocessingSerializer -from miles.backends.fsdp_utils.lora_utils import is_lora_model from miles.utils.distributed_utils import init_process_group - try: from sglang.srt.weight_sync.tensor_bucket import FlattenedTensorBucket # type: ignore[import] except ImportError: from sglang.srt.model_executor.model_runner import FlattenedTensorBucket # type: ignore[import] +from .lora_utils import ( + LORA_ADAPTER_NAME, + LORA_SUBDIR, + delete_lora_from_disk, + get_lora_weights_and_config, + is_lora_model, + save_lora_to_disk, +) logger = logging.getLogger(__name__) @@ -36,6 +43,7 @@ def __init__(self, args: Namespace, model: torch.nn.Module) -> None: self.model = model self.weight_version = 0 self._base_synced = False + self._lora_loaded = False @abc.abstractmethod def connect_rollout_engines( @@ -76,8 +84,74 @@ def update_weights(self) -> None: if bucket: self.wait_and_update_bucket_weights(bucket) del bucket - bucket = [] - bucket_size = 0 + + self._base_synced = True + + # Update lora weights if needed + if is_lora_model(self.model): + if self.args.lora_sync_from_tensor: + self._update_lora_via_tensor() + else: + self._update_lora_via_file() + + def _update_lora_via_file(self) -> None: + """Push LoRA weights to rollout engines using disk files.""" + self._lora_save_dir = os.path.join(self.args.save, LORA_SUBDIR) + if dist.get_rank() == 0: + if os.path.exists(self._lora_save_dir): + delete_lora_from_disk(self._lora_save_dir) + + dist.barrier() + + save_lora_to_disk(self.model, self._lora_save_dir) + + dist.barrier() + + if dist.get_rank() == 0: + if self._lora_loaded: + refs = [engine.unload_lora_adapter.remote(LORA_ADAPTER_NAME) for engine in self.rollout_engines] + ray.get(refs) + + refs = [engine.flush_cache.remote() for engine in self.rollout_engines] + ray.get(refs) + + refs = [ + engine.load_lora_adapter.remote(LORA_ADAPTER_NAME, self._lora_save_dir) + for engine in self.rollout_engines + ] + ray.get(refs) + + refs = [engine.flush_cache.remote() for engine in self.rollout_engines] + ray.get(refs) + + self._lora_loaded = True + + dist.barrier() + + def _update_lora_via_tensor(self) -> None: + """Push LoRA weights to rollout engines using tensors.""" + lora_weights, config_dict = get_lora_weights_and_config(self.model) + dist.barrier() + + if dist.get_rank() == 0: + serialized_tensors = MultiprocessingSerializer.serialize(lora_weights, output_str=True) + + if self._lora_loaded: + refs = [engine.unload_lora_adapter.remote(LORA_ADAPTER_NAME) for engine in self.rollout_engines] + ray.get(refs) + + refs = [ + engine.load_lora_adapter_from_tensors.remote(LORA_ADAPTER_NAME, serialized_tensors, config_dict) + for engine in self.rollout_engines + ] + ray.get(refs) + + refs = [engine.flush_cache.remote() for engine in self.rollout_engines] + ray.get(refs) + + self._lora_loaded = True + + dist.barrier() def wait_and_update_bucket_weights(self, bucket): bucket = [(name, param.wait()) if hasattr(param, "wait") else (name, param) for name, param in bucket] diff --git a/miles/backends/sglang_utils/sglang_engine.py b/miles/backends/sglang_utils/sglang_engine.py index 179306023..a32e27a11 100644 --- a/miles/backends/sglang_utils/sglang_engine.py +++ b/miles/backends/sglang_utils/sglang_engine.py @@ -333,9 +333,15 @@ def get_weight_version(self): response.raise_for_status() return response.json()["weight_version"] - def release_memory_occupation(self): + def release_memory_occupation(self, tags: list[str] = None): + """ + Available tags for multi-stage release: weights, kv_cache, cuda_graph + """ self.flush_cache() - return self._make_request("release_memory_occupation") + return self._make_request( + "release_memory_occupation", + {"tags": tags}, + ) def resume_memory_occupation(self, tags: list[str] = None): """ @@ -391,6 +397,24 @@ def update_weights_from_distributed( payload, ) + def load_lora_adapter(self, lora_name: str, lora_path: str): + return self._make_request( + "load_lora_adapter", + {"lora_name": lora_name, "lora_path": lora_path}, + ) + + def load_lora_adapter_from_tensors(self, lora_name: str, serialized_tensors: str, config_dict: dict): + return self._make_request( + "load_lora_adapter_from_tensors", + {"lora_name": lora_name, "serialized_tensors": serialized_tensors, "config_dict": config_dict}, + ) + + def unload_lora_adapter(self, lora_name: str): + return self._make_request( + "unload_lora_adapter", + {"lora_name": lora_name}, + ) + def pause_generation(self): response = requests.post(f"http://{self.server_host}:{self.server_port}/pause_generation", json={}) response.raise_for_status() @@ -503,6 +527,12 @@ def _compute_server_args( kwargs["enable_return_routed_experts"] = True if args.fp16: kwargs["dtype"] = "float16" + if args.lora_rank > 0 or args.lora_adapter_path is not None: + kwargs["enable_lora"] = True + kwargs["max_lora_rank"] = args.lora_rank + kwargs["max_loras_per_batch"] = 1 + kwargs["lora_target_modules"] = args.target_modules + external_engine_need_check_fields = [k for k in kwargs.keys() if k not in _EXTERNAL_ENGINE_SKIP_CHECK_FIELDS] unused_keys = set(kwargs.keys()) diff --git a/miles/ray/rollout.py b/miles/ray/rollout.py index 79c6649be..099e825b0 100644 --- a/miles/ray/rollout.py +++ b/miles/ray/rollout.py @@ -155,10 +155,14 @@ def save(self, rollout_id): def load(self, rollout_id=None): self.data_source.load(rollout_id) - def offload(self): + def offload(self, tags: list[str] | None = None): self.health_monitoring_pause() return ray.get( - [engine.release_memory_occupation.remote() for engine in self.rollout_engines if engine is not None] + [ + engine.release_memory_occupation.remote(tags=tags) + for engine in self.rollout_engines + if engine is not None + ] ) def onload(self, tags: list[str] | None = None): diff --git a/miles/utils/arguments.py b/miles/utils/arguments.py index 08a60cae1..5a278c5c1 100644 --- a/miles/utils/arguments.py +++ b/miles/utils/arguments.py @@ -965,6 +965,12 @@ def add_algo_arguments(parser): default=None, help="Path to load pre-trained LoRA adapter weights (default: None)", ) + parser.add_argument( + "--lora-sync-from-tensor", + action="store_true", + default=False, + help="Use tensor-based LoRA weight synchronization instead of file-based (default: False)", + ) return parser diff --git a/tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py b/tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py index 94c5157c7..68642bb36 100644 --- a/tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py +++ b/tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py @@ -1,3 +1,4 @@ +import importlib.util import os import miles.utils.external_utils.command_utils as U @@ -6,6 +7,10 @@ def prepare(): + if ENABLE_LORA: + if importlib.util.find_spec("peft") is None: + U.exec_command("pip install peft") + U.exec_command("mkdir -p /root/models /root/datasets") U.exec_command(f"hf download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") U.hf_download_dataset("zhuzilin/gsm8k") @@ -20,6 +25,7 @@ def execute(): "--lora-alpha 32 " "--target-modules all-linear " f"--save /root/models/{MODEL_NAME}-lora-ckpt " + # "--lora-sync-from-tensor " ) if ENABLE_LORA else "" diff --git a/tests/test_qwen3_0.6B_fsdp_distributed.py b/tests/test_qwen3_0.6B_fsdp_distributed.py index 9b11f9221..6b07b7e91 100644 --- a/tests/test_qwen3_0.6B_fsdp_distributed.py +++ b/tests/test_qwen3_0.6B_fsdp_distributed.py @@ -1,3 +1,4 @@ +import importlib.util import os import miles.utils.external_utils.command_utils as U @@ -7,6 +8,10 @@ def prepare(): + if ENABLE_LORA: + if importlib.util.find_spec("peft") is None: + U.exec_command("pip install peft") + U.exec_command("mkdir -p /root/models /root/datasets") U.exec_command(f"hf download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") U.hf_download_dataset("zhuzilin/gsm8k") @@ -21,6 +26,7 @@ def execute(): "--lora-alpha 32 " "--target-modules all-linear " f"--save /root/models/{MODEL_NAME}-lora-ckpt " + # "--lora-sync-from-tensor " ) if ENABLE_LORA else "" From d4d21f7249d19ffe17c1b25996c1bf1f58e638bc Mon Sep 17 00:00:00 2001 From: Guanxing Lu <747398423@qq.com> Date: Mon, 12 Jan 2026 09:20:20 +0000 Subject: [PATCH 5/8] Avoid additional all-gather in obtaining LoRA weight Co-authored-by: PopSoda2002 --- miles/backends/fsdp_utils/lora_utils.py | 31 ++------- .../fsdp_utils/update_weight_utils.py | 66 +++++++++++++------ miles/backends/sglang_utils/sglang_engine.py | 3 +- 3 files changed, 53 insertions(+), 47 deletions(-) diff --git a/miles/backends/fsdp_utils/lora_utils.py b/miles/backends/fsdp_utils/lora_utils.py index 1872fe59c..bb90d0396 100644 --- a/miles/backends/fsdp_utils/lora_utils.py +++ b/miles/backends/fsdp_utils/lora_utils.py @@ -5,7 +5,6 @@ import torch.distributed as dist import torch.nn as nn -from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict try: from peft import LoraConfig, PeftModel, TaskType, get_peft_model @@ -36,7 +35,7 @@ def apply_lora_to_model(model: nn.Module, args) -> nn.Module: bias="none", ) - model = get_peft_model(model, lora_config) # autocast_adapter_dtype=False) + model = get_peft_model(model, lora_config) model.print_trainable_parameters() logger.info(f"Applied LoRA: rank={args.lora_rank}, alpha={args.lora_alpha}") return model @@ -47,13 +46,9 @@ def is_lora_model(module: nn.Module) -> bool: return hasattr(unwrapped, "peft_config") -def save_lora_to_disk(module: nn.Module, save_dir: str) -> str: +def save_lora_to_disk(module: nn.Module, save_dir: str, lora_weights: dict) -> str: """Save LoRA adapter to disk with file lock mechanism.""" - # TODO: All gather lora layers not full layers - options = StateDictOptions(full_state_dict=True, cpu_offload=True) - full_state_dict = get_model_state_dict(module, options=options) - - lora_state_dict = {name: param for name, param in full_state_dict.items() if "lora_" in name} + lora_state_dict = lora_weights if dist.get_rank() == 0: save_path = Path(save_dir) @@ -76,21 +71,8 @@ def delete_lora_from_disk(save_dir: str) -> None: logger.info(f"Deleted LoRA adapter from {save_path}") -def get_lora_weights_and_config(module: nn.Module) -> tuple[dict[str, any], dict[str, any]]: - """Extract LoRA weights and config from PEFT model for tensor-based sync.""" - # TODO: only gather lora weights, or gather lora weights in bucket logic i.e., layered summon - # options = StateDictOptions(full_state_dict=True, cpu_offload=True) - options = StateDictOptions(full_state_dict=True, cpu_offload=False) - full_state_dict = get_model_state_dict(module, options=options) - - state_dict = {name: param for name, param in full_state_dict.items() if "lora_" in name} - if dist.get_rank() == 0: - logger.info(f"Extracted {len(state_dict)} LoRA weight tensors") - - for name in list(state_dict.keys()): - key = name.replace(".default.weight", ".weight") # .replace("base_model.model.", "") - state_dict[key] = state_dict.pop(name) - +def get_lora_config(module: nn.Module) -> dict[str, any]: + """Extract LoRA config from PEFT model.""" peft_config = module.peft_config["default"] config_dict = { "peft_type": "LORA", @@ -99,5 +81,4 @@ def get_lora_weights_and_config(module: nn.Module) -> tuple[dict[str, any], dict "target_modules": list(peft_config.target_modules), "bias": peft_config.bias, } - - return state_dict, config_dict + return config_dict diff --git a/miles/backends/fsdp_utils/update_weight_utils.py b/miles/backends/fsdp_utils/update_weight_utils.py index 7b7674337..380c36749 100644 --- a/miles/backends/fsdp_utils/update_weight_utils.py +++ b/miles/backends/fsdp_utils/update_weight_utils.py @@ -29,7 +29,7 @@ LORA_ADAPTER_NAME, LORA_SUBDIR, delete_lora_from_disk, - get_lora_weights_and_config, + get_lora_config, is_lora_model, save_lora_to_disk, ) @@ -56,13 +56,32 @@ def connect_rollout_engines( def update_weights(self) -> None: self.weight_version += 1 - # Update base model if needed - if not (is_lora_model(self.model) and self._base_synced and "weight" not in self.args.offload_rollout_level): - bucket = [] - bucket_size = 0 - for name, param in self.model.state_dict().items(): - if any(x in name for x in ["_flat_param", "lora_"]): - continue + bucket = [] + bucket_size = 0 + lora_weights = {} + is_lora = is_lora_model(self.model) + should_sync_base = not (is_lora and self._base_synced and "weight" not in self.args.offload_rollout_level) + + for name, param in self.model.state_dict().items(): + # Skip FSDP internal parameters + if "_flat_param" in name: + continue + + # Extract LoRA weights + if is_lora and "lora_" in name: + param = param.cuda() + if isinstance(param, DTensor): + param = param.redistribute( + placements=[Replicate()] * param.device_mesh.ndim, + async_op=True, + ).to_local() + param = param.wait() if hasattr(param, "wait") else param + key = name.replace(".default.weight", ".weight") + lora_weights[key] = param + continue + + # Process base model weights + if should_sync_base: name = name.replace("base_model.model.", "").replace(".base_layer", "") param_size = param.numel() * param.element_size() if bucket and bucket_size + param_size >= self.args.update_weight_buffer_size: @@ -73,7 +92,7 @@ def update_weights(self) -> None: param = param.cuda() if isinstance(param, DTensor): - # async version of param.full_tensor + # Async version of param.full_tensor param = param.redistribute( placements=[Replicate()] * param.device_mesh.ndim, async_op=True, @@ -81,20 +100,19 @@ def update_weights(self) -> None: bucket.append((name, param)) bucket_size += param_size - if bucket: - self.wait_and_update_bucket_weights(bucket) - del bucket - + if should_sync_base and bucket: + self.wait_and_update_bucket_weights(bucket) + del bucket self._base_synced = True - # Update lora weights if needed - if is_lora_model(self.model): + # Update LoRA weights if needed + if is_lora: if self.args.lora_sync_from_tensor: - self._update_lora_via_tensor() + self._update_lora_via_tensor(lora_weights) else: - self._update_lora_via_file() + self._update_lora_via_file(lora_weights) - def _update_lora_via_file(self) -> None: + def _update_lora_via_file(self, lora_weights: dict) -> None: """Push LoRA weights to rollout engines using disk files.""" self._lora_save_dir = os.path.join(self.args.save, LORA_SUBDIR) if dist.get_rank() == 0: @@ -103,7 +121,7 @@ def _update_lora_via_file(self) -> None: dist.barrier() - save_lora_to_disk(self.model, self._lora_save_dir) + save_lora_to_disk(self.model, self._lora_save_dir, lora_weights) dist.barrier() @@ -128,18 +146,24 @@ def _update_lora_via_file(self) -> None: dist.barrier() - def _update_lora_via_tensor(self) -> None: + def _update_lora_via_tensor(self, lora_weights: dict) -> None: """Push LoRA weights to rollout engines using tensors.""" - lora_weights, config_dict = get_lora_weights_and_config(self.model) + config_dict = get_lora_config(self.model) dist.barrier() if dist.get_rank() == 0: serialized_tensors = MultiprocessingSerializer.serialize(lora_weights, output_str=True) if self._lora_loaded: + refs = [engine.flush_cache.remote() for engine in self.rollout_engines] + ray.get(refs) + refs = [engine.unload_lora_adapter.remote(LORA_ADAPTER_NAME) for engine in self.rollout_engines] ray.get(refs) + refs = [engine.flush_cache.remote() for engine in self.rollout_engines] + ray.get(refs) + refs = [ engine.load_lora_adapter_from_tensors.remote(LORA_ADAPTER_NAME, serialized_tensors, config_dict) for engine in self.rollout_engines diff --git a/miles/backends/sglang_utils/sglang_engine.py b/miles/backends/sglang_utils/sglang_engine.py index a32e27a11..ca81db0d4 100644 --- a/miles/backends/sglang_utils/sglang_engine.py +++ b/miles/backends/sglang_utils/sglang_engine.py @@ -530,7 +530,8 @@ def _compute_server_args( if args.lora_rank > 0 or args.lora_adapter_path is not None: kwargs["enable_lora"] = True kwargs["max_lora_rank"] = args.lora_rank - kwargs["max_loras_per_batch"] = 1 + # kwargs["max_loaded_loras"] = 1 + # kwargs["max_loras_per_batch"] = 1 kwargs["lora_target_modules"] = args.target_modules external_engine_need_check_fields = [k for k in kwargs.keys() if k not in _EXTERNAL_ENGINE_SKIP_CHECK_FIELDS] From b56d80dea4ca22dfa0da717c01103e7cafbd954b Mon Sep 17 00:00:00 2001 From: Guanxing Lu <747398423@qq.com> Date: Mon, 12 Jan 2026 17:44:48 +0000 Subject: [PATCH 6/8] Remove save_lora_to_disk and delete_lora_from_disk Co-authored-by: PopSoda2002 --- miles/backends/fsdp_utils/lora_utils.py | 29 ---------------- .../fsdp_utils/update_weight_utils.py | 34 +++++++++---------- miles/backends/sglang_utils/sglang_engine.py | 2 -- 3 files changed, 17 insertions(+), 48 deletions(-) diff --git a/miles/backends/fsdp_utils/lora_utils.py b/miles/backends/fsdp_utils/lora_utils.py index bb90d0396..5f1e5aed7 100644 --- a/miles/backends/fsdp_utils/lora_utils.py +++ b/miles/backends/fsdp_utils/lora_utils.py @@ -1,9 +1,5 @@ import logging -import os -import shutil -from pathlib import Path -import torch.distributed as dist import torch.nn as nn try: @@ -46,31 +42,6 @@ def is_lora_model(module: nn.Module) -> bool: return hasattr(unwrapped, "peft_config") -def save_lora_to_disk(module: nn.Module, save_dir: str, lora_weights: dict) -> str: - """Save LoRA adapter to disk with file lock mechanism.""" - lora_state_dict = lora_weights - - if dist.get_rank() == 0: - save_path = Path(save_dir) - save_path.mkdir(parents=True, exist_ok=True) - - module.save_pretrained(str(save_path), state_dict=lora_state_dict) - - # TODO: check if file lock is needed or better way to do it - os.sync() - - logger.info(f"Saved LoRA adapter to {save_path}") - return save_dir - - -def delete_lora_from_disk(save_dir: str) -> None: - """Delete LoRA adapter files from disk.""" - save_path = Path(save_dir) - if save_path.exists(): - shutil.rmtree(save_path) - logger.info(f"Deleted LoRA adapter from {save_path}") - - def get_lora_config(module: nn.Module) -> dict[str, any]: """Extract LoRA config from PEFT model.""" peft_config = module.peft_config["default"] diff --git a/miles/backends/fsdp_utils/update_weight_utils.py b/miles/backends/fsdp_utils/update_weight_utils.py index 380c36749..7547f55a5 100644 --- a/miles/backends/fsdp_utils/update_weight_utils.py +++ b/miles/backends/fsdp_utils/update_weight_utils.py @@ -1,9 +1,11 @@ import abc import logging import os +import shutil import socket from argparse import Namespace from collections.abc import Sequence +from pathlib import Path import ray import torch @@ -25,14 +27,7 @@ except ImportError: from sglang.srt.model_executor.model_runner import FlattenedTensorBucket # type: ignore[import] -from .lora_utils import ( - LORA_ADAPTER_NAME, - LORA_SUBDIR, - delete_lora_from_disk, - get_lora_config, - is_lora_model, - save_lora_to_disk, -) +from .lora_utils import LORA_ADAPTER_NAME, LORA_SUBDIR, get_lora_config, is_lora_model logger = logging.getLogger(__name__) @@ -64,8 +59,8 @@ def update_weights(self) -> None: for name, param in self.model.state_dict().items(): # Skip FSDP internal parameters - if "_flat_param" in name: - continue + # if "_flat_param" in name: + # continue # Extract LoRA weights if is_lora and "lora_" in name: @@ -76,7 +71,8 @@ def update_weights(self) -> None: async_op=True, ).to_local() param = param.wait() if hasattr(param, "wait") else param - key = name.replace(".default.weight", ".weight") + key = name.replace(".default.weight", ".weight") if self.args.lora_sync_from_tensor else name + lora_weights[key] = param continue @@ -116,12 +112,19 @@ def _update_lora_via_file(self, lora_weights: dict) -> None: """Push LoRA weights to rollout engines using disk files.""" self._lora_save_dir = os.path.join(self.args.save, LORA_SUBDIR) if dist.get_rank() == 0: - if os.path.exists(self._lora_save_dir): - delete_lora_from_disk(self._lora_save_dir) + save_path = Path(self._lora_save_dir) + if save_path.exists(): + shutil.rmtree(save_path) + logger.info(f"Deleted LoRA adapter from {save_path}") dist.barrier() - save_lora_to_disk(self.model, self._lora_save_dir, lora_weights) + if dist.get_rank() == 0: + save_path = Path(self._lora_save_dir) + save_path.mkdir(parents=True, exist_ok=True) + self.model.save_pretrained(str(save_path), state_dict=lora_weights) + os.sync() + logger.info(f"Saved LoRA adapter to {save_path}") dist.barrier() @@ -155,9 +158,6 @@ def _update_lora_via_tensor(self, lora_weights: dict) -> None: serialized_tensors = MultiprocessingSerializer.serialize(lora_weights, output_str=True) if self._lora_loaded: - refs = [engine.flush_cache.remote() for engine in self.rollout_engines] - ray.get(refs) - refs = [engine.unload_lora_adapter.remote(LORA_ADAPTER_NAME) for engine in self.rollout_engines] ray.get(refs) diff --git a/miles/backends/sglang_utils/sglang_engine.py b/miles/backends/sglang_utils/sglang_engine.py index ca81db0d4..f7ff38268 100644 --- a/miles/backends/sglang_utils/sglang_engine.py +++ b/miles/backends/sglang_utils/sglang_engine.py @@ -530,8 +530,6 @@ def _compute_server_args( if args.lora_rank > 0 or args.lora_adapter_path is not None: kwargs["enable_lora"] = True kwargs["max_lora_rank"] = args.lora_rank - # kwargs["max_loaded_loras"] = 1 - # kwargs["max_loras_per_batch"] = 1 kwargs["lora_target_modules"] = args.target_modules external_engine_need_check_fields = [k for k in kwargs.keys() if k not in _EXTERNAL_ENGINE_SKIP_CHECK_FIELDS] From 68971952587f7d816094717baed353d85c9d3221 Mon Sep 17 00:00:00 2001 From: Guanxing Lu <747398423@qq.com> Date: Wed, 14 Jan 2026 15:49:13 +0000 Subject: [PATCH 7/8] Support debug_rollout_only and debug_train_only --- miles/rollout/sglang_rollout.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/miles/rollout/sglang_rollout.py b/miles/rollout/sglang_rollout.py index 2ae0ec870..6570c432f 100644 --- a/miles/rollout/sglang_rollout.py +++ b/miles/rollout/sglang_rollout.py @@ -138,7 +138,7 @@ async def generate(args: Namespace, sample: Sample, sampling_params: dict[str, A } # Use LoRA adapter when LoRA is enabled - if args.lora_rank > 0 or args.lora_adapter_path is not None: + if args.lora_adapter_path is not None or (args.lora_rank > 0 and not args.debug_rollout_only): payload["lora_path"] = LORA_ADAPTER_NAME if args.use_rollout_routing_replay: From 3b34a17b2fe79a3e324cc9859f88b5367d2d18b7 Mon Sep 17 00:00:00 2001 From: Guanxing Lu <747398423@qq.com> Date: Tue, 20 Jan 2026 02:53:13 +0000 Subject: [PATCH 8/8] Only import peft if needed Co-authored-by: PopSoda2002 [zhouhp.me@gmail.com](mailto:zhouhp.me@gmail.com) --- miles/backends/fsdp_utils/lora_utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/miles/backends/fsdp_utils/lora_utils.py b/miles/backends/fsdp_utils/lora_utils.py index 5f1e5aed7..e8d910639 100644 --- a/miles/backends/fsdp_utils/lora_utils.py +++ b/miles/backends/fsdp_utils/lora_utils.py @@ -2,11 +2,6 @@ import torch.nn as nn -try: - from peft import LoraConfig, PeftModel, TaskType, get_peft_model -except ImportError as err: - raise ImportError("peft library required for LoRA. Install with: pip install peft") from err - logger = logging.getLogger(__name__) LORA_ADAPTER_NAME = "miles_lora" @@ -14,6 +9,11 @@ def apply_lora_to_model(model: nn.Module, args) -> nn.Module: + try: + from peft import LoraConfig, PeftModel, TaskType, get_peft_model + except ImportError as err: + raise ImportError("peft library required for LoRA. Install with: pip install peft") from err + if args.lora_adapter_path: logger.info(f"Loading LoRA adapter from {args.lora_adapter_path}") model = PeftModel.from_pretrained(model, args.lora_adapter_path, is_trainable=True)