diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index d34c823aa..7c796b0ce 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -81,6 +81,8 @@ jobs: - name: Execute shell: bash + env: + ENABLE_LORA: ${{ matrix.info.enable_lora || '0' }} run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }} - name: Post-test cleanup @@ -148,6 +150,8 @@ jobs: - name: Execute shell: bash + env: + ENABLE_LORA: ${{ matrix.info.enable_lora || '0' }} run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }} - name: Post-test cleanup @@ -215,6 +219,8 @@ jobs: - name: Execute shell: bash + env: + ENABLE_LORA: ${{ matrix.info.enable_lora || '0' }} run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }} - name: Post-test cleanup @@ -282,6 +288,8 @@ jobs: - name: Execute shell: bash + env: + ENABLE_LORA: ${{ matrix.info.enable_lora || '0' }} run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }} - name: Post-test cleanup @@ -349,6 +357,8 @@ jobs: - name: Execute shell: bash + env: + ENABLE_LORA: ${{ matrix.info.enable_lora || '0' }} run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }} - name: Post-test cleanup @@ -416,6 +426,8 @@ jobs: - name: Execute shell: bash + env: + ENABLE_LORA: ${{ matrix.info.enable_lora || '0' }} run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }} - name: Post-test cleanup @@ -449,7 +461,7 @@ jobs: strategy: fail-fast: false matrix: - info: [{"num_gpus": 4, "test_file": "test_qwen2.5_0.5B_gsm8k_async_short.py"}, {"num_gpus": 4, "test_file": "test_qwen2.5_0.5B_gsm8k_short.py"}, {"num_gpus": 2, "test_file": "test_qwen3_0.6B_fsdp_colocated_2xGPU.py"}, {"num_gpus": 2, "test_file": "test_qwen3_4B_fsdp_true_on_policy.py"}, {"num_gpus": 8, "test_file": "test_qwen3_vl_4B_fsdp.py"}, {"num_gpus": 2, "test_file": "test_qwen3_0.6B_fsdp_distributed.py"}, {"num_gpus": 8, "test_file": "test_quick_start_glm4_9B.py"}, {"num_gpus": 8, "test_file": "test_qwen3_30B_A3B.py"}, {"num_gpus": 8, "test_file": "test_qwen3_4B_ppo.py"}, {"num_gpus": 8, "test_file": "test_moonlight_16B_A3B.py"}, {"num_gpus": 8, "test_file": "test_mimo_7B_mtp_only_grad.py"}, {"num_gpus": 8, "test_file": "test_qwen3_0.6B_parallel_check.py"}, {"num_gpus": 4, "test_file": "test_qwen3_0.6B_megatron_fsdp_align.py"}, {"num_gpus": 8, "test_file": "test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_file": "test_qwen3_4B_ckpt.py --async-save"}, {"num_gpus": 2, "test_file": "test_qwen2.5_0.5B_gsm8k.py"}, {"num_gpus": 2, "test_file": "test_qwen2.5_0.5B_gsm8k_async.py"}] + info: [{"num_gpus": 4, "test_file": "test_qwen2.5_0.5B_gsm8k_async_short.py"}, {"num_gpus": 4, "test_file": "test_qwen2.5_0.5B_gsm8k_short.py"}, {"num_gpus": 2, "test_file": "test_qwen3_0.6B_fsdp_colocated_2xGPU.py"}, {"num_gpus": 2, "test_file": "test_qwen3_4B_fsdp_true_on_policy.py"}, {"num_gpus": 8, "test_file": "test_qwen3_vl_4B_fsdp.py"}, {"num_gpus": 2, "test_file": "test_qwen3_0.6B_fsdp_distributed.py"}, {"num_gpus": 8, "test_file": "test_quick_start_glm4_9B.py"}, {"num_gpus": 8, "test_file": "test_qwen3_30B_A3B.py"}, {"num_gpus": 8, "test_file": "test_qwen3_4B_ppo.py"}, {"num_gpus": 8, "test_file": "test_moonlight_16B_A3B.py"}, {"num_gpus": 8, "test_file": "test_mimo_7B_mtp_only_grad.py"}, {"num_gpus": 8, "test_file": "test_qwen3_0.6B_parallel_check.py"}, {"num_gpus": 4, "test_file": "test_qwen3_0.6B_megatron_fsdp_align.py"}, {"num_gpus": 8, "test_file": "test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_file": "test_qwen3_4B_ckpt.py --async-save"}, {"num_gpus": 2, "test_file": "test_qwen2.5_0.5B_gsm8k.py"}, {"num_gpus": 2, "test_file": "test_qwen2.5_0.5B_gsm8k_async.py"}, {"enable_lora": "1", "num_gpus": 2, "test_file": "test_qwen3_0.6B_fsdp_colocated_2xGPU.py"}, {"enable_lora": "1", "num_gpus": 2, "test_file": "test_qwen3_0.6B_fsdp_distributed.py"}] defaults: run: working-directory: ${{ github.workspace }} @@ -483,6 +495,8 @@ jobs: - name: Execute shell: bash + env: + ENABLE_LORA: ${{ matrix.info.enable_lora || '0' }} run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }} - name: Post-test cleanup diff --git a/.github/workflows/pr-test.yml.j2 b/.github/workflows/pr-test.yml.j2 index 37b6fa446..e42fba327 100644 --- a/.github/workflows/pr-test.yml.j2 +++ b/.github/workflows/pr-test.yml.j2 @@ -71,6 +71,8 @@ {'test_file': 'test_qwen3_4B_ckpt.py --async-save', 'num_gpus': 8}, {'test_file': 'test_qwen2.5_0.5B_gsm8k.py', 'num_gpus': 2}, {'test_file': 'test_qwen2.5_0.5B_gsm8k_async.py', 'num_gpus': 2}, + {'test_file': 'test_qwen3_0.6B_fsdp_colocated_2xGPU.py', 'num_gpus': 2, 'enable_lora': '1'}, + {'test_file': 'test_qwen3_0.6B_fsdp_distributed.py', 'num_gpus': 2, 'enable_lora': '1'}, ], }, } %> @@ -153,6 +155,8 @@ jobs: - name: Execute shell: bash + env: + ENABLE_LORA: ${{ matrix.info.enable_lora || '0' }} run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }} - name: Post-test cleanup diff --git a/miles/backends/fsdp_utils/actor.py b/miles/backends/fsdp_utils/actor.py index 536fdf970..7e39b0ff3 100644 --- a/miles/backends/fsdp_utils/actor.py +++ b/miles/backends/fsdp_utils/actor.py @@ -31,6 +31,7 @@ ) from ..training_utils.loss import compute_advantages_and_returns, get_log_probs_and_entropy, loss_function from . import checkpoint +from .lora_utils import apply_lora_to_model, is_lora_model from .lr_scheduler import get_lr_scheduler from .parallel import create_fsdp_parallel_state from .update_weight_utils import UpdateWeightFromDistributed, UpdateWeightFromTensor @@ -99,6 +100,9 @@ def init(self, args: Namespace, role: str, with_ref: bool = False) -> int: # ty attn_implementation=self.args.attn_implementation, ) + if self.args.lora_rank > 0 or self.args.lora_adapter_path: + model = apply_lora_to_model(model, self.args) + model.train() full_state = model.state_dict() @@ -112,11 +116,14 @@ def init(self, args: Namespace, role: str, with_ref: bool = False) -> int: # ty self.model = model if args.gradient_checkpointing: - self.model.gradient_checkpointing_enable() + # Avoid "does not require grad" error + gc_kwargs = {"use_reentrant": False} if is_lora_model(self.model) else {} + self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gc_kwargs) if args.optimizer == "adam": + trainable_params = [p for p in self.model.parameters() if p.requires_grad] self.optimizer = torch.optim.AdamW( - self.model.parameters(), + trainable_params, lr=args.lr, betas=(args.adam_beta1, args.adam_beta2), eps=args.adam_eps, diff --git a/miles/backends/fsdp_utils/checkpoint.py b/miles/backends/fsdp_utils/checkpoint.py index 6daf7f982..994b01a28 100644 --- a/miles/backends/fsdp_utils/checkpoint.py +++ b/miles/backends/fsdp_utils/checkpoint.py @@ -12,21 +12,34 @@ from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict from torch.distributed.checkpoint.stateful import Stateful +from miles.backends.fsdp_utils.lora_utils import is_lora_model + logger = logging.getLogger(__name__) class ModelState(Stateful): """Wrapper for model state only.""" - def __init__(self, model): + def __init__(self, model, lora_only: bool = False): self.model = model + self.lora_only = lora_only + self._key = "adapter" if lora_only else "model" def state_dict(self): model_state_dict, _ = get_state_dict(self.model, optimizers=[]) - return {"model": model_state_dict} + if self.lora_only: + model_state_dict = {k: v for k, v in model_state_dict.items() if "lora_" in k} + return {self._key: model_state_dict} def load_state_dict(self, state_dict): - set_state_dict(self.model, optimizers=[], model_state_dict=state_dict["model"], optim_state_dict=None) + data = state_dict[self._key] + + if self.lora_only: + full_state_dict, _ = get_state_dict(self.model, optimizers=[]) + full_state_dict.update(data) + set_state_dict(self.model, optimizers=[], model_state_dict=full_state_dict, optim_state_dict=None) + else: + set_state_dict(self.model, optimizers=[], model_state_dict=data, optim_state_dict=None) class OptimizerState(Stateful): @@ -103,20 +116,22 @@ def load(actor: Any) -> dict[str, Any] | None: model_dir = checkpoint_dir / "model" optimizer_dir = checkpoint_dir / "optimizer" lr_scheduler_dir = checkpoint_dir / "lr_scheduler" + lora_dir = checkpoint_dir / "adapter" + + lora_only = lora_dir.exists() and is_lora_model(actor.model) + model_dir = lora_dir if lora_only else model_dir if not model_dir.exists(): - logger.info(f"[FSDP] Model checkpoint {model_dir} not found; skipping load.") + logger.info(f"[FSDP] No model checkpoint found at {model_dir} or {lora_dir}; skipping load.") return None - # Load model weights (always) - model_state = ModelState(actor.model) + model_state = ModelState(actor.model, lora_only=lora_only) state_dict = {"model_state": model_state} - try: dcp.load(state_dict=state_dict, checkpoint_id=str(model_dir)) - logger.info(f"[FSDP] Loaded model from {model_dir}") + logger.info(f"[FSDP] Loaded {'LoRA adapter' if lora_only else 'model'} from {model_dir}") except Exception as e: - logger.error(f"[FSDP] Failed to load model from {model_dir}: {e}") + logger.error(f"[FSDP] Failed to load {'LoRA adapter' if lora_only else 'model'} from {model_dir}: {e}") return None # Load optimizer state (optional) @@ -210,9 +225,19 @@ def save(actor: Any, iteration: int) -> None: dist.barrier() # Save model weights - model_state = ModelState(actor.model) + lora_only = is_lora_model(actor.model) + if lora_only: + save_dir = checkpoint_dir / "adapter" + if dist.get_rank() == 0: + save_dir.mkdir(parents=True, exist_ok=True) + dist.barrier() + else: + save_dir = model_dir + + model_state = ModelState(actor.model, lora_only=lora_only) state_dict = {"model_state": model_state} - dcp.save(state_dict, checkpoint_id=str(model_dir)) + dcp.save(state_dict, checkpoint_id=str(save_dir)) + logger.info(f"[FSDP] Saved {'LoRA adapter' if lora_only else 'model'} to {save_dir}") # Save optimizer state (skip if --no-save-optim is set) save_optimizer_state = not getattr(actor.args, "no_save_optim", False) diff --git a/miles/backends/fsdp_utils/lora_utils.py b/miles/backends/fsdp_utils/lora_utils.py new file mode 100644 index 000000000..e8d910639 --- /dev/null +++ b/miles/backends/fsdp_utils/lora_utils.py @@ -0,0 +1,55 @@ +import logging + +import torch.nn as nn + +logger = logging.getLogger(__name__) + +LORA_ADAPTER_NAME = "miles_lora" +LORA_SUBDIR = "tmp_lora" + + +def apply_lora_to_model(model: nn.Module, args) -> nn.Module: + try: + from peft import LoraConfig, PeftModel, TaskType, get_peft_model + except ImportError as err: + raise ImportError("peft library required for LoRA. Install with: pip install peft") from err + + if args.lora_adapter_path: + logger.info(f"Loading LoRA adapter from {args.lora_adapter_path}") + model = PeftModel.from_pretrained(model, args.lora_adapter_path, is_trainable=True) + peft_config = model.peft_config["default"] + if isinstance(peft_config.task_type, str): + peft_config.task_type = TaskType.CAUSAL_LM + model.print_trainable_parameters() + return model + + lora_config = LoraConfig( + task_type=TaskType.CAUSAL_LM, + r=args.lora_rank, + lora_alpha=args.lora_alpha, + target_modules=args.target_modules, + bias="none", + ) + + model = get_peft_model(model, lora_config) + model.print_trainable_parameters() + logger.info(f"Applied LoRA: rank={args.lora_rank}, alpha={args.lora_alpha}") + return model + + +def is_lora_model(module: nn.Module) -> bool: + unwrapped = getattr(module, "_fsdp_wrapped_module", module) + return hasattr(unwrapped, "peft_config") + + +def get_lora_config(module: nn.Module) -> dict[str, any]: + """Extract LoRA config from PEFT model.""" + peft_config = module.peft_config["default"] + config_dict = { + "peft_type": "LORA", + "r": peft_config.r, + "lora_alpha": peft_config.lora_alpha, + "target_modules": list(peft_config.target_modules), + "bias": peft_config.bias, + } + return config_dict diff --git a/miles/backends/fsdp_utils/update_weight_utils.py b/miles/backends/fsdp_utils/update_weight_utils.py index d0f2360ab..7547f55a5 100644 --- a/miles/backends/fsdp_utils/update_weight_utils.py +++ b/miles/backends/fsdp_utils/update_weight_utils.py @@ -1,8 +1,11 @@ import abc import logging +import os +import shutil import socket from argparse import Namespace from collections.abc import Sequence +from pathlib import Path import ray import torch @@ -19,12 +22,12 @@ from miles.utils.distributed_utils import init_process_group - try: from sglang.srt.weight_sync.tensor_bucket import FlattenedTensorBucket # type: ignore[import] except ImportError: from sglang.srt.model_executor.model_runner import FlattenedTensorBucket # type: ignore[import] +from .lora_utils import LORA_ADAPTER_NAME, LORA_SUBDIR, get_lora_config, is_lora_model logger = logging.getLogger(__name__) @@ -34,6 +37,8 @@ def __init__(self, args: Namespace, model: torch.nn.Module) -> None: self.args = args self.model = model self.weight_version = 0 + self._base_synced = False + self._lora_loaded = False @abc.abstractmethod def connect_rollout_engines( @@ -45,31 +50,132 @@ def connect_rollout_engines( def update_weights(self) -> None: self.weight_version += 1 + bucket = [] bucket_size = 0 + lora_weights = {} + is_lora = is_lora_model(self.model) + should_sync_base = not (is_lora and self._base_synced and "weight" not in self.args.offload_rollout_level) + for name, param in self.model.state_dict().items(): - param_size = param.numel() * param.element_size() - if bucket and bucket_size + param_size >= self.args.update_weight_buffer_size: - self.wait_and_update_bucket_weights(bucket) - del bucket - bucket = [] - bucket_size = 0 - - param = param.cuda() - if isinstance(param, DTensor): - # async version of param.full_tensor - param = param.redistribute( - placements=[Replicate()] * param.device_mesh.ndim, - async_op=True, - ).to_local() - bucket.append((name, param)) - bucket_size += param_size - - if bucket: + # Skip FSDP internal parameters + # if "_flat_param" in name: + # continue + + # Extract LoRA weights + if is_lora and "lora_" in name: + param = param.cuda() + if isinstance(param, DTensor): + param = param.redistribute( + placements=[Replicate()] * param.device_mesh.ndim, + async_op=True, + ).to_local() + param = param.wait() if hasattr(param, "wait") else param + key = name.replace(".default.weight", ".weight") if self.args.lora_sync_from_tensor else name + + lora_weights[key] = param + continue + + # Process base model weights + if should_sync_base: + name = name.replace("base_model.model.", "").replace(".base_layer", "") + param_size = param.numel() * param.element_size() + if bucket and bucket_size + param_size >= self.args.update_weight_buffer_size: + self.wait_and_update_bucket_weights(bucket) + del bucket + bucket = [] + bucket_size = 0 + + param = param.cuda() + if isinstance(param, DTensor): + # Async version of param.full_tensor + param = param.redistribute( + placements=[Replicate()] * param.device_mesh.ndim, + async_op=True, + ).to_local() + bucket.append((name, param)) + bucket_size += param_size + + if should_sync_base and bucket: self.wait_and_update_bucket_weights(bucket) del bucket - bucket = [] - bucket_size = 0 + self._base_synced = True + + # Update LoRA weights if needed + if is_lora: + if self.args.lora_sync_from_tensor: + self._update_lora_via_tensor(lora_weights) + else: + self._update_lora_via_file(lora_weights) + + def _update_lora_via_file(self, lora_weights: dict) -> None: + """Push LoRA weights to rollout engines using disk files.""" + self._lora_save_dir = os.path.join(self.args.save, LORA_SUBDIR) + if dist.get_rank() == 0: + save_path = Path(self._lora_save_dir) + if save_path.exists(): + shutil.rmtree(save_path) + logger.info(f"Deleted LoRA adapter from {save_path}") + + dist.barrier() + + if dist.get_rank() == 0: + save_path = Path(self._lora_save_dir) + save_path.mkdir(parents=True, exist_ok=True) + self.model.save_pretrained(str(save_path), state_dict=lora_weights) + os.sync() + logger.info(f"Saved LoRA adapter to {save_path}") + + dist.barrier() + + if dist.get_rank() == 0: + if self._lora_loaded: + refs = [engine.unload_lora_adapter.remote(LORA_ADAPTER_NAME) for engine in self.rollout_engines] + ray.get(refs) + + refs = [engine.flush_cache.remote() for engine in self.rollout_engines] + ray.get(refs) + + refs = [ + engine.load_lora_adapter.remote(LORA_ADAPTER_NAME, self._lora_save_dir) + for engine in self.rollout_engines + ] + ray.get(refs) + + refs = [engine.flush_cache.remote() for engine in self.rollout_engines] + ray.get(refs) + + self._lora_loaded = True + + dist.barrier() + + def _update_lora_via_tensor(self, lora_weights: dict) -> None: + """Push LoRA weights to rollout engines using tensors.""" + config_dict = get_lora_config(self.model) + dist.barrier() + + if dist.get_rank() == 0: + serialized_tensors = MultiprocessingSerializer.serialize(lora_weights, output_str=True) + + if self._lora_loaded: + refs = [engine.unload_lora_adapter.remote(LORA_ADAPTER_NAME) for engine in self.rollout_engines] + ray.get(refs) + + refs = [engine.flush_cache.remote() for engine in self.rollout_engines] + ray.get(refs) + + refs = [ + engine.load_lora_adapter_from_tensors.remote(LORA_ADAPTER_NAME, serialized_tensors, config_dict) + for engine in self.rollout_engines + ] + ray.get(refs) + + refs = [engine.flush_cache.remote() for engine in self.rollout_engines] + ray.get(refs) + + self._lora_loaded = True + + dist.barrier() def wait_and_update_bucket_weights(self, bucket): bucket = [(name, param.wait()) if hasattr(param, "wait") else (name, param) for name, param in bucket] diff --git a/miles/backends/sglang_utils/sglang_engine.py b/miles/backends/sglang_utils/sglang_engine.py index f736cf97a..57051081e 100644 --- a/miles/backends/sglang_utils/sglang_engine.py +++ b/miles/backends/sglang_utils/sglang_engine.py @@ -333,9 +333,15 @@ def get_weight_version(self): response.raise_for_status() return response.json()["weight_version"] - def release_memory_occupation(self): + def release_memory_occupation(self, tags: list[str] = None): + """ + Available tags for multi-stage release: weights, kv_cache, cuda_graph + """ self.flush_cache() - return self._make_request("release_memory_occupation") + return self._make_request( + "release_memory_occupation", + {"tags": tags}, + ) def resume_memory_occupation(self, tags: list[str] = None): """ @@ -391,6 +397,24 @@ def update_weights_from_distributed( payload, ) + def load_lora_adapter(self, lora_name: str, lora_path: str): + return self._make_request( + "load_lora_adapter", + {"lora_name": lora_name, "lora_path": lora_path}, + ) + + def load_lora_adapter_from_tensors(self, lora_name: str, serialized_tensors: str, config_dict: dict): + return self._make_request( + "load_lora_adapter_from_tensors", + {"lora_name": lora_name, "serialized_tensors": serialized_tensors, "config_dict": config_dict}, + ) + + def unload_lora_adapter(self, lora_name: str): + return self._make_request( + "unload_lora_adapter", + {"lora_name": lora_name}, + ) + def pause_generation(self): response = requests.post(f"http://{self.server_host}:{self.server_port}/pause_generation", json={}) response.raise_for_status() @@ -522,6 +546,11 @@ def _compute_server_args( kwargs["enable_return_routed_experts"] = True if args.fp16: kwargs["dtype"] = "float16" + if args.lora_rank > 0 or args.lora_adapter_path is not None: + kwargs["enable_lora"] = True + kwargs["max_lora_rank"] = args.lora_rank + kwargs["lora_target_modules"] = args.target_modules + external_engine_need_check_fields = [k for k in kwargs.keys() if k not in _EXTERNAL_ENGINE_SKIP_CHECK_FIELDS] unused_keys = set(kwargs.keys()) diff --git a/miles/ray/placement_group.py b/miles/ray/placement_group.py index eb232b161..f81c7f8ef 100644 --- a/miles/ray/placement_group.py +++ b/miles/ray/placement_group.py @@ -184,6 +184,7 @@ def create_rollout_manager(args, pg): ray.get(rollout_manager.check_weights.remote(action="reset_tensors")) if args.offload_rollout: + # TODO: Optimization in the future: offload model weights to cpu to make more space for training? ray.get(rollout_manager.offload.remote()) return rollout_manager, num_rollout_per_epoch diff --git a/miles/ray/rollout.py b/miles/ray/rollout.py index 79c6649be..099e825b0 100644 --- a/miles/ray/rollout.py +++ b/miles/ray/rollout.py @@ -155,10 +155,14 @@ def save(self, rollout_id): def load(self, rollout_id=None): self.data_source.load(rollout_id) - def offload(self): + def offload(self, tags: list[str] | None = None): self.health_monitoring_pause() return ray.get( - [engine.release_memory_occupation.remote() for engine in self.rollout_engines if engine is not None] + [ + engine.release_memory_occupation.remote(tags=tags) + for engine in self.rollout_engines + if engine is not None + ] ) def onload(self, tags: list[str] | None = None): diff --git a/miles/rollout/sglang_rollout.py b/miles/rollout/sglang_rollout.py index 91918340a..268888e4a 100644 --- a/miles/rollout/sglang_rollout.py +++ b/miles/rollout/sglang_rollout.py @@ -13,6 +13,7 @@ from packaging.version import parse from tqdm import tqdm +from miles.backends.fsdp_utils.lora_utils import LORA_ADAPTER_NAME from miles.rollout.base_types import RolloutFnEvalOutput, RolloutFnTrainOutput from miles.rollout.filter_hub.base_types import MetricGatherer, call_dynamic_filter from miles.utils.async_utils import run @@ -136,6 +137,10 @@ async def generate(args: Namespace, sample: Sample, sampling_params: dict[str, A "return_logprob": True, } + # Use LoRA adapter when LoRA is enabled + if args.lora_adapter_path is not None or (args.lora_rank > 0 and not args.debug_rollout_only): + payload["lora_path"] = LORA_ADAPTER_NAME + if args.use_rollout_routing_replay: payload["return_routed_experts"] = True diff --git a/miles/utils/arguments.py b/miles/utils/arguments.py index 79b2c419c..6cb466ac8 100644 --- a/miles/utils/arguments.py +++ b/miles/utils/arguments.py @@ -103,6 +103,17 @@ def add_cluster_arguments(parser): "This will always be true when --colocate is set." ), ) + parser.add_argument( + "--offload-rollout-level", + type=str, + nargs="+", + default=["kv_cache", "weight"], + help=( + "Specifies what to offload during rollout when offload-rollout is set. " + "Possible values: 'kv_cache', 'weight'. Default: both 'kv_cache' and 'weight'. " + "Example: --offload-rollout-level kv_cache weight" + ), + ) reset_arg(parser, "--distributed-backend", type=str, default="nccl") reset_arg(parser, "--distributed-timeout-minutes", type=int, default=10) @@ -920,6 +931,47 @@ def add_algo_arguments(parser): default=1e-4, help="The threshold for Off-Policy Sequence Masking (OPSM).", ) + parser.add_argument( + "--lora-rank", + type=int, + default=0, + help="LoRA rank. Set to 0 to disable LoRA (default: 0)", + ) + parser.add_argument( + "--lora-alpha", + type=int, + default=16, + help="LoRA alpha parameter (default: 16)", + ) + parser.add_argument( + "--target-modules", + type=str, + default=None, + help=( + "Target modules for LoRA adaptation. " + "Can be 'all-linear', a single module name, or comma-separated module names. " + "Example: 'q_proj,k_proj,v_proj' (default: None)" + ), + ) + parser.add_argument( + "--exclude-modules", + type=str, + default=None, + help="Comma-separated list of modules to exclude from LoRA adaptation (default: None)", + ) + parser.add_argument( + "--lora-adapter-path", + type=str, + default=None, + help="Path to load pre-trained LoRA adapter weights (default: None)", + ) + parser.add_argument( + "--lora-sync-from-tensor", + action="store_true", + default=False, + help="Use tensor-based LoRA weight synchronization instead of file-based (default: False)", + ) + return parser def add_router_arguments(parser): @@ -1493,6 +1545,13 @@ def _resolve_eval_datasets(args) -> list[EvalDatasetConfig]: def miles_validate_args(args): args.eval_datasets = _resolve_eval_datasets(args) + # Check if LoRA is enabled with Megatron backend (not yet implemented) + if args.lora_rank > 0 and args.train_backend == "megatron": + raise NotImplementedError( + "LoRA is not yet implemented for Megatron backend. " + "Please use FSDP backend (--train-backend fsdp) or disable LoRA (--lora-rank 0)." + ) + if args.kl_coef != 0 or args.use_kl_loss: if not os.path.exists(args.ref_load): raise FileNotFoundError(f"ref_load {args.ref_load} does not exist, please check the path.") @@ -1528,6 +1587,27 @@ def miles_validate_args(args): if args.save_interval is not None: assert args.save is not None, "'--save' is required when save_interval is set." + if args.lora_rank > 0: + assert args.save is not None, "'--save' is required when LoRA is enabled." + assert args.target_modules is not None, "'--target-modules' is required when LoRA is enabled." + + if args.target_modules == "all-linear": + modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] + elif "," in args.target_modules: + modules = [m.strip() for m in args.target_modules.split(",")] + else: + modules = [args.target_modules] + + if args.exclude_modules: + exclude_set = ( + set(m.strip() for m in args.exclude_modules.split(",")) + if "," in args.exclude_modules + else {args.exclude_modules} + ) + modules = [m for m in modules if m not in exclude_set] + + args.target_modules = modules + assert not (args.kl_coef != 0 and args.kl_loss_coef != 0), "Only one of kl_coef and kl_loss_coef can be set" if args.advantage_estimator in ["reinforce_plus_plus", "reinforce_plus_plus_baseline"]: diff --git a/requirements.txt b/requirements.txt index 2c20195fc..3840f294d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,6 +5,7 @@ httpx[http2] mcp[cli] memray # needed for debugging (but is lightweight), we can put it to dev mode when using pyproject.toml omegaconf +peft pillow pylatexenc pyyaml diff --git a/tests/test_external_rollout.py b/tests/test_external_rollout.py index c5c0838c5..f12837d88 100644 --- a/tests/test_external_rollout.py +++ b/tests/test_external_rollout.py @@ -14,7 +14,7 @@ def prepare(): U.exec_command("mkdir -p /root/models /root/datasets") - U.exec_command(f"huggingface-cli download Qwen/Qwen2.5-0.5B-Instruct --local-dir /root/models/{MODEL_NAME}") + U.exec_command(f"hf download Qwen/Qwen2.5-0.5B-Instruct --local-dir /root/models/{MODEL_NAME}") U.hf_download_dataset("zhuzilin/gsm8k") U.convert_checkpoint(model_name=MODEL_NAME, megatron_model_type=MODEL_TYPE, num_gpus_per_node=NUM_GPUS) diff --git a/tests/test_qwen2.5_0.5B_gsm8k.py b/tests/test_qwen2.5_0.5B_gsm8k.py index dcdbd5834..e33698af3 100644 --- a/tests/test_qwen2.5_0.5B_gsm8k.py +++ b/tests/test_qwen2.5_0.5B_gsm8k.py @@ -12,7 +12,7 @@ def prepare(): U.exec_command("mkdir -p /root/models /root/datasets") - U.exec_command(f"huggingface-cli download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") + U.exec_command(f"hf download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") U.hf_download_dataset("zhuzilin/gsm8k") diff --git a/tests/test_qwen2.5_0.5B_gsm8k_async.py b/tests/test_qwen2.5_0.5B_gsm8k_async.py index dcaaf5e1f..f52c80afe 100644 --- a/tests/test_qwen2.5_0.5B_gsm8k_async.py +++ b/tests/test_qwen2.5_0.5B_gsm8k_async.py @@ -11,7 +11,7 @@ def prepare(): U.exec_command("mkdir -p /root/models /root/datasets") - U.exec_command(f"huggingface-cli download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") + U.exec_command(f"hf download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") U.hf_download_dataset("zhuzilin/gsm8k") diff --git a/tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py b/tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py index 3d19b48ce..68642bb36 100644 --- a/tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py +++ b/tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py @@ -1,18 +1,36 @@ +import importlib.util import os import miles.utils.external_utils.command_utils as U MODEL_NAME = "Qwen3-0.6B" +ENABLE_LORA = U.get_bool_env_var("ENABLE_LORA", "0") def prepare(): + if ENABLE_LORA: + if importlib.util.find_spec("peft") is None: + U.exec_command("pip install peft") + U.exec_command("mkdir -p /root/models /root/datasets") - U.exec_command(f"huggingface-cli download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") + U.exec_command(f"hf download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") U.hf_download_dataset("zhuzilin/gsm8k") def execute(): ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME} " + lora_args = ( + ( + "--lora-rank 32 " + "--lora-alpha 32 " + "--target-modules all-linear " + f"--save /root/models/{MODEL_NAME}-lora-ckpt " + # "--lora-sync-from-tensor " + ) + if ENABLE_LORA + else "" + ) + rollout_args = ( "--prompt-data /root/datasets/gsm8k/train.parquet " "--input-key messages " @@ -51,7 +69,7 @@ def execute(): optimizer_args = ( "--optimizer adam " - "--lr 1e-6 " + f"--lr {'2e-5' if ENABLE_LORA else '1e-6'} " "--lr-decay-style constant " "--weight-decay 0.1 " "--adam-beta1 0.9 " @@ -74,10 +92,17 @@ def execute(): "--ci-metric-checker-threshold 0.71 " # loose threshold at 60 step ) - misc_args = "--actor-num-nodes 1 " "--actor-num-gpus-per-node 2 " "--colocate " "--train-backend fsdp " + misc_args = ( + "--actor-num-nodes 1 " + "--actor-num-gpus-per-node 2 " + "--colocate " + "--offload-rollout-level kv_cache weight " + "--train-backend fsdp " + ) train_args = ( f"{ckpt_args} " + f"{lora_args} " f"{rollout_args} " f"{optimizer_args} " f"{grpo_args} " diff --git a/tests/test_qwen3_0.6B_fsdp_distributed.py b/tests/test_qwen3_0.6B_fsdp_distributed.py index 3d70f3e4c..6b07b7e91 100644 --- a/tests/test_qwen3_0.6B_fsdp_distributed.py +++ b/tests/test_qwen3_0.6B_fsdp_distributed.py @@ -1,21 +1,37 @@ +import importlib.util import os import miles.utils.external_utils.command_utils as U MODEL_NAME = "Qwen3-0.6B" - - +ENABLE_LORA = U.get_bool_env_var("ENABLE_LORA", "0") FEW_GPU = U.get_bool_env_var("MILES_TEST_FEW_GPU", "1") def prepare(): + if ENABLE_LORA: + if importlib.util.find_spec("peft") is None: + U.exec_command("pip install peft") + U.exec_command("mkdir -p /root/models /root/datasets") - U.exec_command(f"huggingface-cli download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") + U.exec_command(f"hf download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") U.hf_download_dataset("zhuzilin/gsm8k") def execute(): ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME} " + lora_args = ( + ( + "--lora-rank 32 " + "--lora-alpha 32 " + "--target-modules all-linear " + f"--save /root/models/{MODEL_NAME}-lora-ckpt " + # "--lora-sync-from-tensor " + ) + if ENABLE_LORA + else "" + ) + rollout_args = ( "--prompt-data /root/datasets/gsm8k/train.parquet " "--input-key messages " @@ -55,7 +71,7 @@ def execute(): optimizer_args = ( "--optimizer adam " - "--lr 1e-6 " + f"--lr {'2e-5' if ENABLE_LORA else '1e-6'} " "--lr-decay-style constant " "--weight-decay 0.1 " "--adam-beta1 0.9 " @@ -68,6 +84,7 @@ def execute(): "--actor-num-nodes 1 " f"--actor-num-gpus-per-node {1 if FEW_GPU else 2} " f"--rollout-num-gpus {1 if FEW_GPU else 2} " + "--offload-rollout-level kv_cache weight " "--train-backend fsdp " ) @@ -80,6 +97,7 @@ def execute(): train_args = ( f"{ckpt_args} " + f"{lora_args} " f"{rollout_args} " f"{optimizer_args} " f"{grpo_args} " diff --git a/train.py b/train.py index 745dcbed6..e780811f5 100644 --- a/train.py +++ b/train.py @@ -1,4 +1,5 @@ import ray +from sglang.srt.constants import GPU_MEMORY_TYPE_CUDA_GRAPH, GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS from miles.ray.placement_group import create_placement_groups, create_rollout_manager, create_training_models from miles.utils.arguments import parse_args @@ -61,6 +62,10 @@ def save(rollout_id): if args.rollout_global_dataset: ray.get(rollout_manager.save.remote(rollout_id)) + def onload_rollout(): + if args.offload_rollout and "weight" in args.offload_rollout_level: + ray.get(rollout_manager.onload.remote(tags=[GPU_MEMORY_TYPE_WEIGHTS])) + # train loop. # note that for async training, one can change the position of the sync operation(ray.get). for rollout_id in range(args.start_rollout_id, args.num_rollout): @@ -70,7 +75,12 @@ def save(rollout_id): rollout_data_ref = ray.get(rollout_manager.generate.remote(rollout_id)) if args.offload_rollout: - ray.get(rollout_manager.offload.remote()) + offload_tags = [GPU_MEMORY_TYPE_CUDA_GRAPH] + if "kv_cache" in args.offload_rollout_level: + offload_tags.append(GPU_MEMORY_TYPE_KV_CACHE) + if "weight" in args.offload_rollout_level: + offload_tags.append(GPU_MEMORY_TYPE_WEIGHTS) + ray.get(rollout_manager.offload.remote(tags=offload_tags)) if args.use_critic: critic_train_handle = critic_model.async_train(rollout_id, rollout_data_ref)