From 4967316544086efc6012cc35a0e289bb6130747a Mon Sep 17 00:00:00 2001 From: Yusheng Su Date: Wed, 7 Jan 2026 20:15:51 +0000 Subject: [PATCH 01/12] add example --- .../run-qwen2.5-0.5B-gsm8k-lora.sh | 160 ++++++++++++++++++ 1 file changed, 160 insertions(+) create mode 100644 examples/reproducibility/run-qwen2.5-0.5B-gsm8k-lora.sh diff --git a/examples/reproducibility/run-qwen2.5-0.5B-gsm8k-lora.sh b/examples/reproducibility/run-qwen2.5-0.5B-gsm8k-lora.sh new file mode 100644 index 000000000..7f5ffa751 --- /dev/null +++ b/examples/reproducibility/run-qwen2.5-0.5B-gsm8k-lora.sh @@ -0,0 +1,160 @@ +#!/bin/bash + +# for rerun the task +pkill -9 sglang +sleep 3 +ray stop --force +pkill -9 ray +pkill -9 python +sleep 3 +pkill -9 ray +pkill -9 python + +set -ex + +# will prevent ray from buffering stdout/stderr +export PYTHONBUFFERED=16 + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" +source "${SCRIPT_DIR}/../../scripts/models/qwen2.5-0.5B.sh" + +CKPT_ARGS=( + --hf-checkpoint /root/Qwen2.5-0.5B-Instruct/ + --ref-load /root/Qwen2.5-0.5B-Instruct_torch_dist/ + # Uncomment to save checkpoints (required for LoRA) + --save /root/checkpoints/qwen2.5-0.5B-lora-megatron/ + --save-interval 5 +) + +LORA_ARGS=( + --lora-rank 16 # LoRA rank (typical values: 8, 16, 32, 64) + --lora-alpha 32 # LoRA alpha (usually 2x rank) + --lora-dropout 0.0 # LoRA dropout (0.0 for RL training) + # Target modules - use Megatron naming or HF naming + # Megatron: linear_qkv, linear_proj, linear_fc1, linear_fc2 + # HF: q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj + # --target-modules "all-linear" + --target-modules "q_proj,k_proj,v_proj,o_proj" + # --target-modules "q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj" + # --lora-sync-from-tensor # Use tensor-based sync (more efficient) + # Uncomment to share base model between actor and ref (saves memory) + --share-ref-base-model +) + +ROLLOUT_ARGS=( + --prompt-data /root/gsm8k/train.parquet + --input-key messages + --label-key label + --apply-chat-template + --rollout-shuffle + --rm-type math + # --num-rollout 100 + --num-rollout 10 # onyl train 10 stesp + --rollout-batch-size 32 + --n-samples-per-prompt 8 + --rollout-max-response-len 1024 + --rollout-temperature 1 + + --global-batch-size 256 +) + +EVAL_ARGS=( + --eval-interval 20 + --eval-prompt-data gsm8k /root/gsm8k/test.parquet + --n-samples-per-eval-prompt 1 + --eval-max-response-len 1024 + --eval-top-k 1 +) + +PERF_ARGS=( + --tensor-model-parallel-size 1 + --sequence-parallel + --pipeline-model-parallel-size 1 + --context-parallel-size 1 + --expert-model-parallel-size 1 + --expert-tensor-parallel-size 1 + + --use-dynamic-batch-size + --max-tokens-per-gpu 9216 +) + +GRPO_ARGS=( + --advantage-estimator grpo + --use-kl-loss + --kl-loss-coef 0.00 + --kl-loss-type low_var_kl + --kl-coef 0.00 + --entropy-coef 0.00 + --eps-clip 0.2 + --eps-clip-high 0.28 +) + +OPTIMIZER_ARGS=( + --optimizer adam + # --lr 1e-6 + --lr 1e-5 # Higher LR often works better for LoRA + --lr-decay-style constant + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.98 +) + +# WANDB_ARGS=( +# --use-wandb +# --wandb-host https://wandb.ai/ +# --wandb-team glm-zero +# --wandb-project miles-dev +# --wandb-group qwen2.5-0.5B-gsm8k-deterministic +# ) + +SGLANG_ARGS=( + --rollout-num-gpus-per-engine 1 + --sglang-mem-fraction-static 0.7 + + --sglang-enable-deterministic-inference + --sglang-attention-backend flashinfer + + --deterministic-mode +) + +MISC_ARGS=( + # default dropout in megatron is 0.1 + --attention-dropout 0.0 + --hidden-dropout 0.0 + # should be good for model performance + --accumulate-allreduce-grads-in-fp32 + --attention-softmax-in-fp32 + # need to comment this when using model with MLA + --attention-backend flash +) + +# launch the master node of ray in container +ray start --head --node-ip-address 127.0.0.1 --num-gpus 8 --disable-usage-stats + +ray job submit --address="http://127.0.0.1:8265" \ + --runtime-env-json='{ + "env_vars": { + "PYTHONPATH": "/root/Megatron-LM", + "CUDA_DEVICE_MAX_CONNECTIONS": "1", + "NCCL_ALGO": "Ring", + "NVTE_ALLOW_NONDETERMINISTIC_ALGO": "0", + "CUBLAS_WORKSPACE_CONFIG": ":4096:8" + } + }' \ + -- python3 train.py \ + --actor-num-nodes 1 \ + --actor-num-gpus-per-node 8 \ + --colocate \ + --calculate-per-token-loss \ + --use-miles-router \ + ${MODEL_ARGS[@]} \ + ${CKPT_ARGS[@]} \ + ${LORA_ARGS[@]} \ + ${ROLLOUT_ARGS[@]} \ + ${OPTIMIZER_ARGS[@]} \ + ${GRPO_ARGS[@]} \ + ${WANDB_ARGS[@]} \ + ${PERF_ARGS[@]} \ + ${EVAL_ARGS[@]} \ + ${SGLANG_ARGS[@]} \ + ${MISC_ARGS[@]} From f593e3311f2cb33c97df76d38184dbce257e5a7e Mon Sep 17 00:00:00 2001 From: Yusheng Su Date: Mon, 12 Jan 2026 00:43:41 +0000 Subject: [PATCH 02/12] fix megatron training problem --- .../run-qwen2.5-0.5B-gsm8k-lora.sh | 58 ++- miles/backends/megatron_utils/actor.py | 97 +++- miles/backends/megatron_utils/arguments.py | 33 ++ miles/backends/megatron_utils/checkpoint.py | 49 +- miles/backends/megatron_utils/lora_utils.py | 446 ++++++++++++++++++ .../megatron_utils/megatron_to_hf/__init__.py | 48 ++ miles/backends/megatron_utils/model.py | 69 ++- .../backends/megatron_utils/model_provider.py | 43 ++ .../update_weight/hf_weight_iterator_base.py | 11 +- .../hf_weight_iterator_bridge.py | 129 ++++- .../hf_weight_iterator_direct.py | 2 +- .../update_weight_from_tensor.py | 27 +- miles/backends/sglang_utils/sglang_engine.py | 91 ++++ miles/rollout/sglang_rollout.py | 12 + miles/utils/arguments.py | 103 ++++ 15 files changed, 1169 insertions(+), 49 deletions(-) create mode 100644 miles/backends/megatron_utils/lora_utils.py diff --git a/examples/reproducibility/run-qwen2.5-0.5B-gsm8k-lora.sh b/examples/reproducibility/run-qwen2.5-0.5B-gsm8k-lora.sh index 7f5ffa751..1cc67bbcc 100644 --- a/examples/reproducibility/run-qwen2.5-0.5B-gsm8k-lora.sh +++ b/examples/reproducibility/run-qwen2.5-0.5B-gsm8k-lora.sh @@ -20,26 +20,47 @@ source "${SCRIPT_DIR}/../../scripts/models/qwen2.5-0.5B.sh" CKPT_ARGS=( --hf-checkpoint /root/Qwen2.5-0.5B-Instruct/ - --ref-load /root/Qwen2.5-0.5B-Instruct_torch_dist/ + # --ref-load /root/Qwen2.5-0.5B-Instruct_torch_dist/ # Uncomment to save checkpoints (required for LoRA) --save /root/checkpoints/qwen2.5-0.5B-lora-megatron/ --save-interval 5 ) + +############################## +###########lora############### +############################## LORA_ARGS=( --lora-rank 16 # LoRA rank (typical values: 8, 16, 32, 64) --lora-alpha 32 # LoRA alpha (usually 2x rank) --lora-dropout 0.0 # LoRA dropout (0.0 for RL training) # Target modules - use Megatron naming or HF naming # Megatron: linear_qkv, linear_proj, linear_fc1, linear_fc2 - # HF: q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj - # --target-modules "all-linear" - --target-modules "q_proj,k_proj,v_proj,o_proj" - # --target-modules "q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj" + --target-modules "all-linear" + # Need this PR: Update LoRA Weights via Tensor sgl-project/sglang#16226 # --lora-sync-from-tensor # Use tensor-based sync (more efficient) - # Uncomment to share base model between actor and ref (saves memory) - --share-ref-base-model + ## Uncomment to share base model between actor and ref (saves memory) + # --share-ref-base-model + ############################## + ############################## + # # Debug + # --debug-rollout-only + --debug-train-only + --load-debug-rollout-data /root/debug_data/rollout_data.pt + # # --save-debug-rollout-data /root/debug_data/rollout_data.pt + ############################## + ############################## + # --no-use-distributed-optimizer # if open it will has error: /home/radixark/yushengsu/miles-pr/miles/miles/utils/arguments.py: + #def set_default_megatron_args(args): (error) # optimizer cannot distributed to other gpus (enable) + + --megatron-to-hf-mode bridge + # Disable gradient accumulation fusion for LoRA training + + # --no-gradient-accumulation-fusion #Root cause: When training with LoRA, the base model’s parameters are frozen (requires_grad=False). However, Megatron-LM’s tensor-parallel layers use gradient-accumulation fusion during the backward pass, and that fusion path checks weight.main_grad.dtype. For frozen parameters, main_grad is never allocated (it remains None), which triggers the error. (enable) ) +############################## +############################## +############################## ROLLOUT_ARGS=( --prompt-data /root/gsm8k/train.parquet @@ -68,7 +89,7 @@ EVAL_ARGS=( PERF_ARGS=( --tensor-model-parallel-size 1 - --sequence-parallel + --sequence-parallel #becasue of lora training error: RuntimeError: Cannot access the main gradient of a frozen parameter. main_grad is None. (enable) --pipeline-model-parallel-size 1 --context-parallel-size 1 --expert-model-parallel-size 1 @@ -80,7 +101,7 @@ PERF_ARGS=( GRPO_ARGS=( --advantage-estimator grpo - --use-kl-loss + # --use-kl-loss # if use kl loss, should use --ref-load --kl-loss-coef 0.00 --kl-loss-type low_var_kl --kl-coef 0.00 @@ -128,8 +149,18 @@ MISC_ARGS=( --attention-backend flash ) + +############################## +###########lora############### +############################## +export GPUS_PER_NODE=1 +############################## +############################## +############################## + # launch the master node of ray in container -ray start --head --node-ip-address 127.0.0.1 --num-gpus 8 --disable-usage-stats +ray start --head --node-ip-address 127.0.0.1 --num-gpus $GPUS_PER_NODE --disable-usage-stats +# ray start --head --node-ip-address 127.0.0.1 --num-gpus 1 --disable-usage-stats ray job submit --address="http://127.0.0.1:8265" \ --runtime-env-json='{ @@ -143,18 +174,19 @@ ray job submit --address="http://127.0.0.1:8265" \ }' \ -- python3 train.py \ --actor-num-nodes 1 \ - --actor-num-gpus-per-node 8 \ + --actor-num-gpus-per-node $GPUS_PER_NODE \ --colocate \ --calculate-per-token-loss \ --use-miles-router \ ${MODEL_ARGS[@]} \ ${CKPT_ARGS[@]} \ ${LORA_ARGS[@]} \ - ${ROLLOUT_ARGS[@]} \ ${OPTIMIZER_ARGS[@]} \ ${GRPO_ARGS[@]} \ ${WANDB_ARGS[@]} \ ${PERF_ARGS[@]} \ ${EVAL_ARGS[@]} \ ${SGLANG_ARGS[@]} \ - ${MISC_ARGS[@]} + ${MISC_ARGS[@]} \ + ${ROLLOUT_ARGS[@]} \ + ${LORA_ARGS[@]} diff --git a/miles/backends/megatron_utils/actor.py b/miles/backends/megatron_utils/actor.py index 7cc7f2619..1961b4e0e 100644 --- a/miles/backends/megatron_utils/actor.py +++ b/miles/backends/megatron_utils/actor.py @@ -37,7 +37,19 @@ from .update_weight.common import named_params_and_buffers from .update_weight.update_weight_from_distributed import UpdateWeightFromDistributed from .update_weight.update_weight_from_tensor import UpdateWeightFromTensor - +############################## +###########lora############### +############################## +from .lora_utils import ( + is_lora_enabled, + is_lora_model, + # apply_lora_to_megatron_model, + # get_lora_weights_and_config, + freeze_base_model, +) +############################## +############################## +############################## logging.getLogger("megatron").setLevel(logging.WARNING) logger = logging.getLogger(__name__) @@ -92,6 +104,33 @@ def init( args, role ) + ### share ref model + ############################## + ###########lora############### + ############################## + # # For LoRA with share-ref-base-model: backup base model weights BEFORE applying LoRA + # if is_lora_enabled(args) and role == "actor" and with_ref and getattr(args, 'share_ref_base_model', False): + # # Create weights_backuper early to backup base weights as "ref" before LoRA + # self.weights_backuper = TensorBackuper.create( + # source_getter=lambda: named_params_and_buffers( + # self.args, + # self.model, + # convert_to_global_name=args.megatron_to_hf_mode == "raw", + # translate_gpu_to_cpu=not self.args.enable_weights_backuper, + # ), + # single_tag=None if args.enable_weights_backuper else "actor", + # ) + # self.weights_backuper.backup("ref") # Backup base weights as ref BEFORE LoRA + # logger.info("Backed up base model weights as 'ref' before applying LoRA (share-ref-base-model mode)") + + # if is_lora_enabled(args) and role == "actor": + # self.model = apply_lora_to_megatron_model(self.model, args) + # freeze_base_model(self.model) + ############################## + ############################## + ############################## + + if role == "critic": if self.args.offload_train: self.sleep() @@ -108,12 +147,47 @@ def init( ), single_tag=None if args.enable_weights_backuper else "actor", ) + # Deal with actor model --> delt with in model.py + # ############################## + # ###########lora############### + # ############################## + # if is_lora_enabled(args): + # # self.weights_backuper.backup("ref") # Backup base weights as ref BEFORE LoRA (prevent load model weight again on later) + + # self.model = apply_lora_to_megatron_model(self.model, args) # model: base + lora including `requires_grad` process + # # freeze_base_model(self.model) # Set `requires_grad`: base + lora .. do not set here since self.weights_backuper.backup(...) does not process `requires_grad` + # ############################## + # ############################## + # ############################## self._active_model_tag: str | None = "actor" self.weights_backuper.backup("actor") + if with_ref: - self.load_other_checkpoint("ref", args.ref_load) + ############################## + ###########lora############### + ############################## + # self.load_other_checkpoint("ref", args.ref_load) + + # if use lora: --ref-load /root/Qwen2.5-0.5B-Instruct_torch_dist/ (should be also lora weight) + if is_lora_enabled(args): + raise NotImplementedError( + "LoRA with reference model is not yet fully implemented. " + "Please remove reference model settings from your training script:\n" + " 0. Might need to ensure self.load_other_checkpoint can load loar module as well.\n" + " 1. Remove '--use-kl-loss' flag, OR\n" + " 2. Set '--kl-coef 0' without '--use-kl-loss', OR\n" + " 3. Remove '--ref-load' parameter\n" + "This will disable reference model loading (with_ref=False) and allow LoRA training to proceed." + ) + else: + self.load_other_checkpoint("ref", args.ref_load) + ############################## + ############################## + ############################## + + if self.args.keep_old_actor: # Load old_actor checkpoint self.load_other_checkpoint("old_actor", args.load) @@ -131,6 +205,13 @@ def init( weights_getter=lambda: self.weights_backuper.get("actor"), model_name=type(self.hf_config).__name__.lower() if self.args.model_name is None else self.args.model_name, quantization_config=getattr(self.hf_config, "quantization_config", None), + ############################## + ###########lora############### + ############################## + is_lora=is_lora_enabled(args), + ############################## + ############################## + ############################## ) # empty cache after initialization @@ -248,6 +329,18 @@ def _switch_model(self, target_tag: str) -> None: self.weights_backuper.restore(target_tag) self._active_model_tag = target_tag + ############################## + ###########lora############### + ############################## + # Restore requires_grad after weight restoration + # For LoRA training: only adapter params should be trainable, base model frozen + if is_lora_enabled(self.args): + freeze_base_model(self.model) + # Note: ref model uses forward_only (@torch.no_grad), so requires_grad doesn't matter + ############################## + ############################## + ############################## + def fill_routing_replay(self, data_iterator, num_microbatches, rollout_data): if "rollout_routed_experts" not in rollout_data: raise ValueError( diff --git a/miles/backends/megatron_utils/arguments.py b/miles/backends/megatron_utils/arguments.py index aea72ceb8..03ed0a2e8 100644 --- a/miles/backends/megatron_utils/arguments.py +++ b/miles/backends/megatron_utils/arguments.py @@ -10,7 +10,40 @@ def set_default_megatron_args(args): # always use zero optimizer + ############################## + ###########lora############### + ############################## args.use_distributed_optimizer = True + + # from miles.backends.megatron_utils.lora_utils import is_lora_enabled + # # this should be enalbe after optimize + # if is_lora_enabled(args): + # # Cannot Use distributed optimizer (ZeRO) in LoRA training. + # args.use_distributed_optimizer = False + + # # === NEW: Disable features that cause issues with frozen parameters === + # # Disable gradient accumulation fusion (already have --no-gradient-accumulation-fusion) + # args.gradient_accumulation_fusion = False + + # # Disable async tensor model parallel allreduce to avoid main_grad access + # args.async_tensor_model_parallel_allreduce = False + + # # Disable overlap grad reduce (needs gradient buffers for all params) + # args.overlap_grad_reduce = False + + # # Disable sequence parallel if enabled (can cause similar issues) + # if hasattr(args, 'sequence_parallel') and args.sequence_parallel: + # import logging + # logging.getLogger(__name__).warning( + # "Disabling sequence_parallel for LoRA training (incompatible with frozen parameters)" + # ) + # args.sequence_parallel = False + # else: + # args.use_distributed_optimizer = True + ############################## + ############################## + ############################## + # TODO: maybe change this after megatron has good fp8 support args.bf16 = not args.fp16 # placeholders diff --git a/miles/backends/megatron_utils/checkpoint.py b/miles/backends/megatron_utils/checkpoint.py index 35d712910..914bab96d 100644 --- a/miles/backends/megatron_utils/checkpoint.py +++ b/miles/backends/megatron_utils/checkpoint.py @@ -10,9 +10,24 @@ from miles.utils import megatron_bridge_utils +############################## +###########lora############### +############################## +from miles.backends.megatron_utils.lora_utils import is_lora_model, save_lora_checkpoint, load_lora_checkpoint +############################## +############################## +############################## + logger = logging.getLogger(__name__) -__all__ = ["save_checkpoint"] +############################## +###########lora############### +############################## +# __all__ = ["save_checkpoint"] +__all__ = ["save_checkpoint", "save_checkpoint_with_lora", "load_checkpoint"] +############################## +############################## +############################## def load_checkpoint(ddp_model, optimizer, opt_param_scheduler, checkpointing_context, skip_load_to_model_and_opt): @@ -24,6 +39,20 @@ def load_checkpoint(ddp_model, optimizer, opt_param_scheduler, checkpointing_con load_path ), f"{args.load=} does not exist or is an empty directory. Did you specify the wrong folder?" + ############################## + ###########lora############### + ############################## + # Check for LoRA adapter first + lora_path = Path(load_path) / "adapter" + if lora_path.exists() and is_lora_model(ddp_model): + logger.info(f"Loading LoRA checkpoint from {lora_path}") + iteration = load_lora_checkpoint(ddp_model, args, str(lora_path)) + num_floating_point_operations_so_far = 0 + return iteration, num_floating_point_operations_so_far + ############################## + ############################## + ############################## + if _is_megatron_checkpoint(load_path): return _load_checkpoint_megatron( ddp_model=ddp_model, @@ -40,6 +69,24 @@ def load_checkpoint(ddp_model, optimizer, opt_param_scheduler, checkpointing_con load_path=load_path, ) +############################## +###########lora############### +############################## +def save_checkpoint_with_lora(iteration, model, optimizer, opt_param_scheduler): + """Extended save that handles LoRA adapters separately.""" + args = get_args() + + if is_lora_model(model): + # Save only LoRA adapter weights + save_dir = Path(args.save) / f"iter_{iteration:07d}" / "adapter" + logger.info(f"Saving LoRA checkpoint to {save_dir}") + save_lora_checkpoint(model, args, str(save_dir)) + else: + # Use standard Megatron save + save_checkpoint(iteration, model, optimizer, opt_param_scheduler) +############################## +############################## +############################## def _is_megatron_checkpoint(path: str | Path) -> bool: return (Path(path) / "latest_checkpointed_iteration.txt").is_file() or bool( diff --git a/miles/backends/megatron_utils/lora_utils.py b/miles/backends/megatron_utils/lora_utils.py new file mode 100644 index 000000000..5bf370314 --- /dev/null +++ b/miles/backends/megatron_utils/lora_utils.py @@ -0,0 +1,446 @@ +############################## +###########lora############### +############################## +# to-do(yusheng): this should be moved to utils or split into hf_weight_iterator_bridge.py + +"""LoRA utilities for Megatron backend using Megatron-Bridge PEFT integration.""" + +import logging +import os +from argparse import Namespace +from collections.abc import Sequence +from pathlib import Path +from typing import Any + +import torch +import torch.distributed as dist +from megatron.core import mpu + +logger = logging.getLogger(__name__) + +LORA_ADAPTER_NAME = "miles_lora" +LORA_SUBDIR = "tmp_lora" + + +def is_lora_enabled(args: Namespace) -> bool: + """Check if LoRA is enabled.""" + return args.lora_rank > 0 or args.lora_adapter_path is not None + + +# def apply_lora_to_megatron_model( +# model: Sequence[torch.nn.Module], +# args: Namespace, +# ) -> Sequence[torch.nn.Module]: +# """Apply LoRA to Megatron model using Megatron-Bridge PEFT integration. + +# This uses the Megatron-Bridge's PEFT support from: +# https://github.com/NVIDIA-NeMo/Megatron-Bridge/tree/main/src/megatron/bridge/peft + +# Note: in this version implementation, we use this Megatron-Bridge branch: https://github.com/yushengsu-thu/Megatron-Bridge/tree/merged-megatron-0.16.0rc0 + +# Args: +# model: Megatron model (DDP wrapped) +# args: Training arguments with LoRA config + +# Returns: +# LoRA-wrapped model +# """ +# # from megatron.bridge.peft import apply_lora_adapter, LoraConfig +# from megatron.bridge.peft.lora import LoRA + +# if args.lora_adapter_path: +# # TODO: Loading existing LoRA adapter needs separate implementation +# # Megatron-Bridge may have different API for loading +# # Refer to this one: https://github.com/volcengine/verl/pull/4063/files#diff-10d5abfbdb508c9478018ad08f295686a960701639fc4e3f3c24a4bdc2f0b711 +# raise NotImplementedError("Loading existing LoRA adapter is not yet implemented") +# else: +# # Determine lora_dtype from args +# if hasattr(args, 'bf16') and args.bf16: +# lora_dtype = torch.bfloat16 +# elif hasattr(args, 'fp16') and args.fp16: +# lora_dtype = torch.float16 +# else: +# lora_dtype = None # Will use model's dtype + +# # Get exclude_modules as list +# exclude_modules = [] +# if hasattr(args, 'exclude_modules') and args.exclude_modules: +# if isinstance(args.exclude_modules, str): +# exclude_modules = [m.strip() for m in args.exclude_modules.split(",")] +# else: +# exclude_modules = list(args.exclude_modules) + +# # Create new LoRA adapter using Megatron-Bridge LoRA dataclass +# # There are different lora_type, I just use the classic one (speed and acc might not the optimal) +# # https://github.com/volcengine/verl/pull/4063/files#diff-10d5abfbdb508c9478018ad08f295686a960701639fc4e3f3c24a4bdc2f0b711 +# lora = LoRA( +# target_modules=args.target_modules, # e.g., ["linear_qkv", "linear_proj", "linear_fc1", "linear_fc2"] +# exclude_modules=exclude_modules, # Modules to exclude from LoRA +# dim=args.lora_rank, # LoRA rank (called 'dim' in Megatron-Bridge) +# alpha=args.lora_alpha, # LoRA alpha scaling factor +# dropout=args.lora_dropout, # LoRA dropout rate +# dropout_position=getattr(args, 'lora_dropout_position', 'pre'), # 'pre' or 'post' +# lora_A_init_method=getattr(args, 'lora_A_init_method', 'xavier'), # Initialization for LoRA A matrix +# lora_B_init_method=getattr(args, 'lora_B_init_method', 'zero'), # Initialization for LoRA B matrix +# a2a_experimental=getattr(args, 'lora_a2a_experimental', False), # Experimental All-to-All communication +# lora_dtype=lora_dtype, # Parameter data type for LoRA weights +# ) +# logger.info(f"Applying LoRA: rank={args.lora_rank}, alpha={args.lora_alpha}, " +# f"dropout={args.lora_dropout}, target_modules={args.target_modules}, " +# f"exclude_modules={exclude_modules}, lora_dtype={lora_dtype}") + +# # Apply LoRA to each model chunk +# # The LoRA class is callable - calling it applies the transformation +# for model_chunk in model: +# # lora(model_chunk.module, training=True) applies LoRA and freezes base model +# lora(model_chunk.module, training=True) + +# # Print trainable parameters info +# _print_trainable_parameters(model) + +# return model + + +def _print_trainable_parameters(model: Sequence[torch.nn.Module]) -> None: + """Print trainable parameters statistics.""" + total_params = 0 + trainable_params = 0 + trainable_param_names = [] + + for model_chunk in model: + for name, param in model_chunk.named_parameters(): + total_params += param.numel() + if param.requires_grad: + trainable_params += param.numel() + trainable_param_names.append((name, param.numel())) + + if mpu.get_data_parallel_rank() == 0 and mpu.get_tensor_model_parallel_rank() == 0: + logger.info( + f"LoRA trainable params: {trainable_params:,} / {total_params:,} " + f"({100 * trainable_params / total_params:.2f}%)" + ) + # if trainable_param_names: + # logger.info(f"\nTrainable layers ({len(trainable_param_names)} parameters):") + # for name, num_params in trainable_param_names: + # logger.info(f" ✓ {name}: {num_params:,} params") + # else: + # logger.warning("⚠️ NO TRAINABLE PARAMETERS! LoRA may not be applied correctly.") + + +def is_lora_model(model: Sequence[torch.nn.Module]) -> bool: + """Check if model has LoRA layers applied.""" + for model_chunk in model: + if hasattr(model_chunk.module, "peft_config"): + return True + # Check for LoRA layers in parameters + for name, _ in model_chunk.named_parameters(): + if "lora_" in name: + return True + return False + + +# def get_lora_state_dict( +# model: Sequence[torch.nn.Module], +# args: Namespace, +# ) -> dict[str, torch.Tensor]: +# """Extract LoRA weights from model. + +# Returns only the LoRA adapter weights, not the base model weights. +# """ +# from miles.backends.megatron_utils.update_weight.common import named_params_and_buffers + +# lora_state_dict = {} + +# for name, param in named_params_and_buffers(args, model, convert_to_global_name=True): +# if "lora_" in name or ".adapter." in name: +# lora_state_dict[name] = param + +# return lora_state_dict + + +# def get_lora_weights_and_config( +# model: Sequence[torch.nn.Module], +# args: Namespace, +# ) -> tuple[dict[str, torch.Tensor], dict[str, Any]]: +# """Extract LoRA weights and config for tensor-based sync. + +# This is used for efficient weight sync to SGLang engines. +# """ +# lora_state_dict = get_lora_state_dict(model, args) + +# # Convert Megatron names to HF-compatible names for SGLang +# hf_state_dict = {} +# for name, param in lora_state_dict.items(): +# # Convert megatron naming to HF naming +# hf_name = _convert_megatron_to_hf_lora_name(name) +# hf_state_dict[hf_name] = param + +# config_dict = { +# "peft_type": "LORA", +# "r": args.lora_rank, +# "lora_alpha": args.lora_alpha, +# "target_modules": list(args.target_modules), +# "bias": "none", +# } + +# if mpu.get_data_parallel_rank() == 0: +# logger.info(f"Extracted {len(hf_state_dict)} LoRA weight tensors for sync") + +# return hf_state_dict, config_dict + + +# def _convert_megatron_to_hf_lora_name(name: str) -> str: +# """Convert Megatron LoRA parameter name to HuggingFace format. + +# Megatron: module.module.decoder.layers.0.self_attention.linear_qkv.lora_A.weight +# HF: model.layers.0.self_attn.q_proj.lora_A.weight +# """ +# # This mapping should match your specific model architecture +# replacements = [ +# ("module.module.decoder.layers.", "model.layers."), +# (".self_attention.linear_qkv.lora_", ".self_attn.q_proj.lora_"), +# (".self_attention.linear_proj.lora_", ".self_attn.o_proj.lora_"), +# (".mlp.linear_fc1.lora_", ".mlp.gate_proj.lora_"), +# (".mlp.linear_fc2.lora_", ".mlp.down_proj.lora_"), +# ] + +# result = name +# for old, new in replacements: +# result = result.replace(old, new) + +# return result + + +# def save_lora_checkpoint( +# model: Sequence[torch.nn.Module], +# args: Namespace, +# save_dir: str, +# ) -> str: +# """Save LoRA adapter checkpoint to disk. + +# Args: +# model: Megatron model with LoRA +# args: Training arguments +# save_dir: Directory to save checkpoint + +# Returns: +# Path to saved checkpoint +# """ +# from megatron.bridge.peft import save_lora_adapter + +# save_path = Path(save_dir) +# save_path.mkdir(parents=True, exist_ok=True) + +# # Use Megatron-Bridge's save function +# if mpu.get_data_parallel_rank() == 0 and mpu.get_tensor_model_parallel_rank() == 0: +# for model_chunk in model: +# save_lora_adapter(model_chunk.module, str(save_path)) +# os.sync() +# logger.info(f"Saved LoRA adapter to {save_path}") + +# dist.barrier() +# return str(save_path) + + +## to-do (yusheng): need to confirm usage +def save_lora_checkpoint( + model: Sequence[torch.nn.Module], + args: Namespace, + save_dir: str, +) -> str: + """Save LoRA adapter checkpoint to disk in HuggingFace PEFT format. + + Since Megatron-Bridge doesn't have a save_lora_adapter function, + we manually extract adapter weights and convert to PEFT format. + """ + import json + from pathlib import Path + from megatron.bridge.peft.lora_layers import LoRALinear, LinearAdapter, TELinearAdapter + from megatron.bridge.peft.adapter_wrapper import AdapterWrapper + + save_path = Path(save_dir) + + # Only rank 0 saves (other ranks just return) + if not (mpu.get_data_parallel_rank() == 0 and mpu.get_tensor_model_parallel_rank() == 0): + return str(save_path) + + save_path.mkdir(parents=True, exist_ok=True) + + lora_state_dict = {} + + for model_chunk in model: + for name, module in model_chunk.named_modules(): + linear_in = None + linear_out = None + + # LoRALinear (wraps base layer with adapter) + if isinstance(module, AdapterWrapper) and hasattr(module, 'adapter'): + adapter = module.adapter + if hasattr(adapter, 'linear_in') and hasattr(adapter, 'linear_out'): + linear_in = adapter.linear_in + linear_out = adapter.linear_out + # LinearAdapter/TELinearAdapter (extends nn.Linear with lora) + elif isinstance(module, (LinearAdapter, TELinearAdapter)): + if hasattr(module, 'linear_in') and hasattr(module, 'linear_out'): + linear_in = module.linear_in + linear_out = module.linear_out + + if linear_in is not None and linear_out is not None: + # Convert Megatron naming to HF PEFT naming + base_name = name.replace("module.module.", "base_model.model.") + base_name = base_name.replace(".decoder.layers.", ".model.layers.") + base_name = base_name.replace(".self_attention.linear_qkv", ".self_attn.q_proj") + base_name = base_name.replace(".self_attention.linear_proj", ".self_attn.o_proj") + base_name = base_name.replace(".mlp.linear_fc1", ".mlp.gate_proj") + base_name = base_name.replace(".mlp.linear_fc2", ".mlp.down_proj") + + lora_state_dict[f"{base_name}.lora_A.weight"] = linear_in.weight.data.cpu() + lora_state_dict[f"{base_name}.lora_B.weight"] = linear_out.weight.data.cpu() + + # Save weights + torch.save(lora_state_dict, save_path / "adapter_model.bin") + + # Save PEFT config + config = { + "peft_type": "LORA", + "r": args.lora_rank, + "lora_alpha": args.lora_alpha, + "target_modules": list(args.target_modules) if args.target_modules else ["q_proj", "o_proj", "gate_proj", "down_proj"], + "bias": "none", + "task_type": "CAUSAL_LM", + } + with open(save_path / "adapter_config.json", "w") as f: + json.dump(config, f, indent=2) + + os.sync() + logger.info(f"Saved LoRA adapter to {save_path} with {len(lora_state_dict)} tensors") + + return str(save_path) + + + + +def load_lora_checkpoint( + model: Sequence[torch.nn.Module], + args: Namespace, + load_dir: str, +) -> None: + """Load LoRA adapter checkpoint from disk. + + Args: + model: Megatron model + args: Training arguments + load_dir: Directory containing checkpoint + """ + from megatron.bridge.peft import load_lora_adapter + + load_path = Path(load_dir) + if not load_path.exists(): + raise FileNotFoundError(f"LoRA checkpoint not found at {load_path}") + + logger.info(f"Loading LoRA adapter from {load_path}") + + for model_chunk in model: + load_lora_adapter(model_chunk.module, str(load_path)) + + dist.barrier() + + +# ## to-do (yusheng): need to confirm usage +# def load_lora_checkpoint( +# model: Sequence[torch.nn.Module], +# args: Namespace, +# load_dir: str, +# ) -> None: +# """Load LoRA adapter checkpoint from disk. + +# Note: This loads PEFT-format checkpoints into Megatron-Bridge LoRA layers. +# The checkpoint must be in HuggingFace PEFT format (adapter_model.bin + adapter_config.json). +# """ +# import json +# from pathlib import Path +# from megatron.bridge.peft.lora_layers import LoRALinear, LinearAdapter, TELinearAdapter +# from megatron.bridge.peft.adapter_wrapper import AdapterWrapper + +# load_path = Path(load_dir) +# if not load_path.exists(): +# raise FileNotFoundError(f"LoRA checkpoint not found at {load_path}") + +# # Load state dict +# state_dict_path = load_path / "adapter_model.bin" +# if not state_dict_path.exists(): +# raise FileNotFoundError(f"adapter_model.bin not found in {load_path}") + +# lora_state_dict = torch.load(state_dict_path, map_location="cpu") + +# logger.info(f"Loading LoRA adapter from {load_path} with {len(lora_state_dict)} tensors") + +# # Build reverse name mapping (HF -> Megatron) +# def hf_to_megatron_name(hf_name: str) -> str: +# name = hf_name.replace("base_model.model.", "module.module.") +# name = name.replace(".model.layers.", ".decoder.layers.") +# name = name.replace(".self_attn.q_proj", ".self_attention.linear_qkv") +# name = name.replace(".self_attn.o_proj", ".self_attention.linear_proj") +# name = name.replace(".mlp.gate_proj", ".mlp.linear_fc1") +# name = name.replace(".mlp.down_proj", ".mlp.linear_fc2") +# return name + +# # Load weights into model +# for model_chunk in model: +# for name, module in model_chunk.named_modules(): +# linear_in = None +# linear_out = None + +# if isinstance(module, AdapterWrapper) and hasattr(module, 'adapter'): +# adapter = module.adapter +# if hasattr(adapter, 'linear_in') and hasattr(adapter, 'linear_out'): +# linear_in = adapter.linear_in +# linear_out = adapter.linear_out +# elif isinstance(module, (LinearAdapter, TELinearAdapter)): +# if hasattr(module, 'linear_in') and hasattr(module, 'linear_out'): +# linear_in = module.linear_in +# linear_out = module.linear_out + +# if linear_in is not None and linear_out is not None: +# # Find corresponding HF name +# base_name = name.replace("module.module.", "base_model.model.") +# base_name = base_name.replace(".decoder.layers.", ".model.layers.") +# base_name = base_name.replace(".self_attention.linear_qkv", ".self_attn.q_proj") +# base_name = base_name.replace(".self_attention.linear_proj", ".self_attn.o_proj") +# base_name = base_name.replace(".mlp.linear_fc1", ".mlp.gate_proj") +# base_name = base_name.replace(".mlp.linear_fc2", ".mlp.down_proj") + +# lora_a_key = f"{base_name}.lora_A.weight" +# lora_b_key = f"{base_name}.lora_B.weight" + +# if lora_a_key in lora_state_dict and lora_b_key in lora_state_dict: +# linear_in.weight.data.copy_(lora_state_dict[lora_a_key].to(linear_in.weight.device)) +# linear_out.weight.data.copy_(lora_state_dict[lora_b_key].to(linear_out.weight.device)) + +# dist.barrier() +# logger.info(f"Successfully loaded LoRA adapter from {load_path}") + + + + +def freeze_base_model(model: Sequence[torch.nn.Module]) -> None: + """Freeze base model parameters, only keep LoRA trainable.""" + for model_chunk in model: + for name, param in model_chunk.named_parameters(): + if "lora_" not in name and "adapter" not in name: + param.requires_grad = False + + +def get_trainable_params_for_optimizer( + model: Sequence[torch.nn.Module], +) -> list[torch.nn.Parameter]: + """Get only trainable parameters for optimizer (LoRA params only).""" + trainable_params = [] + for model_chunk in model: + for param in model_chunk.parameters(): + if param.requires_grad: + trainable_params.append(param) + return trainable_params +############################## +############################## +############################## \ No newline at end of file diff --git a/miles/backends/megatron_utils/megatron_to_hf/__init__.py b/miles/backends/megatron_utils/megatron_to_hf/__init__.py index ba5a286a3..bc1ba073b 100644 --- a/miles/backends/megatron_utils/megatron_to_hf/__init__.py +++ b/miles/backends/megatron_utils/megatron_to_hf/__init__.py @@ -83,3 +83,51 @@ def _convert_to_hf_core(args, model_name, name, param): else: converted_named_tensors.append((converted_name, converted_param)) return converted_named_tensors + +############################## +###########lora############### +############################## +### This might be model specific --> make it more general +def convert_lora_to_hf(args, model_name, name, param): + """ + Convert Megatron LoRA parameter to HuggingFace PEFT format. + + Megatron format: module.module.decoder.layers.0.self_attention.linear_qkv.adapter.linear_in.weight + HF PEFT format: base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight + """ + # Determine if this is lora_A (linear_in) or lora_B (linear_out) + if ".linear_in." in name or ".lora_A." in name: + lora_suffix = "lora_A.weight" + elif ".linear_out." in name or ".lora_B." in name: + lora_suffix = "lora_B.weight" + else: + # Fallback - return as is + return [(name, param)] + + # Convert Megatron naming to HF PEFT naming + hf_name = name + + # Remove Megatron wrapper prefixes + hf_name = hf_name.replace("module.module.", "base_model.model.") + + # Convert layer path + hf_name = hf_name.replace(".decoder.layers.", ".model.layers.") + + # Convert attention modules + hf_name = hf_name.replace(".self_attention.linear_qkv", ".self_attn.q_proj") + hf_name = hf_name.replace(".self_attention.linear_proj", ".self_attn.o_proj") + + # Convert MLP modules + hf_name = hf_name.replace(".mlp.linear_fc1", ".mlp.gate_proj") + hf_name = hf_name.replace(".mlp.linear_fc2", ".mlp.down_proj") + + # Replace adapter naming with lora naming + hf_name = hf_name.replace(".adapter.linear_in.weight", f".{lora_suffix}") + hf_name = hf_name.replace(".adapter.linear_out.weight", f".{lora_suffix}") + hf_name = hf_name.replace(".lora_A.weight", f".{lora_suffix}") + hf_name = hf_name.replace(".lora_B.weight", f".{lora_suffix}") + + return [(hf_name, param)] +############################## +############################## +############################## \ No newline at end of file diff --git a/miles/backends/megatron_utils/model.py b/miles/backends/megatron_utils/model.py index 780370453..d3ffd6a53 100644 --- a/miles/backends/megatron_utils/model.py +++ b/miles/backends/megatron_utils/model.py @@ -25,7 +25,16 @@ from miles.utils import tracking_utils from miles.utils.memory_utils import clear_memory -from .checkpoint import load_checkpoint, save_checkpoint +############################## +###########lora############### +############################## +# from .checkpoint import load_checkpoint, save_checkpoint +from .checkpoint import load_checkpoint, save_checkpoint, save_checkpoint_with_lora +from .lora_utils import is_lora_model +############################## +############################## +############################## + from .data import DataIterator, get_batch from .loss import loss_function from .model_provider import get_model_provider_func @@ -107,6 +116,16 @@ def setup_model_and_optimizer( model = get_model(get_model_provider_func(args, role), ModelType.encoder_or_decoder) + ############################## + ###########lora############### + ############################## + # from miles.backends.megatron_utils.lora_utils import is_lora_enabled, apply_lora_to_megatron_model + # if is_lora_enabled(args) and role == "actor": + # model = apply_lora_to_megatron_model(model, args) + ############################## + ############################## + ############################## + # Optimizer kwargs = {} for f in dataclasses.fields(OptimizerConfig): @@ -703,16 +722,44 @@ def save( args = get_args() if should_disable_forward_pre_hook(args): disable_forward_pre_hook(model) - save_checkpoint( - iteration, - model, - optimizer, - opt_param_scheduler, - num_floating_point_operations_so_far=0, - checkpointing_context=None, - train_data_iterator=None, - preprocess_common_state_dict_fn=None, - ) + + ############################## + ###########lora############### + ############################## + # save_checkpoint( + # iteration, + # model, + # optimizer, + # opt_param_scheduler, + # num_floating_point_operations_so_far=0, + # checkpointing_context=None, + # train_data_iterator=None, + # preprocess_common_state_dict_fn=None, + # ) + + if is_lora_model(model): + save_checkpoint_with_lora( + iteration, + model, + optimizer, + opt_param_scheduler, + ) + else: + save_checkpoint( + iteration, + model, + optimizer, + opt_param_scheduler, + num_floating_point_operations_so_far=0, + checkpointing_context=None, + train_data_iterator=None, + preprocess_common_state_dict_fn=None, + ) + + ############################## + ############################## + ############################## + if should_disable_forward_pre_hook(args): enable_forward_pre_hook(model) diff --git a/miles/backends/megatron_utils/model_provider.py b/miles/backends/megatron_utils/model_provider.py index 7834f1101..7b6a285bf 100644 --- a/miles/backends/megatron_utils/model_provider.py +++ b/miles/backends/megatron_utils/model_provider.py @@ -79,6 +79,13 @@ def wrapped_model_provider( if args.megatron_to_hf_mode == "bridge": from megatron.bridge import AutoBridge + ############################## + ###########lora############### + ############################## + from miles.backends.megatron_utils.lora_utils import is_lora_enabled + ############################## + ############################## + ############################## bridge = AutoBridge.from_hf_pretrained(args.hf_checkpoint, trust_remote_code=True) provider = bridge.to_megatron_provider(load_weights=False) @@ -88,6 +95,42 @@ def wrapped_model_provider( provider.expert_model_parallel_size = args.expert_model_parallel_size provider.expert_tensor_parallel_size = args.expert_tensor_parallel_size provider.sequence_parallel = args.sequence_parallel + ############################## + ###########lora############### + ############################## + # Register LoRA pre_wrap_hook(before setting up DDP) + if is_lora_enabled(args) and role == "actor": + def lora_pre_wrap_hook(model): + """Apply LoRA to model BEFORE DDP wrapping.""" + from megatron.bridge.peft.lora import LoRA + import torch + + # Set up lora_dtype + if hasattr(args, 'bf16') and args.bf16: + lora_dtype = torch.bfloat16 + elif hasattr(args, 'fp16') and args.fp16: + lora_dtype = torch.float16 + else: + lora_dtype = None + + lora = LoRA( + target_modules=args.target_modules, + dim=args.lora_rank, + alpha=args.lora_alpha, + dropout=args.lora_dropout, + lora_dtype=lora_dtype, + ) + + # Apply LoRA and freeze base model + transformed_model = lora(model, training=True) + lora.set_params_to_save(transformed_model) + + return transformed_model + + provider.register_pre_wrap_hook(lora_pre_wrap_hook) + ############################## + ############################## + ############################## provider.finalize() return provider.provide diff --git a/miles/backends/megatron_utils/update_weight/hf_weight_iterator_base.py b/miles/backends/megatron_utils/update_weight/hf_weight_iterator_base.py index ef7d62e8a..3cedf28db 100644 --- a/miles/backends/megatron_utils/update_weight/hf_weight_iterator_base.py +++ b/miles/backends/megatron_utils/update_weight/hf_weight_iterator_base.py @@ -14,11 +14,20 @@ def create(args, model, **kwargs): return c(args, model, **kwargs) - def __init__(self, args, model, model_name, quantization_config): + # def __init__(self, args, model, model_name, quantization_config): + def __init__(self, args, model, model_name, quantization_config, **kwargs): self.args = args self.model = model self.model_name = model_name self.quantization_config = quantization_config + ############################## + ###########lora############### + ############################## + self.is_lora = kwargs.pop('is_lora', False) + self._base_synced = kwargs.pop('_base_synced', False) + ############################## + ############################## + ############################## @abstractmethod def get_hf_weight_chunks(self, megatron_local_weights): diff --git a/miles/backends/megatron_utils/update_weight/hf_weight_iterator_bridge.py b/miles/backends/megatron_utils/update_weight/hf_weight_iterator_bridge.py index 7e0a4817e..f15a0491c 100644 --- a/miles/backends/megatron_utils/update_weight/hf_weight_iterator_bridge.py +++ b/miles/backends/megatron_utils/update_weight/hf_weight_iterator_bridge.py @@ -1,4 +1,5 @@ import dataclasses +import torch from miles.utils import megatron_bridge_utils from miles.utils.iter_utils import chunk_named_params_by_size @@ -7,11 +8,34 @@ from ..misc_utils import strip_param_name_prefix from .hf_weight_iterator_base import HfWeightIteratorBase +############################## +###########lora############### +############################## +def _normalize_base_weight_name(param_name: str) -> str: + """Remove the 'base_layer' suffix emitted when merge_adapter_weights=False.""" + if param_name.endswith("base_layer.weight"): + return param_name[: -len("base_layer.weight")] + "weight" + return param_name +############################## +############################## +############################## + class HfWeightIteratorBridge(HfWeightIteratorBase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # ############################## + # ###########lora############### + # ############################## + # self.is_lora = is_lora # already get from HfWeightIteratorBase + # self._base_synced = _base_synced # already get from HfWeightIteratorBase + # ############################## + # ############################## + # ############################## + from megatron.bridge import AutoBridge import miles_plugins.megatron_bridge # noqa: F401 @@ -22,25 +46,94 @@ def get_hf_weight_chunks(self, megatron_local_weights): # TODO support quantization (e.g. modify megatron-bridge to provide megatron param name) renamed_megatron_local_weights = {strip_param_name_prefix(k): v for k, v in megatron_local_weights.items()} with megatron_bridge_utils.patch_megatron_model(self.model): - conversion_tasks = self._bridge.get_conversion_tasks(self.model) - conversion_tasks = _process_conversion_tasks(conversion_tasks, renamed_megatron_local_weights) - - named_weights = self._bridge.export_hf_weights(self.model, cpu=False, conversion_tasks=conversion_tasks) - - named_weights = ( - ( - hf_param_name, - postprocess_hf_param( - args=self.args, - megatron_param_name=megatron_param_name, - hf_param_name=hf_param_name, - param=weight, - ), + ############################## + ###########lora############### + ############################## + + # conversion_tasks = self._bridge.get_conversion_tasks(self.model) + # conversion_tasks = _process_conversion_tasks(conversion_tasks, renamed_megatron_local_weights) + + # named_weights = self._bridge.export_hf_weights(self.model, cpu=False, conversion_tasks=conversion_tasks) + + # named_weights = ( + # ( + # hf_param_name, + # postprocess_hf_param( + # args=self.args, + # megatron_param_name=megatron_param_name, + # hf_param_name=hf_param_name, + # param=weight, + # ), + # ) + # for hf_param_name, weight, megatron_param_name in named_weights + # ) + + # yield from chunk_named_params_by_size(named_weights, chunk_size=self.args.update_weight_buffer_size) + + #### + + # Only sync base model on first call (or if not LoRA-only mode) + if not self.is_lora or not self._base_synced: + conversion_tasks = self._bridge.get_conversion_tasks(self.model) + conversion_tasks = _process_conversion_tasks(conversion_tasks, renamed_megatron_local_weights) + named_weights = self._bridge.export_hf_weights( + self.model, + cpu=False, + conversion_tasks=conversion_tasks, + merge_adapter_weights=not self.is_lora, # Do not return merged (base.weight + lora.weight). ) - for hf_param_name, weight, megatron_param_name in named_weights - ) - - yield from chunk_named_params_by_size(named_weights, chunk_size=self.args.update_weight_buffer_size) + named_weights = ( + ( + ############################## + ###########lora############### + ############################## + # hf_param_name, + _normalize_base_weight_name(hf_param_name), + ############################## + ############################## + ############################## + postprocess_hf_param( + args=self.args, + megatron_param_name=megatron_param_name, + hf_param_name=hf_param_name, + param=weight, + ), + ) + for hf_param_name, weight, megatron_param_name in named_weights + ) + yield from chunk_named_params_by_size(named_weights, chunk_size=self.args.update_weight_buffer_size) + if self.is_lora: + self._base_synced = True + # torch.cuda.synchronize() + ############################## + ############################## + ############################## + + ############################## + ###########lora############### + ############################## + if self.is_lora: + lora_weights = self._bridge.export_adapter_weights( + self.model, + cpu=False, + show_progress=False + ) + lora_weights = ( + ( + hf_param_name, + postprocess_hf_param( + args=self.args, + megatron_param_name=megatron_param_name, + hf_param_name=hf_param_name, + param=weight, + ), + ) + for hf_param_name, weight, megatron_param_name in lora_weights + ) + yield from chunk_named_params_by_size(lora_weights, chunk_size=self.args.update_weight_buffer_size) + ############################## + ############################## + ############################## def _process_conversion_tasks(vanilla_conversion_tasks, new_weight_dict): diff --git a/miles/backends/megatron_utils/update_weight/hf_weight_iterator_direct.py b/miles/backends/megatron_utils/update_weight/hf_weight_iterator_direct.py index af2250dc1..1beded53f 100644 --- a/miles/backends/megatron_utils/update_weight/hf_weight_iterator_direct.py +++ b/miles/backends/megatron_utils/update_weight/hf_weight_iterator_direct.py @@ -209,4 +209,4 @@ def _get_megatron_local_param_infos(args: Namespace, model: Sequence[torch.nn.Mo infos[i].dtype == param_info.dtype ), f"Parameter dtype mismatch: {infos[i].dtype} != {param_info.dtype}" - return param_infos + return param_infos \ No newline at end of file diff --git a/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py b/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py index 527d3cfe9..f80204a5b 100644 --- a/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py +++ b/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py @@ -36,6 +36,13 @@ def __init__( *, model_name: str, quantization_config: dict[str, int | str | list[str]] | None, + ############################## + ###########lora############### + ############################## + is_lora: bool = False, + ############################## + ############################## + ############################## ) -> None: """ Compute param buckets, create IPC Gloo groups (rollout_num_gpus_per_engine ranks/group). @@ -46,11 +53,26 @@ def __init__( self.model_name = model_name self.quantization_config = quantization_config self.weight_version = 0 - + ############################## + ###########lora############### + ############################## + self.is_lora = is_lora + self._lora_loaded = False + self._base_synced = False + + # self._hf_weight_iterator = HfWeightIteratorBase.create( + # args=args, model=model, model_name=model_name, quantization_config=quantization_config + # ) self._hf_weight_iterator = HfWeightIteratorBase.create( - args=args, model=model, model_name=model_name, quantization_config=quantization_config + args=args, model=model, model_name=model_name, quantization_config=quantization_config, + # is_lora=self.is_lora, + _base_synced=self._base_synced, ) + ############################## + ############################## + ############################## + # create the group within megatron. for start_rank in range(0, dist.get_world_size(), self.args.rollout_num_gpus_per_engine): end_rank = start_rank + self.args.rollout_num_gpus_per_engine @@ -62,6 +84,7 @@ def __init__( self._model_update_groups = None + def connect_rollout_engines( self, rollout_engines: Sequence[ActorHandle], rollout_engine_lock: ActorHandle ) -> None: diff --git a/miles/backends/sglang_utils/sglang_engine.py b/miles/backends/sglang_utils/sglang_engine.py index 9bb9b1287..fe7d0b5be 100644 --- a/miles/backends/sglang_utils/sglang_engine.py +++ b/miles/backends/sglang_utils/sglang_engine.py @@ -16,6 +16,43 @@ from miles.ray.ray_actor import RayActor from miles.utils.http_utils import get_host_info +############################## +###########lora############### +############################## +from argparse import Namespace + +def is_lora_enabled(args: Namespace) -> bool: + """Check if LoRA is enabled.""" + return args.lora_rank > 0 or args.lora_adapter_path is not None + + +def convert_target_modules_to_hf(megatron_modules: list[str]) -> list[str]: + """Convert Megatron LoRA target module names to HuggingFace format. + + Megatron: linear_qkv, linear_proj, linear_fc1, linear_fc2 + HF: q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj + """ + # This mapping should match your specific model architecture + replacements = { + "linear_qkv": ["q_proj", "k_proj", "v_proj"], + "linear_proj": ["o_proj"], + "linear_fc1": ["gate_proj", "up_proj"], + "linear_fc2": ["down_proj"], + } + + hf_modules = [] + for module in megatron_modules: + if module in replacements: + hf_modules.extend(replacements[module]) + else: + # Keep as-is if not in mapping (might already be HF format) + hf_modules.append(module) + + return hf_modules +############################## +############################## +############################## + logger = logging.getLogger(__name__) @@ -53,8 +90,19 @@ def _to_local_gpu_id(physical_gpu_id: int) -> int: def launch_server_process(server_args: ServerArgs) -> multiprocessing.Process: from sglang.srt.entrypoints.http_server import launch_server + multiprocessing.set_start_method("spawn", force=True) server_args.host = server_args.host.strip("[]") + ############################## + ###########lora############### + ############################## + # Add logging to see what args are being passed + logger.info(f"Launching SGLang server with args: enable_lora={getattr(server_args, 'enable_lora', None)}, " + f"max_lora_rank={getattr(server_args, 'max_lora_rank', None)}, " + f"base_gpu_id={server_args.base_gpu_id}") + ############################## + ############################## + ############################## p = multiprocessing.Process(target=launch_server, args=(server_args,)) p.start() @@ -333,6 +381,33 @@ def get_weight_version(self): response.raise_for_status() return response.json()["weight_version"] + ############################## + ###########lora############### + ############################## + def load_lora_adapter(self, lora_name: str, lora_path: str): + """Load LoRA adapter from disk.""" + return self._make_request( + "load_lora_adapter", + {"lora_name": lora_name, "lora_path": lora_path}, + ) + + def load_lora_adapter_from_tensors(self, lora_name: str, serialized_tensors: str, config_dict: dict): + """Load LoRA adapter from serialized tensors.""" + return self._make_request( + "load_lora_adapter_from_tensors", + {"lora_name": lora_name, "serialized_tensors": serialized_tensors, "config_dict": config_dict}, + ) + + def unload_lora_adapter(self, lora_name: str): + """Unload LoRA adapter.""" + return self._make_request( + "unload_lora_adapter", + {"lora_name": lora_name}, + ) + ############################## + ############################## + ############################## + def release_memory_occupation(self): self.flush_cache() return self._make_request("release_memory_occupation") @@ -494,6 +569,22 @@ def _compute_server_args( kwargs["dtype"] = "float16" external_engine_need_check_fields = [k for k in kwargs.keys() if k not in _EXTERNAL_ENGINE_SKIP_CHECK_FIELDS] + ############################## + ###########lora############### + ############################## + if is_lora_enabled(args): + kwargs["enable_lora"] = True + kwargs["max_lora_rank"] = args.lora_rank + kwargs["max_loras_per_batch"] = 1 + # NOTE: lora_target_modules might not be supported by your SGLang version + # Comment out this line if SGLang doesn't support it: + # kwargs["lora_target_modules"] = convert_target_modules_to_hf(args.target_modules) + # Log for debugging + kwargs["lora_target_modules"] = convert_target_modules_to_hf(args.target_modules) + ############################## + ############################## + ############################## + unused_keys = set(kwargs.keys()) for attr in dataclasses.fields(ServerArgs): if worker_type == "decode" and attr.name == "enable_hierarchical_cache": diff --git a/miles/rollout/sglang_rollout.py b/miles/rollout/sglang_rollout.py index 77c540d60..d8fb2cc18 100644 --- a/miles/rollout/sglang_rollout.py +++ b/miles/rollout/sglang_rollout.py @@ -25,6 +25,8 @@ from .rm_hub import async_rm, batched_async_rm +from miles.backends.sglang_utils.sglang_engine import is_lora_enabled + __all__ = ["generate_rollout"] logger = logging.getLogger(__name__) @@ -136,6 +138,16 @@ async def generate(args: Namespace, sample: Sample, sampling_params: dict[str, A "return_logprob": True, } + ############################## + ###########lora############### + ############################## + if is_lora_enabled(args): + from miles.backends.megatron_utils.lora_utils import LORA_ADAPTER_NAME + payload["lora_path"] = LORA_ADAPTER_NAME + ############################## + ############################## + ############################## + if args.use_rollout_routing_replay: payload["return_routed_experts"] = True diff --git a/miles/utils/arguments.py b/miles/utils/arguments.py index 595542c58..c8dceff6f 100644 --- a/miles/utils/arguments.py +++ b/miles/utils/arguments.py @@ -911,6 +911,74 @@ def add_algo_arguments(parser): help="The threshold for Off-Policy Sequence Masking (OPSM).", ) return parser + + ############################## + ###########lora############### + ############################## + def add_lora_arguments(parser): + """Add LoRA-related arguments for Megatron backend.""" + parser.add_argument( + "--lora-rank", + type=int, + default=0, + help="LoRA rank. Set to 0 to disable LoRA (default: 0)", + ) + parser.add_argument( + "--lora-alpha", + type=int, + default=16, + help="LoRA alpha for scaling (default: 16)", + ) + parser.add_argument( + "--lora-dropout", + type=float, + default=0.0, + help="LoRA dropout rate (default: 0.0)", + ) + parser.add_argument( + "--target-modules", + type=str, + default=None, + help="Target modules for LoRA. Use 'all-linear' or comma-separated module names " + "(e.g., 'q_proj,k_proj,v_proj,o_proj' for HF naming or 'linear_qkv,linear_proj' for Megatron naming)", + ) + parser.add_argument( + "--exclude-modules", + type=str, + default=None, + help="Modules to exclude from LoRA (comma-separated)", + ) + parser.add_argument( + "--lora-adapter-path", + type=str, + default=None, + help="Path to pre-trained LoRA adapter to load", + ) + parser.add_argument( + "--lora-sync-from-tensor", + action="store_true", + default=False, + help="Sync LoRA weights via tensor instead of file (more efficient)", + ) + # parser.add_argument( + # "--share-ref-base-model", + # action="store_true", + # default=False, + # help="Share base model between actor and reference model (saves memory for LoRA)", + # ) + + parser.add_argument( + "--no-use-distributed-optimizer", + action="store_false", + default=True, + dest="Use distributed optimizer (ZeRO)", + help="Use distributed optimizer (ZeRO). Disable for LoRA training. (default: True)", + ) + + return parser + ############################## + ############################## + ############################## def add_router_arguments(parser): parser.add_argument( @@ -1352,6 +1420,13 @@ def add_sglang_tp_size(): parser = add_data_arguments(parser) parser = add_eval_arguments(parser) parser = add_algo_arguments(parser) + ############################## + ###########lora############### + ############################## + parser = add_lora_arguments(parser) + ############################## + ############################## + ############################## parser = add_wandb_arguments(parser) parser = add_tensorboard_arguments(parser) parser = add_router_arguments(parser) @@ -1652,6 +1727,34 @@ def miles_validate_args(args): if args.enable_mtp_training: assert args.mtp_num_layers, "mtp_num_layers must be set when enable_mtp_training is set" + ############################## + ###########lora############### + ############################## + ### considert move these to megatron arguments.py + if args.lora_rank > 0: + assert args.save is not None, "'--save' is required when LoRA is enabled." + assert args.target_modules is not None, "'--target-modules' is required when LoRA is enabled." + + # Parse target modules + if args.target_modules == "all-linear": + # to-do: need to check both on megatron and sglang side support modules and names + # Megatron module names + modules = ["linear_qkv", "linear_proj", "linear_fc1", "linear_fc2"] + elif "," in args.target_modules: + modules = [m.strip() for m in args.target_modules.split(",")] + else: + modules = [args.target_modules] + + # Handle excluded modules + if args.exclude_modules: + exclude_set = set(m.strip() for m in args.exclude_modules.split(",")) + modules = [m for m in modules if m not in exclude_set] + + args.target_modules = modules + ############################## + ############################## + ############################## + if args.use_rollout_routing_replay: args.use_routing_replay = True From 342d9f5df5bdbd40e4425d0ea42b85783a7b78d1 Mon Sep 17 00:00:00 2001 From: Yusheng Su Date: Mon, 12 Jan 2026 00:45:09 +0000 Subject: [PATCH 03/12] update --- miles/backends/megatron_utils/arguments.py | 32 ---------------------- 1 file changed, 32 deletions(-) diff --git a/miles/backends/megatron_utils/arguments.py b/miles/backends/megatron_utils/arguments.py index 03ed0a2e8..8e00dc128 100644 --- a/miles/backends/megatron_utils/arguments.py +++ b/miles/backends/megatron_utils/arguments.py @@ -10,40 +10,8 @@ def set_default_megatron_args(args): # always use zero optimizer - ############################## - ###########lora############### - ############################## args.use_distributed_optimizer = True - # from miles.backends.megatron_utils.lora_utils import is_lora_enabled - # # this should be enalbe after optimize - # if is_lora_enabled(args): - # # Cannot Use distributed optimizer (ZeRO) in LoRA training. - # args.use_distributed_optimizer = False - - # # === NEW: Disable features that cause issues with frozen parameters === - # # Disable gradient accumulation fusion (already have --no-gradient-accumulation-fusion) - # args.gradient_accumulation_fusion = False - - # # Disable async tensor model parallel allreduce to avoid main_grad access - # args.async_tensor_model_parallel_allreduce = False - - # # Disable overlap grad reduce (needs gradient buffers for all params) - # args.overlap_grad_reduce = False - - # # Disable sequence parallel if enabled (can cause similar issues) - # if hasattr(args, 'sequence_parallel') and args.sequence_parallel: - # import logging - # logging.getLogger(__name__).warning( - # "Disabling sequence_parallel for LoRA training (incompatible with frozen parameters)" - # ) - # args.sequence_parallel = False - # else: - # args.use_distributed_optimizer = True - ############################## - ############################## - ############################## - # TODO: maybe change this after megatron has good fp8 support args.bf16 = not args.fp16 # placeholders From 4ce98574ac5eb1822bad683da17a5219bec11168 Mon Sep 17 00:00:00 2001 From: Yusheng Su Date: Mon, 12 Jan 2026 10:59:10 +0000 Subject: [PATCH 04/12] support training side - megatron: base + lora --- miles/backends/megatron_utils/actor.py | 4 +- miles/backends/megatron_utils/lora_utils.py | 126 +++++++------- miles/backends/megatron_utils/model.py | 161 +++++++++++++++++- .../backends/megatron_utils/model_provider.py | 55 +++--- .../hf_weight_iterator_bridge.py | 4 +- .../update_weight_from_tensor.py | 103 ++++++++++- miles/backends/sglang_utils/sglang_engine.py | 24 ++- miles/ray/rollout.py | 32 +++- miles/utils/arguments.py | 32 +++- train.py | 34 ++++ 10 files changed, 458 insertions(+), 117 deletions(-) diff --git a/miles/backends/megatron_utils/actor.py b/miles/backends/megatron_utils/actor.py index 1961b4e0e..e0634d39e 100644 --- a/miles/backends/megatron_utils/actor.py +++ b/miles/backends/megatron_utils/actor.py @@ -43,7 +43,7 @@ from .lora_utils import ( is_lora_enabled, is_lora_model, - # apply_lora_to_megatron_model, + apply_lora_to_megatron_model, # get_lora_weights_and_config, freeze_base_model, ) @@ -638,7 +638,7 @@ def update_weights(self) -> None: if self.args.offload_train: destroy_process_groups() - + def load_other_checkpoint(self, model_tag: str, path: str) -> None: old_args = self.args.load, self.args.no_load_optim, self.args.no_load_rng, self.args.finetune self.args.load = path diff --git a/miles/backends/megatron_utils/lora_utils.py b/miles/backends/megatron_utils/lora_utils.py index 5bf370314..44f9355b4 100644 --- a/miles/backends/megatron_utils/lora_utils.py +++ b/miles/backends/megatron_utils/lora_utils.py @@ -27,78 +27,78 @@ def is_lora_enabled(args: Namespace) -> bool: return args.lora_rank > 0 or args.lora_adapter_path is not None -# def apply_lora_to_megatron_model( -# model: Sequence[torch.nn.Module], -# args: Namespace, -# ) -> Sequence[torch.nn.Module]: -# """Apply LoRA to Megatron model using Megatron-Bridge PEFT integration. +def apply_lora_to_megatron_model( + model: Sequence[torch.nn.Module], + args: Namespace, +) -> Sequence[torch.nn.Module]: + """Apply LoRA to Megatron model using Megatron-Bridge PEFT integration. -# This uses the Megatron-Bridge's PEFT support from: -# https://github.com/NVIDIA-NeMo/Megatron-Bridge/tree/main/src/megatron/bridge/peft + This uses the Megatron-Bridge's PEFT support from: + https://github.com/NVIDIA-NeMo/Megatron-Bridge/tree/main/src/megatron/bridge/peft -# Note: in this version implementation, we use this Megatron-Bridge branch: https://github.com/yushengsu-thu/Megatron-Bridge/tree/merged-megatron-0.16.0rc0 + Note: in this version implementation, we use this Megatron-Bridge branch: https://github.com/yushengsu-thu/Megatron-Bridge/tree/merged-megatron-0.16.0rc0 -# Args: -# model: Megatron model (DDP wrapped) -# args: Training arguments with LoRA config + Args: + model: Megatron model (DDP wrapped) + args: Training arguments with LoRA config -# Returns: -# LoRA-wrapped model -# """ -# # from megatron.bridge.peft import apply_lora_adapter, LoraConfig -# from megatron.bridge.peft.lora import LoRA - -# if args.lora_adapter_path: -# # TODO: Loading existing LoRA adapter needs separate implementation -# # Megatron-Bridge may have different API for loading -# # Refer to this one: https://github.com/volcengine/verl/pull/4063/files#diff-10d5abfbdb508c9478018ad08f295686a960701639fc4e3f3c24a4bdc2f0b711 -# raise NotImplementedError("Loading existing LoRA adapter is not yet implemented") -# else: -# # Determine lora_dtype from args -# if hasattr(args, 'bf16') and args.bf16: -# lora_dtype = torch.bfloat16 -# elif hasattr(args, 'fp16') and args.fp16: -# lora_dtype = torch.float16 -# else: -# lora_dtype = None # Will use model's dtype + Returns: + LoRA-wrapped model + """ + # from megatron.bridge.peft import apply_lora_adapter, LoraConfig + from megatron.bridge.peft.lora import LoRA + + if args.lora_adapter_path: + # TODO: Loading existing LoRA adapter needs separate implementation + # Megatron-Bridge may have different API for loading + # Refer to this one: https://github.com/volcengine/verl/pull/4063/files#diff-10d5abfbdb508c9478018ad08f295686a960701639fc4e3f3c24a4bdc2f0b711 + raise NotImplementedError("Loading existing LoRA adapter is not yet implemented") + else: + # Determine lora_dtype from args + if hasattr(args, 'bf16') and args.bf16: + lora_dtype = torch.bfloat16 + elif hasattr(args, 'fp16') and args.fp16: + lora_dtype = torch.float16 + else: + lora_dtype = None # Will use model's dtype -# # Get exclude_modules as list -# exclude_modules = [] -# if hasattr(args, 'exclude_modules') and args.exclude_modules: -# if isinstance(args.exclude_modules, str): -# exclude_modules = [m.strip() for m in args.exclude_modules.split(",")] -# else: -# exclude_modules = list(args.exclude_modules) + # Get exclude_modules as list + exclude_modules = [] + if hasattr(args, 'exclude_modules') and args.exclude_modules: + if isinstance(args.exclude_modules, str): + exclude_modules = [m.strip() for m in args.exclude_modules.split(",")] + else: + exclude_modules = list(args.exclude_modules) -# # Create new LoRA adapter using Megatron-Bridge LoRA dataclass -# # There are different lora_type, I just use the classic one (speed and acc might not the optimal) -# # https://github.com/volcengine/verl/pull/4063/files#diff-10d5abfbdb508c9478018ad08f295686a960701639fc4e3f3c24a4bdc2f0b711 -# lora = LoRA( -# target_modules=args.target_modules, # e.g., ["linear_qkv", "linear_proj", "linear_fc1", "linear_fc2"] -# exclude_modules=exclude_modules, # Modules to exclude from LoRA -# dim=args.lora_rank, # LoRA rank (called 'dim' in Megatron-Bridge) -# alpha=args.lora_alpha, # LoRA alpha scaling factor -# dropout=args.lora_dropout, # LoRA dropout rate -# dropout_position=getattr(args, 'lora_dropout_position', 'pre'), # 'pre' or 'post' -# lora_A_init_method=getattr(args, 'lora_A_init_method', 'xavier'), # Initialization for LoRA A matrix -# lora_B_init_method=getattr(args, 'lora_B_init_method', 'zero'), # Initialization for LoRA B matrix -# a2a_experimental=getattr(args, 'lora_a2a_experimental', False), # Experimental All-to-All communication -# lora_dtype=lora_dtype, # Parameter data type for LoRA weights -# ) -# logger.info(f"Applying LoRA: rank={args.lora_rank}, alpha={args.lora_alpha}, " -# f"dropout={args.lora_dropout}, target_modules={args.target_modules}, " -# f"exclude_modules={exclude_modules}, lora_dtype={lora_dtype}") + # Create new LoRA adapter using Megatron-Bridge LoRA dataclass + # There are different lora_type, I just use the classic one (speed and acc might not the optimal) + # https://github.com/volcengine/verl/pull/4063/files#diff-10d5abfbdb508c9478018ad08f295686a960701639fc4e3f3c24a4bdc2f0b711 + lora = LoRA( + target_modules=args.target_modules, # e.g., ["linear_qkv", "linear_proj", "linear_fc1", "linear_fc2"] + exclude_modules=exclude_modules, # Modules to exclude from LoRA + dim=args.lora_rank, # LoRA rank (called 'dim' in Megatron-Bridge) + alpha=args.lora_alpha, # LoRA alpha scaling factor + dropout=args.lora_dropout, # LoRA dropout rate + dropout_position=getattr(args, 'lora_dropout_position', 'pre'), # 'pre' or 'post' + lora_A_init_method=getattr(args, 'lora_A_init_method', 'xavier'), # Initialization for LoRA A matrix + lora_B_init_method=getattr(args, 'lora_B_init_method', 'zero'), # Initialization for LoRA B matrix + a2a_experimental=getattr(args, 'lora_a2a_experimental', False), # Experimental All-to-All communication + lora_dtype=lora_dtype, # Parameter data type for LoRA weights + ) + logger.info(f"Applying LoRA: rank={args.lora_rank}, alpha={args.lora_alpha}, " + f"dropout={args.lora_dropout}, target_modules={args.target_modules}, " + f"exclude_modules={exclude_modules}, lora_dtype={lora_dtype}") -# # Apply LoRA to each model chunk -# # The LoRA class is callable - calling it applies the transformation -# for model_chunk in model: -# # lora(model_chunk.module, training=True) applies LoRA and freezes base model -# lora(model_chunk.module, training=True) + # Apply LoRA to each model chunk + # The LoRA class is callable - calling it applies the transformation + for model_chunk in model: + # lora(model_chunk.module, training=True) applies LoRA and freezes base model + lora(model_chunk.module, training=True) -# # Print trainable parameters info -# _print_trainable_parameters(model) + # Print trainable parameters info + _print_trainable_parameters(model) -# return model + return model def _print_trainable_parameters(model: Sequence[torch.nn.Module]) -> None: diff --git a/miles/backends/megatron_utils/model.py b/miles/backends/megatron_utils/model.py index d3ffd6a53..841799d8a 100644 --- a/miles/backends/megatron_utils/model.py +++ b/miles/backends/megatron_utils/model.py @@ -31,6 +31,8 @@ # from .checkpoint import load_checkpoint, save_checkpoint from .checkpoint import load_checkpoint, save_checkpoint, save_checkpoint_with_lora from .lora_utils import is_lora_model +# from miles.backends.megatron_utils.lora_utils import is_lora_enabled +from miles.backends.megatron_utils.lora_utils import is_lora_enabled, apply_lora_to_megatron_model ############################## ############################## ############################## @@ -114,7 +116,149 @@ def setup_model_and_optimizer( assert not args.moe_use_upcycling assert args.load is not None or args.pretrained_checkpoint is not None - model = get_model(get_model_provider_func(args, role), ModelType.encoder_or_decoder) + ############################## + ###########lora############### + ############################## + # model = get_model(get_model_provider_func(args, role), ModelType.encoder_or_decoder) + + # if is_lora_enabled(args): + + # from megatron.core.distributed import DistributedDataParallelConfig + # from megatron.bridge.models.model_provider import get_model + # provider = get_model_provider_func(args, role) + + # ddp_config = DistributedDataParallelConfig( + # grad_reduce_in_fp32=getattr(args, 'grad_reduce_in_fp32', False), + # check_for_nan_in_grad=getattr(args, 'check_for_nan_in_grad', False), + # overlap_grad_reduce=getattr(args, 'overlap_grad_reduce', False), + # overlap_param_gather=getattr(args, 'overlap_param_gather', False), + # average_in_collective=getattr(args, 'average_in_collective', False), + # use_distributed_optimizer=getattr(args, 'use_distributed_optimizer', False), + # ) + # # model = provider.provide_distributed_model( + # # ddp_config=ddp_config, + # # wrap_with_ddp=True, + # # bf16=getattr(args, 'bf16', False), + # # fp16=getattr(args, 'fp16', False), + # # ) + + # model = get_model( + # model_provider=provider, # must be ModelProviderMixin object + # ddp_config=ddp_config, + # model_type=ModelType.encoder_or_decoder, + # wrap_with_ddp=True, + # use_cpu_initialization=False, + # ) + + + # print(111111) + # print(model) + # print(111111) + # exit() + # else: + # model = get_model(get_model_provider_func(args, role), ModelType.encoder_or_decoder) + + + ########### + + if is_lora_enabled(args) and role == "actor" and args.megatron_to_hf_mode == "bridge": + from megatron.core.distributed import DistributedDataParallelConfig + from megatron.bridge.models.model_provider import get_model as bridge_get_model + from megatron.bridge import AutoBridge + from megatron.bridge.peft.lora import LoRA + import torch + + # Build the provider from HF checkpoint + bridge = AutoBridge.from_hf_pretrained(args.hf_checkpoint, trust_remote_code=True) + provider = bridge.to_megatron_provider(load_weights=False) + + # Set parallel configs on the provider + provider.tensor_model_parallel_size = args.tensor_model_parallel_size + provider.pipeline_model_parallel_size = args.pipeline_model_parallel_size + provider.expert_model_parallel_size = args.expert_model_parallel_size + provider.expert_tensor_parallel_size = args.expert_tensor_parallel_size + provider.sequence_parallel = args.sequence_parallel + + # Determine lora_dtype + if hasattr(args, 'bf16') and args.bf16: + lora_dtype = torch.bfloat16 + elif hasattr(args, 'fp16') and args.fp16: + lora_dtype = torch.float16 + else: + lora_dtype = None + + # Get exclude_modules as list + exclude_modules = [] + if hasattr(args, 'exclude_modules') and args.exclude_modules: + if isinstance(args.exclude_modules, str): + exclude_modules = [m.strip() for m in args.exclude_modules.split(",")] + else: + exclude_modules = list(args.exclude_modules) + + # Create LoRA config + lora = LoRA( + target_modules=args.target_modules, + exclude_modules=exclude_modules, + dim=args.lora_rank, + alpha=args.lora_alpha, + dropout=args.lora_dropout, + dropout_position=getattr(args, 'lora_dropout_position', 'pre'), + lora_A_init_method=getattr(args, 'lora_A_init_method', 'xavier'), + lora_B_init_method=getattr(args, 'lora_B_init_method', 'zero'), + a2a_experimental=getattr(args, 'lora_a2a_experimental', False), + lora_dtype=lora_dtype, + ) + + # Define pre_wrap_hook to apply LoRA before DDP wrapping + def apply_lora_hook(model_chunks): + transformed = lora(model_chunks, training=True) + lora.set_params_to_save(transformed) + return transformed + + # Register the hook + provider.register_pre_wrap_hook(apply_lora_hook) + provider.finalize() + + # Build DDP config + ddp_config = DistributedDataParallelConfig( + grad_reduce_in_fp32=getattr(args, 'grad_reduce_in_fp32', False), + check_for_nan_in_grad=getattr(args, 'check_for_nan_in_grad', False), + overlap_grad_reduce=getattr(args, 'overlap_grad_reduce', False), + overlap_param_gather=getattr(args, 'overlap_param_gather', False), + average_in_collective=getattr(args, 'average_in_collective', False), + use_distributed_optimizer=getattr(args, 'use_distributed_optimizer', False), + ) + + # Use Bridge's get_model with the provider (which now has LoRA hook registered) + model = bridge_get_model( + model_provider=provider, + ddp_config=ddp_config, + model_type=ModelType.encoder_or_decoder, + wrap_with_ddp=True, + use_cpu_initialization=False, + bf16=getattr(args, 'bf16', False), + fp16=getattr(args, 'fp16', False), + pre_wrap_hook=provider.pre_wrap_hook, + ) + + # Store lora instance for later use (e.g., checkpoint saving) + # You may want to attach this to the model or args for later access + if hasattr(args, '_lora_instance'): + args._lora_instance = lora + + # print(11111111) + # print(model) + # print(11111111) + # exit() + else: + # Original non-LoRA path or non-bridge mode + model = get_model(get_model_provider_func(args, role), ModelType.encoder_or_decoder) + + + ############################## + ############################## + ############################## + ############################## ###########lora############### @@ -122,6 +266,20 @@ def setup_model_and_optimizer( # from miles.backends.megatron_utils.lora_utils import is_lora_enabled, apply_lora_to_megatron_model # if is_lora_enabled(args) and role == "actor": # model = apply_lora_to_megatron_model(model, args) + + ######### + # if is_lora_enabled(args) and role == "actor": + # from megatron.bridge.peft.lora import LoRA + + # lora = LoRA( + # target_modules=args.target_modules, + # dim=args.lora_rank, + # alpha=args.lora_alpha, + # dropout=args.lora_dropout, + # ) + # # model is list[DDP],it require unwrap + # model = lora(model, training=True) + # lora.set_params_to_save(model) ############################## ############################## ############################## @@ -133,7 +291,6 @@ def setup_model_and_optimizer( kwargs[f.name] = getattr(args, f.name) config = OptimizerConfig(**kwargs) config.timers = None - optimizer = get_megatron_optimizer( config=config, model_chunks=model, diff --git a/miles/backends/megatron_utils/model_provider.py b/miles/backends/megatron_utils/model_provider.py index 7b6a285bf..c854bfb6f 100644 --- a/miles/backends/megatron_utils/model_provider.py +++ b/miles/backends/megatron_utils/model_provider.py @@ -19,6 +19,9 @@ from miles.utils.misc import load_function + + + # Adapt from https://github.com/volcengine/verl/blob/c3b20575d2bc815fcccd84bddb4c0401fc4b632b/verl/models/llama/megatron/layers/parallel_linear.py#L82 class LinearForLastLayer(torch.nn.Linear): def __init__( @@ -82,7 +85,7 @@ def wrapped_model_provider( ############################## ###########lora############### ############################## - from miles.backends.megatron_utils.lora_utils import is_lora_enabled + # from miles.backends.megatron_utils.lora_utils import is_lora_enabled ############################## ############################## ############################## @@ -97,43 +100,29 @@ def wrapped_model_provider( provider.sequence_parallel = args.sequence_parallel ############################## ###########lora############### - ############################## - # Register LoRA pre_wrap_hook(before setting up DDP) - if is_lora_enabled(args) and role == "actor": - def lora_pre_wrap_hook(model): - """Apply LoRA to model BEFORE DDP wrapping.""" - from megatron.bridge.peft.lora import LoRA - import torch - - # Set up lora_dtype - if hasattr(args, 'bf16') and args.bf16: - lora_dtype = torch.bfloat16 - elif hasattr(args, 'fp16') and args.fp16: - lora_dtype = torch.float16 - else: - lora_dtype = None - - lora = LoRA( - target_modules=args.target_modules, - dim=args.lora_rank, - alpha=args.lora_alpha, - dropout=args.lora_dropout, - lora_dtype=lora_dtype, - ) - - # Apply LoRA and freeze base model - transformed_model = lora(model, training=True) - lora.set_params_to_save(transformed_model) - - return transformed_model - - provider.register_pre_wrap_hook(lora_pre_wrap_hook) + ############################## + # if is_lora_enabled(args) and role == "actor": + # from megatron.bridge.peft.lora import LoRA + # lora = LoRA( + # target_modules=args.target_modules, + # dim=args.lora_rank, + # alpha=args.lora_alpha, + # dropout=args.lora_dropout, + # # lora_dtype=lora_dtype, + # ) + # # Apply LoRA and freeze base model + # def apply_lora(model_chunks): + # transformed = lora(model_chunks, training=True) + # lora.set_params_to_save(transformed) + # return transformed + # provider.register_pre_wrap_hook(apply_lora) ############################## ############################## ############################## provider.finalize() return provider.provide - + + def model_provider(pre_process: bool = True, post_process: bool = True, vp_stage: int | None = None) -> GPTModel: """Builds the model. diff --git a/miles/backends/megatron_utils/update_weight/hf_weight_iterator_bridge.py b/miles/backends/megatron_utils/update_weight/hf_weight_iterator_bridge.py index f15a0491c..07606c5ce 100644 --- a/miles/backends/megatron_utils/update_weight/hf_weight_iterator_bridge.py +++ b/miles/backends/megatron_utils/update_weight/hf_weight_iterator_bridge.py @@ -87,8 +87,8 @@ def get_hf_weight_chunks(self, megatron_local_weights): ############################## ###########lora############### ############################## - # hf_param_name, - _normalize_base_weight_name(hf_param_name), + hf_param_name, + # _normalize_base_weight_name(hf_param_name), ############################## ############################## ############################## diff --git a/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py b/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py index f80204a5b..85ab0ef28 100644 --- a/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py +++ b/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py @@ -19,6 +19,14 @@ update_weights_from_distributed, ) +############################## +###########lora############### +############################## +from miles.backends.megatron_utils.lora_utils import LORA_ADAPTER_NAME, is_lora_enabled +############################## +############################## +############################## + class UpdateWeightFromTensor: """ @@ -65,7 +73,7 @@ def __init__( # ) self._hf_weight_iterator = HfWeightIteratorBase.create( args=args, model=model, model_name=model_name, quantization_config=quantization_config, - # is_lora=self.is_lora, + is_lora=self.is_lora, _base_synced=self._base_synced, ) ############################## @@ -139,13 +147,100 @@ def update_weights(self) -> None: megatron_local_weights = self.weights_getter() + ############################## + ###########lora############### + ############################## + lora_named_tensors = [] + ############################## + ############################## + ############################## + for hf_named_tensors in self._hf_weight_iterator.get_hf_weight_chunks(megatron_local_weights): - refs, long_lived_tensors = self._send_hf_params(hf_named_tensors) - ray.get(refs) - del long_lived_tensors + ############################## + ###########lora############### + ############################## + # refs, long_lived_tensors = self._send_hf_params(hf_named_tensors) + # ray.get(refs) + # del long_lived_tensors + + + # Check if this chunk contains LoRA weights + if self.is_lora: + # print() + lora_weights = [(name, tensor) for name, tensor in hf_named_tensors + if 'lora_' in name.lower() or 'adapter' in name.lower()] + # print(1111111) + # print(hf_named_tensors) + # print(lora_weights) + # print(1111111) + # exit() + base_weights = [(name, tensor) for name, tensor in hf_named_tensors + if 'lora_' not in name.lower()] + + # Sync base weights normally + if base_weights: + refs, long_lived_tensors = self._send_hf_params(base_weights) + ray.get(refs) + del long_lived_tensors + + # Collect LoRA weights for later + lora_named_tensors.extend(lora_weights) + else: + refs, long_lived_tensors = self._send_hf_params(hf_named_tensors) + ray.get(refs) + del long_lived_tensors + + + # After syncing all weights, load LoRA adapter into SGLang + if self.is_lora and lora_named_tensors: + self._load_lora_adapter(lora_named_tensors) + ############################## + ############################## + ############################## dist.barrier(group=get_gloo_group()) + ############################## + ###########lora############### + ############################## + def _load_lora_adapter(self, lora_named_tensors: list[tuple[str, torch.Tensor]]) -> None: + """Load LoRA adapter into SGLang engine.""" + from ..sglang import FlattenedTensorBucket, MultiprocessingSerializer + + # Create config dict + config_dict = { + "peft_type": "LORA", + "r": self.args.lora_rank, + "lora_alpha": self.args.lora_alpha, + "target_modules": list(self.args.target_modules) if self.args.target_modules else [], + "bias": "none", + } + + # Serialize LoRA tensors + flattened_tensor_bucket = FlattenedTensorBucket(named_tensors=lora_named_tensors) + flattened_tensor_data = { + "flattened_tensor": flattened_tensor_bucket.get_flattened_tensor(), + "metadata": flattened_tensor_bucket.get_metadata(), + } + serialized_tensors = MultiprocessingSerializer.serialize(flattened_tensor_data, output_str=True) + + # Load adapter on rank 0 + rank = dist.get_rank() + if rank == 0: + refs = [ + engine.load_lora_adapter_from_tensors.remote( + lora_name=LORA_ADAPTER_NAME, + serialized_tensors=serialized_tensors, + config_dict=config_dict, + ) + for engine in self.rollout_engines + ] + ray.get(refs) + self._lora_loaded = True + ############################## + ############################## + ############################## + def _send_hf_params(self, hf_named_tensors) -> tuple[list[ObjectRef], Any]: all_refs = [] diff --git a/miles/backends/sglang_utils/sglang_engine.py b/miles/backends/sglang_utils/sglang_engine.py index fe7d0b5be..1b362dc45 100644 --- a/miles/backends/sglang_utils/sglang_engine.py +++ b/miles/backends/sglang_utils/sglang_engine.py @@ -408,9 +408,29 @@ def unload_lora_adapter(self, lora_name: str): ############################## ############################## - def release_memory_occupation(self): + + ############################## + ###########lora############### + ############################## + # def release_memory_occupation(self): + # self.flush_cache() + # return self._make_request("release_memory_occupation") + + def release_memory_occupation(self, tags: list[str] = None): + """ + Available tags for multi-stage release: weights, kv_cache + """ self.flush_cache() - return self._make_request("release_memory_occupation") + return self._make_request( + "release_memory_occupation", + {"tags": tags}, + ) + ############################## + ############################## + ############################## + + + def resume_memory_occupation(self, tags: list[str] = None): """ diff --git a/miles/ray/rollout.py b/miles/ray/rollout.py index 4b22c5ddc..725b82617 100644 --- a/miles/ray/rollout.py +++ b/miles/ray/rollout.py @@ -130,8 +130,36 @@ def save(self, rollout_id): def load(self, rollout_id=None): self.data_source.load(rollout_id) - def offload(self): - return ray.get([engine.release_memory_occupation.remote() for engine in self.rollout_engines]) + ############################## + ###########lora############### + ############################## + # def offload(self): + # return ray.get([engine.release_memory_occupation.remote() for engine in self.rollout_engines]) + + def offload(self, tags: list[str] | None = None): + self.health_monitoring_pause() + return ray.get( + [ + engine.release_memory_occupation.remote(tags=tags) + for engine in self.rollout_engines + if engine is not None + ] + ) + + + def health_monitoring_pause(self): + if self.args.use_fault_tolerance and hasattr(self, '_health_monitor'): + self._health_monitor.stop() + + def health_monitoring_resume(self): + if self.args.use_fault_tolerance and hasattr(self, '_health_monitor'): + self._health_monitor.start() + + ############################## + ############################## + ############################## + + def onload(self, tags: list[str] = None): return ray.get([engine.resume_memory_occupation.remote(tags=tags) for engine in self.rollout_engines]) diff --git a/miles/utils/arguments.py b/miles/utils/arguments.py index c8dceff6f..63bf34d57 100644 --- a/miles/utils/arguments.py +++ b/miles/utils/arguments.py @@ -104,6 +104,24 @@ def add_cluster_arguments(parser): ), ) + ############################## + ###########lora############### + ############################## + parser.add_argument( + "--offload-rollout-level", + type=str, + nargs="+", + default=["kv_cache", "weight"], + help=( + "Specifies what to offload during rollout when offload-rollout is set. " + "Possible values: 'kv_cache', 'weight'. Default: both 'kv_cache' and 'weight'. " + "Example: --offload-rollout-level kv_cache weight" + ), + ) + ############################## + ############################## + ############################## + reset_arg(parser, "--distributed-backend", type=str, default="nccl") reset_arg(parser, "--distributed-timeout-minutes", type=int, default=10) @@ -967,13 +985,13 @@ def add_lora_arguments(parser): # help="Share base model between actor and reference model (saves memory for LoRA)", # ) - parser.add_argument( - "--no-use-distributed-optimizer", - action="store_false", - default=True, - dest="Use distributed optimizer (ZeRO)", - help="Use distributed optimizer (ZeRO). Disable for LoRA training. (default: True)", - ) + # parser.add_argument( + # "--no-use-distributed-optimizer", + # action="store_false", + # default=True, + # dest="Use distributed optimizer (ZeRO)", + # help="Use distributed optimizer (ZeRO). Disable for LoRA training. (default: True)", + # ) return parser ############################## diff --git a/train.py b/train.py index a4f6824cc..2ed15716c 100644 --- a/train.py +++ b/train.py @@ -12,6 +12,16 @@ from miles.utils.misc import should_run_periodic_action from miles.utils.tracking_utils import init_tracking +############################## +###########lora############### +############################## +from sglang.srt.constants import GPU_MEMORY_TYPE_CUDA_GRAPH, GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS +############################## +############################## +############################## + + + def train(args): configure_logger() @@ -55,10 +65,20 @@ def offload_train(): else: actor_model.clear_memory() + ############################## + ###########lora############### + ############################## def onload_rollout(): if args.offload_rollout: ray.get(rollout_manager.onload.remote(tags=[GPU_MEMORY_TYPE_WEIGHTS])) + # def onload_rollout(): + # if args.offload_rollout and "weight" in args.offload_rollout_level: + # ray.get(rollout_manager.onload.remote(tags=[GPU_MEMORY_TYPE_WEIGHTS])) + ############################## + ############################## + ############################## + # train loop. # note that for async training, one can change the position of the sync operation(ray.get). for rollout_id in range(args.start_rollout_id, args.num_rollout): @@ -67,9 +87,23 @@ def onload_rollout(): rollout_data_ref = ray.get(rollout_manager.generate.remote(rollout_id)) + ############################## + ###########lora############### + ############################## if args.offload_rollout: ray.get(rollout_manager.offload.remote()) + # if args.offload_rollout: + # offload_tags = [GPU_MEMORY_TYPE_CUDA_GRAPH] + # if "kv_cache" in args.offload_rollout_level: + # offload_tags.append(GPU_MEMORY_TYPE_KV_CACHE) + # if "weight" in args.offload_rollout_level: + # offload_tags.append(GPU_MEMORY_TYPE_WEIGHTS) + # ray.get(rollout_manager.offload.remote(tags=offload_tags)) + ############################## + ############################## + ############################## + if args.use_critic: critic_train_handle = critic_model.async_train(rollout_id, rollout_data_ref) if rollout_id >= args.num_critic_only_steps: From 4f5cf71b425a05acfaa56dedfd1fb5cc6a1e805e Mon Sep 17 00:00:00 2001 From: Yusheng Su Date: Tue, 13 Jan 2026 19:22:41 +0000 Subject: [PATCH 05/12] support rollout part --- .../run-qwen2.5-0.5B-gsm8k-lora.sh | 33 +++- miles/backends/megatron_utils/model.py | 4 - .../hf_weight_iterator_bridge.py | 152 ++++++++++-------- .../update_weight_from_tensor.py | 129 ++++++++------- miles/backends/sglang_utils/sglang_engine.py | 58 +++++-- miles/ray/rollout.py | 20 ++- miles/rollout/sglang_rollout.py | 6 + miles/utils/arguments.py | 55 ++++--- 8 files changed, 272 insertions(+), 185 deletions(-) diff --git a/examples/reproducibility/run-qwen2.5-0.5B-gsm8k-lora.sh b/examples/reproducibility/run-qwen2.5-0.5B-gsm8k-lora.sh index 1cc67bbcc..b5f257d86 100644 --- a/examples/reproducibility/run-qwen2.5-0.5B-gsm8k-lora.sh +++ b/examples/reproducibility/run-qwen2.5-0.5B-gsm8k-lora.sh @@ -22,8 +22,10 @@ CKPT_ARGS=( --hf-checkpoint /root/Qwen2.5-0.5B-Instruct/ # --ref-load /root/Qwen2.5-0.5B-Instruct_torch_dist/ # Uncomment to save checkpoints (required for LoRA) + #### train --save /root/checkpoints/qwen2.5-0.5B-lora-megatron/ --save-interval 5 + ### ) @@ -36,18 +38,31 @@ LORA_ARGS=( --lora-dropout 0.0 # LoRA dropout (0.0 for RL training) # Target modules - use Megatron naming or HF naming # Megatron: linear_qkv, linear_proj, linear_fc1, linear_fc2 - --target-modules "all-linear" # Need this PR: Update LoRA Weights via Tensor sgl-project/sglang#16226 # --lora-sync-from-tensor # Use tensor-based sync (more efficient) - ## Uncomment to share base model between actor and ref (saves memory) + # # Uncomment to share base model between actor and ref (saves memory) # --share-ref-base-model + + --target-modules "all-linear" + # --target-modules "o_proj,down_proj,k_proj,gate_proj,q_proj,v_proj,up_proj" + # --target-modules "q_proj,k_proj,v_proj,o_proj" ############################## ############################## # # Debug - # --debug-rollout-only - --debug-train-only - --load-debug-rollout-data /root/debug_data/rollout_data.pt - # # --save-debug-rollout-data /root/debug_data/rollout_data.pt + #### inference + --debug-rollout-only + ### --lora-adapter-path /root/checkpoints/qwen2.5-0.5B-lora-megatron/lora_adapter.pt + --lora-adapter-path lewtun/Qwen2.5-0.5B-SFT-LoRA + # --lora-adapter-path /root/checkpoints/qwen2.5-0.5B-lora-megatron/ + ### + + #### train + # --debug-train-only + # --load-debug-rollout-data /root/debug_data/rollout_data.pt + ## --save /root/checkpoints/qwen2.5-0.5B-lora-megatron/ + + # --save-debug-rollout-data /root/debug_data/rollout_data.pt + ### ############################## ############################## # --no-use-distributed-optimizer # if open it will has error: /home/radixark/yushengsu/miles-pr/miles/miles/utils/arguments.py: @@ -153,7 +168,7 @@ MISC_ARGS=( ############################## ###########lora############### ############################## -export GPUS_PER_NODE=1 +export GPUS_PER_NODE=2 ############################## ############################## ############################## @@ -190,3 +205,7 @@ ray job submit --address="http://127.0.0.1:8265" \ ${MISC_ARGS[@]} \ ${ROLLOUT_ARGS[@]} \ ${LORA_ARGS[@]} + + +# colocate : update from tesnor +# disaggrate : update from distributed \ No newline at end of file diff --git a/miles/backends/megatron_utils/model.py b/miles/backends/megatron_utils/model.py index 841799d8a..14edfe011 100644 --- a/miles/backends/megatron_utils/model.py +++ b/miles/backends/megatron_utils/model.py @@ -246,10 +246,6 @@ def apply_lora_hook(model_chunks): if hasattr(args, '_lora_instance'): args._lora_instance = lora - # print(11111111) - # print(model) - # print(11111111) - # exit() else: # Original non-LoRA path or non-bridge mode model = get_model(get_model_provider_func(args, role), ModelType.encoder_or_decoder) diff --git a/miles/backends/megatron_utils/update_weight/hf_weight_iterator_bridge.py b/miles/backends/megatron_utils/update_weight/hf_weight_iterator_bridge.py index 07606c5ce..4927b2000 100644 --- a/miles/backends/megatron_utils/update_weight/hf_weight_iterator_bridge.py +++ b/miles/backends/megatron_utils/update_weight/hf_weight_iterator_bridge.py @@ -49,88 +49,104 @@ def get_hf_weight_chunks(self, megatron_local_weights): ############################## ###########lora############### ############################## + ## This is the origin way - weight sync will process - base model + lora weights + ## to-do (yusheng): Optimize: use the method in `self.is_lora` but need to deal with CUDA issue (weight not on the same device) - might need to be delt with in megatron-core - # conversion_tasks = self._bridge.get_conversion_tasks(self.model) - # conversion_tasks = _process_conversion_tasks(conversion_tasks, renamed_megatron_local_weights) + conversion_tasks = self._bridge.get_conversion_tasks(self.model) + conversion_tasks = _process_conversion_tasks(conversion_tasks, renamed_megatron_local_weights) - # named_weights = self._bridge.export_hf_weights(self.model, cpu=False, conversion_tasks=conversion_tasks) + named_weights = self._bridge.export_hf_weights(self.model, cpu=False, conversion_tasks=conversion_tasks) - # named_weights = ( - # ( - # hf_param_name, - # postprocess_hf_param( - # args=self.args, - # megatron_param_name=megatron_param_name, - # hf_param_name=hf_param_name, - # param=weight, - # ), - # ) - # for hf_param_name, weight, megatron_param_name in named_weights - # ) + # for hf_param_name, weight, megatron_param_name in named_weights: + # print(hf_param_name) + # exit() + + named_weights = ( + ( + hf_param_name, + postprocess_hf_param( + args=self.args, + megatron_param_name=megatron_param_name, + hf_param_name=hf_param_name, + param=weight, + ), + ) + for hf_param_name, weight, megatron_param_name in named_weights + ) - # yield from chunk_named_params_by_size(named_weights, chunk_size=self.args.update_weight_buffer_size) + yield from chunk_named_params_by_size(named_weights, chunk_size=self.args.update_weight_buffer_size) #### - # Only sync base model on first call (or if not LoRA-only mode) - if not self.is_lora or not self._base_synced: - conversion_tasks = self._bridge.get_conversion_tasks(self.model) - conversion_tasks = _process_conversion_tasks(conversion_tasks, renamed_megatron_local_weights) - named_weights = self._bridge.export_hf_weights( - self.model, - cpu=False, - conversion_tasks=conversion_tasks, - merge_adapter_weights=not self.is_lora, # Do not return merged (base.weight + lora.weight). - ) - named_weights = ( - ( - ############################## - ###########lora############### - ############################## - hf_param_name, - # _normalize_base_weight_name(hf_param_name), - ############################## - ############################## - ############################## - postprocess_hf_param( - args=self.args, - megatron_param_name=megatron_param_name, - hf_param_name=hf_param_name, - param=weight, - ), - ) - for hf_param_name, weight, megatron_param_name in named_weights - ) - yield from chunk_named_params_by_size(named_weights, chunk_size=self.args.update_weight_buffer_size) - if self.is_lora: - self._base_synced = True - # torch.cuda.synchronize() + # # Only sync base model on first call (or if not LoRA-only mode) + # # if not self.is_lora or not self._base_synced: + # if not self.is_lora: + # # only pass base model + # conversion_tasks = self._bridge.get_conversion_tasks(self.model) + # conversion_tasks = _process_conversion_tasks(conversion_tasks, renamed_megatron_local_weights) + # named_weights = self._bridge.export_hf_weights( + # self.model, + # cpu=False, + # conversion_tasks=conversion_tasks, + # merge_adapter_weights=not self.is_lora, # Do not return merged (base.weight + lora.weight). + # ) + + # # for hf_param_name, weight, megatron_param_name in named_weights: + # # print(hf_param_name) + + # named_weights = ( + # ( + # ############################## + # ###########lora############### + # ############################## + # hf_param_name, + # # _normalize_base_weight_name(hf_param_name), + # ############################## + # ############################## + # ############################## + # postprocess_hf_param( + # args=self.args, + # megatron_param_name=megatron_param_name, + # hf_param_name=hf_param_name, + # param=weight, + # ), + # ) + # for hf_param_name, weight, megatron_param_name in named_weights + # ) + # yield from chunk_named_params_by_size(named_weights, chunk_size=self.args.update_weight_buffer_size) + # if self.is_lora: + # self._base_synced = True + # # torch.cuda.synchronize() ############################## ############################## ############################## + + ############################## ###########lora############### ############################## - if self.is_lora: - lora_weights = self._bridge.export_adapter_weights( - self.model, - cpu=False, - show_progress=False - ) - lora_weights = ( - ( - hf_param_name, - postprocess_hf_param( - args=self.args, - megatron_param_name=megatron_param_name, - hf_param_name=hf_param_name, - param=weight, - ), - ) - for hf_param_name, weight, megatron_param_name in lora_weights - ) - yield from chunk_named_params_by_size(lora_weights, chunk_size=self.args.update_weight_buffer_size) + # if self.is_lora: + # lora_weights = self._bridge.export_adapter_weights( + # self.model, + # cpu=False, + # show_progress=False + # ) + + # lora_weights = ( + # ( + # hf_param_name, + # postprocess_hf_param( + # args=self.args, + # megatron_param_name=megatron_param_name, + # hf_param_name=hf_param_name, + # param=weight, + # ), + # ) + # for hf_param_name, weight, megatron_param_name in lora_weights + # ) + + # yield from chunk_named_params_by_size(lora_weights, chunk_size=self.args.update_weight_buffer_size) ############################## ############################## ############################## diff --git a/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py b/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py index 85ab0ef28..80a42eba4 100644 --- a/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py +++ b/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py @@ -71,6 +71,8 @@ def __init__( # self._hf_weight_iterator = HfWeightIteratorBase.create( # args=args, model=model, model_name=model_name, quantization_config=quantization_config # ) + + self._hf_weight_iterator = HfWeightIteratorBase.create( args=args, model=model, model_name=model_name, quantization_config=quantization_config, is_lora=self.is_lora, @@ -150,93 +152,90 @@ def update_weights(self) -> None: ############################## ###########lora############### ############################## - lora_named_tensors = [] + # lora_named_tensors = [] ############################## ############################## ############################## - for hf_named_tensors in self._hf_weight_iterator.get_hf_weight_chunks(megatron_local_weights): ############################## ###########lora############### ############################## - # refs, long_lived_tensors = self._send_hf_params(hf_named_tensors) - # ray.get(refs) - # del long_lived_tensors - - - # Check if this chunk contains LoRA weights - if self.is_lora: - # print() - lora_weights = [(name, tensor) for name, tensor in hf_named_tensors - if 'lora_' in name.lower() or 'adapter' in name.lower()] - # print(1111111) - # print(hf_named_tensors) - # print(lora_weights) - # print(1111111) - # exit() - base_weights = [(name, tensor) for name, tensor in hf_named_tensors - if 'lora_' not in name.lower()] + refs, long_lived_tensors = self._send_hf_params(hf_named_tensors) + ray.get(refs) + del long_lived_tensors + + # above original way: will pass base+lora weight + + ########## (Optimize - lora pass only): refer - https://github.com/radixark/miles/blob/3b975df8e3e6af3453d36de491a764334f16b059/miles/backends/fsdp_utils/update_weight_utils.py + # # Only pass throguht lora weight + # # Check if this chunk contains LoRA weights + # if self.is_lora: + # # print() + # lora_weights = [(name, tensor) for name, tensor in hf_named_tensors + # if 'lora_' in name.lower() or 'adapter' in name.lower()] + + # base_weights = [(name, tensor) for name, tensor in hf_named_tensors + # if 'lora_' not in name.lower()] - # Sync base weights normally - if base_weights: - refs, long_lived_tensors = self._send_hf_params(base_weights) - ray.get(refs) - del long_lived_tensors + # # Sync base weights normally + # if base_weights: + # refs, long_lived_tensors = self._send_hf_params(base_weights) + # ray.get(refs) + # del long_lived_tensors - # Collect LoRA weights for later - lora_named_tensors.extend(lora_weights) - else: - refs, long_lived_tensors = self._send_hf_params(hf_named_tensors) - ray.get(refs) - del long_lived_tensors + # # Collect LoRA weights for later + # lora_named_tensors.extend(lora_weights) + # else: + # refs, long_lived_tensors = self._send_hf_params(hf_named_tensors) + # ray.get(refs) + # del long_lived_tensors - # After syncing all weights, load LoRA adapter into SGLang - if self.is_lora and lora_named_tensors: - self._load_lora_adapter(lora_named_tensors) + # # After syncing all weights, load LoRA adapter into SGLang + # if self.is_lora and lora_named_tensors: + # self._load_lora_adapter(lora_named_tensors) ############################## ############################## ############################## - dist.barrier(group=get_gloo_group()) ############################## ###########lora############### ############################## - def _load_lora_adapter(self, lora_named_tensors: list[tuple[str, torch.Tensor]]) -> None: - """Load LoRA adapter into SGLang engine.""" - from ..sglang import FlattenedTensorBucket, MultiprocessingSerializer + # def _load_lora_adapter(self, lora_named_tensors: list[tuple[str, torch.Tensor]]) -> None: + # """Load LoRA adapter into SGLang engine.""" + # from ..sglang import FlattenedTensorBucket, MultiprocessingSerializer - # Create config dict - config_dict = { - "peft_type": "LORA", - "r": self.args.lora_rank, - "lora_alpha": self.args.lora_alpha, - "target_modules": list(self.args.target_modules) if self.args.target_modules else [], - "bias": "none", - } + # # Create config dict + # config_dict = { + # "peft_type": "LORA", + # "r": self.args.lora_rank, + # "lora_alpha": self.args.lora_alpha, + # "target_modules": list(self.args.target_modules) if self.args.target_modules else [], + # "bias": "none", + # } - # Serialize LoRA tensors - flattened_tensor_bucket = FlattenedTensorBucket(named_tensors=lora_named_tensors) - flattened_tensor_data = { - "flattened_tensor": flattened_tensor_bucket.get_flattened_tensor(), - "metadata": flattened_tensor_bucket.get_metadata(), - } - serialized_tensors = MultiprocessingSerializer.serialize(flattened_tensor_data, output_str=True) + # # Serialize LoRA tensors + # flattened_tensor_bucket = FlattenedTensorBucket(named_tensors=lora_named_tensors) + # flattened_tensor_data = { + # "flattened_tensor": flattened_tensor_bucket.get_flattened_tensor(), + # "metadata": flattened_tensor_bucket.get_metadata(), + # } + # serialized_tensors = MultiprocessingSerializer.serialize(flattened_tensor_data, output_str=True) - # Load adapter on rank 0 - rank = dist.get_rank() - if rank == 0: - refs = [ - engine.load_lora_adapter_from_tensors.remote( - lora_name=LORA_ADAPTER_NAME, - serialized_tensors=serialized_tensors, - config_dict=config_dict, - ) - for engine in self.rollout_engines - ] - ray.get(refs) - self._lora_loaded = True + # # Load adapter on rank 0 + # rank = dist.get_rank() + # if rank == 0: + # refs = [ + # engine.load_lora_adapter_from_tensors.remote( + # lora_name=LORA_ADAPTER_NAME, + # serialized_tensors=serialized_tensors, + # config_dict=config_dict, + # ) + # for engine in self.rollout_engines + # ] + # ray.get(refs) + # self._lora_loaded = True ############################## ############################## ############################## diff --git a/miles/backends/sglang_utils/sglang_engine.py b/miles/backends/sglang_utils/sglang_engine.py index 1b362dc45..abfb1fb9f 100644 --- a/miles/backends/sglang_utils/sglang_engine.py +++ b/miles/backends/sglang_utils/sglang_engine.py @@ -32,6 +32,13 @@ def convert_target_modules_to_hf(megatron_modules: list[str]) -> list[str]: Megatron: linear_qkv, linear_proj, linear_fc1, linear_fc2 HF: q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj """ + + # If "all-linear" was converted to the standard Megatron list, just return "all" + # This allows SGLang to accept any LoRA adapter regardless of its target modules + # standard_all_linear = ["linear_qkv", "linear_proj", "linear_fc1", "linear_fc2"] + # if set(megatron_modules) == set(standard_all_linear): + # return "all" + # This mapping should match your specific model architecture replacements = { "linear_qkv": ["q_proj", "k_proj", "v_proj"], @@ -96,10 +103,14 @@ def launch_server_process(server_args: ServerArgs) -> multiprocessing.Process: ############################## ###########lora############### ############################## + # for debugging - can be removed # Add logging to see what args are being passed - logger.info(f"Launching SGLang server with args: enable_lora={getattr(server_args, 'enable_lora', None)}, " - f"max_lora_rank={getattr(server_args, 'max_lora_rank', None)}, " - f"base_gpu_id={server_args.base_gpu_id}") + logger.info(f"Launching SGLang server with args:") + logger.info(f" enable_lora={getattr(server_args, 'enable_lora', None)}") + logger.info(f" max_lora_rank={getattr(server_args, 'max_lora_rank', None)}") + logger.info(f" max_loras_per_batch={getattr(server_args, 'max_loras_per_batch', None)}") + logger.info(f" lora_target_modules={getattr(server_args, 'lora_target_modules', None)}") + logger.info(f" base_gpu_id={server_args.base_gpu_id}") ############################## ############################## ############################## @@ -592,15 +603,41 @@ def _compute_server_args( ############################## ###########lora############### ############################## - if is_lora_enabled(args): + # if is_lora_enabled(args): + # kwargs["enable_lora"] = True + # kwargs["max_lora_rank"] = args.lora_rank + # kwargs["max_loras_per_batch"] = 1 + # # NOTE: lora_target_modules might not be supported by your SGLang version + # # Comment out this line if SGLang doesn't support it: + # # kwargs["lora_target_modules"] = convert_target_modules_to_hf(args.target_modules) + # # Log for debugging + # kwargs["lora_target_modules"] = convert_target_modules_to_hf(args.target_modules) + # # kwargs["lora_target_modules_list"] = convert_target_modules_to_hf(args.target_modules) + # # print(1111111111) + # # print(kwargs["lora_target_modules"]) + # # print(2222222222) + # # exit() + + if args.lora_rank > 0 or args.lora_adapter_path is not None: + kwargs["max_loras_per_batch"] = 1 #!!!!!!!! kwargs["enable_lora"] = True - kwargs["max_lora_rank"] = args.lora_rank - kwargs["max_loras_per_batch"] = 1 - # NOTE: lora_target_modules might not be supported by your SGLang version - # Comment out this line if SGLang doesn't support it: - # kwargs["lora_target_modules"] = convert_target_modules_to_hf(args.target_modules) - # Log for debugging + # kwargs["max_lora_rank"] = args.lora_rank + # Ensure a valid positive LoRA rank is passed to the SGLang engine. + # If LoRA is enabled via adapter path but lora_rank is not set to a + # positive value, default to rank 1 to avoid an invalid configuration. + if getattr(args, "lora_rank", None) and args.lora_rank > 0: + max_lora_rank = args.lora_rank + else: + max_lora_rank = 1 + kwargs["max_lora_rank"] = max_lora_rank + # kwargs["lora_target_modules"] = args.target_modules kwargs["lora_target_modules"] = convert_target_modules_to_hf(args.target_modules) + + ##### For rollout debug mode to add: + if args.debug_rollout_only and args.lora_adapter_path: + from miles.backends.megatron_utils.lora_utils import LORA_ADAPTER_NAME + # SGLang lora_paths Format: {"adapter_name": "path_to_adapter"} + kwargs["lora_paths"] = {LORA_ADAPTER_NAME: args.lora_adapter_path} ############################## ############################## ############################## @@ -613,6 +650,7 @@ def _compute_server_args( kwargs[attr.name] = getattr(args, f"sglang_{attr.name}") unused_keys.discard(attr.name) + # for compatibility with old args if len(unused_keys) > 0: logger.info(f"Warning: The following arguments is not supported in the current sglang: {unused_keys}.") diff --git a/miles/ray/rollout.py b/miles/ray/rollout.py index 725b82617..08c886516 100644 --- a/miles/ray/rollout.py +++ b/miles/ray/rollout.py @@ -145,7 +145,6 @@ def offload(self, tags: list[str] | None = None): if engine is not None ] ) - def health_monitoring_pause(self): if self.args.use_fault_tolerance and hasattr(self, '_health_monitor'): @@ -154,13 +153,11 @@ def health_monitoring_pause(self): def health_monitoring_resume(self): if self.args.use_fault_tolerance and hasattr(self, '_health_monitor'): self._health_monitor.start() - ############################## ############################## ############################## - def onload(self, tags: list[str] = None): return ray.get([engine.resume_memory_occupation.remote(tags=tags) for engine in self.rollout_engines]) @@ -487,6 +484,23 @@ def init_rollout_engines(args, pg, all_rollout_engines): init_handles = [engine.init.remote(**(addr_and_ports[rank])) for rank, engine in rollout_engines] ray.get(init_handles) + # ############################## + # ###########lora############### + # ############################## + # # Load LoRA adapter from disk in debug-rollout-only mode + # if args.debug_rollout_only and args.lora_adapter_path: + # from miles.backends.megatron_utils.lora_utils import LORA_ADAPTER_NAME + # logger.info(f"Loading LoRA adapter from {args.lora_adapter_path} for debug-rollout-only mode") + # for i, engine in rollout_engines: + # ray.get(engine.load_lora_adapter.remote( + # lora_name=LORA_ADAPTER_NAME, + # lora_path=args.lora_adapter_path + # )) + # logger.info("LoRA adapter loaded successfully") + # ############################## + # ############################## + # ############################## + return num_new_engines diff --git a/miles/rollout/sglang_rollout.py b/miles/rollout/sglang_rollout.py index d8fb2cc18..74c4bfd1d 100644 --- a/miles/rollout/sglang_rollout.py +++ b/miles/rollout/sglang_rollout.py @@ -25,7 +25,13 @@ from .rm_hub import async_rm, batched_async_rm +############################## +###########lora############### +############################## from miles.backends.sglang_utils.sglang_engine import is_lora_enabled +############################## +############################## +############################## __all__ = ["generate_rollout"] diff --git a/miles/utils/arguments.py b/miles/utils/arguments.py index 63bf34d57..78c6ac01f 100644 --- a/miles/utils/arguments.py +++ b/miles/utils/arguments.py @@ -1609,6 +1609,33 @@ def miles_validate_args(args): if args.save_interval is not None: assert args.save is not None, "'--save' is required when save_interval is set." + ############################## + ###########lora############### + ############################## + if args.lora_rank > 0: + assert args.save is not None, "'--save' is required when LoRA is enabled." + assert args.target_modules is not None, "'--target-modules' is required when LoRA is enabled." + + if args.target_modules == "all-linear": + modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] + elif "," in args.target_modules: + modules = [m.strip() for m in args.target_modules.split(",")] + else: + modules = [args.target_modules] + + if args.exclude_modules: + exclude_set = ( + set(m.strip() for m in args.exclude_modules.split(",")) + if "," in args.exclude_modules + else {args.exclude_modules} + ) + modules = [m for m in modules if m not in exclude_set] + + args.target_modules = modules + ############################## + ############################## + ############################## + assert not (args.kl_coef != 0 and args.kl_loss_coef != 0), "Only one of kl_coef and kl_loss_coef can be set" if args.advantage_estimator in ["reinforce_plus_plus", "reinforce_plus_plus_baseline"]: @@ -1745,34 +1772,6 @@ def miles_validate_args(args): if args.enable_mtp_training: assert args.mtp_num_layers, "mtp_num_layers must be set when enable_mtp_training is set" - ############################## - ###########lora############### - ############################## - ### considert move these to megatron arguments.py - if args.lora_rank > 0: - assert args.save is not None, "'--save' is required when LoRA is enabled." - assert args.target_modules is not None, "'--target-modules' is required when LoRA is enabled." - - # Parse target modules - if args.target_modules == "all-linear": - # to-do: need to check both on megatron and sglang side support modules and names - # Megatron module names - modules = ["linear_qkv", "linear_proj", "linear_fc1", "linear_fc2"] - elif "," in args.target_modules: - modules = [m.strip() for m in args.target_modules.split(",")] - else: - modules = [args.target_modules] - - # Handle excluded modules - if args.exclude_modules: - exclude_set = set(m.strip() for m in args.exclude_modules.split(",")) - modules = [m for m in modules if m not in exclude_set] - - args.target_modules = modules - ############################## - ############################## - ############################## - if args.use_rollout_routing_replay: args.use_routing_replay = True From e1934179e8bf16ae98c0fa926d73ff27a50fcfa0 Mon Sep 17 00:00:00 2001 From: Yusheng Su Date: Wed, 14 Jan 2026 10:40:58 +0000 Subject: [PATCH 06/12] 1.minor fix 2.change Lora to CanonicalLoRA - fix cuda problem. - efficient Lora should be supported --- .../run-qwen2.5-0.5B-gsm8k-lora.sh | 16 +- miles/backends/megatron_utils/actor.py | 2 +- miles/backends/megatron_utils/checkpoint.py | 28 +- miles/backends/megatron_utils/lora_utils.py | 325 ++++++++++++++---- miles/backends/megatron_utils/model.py | 48 ++- .../hf_weight_iterator_bridge.py | 154 ++++----- miles/utils/arguments.py | 1 + 7 files changed, 409 insertions(+), 165 deletions(-) diff --git a/examples/reproducibility/run-qwen2.5-0.5B-gsm8k-lora.sh b/examples/reproducibility/run-qwen2.5-0.5B-gsm8k-lora.sh index b5f257d86..52d1075ac 100644 --- a/examples/reproducibility/run-qwen2.5-0.5B-gsm8k-lora.sh +++ b/examples/reproducibility/run-qwen2.5-0.5B-gsm8k-lora.sh @@ -50,10 +50,10 @@ LORA_ARGS=( ############################## # # Debug #### inference - --debug-rollout-only + #--debug-rollout-only ### --lora-adapter-path /root/checkpoints/qwen2.5-0.5B-lora-megatron/lora_adapter.pt - --lora-adapter-path lewtun/Qwen2.5-0.5B-SFT-LoRA - # --lora-adapter-path /root/checkpoints/qwen2.5-0.5B-lora-megatron/ + #--lora-adapter-path lewtun/Qwen2.5-0.5B-SFT-LoRA + ## --lora-adapter-path /root/checkpoints/qwen2.5-0.5B-lora-megatron/ ### #### train @@ -86,12 +86,14 @@ ROLLOUT_ARGS=( --rm-type math # --num-rollout 100 --num-rollout 10 # onyl train 10 stesp - --rollout-batch-size 32 + # --rollout-batch-size 32 + --rollout-batch-size 16 --n-samples-per-prompt 8 --rollout-max-response-len 1024 --rollout-temperature 1 - --global-batch-size 256 + # --global-batch-size 256 + --global-batch-size 32 ) EVAL_ARGS=( @@ -168,7 +170,9 @@ MISC_ARGS=( ############################## ###########lora############### ############################## -export GPUS_PER_NODE=2 +# export GPUS_PER_NODE=1 +# export GPUS_PER_NODE=4 +export GPUS_PER_NODE=8 ############################## ############################## ############################## diff --git a/miles/backends/megatron_utils/actor.py b/miles/backends/megatron_utils/actor.py index e0634d39e..05753d3d9 100644 --- a/miles/backends/megatron_utils/actor.py +++ b/miles/backends/megatron_utils/actor.py @@ -43,7 +43,7 @@ from .lora_utils import ( is_lora_enabled, is_lora_model, - apply_lora_to_megatron_model, + # apply_lora_to_megatron_model, # get_lora_weights_and_config, freeze_base_model, ) diff --git a/miles/backends/megatron_utils/checkpoint.py b/miles/backends/megatron_utils/checkpoint.py index 914bab96d..9681923f7 100644 --- a/miles/backends/megatron_utils/checkpoint.py +++ b/miles/backends/megatron_utils/checkpoint.py @@ -30,6 +30,7 @@ ############################## + def load_checkpoint(ddp_model, optimizer, opt_param_scheduler, checkpointing_context, skip_load_to_model_and_opt): # ref: how megatron `load_checkpoint` gets directory args = get_args() @@ -43,16 +44,31 @@ def load_checkpoint(ddp_model, optimizer, opt_param_scheduler, checkpointing_con ###########lora############### ############################## # Check for LoRA adapter first - lora_path = Path(load_path) / "adapter" - if lora_path.exists() and is_lora_model(ddp_model): - logger.info(f"Loading LoRA checkpoint from {lora_path}") - iteration = load_lora_checkpoint(ddp_model, args, str(lora_path)) - num_floating_point_operations_so_far = 0 - return iteration, num_floating_point_operations_so_far + ## Not correct - need to check the saving format and name + if is_lora_model(ddp_model): + lora_path = Path(load_path) / "adapter" + if lora_path.exists(): + logger.info(f"Loading LoRA checkpoint from {lora_path}") + iteration = load_lora_checkpoint(ddp_model, args, str(lora_path)) + num_floating_point_operations_so_far = 0 + return iteration, num_floating_point_operations_so_far + else: + logger.info(f"Not Found LoRA checkpoint from {lora_path}. Use the random initial weight.") ############################## ############################## ############################## + + ############################## + ###########lora############### + ############################## + # (to-do): yusheng- Also need to add megatron load lora function + ############################## + ############################## + ############################## + ############################## + + if _is_megatron_checkpoint(load_path): return _load_checkpoint_megatron( ddp_model=ddp_model, diff --git a/miles/backends/megatron_utils/lora_utils.py b/miles/backends/megatron_utils/lora_utils.py index 44f9355b4..0401131b4 100644 --- a/miles/backends/megatron_utils/lora_utils.py +++ b/miles/backends/megatron_utils/lora_utils.py @@ -27,78 +27,78 @@ def is_lora_enabled(args: Namespace) -> bool: return args.lora_rank > 0 or args.lora_adapter_path is not None -def apply_lora_to_megatron_model( - model: Sequence[torch.nn.Module], - args: Namespace, -) -> Sequence[torch.nn.Module]: - """Apply LoRA to Megatron model using Megatron-Bridge PEFT integration. +# def apply_lora_to_megatron_model( +# model: Sequence[torch.nn.Module], +# args: Namespace, +# ) -> Sequence[torch.nn.Module]: +# """Apply LoRA to Megatron model using Megatron-Bridge PEFT integration. - This uses the Megatron-Bridge's PEFT support from: - https://github.com/NVIDIA-NeMo/Megatron-Bridge/tree/main/src/megatron/bridge/peft +# This uses the Megatron-Bridge's PEFT support from: +# https://github.com/NVIDIA-NeMo/Megatron-Bridge/tree/main/src/megatron/bridge/peft - Note: in this version implementation, we use this Megatron-Bridge branch: https://github.com/yushengsu-thu/Megatron-Bridge/tree/merged-megatron-0.16.0rc0 +# Note: in this version implementation, we use this Megatron-Bridge branch: https://github.com/yushengsu-thu/Megatron-Bridge/tree/merged-megatron-0.16.0rc0 - Args: - model: Megatron model (DDP wrapped) - args: Training arguments with LoRA config +# Args: +# model: Megatron model (DDP wrapped) +# args: Training arguments with LoRA config - Returns: - LoRA-wrapped model - """ - # from megatron.bridge.peft import apply_lora_adapter, LoraConfig - from megatron.bridge.peft.lora import LoRA - - if args.lora_adapter_path: - # TODO: Loading existing LoRA adapter needs separate implementation - # Megatron-Bridge may have different API for loading - # Refer to this one: https://github.com/volcengine/verl/pull/4063/files#diff-10d5abfbdb508c9478018ad08f295686a960701639fc4e3f3c24a4bdc2f0b711 - raise NotImplementedError("Loading existing LoRA adapter is not yet implemented") - else: - # Determine lora_dtype from args - if hasattr(args, 'bf16') and args.bf16: - lora_dtype = torch.bfloat16 - elif hasattr(args, 'fp16') and args.fp16: - lora_dtype = torch.float16 - else: - lora_dtype = None # Will use model's dtype +# Returns: +# LoRA-wrapped model +# """ +# # from megatron.bridge.peft import apply_lora_adapter, LoraConfig +# from megatron.bridge.peft.lora import LoRA + +# if args.lora_adapter_path: +# # TODO: Loading existing LoRA adapter needs separate implementation +# # Megatron-Bridge may have different API for loading +# # Refer to this one: https://github.com/volcengine/verl/pull/4063/files#diff-10d5abfbdb508c9478018ad08f295686a960701639fc4e3f3c24a4bdc2f0b711 +# raise NotImplementedError("Loading existing LoRA adapter is not yet implemented") +# else: +# # Determine lora_dtype from args +# if hasattr(args, 'bf16') and args.bf16: +# lora_dtype = torch.bfloat16 +# elif hasattr(args, 'fp16') and args.fp16: +# lora_dtype = torch.float16 +# else: +# lora_dtype = None # Will use model's dtype - # Get exclude_modules as list - exclude_modules = [] - if hasattr(args, 'exclude_modules') and args.exclude_modules: - if isinstance(args.exclude_modules, str): - exclude_modules = [m.strip() for m in args.exclude_modules.split(",")] - else: - exclude_modules = list(args.exclude_modules) +# # Get exclude_modules as list +# exclude_modules = [] +# if hasattr(args, 'exclude_modules') and args.exclude_modules: +# if isinstance(args.exclude_modules, str): +# exclude_modules = [m.strip() for m in args.exclude_modules.split(",")] +# else: +# exclude_modules = list(args.exclude_modules) - # Create new LoRA adapter using Megatron-Bridge LoRA dataclass - # There are different lora_type, I just use the classic one (speed and acc might not the optimal) - # https://github.com/volcengine/verl/pull/4063/files#diff-10d5abfbdb508c9478018ad08f295686a960701639fc4e3f3c24a4bdc2f0b711 - lora = LoRA( - target_modules=args.target_modules, # e.g., ["linear_qkv", "linear_proj", "linear_fc1", "linear_fc2"] - exclude_modules=exclude_modules, # Modules to exclude from LoRA - dim=args.lora_rank, # LoRA rank (called 'dim' in Megatron-Bridge) - alpha=args.lora_alpha, # LoRA alpha scaling factor - dropout=args.lora_dropout, # LoRA dropout rate - dropout_position=getattr(args, 'lora_dropout_position', 'pre'), # 'pre' or 'post' - lora_A_init_method=getattr(args, 'lora_A_init_method', 'xavier'), # Initialization for LoRA A matrix - lora_B_init_method=getattr(args, 'lora_B_init_method', 'zero'), # Initialization for LoRA B matrix - a2a_experimental=getattr(args, 'lora_a2a_experimental', False), # Experimental All-to-All communication - lora_dtype=lora_dtype, # Parameter data type for LoRA weights - ) - logger.info(f"Applying LoRA: rank={args.lora_rank}, alpha={args.lora_alpha}, " - f"dropout={args.lora_dropout}, target_modules={args.target_modules}, " - f"exclude_modules={exclude_modules}, lora_dtype={lora_dtype}") +# # Create new LoRA adapter using Megatron-Bridge LoRA dataclass +# # There are different lora_type, I just use the classic one (speed and acc might not the optimal) +# # https://github.com/volcengine/verl/pull/4063/files#diff-10d5abfbdb508c9478018ad08f295686a960701639fc4e3f3c24a4bdc2f0b711 +# lora = LoRA( +# target_modules=args.target_modules, # e.g., ["linear_qkv", "linear_proj", "linear_fc1", "linear_fc2"] +# exclude_modules=exclude_modules, # Modules to exclude from LoRA +# dim=args.lora_rank, # LoRA rank (called 'dim' in Megatron-Bridge) +# alpha=args.lora_alpha, # LoRA alpha scaling factor +# dropout=args.lora_dropout, # LoRA dropout rate +# dropout_position=getattr(args, 'lora_dropout_position', 'pre'), # 'pre' or 'post' +# lora_A_init_method=getattr(args, 'lora_A_init_method', 'xavier'), # Initialization for LoRA A matrix +# lora_B_init_method=getattr(args, 'lora_B_init_method', 'zero'), # Initialization for LoRA B matrix +# a2a_experimental=getattr(args, 'lora_a2a_experimental', False), # Experimental All-to-All communication +# lora_dtype=lora_dtype, # Parameter data type for LoRA weights +# ) +# logger.info(f"Applying LoRA: rank={args.lora_rank}, alpha={args.lora_alpha}, " +# f"dropout={args.lora_dropout}, target_modules={args.target_modules}, " +# f"exclude_modules={exclude_modules}, lora_dtype={lora_dtype}") - # Apply LoRA to each model chunk - # The LoRA class is callable - calling it applies the transformation - for model_chunk in model: - # lora(model_chunk.module, training=True) applies LoRA and freezes base model - lora(model_chunk.module, training=True) +# # Apply LoRA to each model chunk +# # The LoRA class is callable - calling it applies the transformation +# for model_chunk in model: +# # lora(model_chunk.module, training=True) applies LoRA and freezes base model +# lora(model_chunk.module, training=True) - # Print trainable parameters info - _print_trainable_parameters(model) +# # Print trainable parameters info +# _print_trainable_parameters(model) - return model +# return model def _print_trainable_parameters(model: Sequence[torch.nn.Module]) -> None: @@ -129,12 +129,13 @@ def _print_trainable_parameters(model: Sequence[torch.nn.Module]) -> None: def is_lora_model(model: Sequence[torch.nn.Module]) -> bool: """Check if model has LoRA layers applied.""" + for model_chunk in model: if hasattr(model_chunk.module, "peft_config"): return True # Check for LoRA layers in parameters for name, _ in model_chunk.named_parameters(): - if "lora_" in name: + if "lora_" in name or "adapter" in name: return True return False @@ -423,6 +424,202 @@ def load_lora_checkpoint( +####!!!! (to-do) yusheng: need to based on different Lora to provide the different mapping +from typing import Union, Type, TYPE_CHECKING +if TYPE_CHECKING: + from megatron.bridge.peft.lora import LoRA + from megatron.bridge.peft.canonical_lora import CanonicalLoRA + +def convert_target_modules_to_megatron( + hf_modules: list[str], + lora_type: Union[Type, object, None] = None +) -> list[str]: + """Convert HuggingFace LoRA target module names to Megatron format. + + HF: q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj + + Megatron (Standard LoRA): linear_qkv, linear_proj, linear_fc1, linear_fc2 + Megatron (CanonicalLoRA): linear_q, linear_k, linear_v, linear_proj, + linear_fc1_up, linear_fc1_gate, linear_fc2 + + Special cases: + - "all", "all-linear", "all_linear" -> returns all standard Megatron linear modules + + Args: + hf_modules: List of HuggingFace module names or single string + lora_type: LoRA class or instance (LoRA, CanonicalLoRA, etc.) + If None, defaults to CanonicalLoRA format + + Returns: + List of Megatron module names + + If input is already in Megatron format, returns as-is without conversion. + """ + # Get the class name whether lora_type is a class or an instance + if lora_type is not None: + if isinstance(lora_type, type): + # It's a class + class_name = lora_type.__name__ + else: + # It's an instance + class_name = type(lora_type).__name__ + + logger.info(f"Converting target modules for {class_name}") + else: + # Default to CanonicalLoRA if not specified + class_name = "CanonicalLoRA" + logger.info(f"Converting target modules (defaulting to CanonicalLoRA)") + + # Handle special cases for "all" variants + if isinstance(hf_modules, str): + if hf_modules in ["all", "all-linear", "all_linear"]: + if class_name == "CanonicalLoRA": + return ["linear_q", "linear_k", "linear_v", "linear_proj", + "linear_fc1_up", "linear_fc1_gate", "linear_fc2"] + else: # Standard LoRA + return ["linear_qkv", "linear_proj", "linear_fc1", "linear_fc2"] + # Convert single string to list + hf_modules = [hf_modules] + elif isinstance(hf_modules, list) and len(hf_modules) == 1: + if hf_modules[0] in ["all", "all-linear", "all_linear"]: + if class_name == "CanonicalLoRA": + return ["linear_q", "linear_k", "linear_v", "linear_proj", + "linear_fc1_up", "linear_fc1_gate", "linear_fc2"] + else: # Standard LoRA + return ["linear_qkv", "linear_proj", "linear_fc1", "linear_fc2"] + + # Define module name sets based on LoRA type + if class_name == "CanonicalLoRA": + megatron_modules_set = { + "linear_q", "linear_k", "linear_v", "linear_proj", + "linear_fc1_up", "linear_fc1_gate", "linear_fc2" + } + else: # Standard LoRA + megatron_modules_set = {"linear_qkv", "linear_proj", "linear_fc1", "linear_fc2"} + + hf_modules_set = {"q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"} + + # Check if all modules are already in Megatron format (or wildcards) + all_megatron_format = True + for module in hf_modules: + # Skip wildcard patterns (e.g., "*.layers.0.*.linear_qkv") + if "*" in module: + continue + # Check if it's a known HF module name + if module in hf_modules_set: + all_megatron_format = False + break + + # If already in Megatron format, return as-is + if all_megatron_format: + return hf_modules + + # Otherwise, perform conversion based on LoRA type + if class_name == "CanonicalLoRA": + # CanonicalLoRA: Split Q/K/V and up/gate + hf_to_megatron = { + "q_proj": "linear_q", + "k_proj": "linear_k", + "v_proj": "linear_v", + "o_proj": "linear_proj", + "gate_proj": "linear_fc1_gate", + "up_proj": "linear_fc1_up", + "down_proj": "linear_fc2", + } + else: # Standard LoRA + # Standard LoRA: Merged Q/K/V and merged up/gate + hf_to_megatron = { + "q_proj": "linear_qkv", + "k_proj": "linear_qkv", + "v_proj": "linear_qkv", + "o_proj": "linear_proj", + "gate_proj": "linear_fc1", + "up_proj": "linear_fc1", + "down_proj": "linear_fc2", + } + + megatron_modules = [] + for module in hf_modules: + if module in hf_to_megatron: + megatron_name = hf_to_megatron[module] + if megatron_name not in megatron_modules: + megatron_modules.append(megatron_name) + else: + # Keep as-is if not in mapping (might already be Megatron format or wildcard) + if module not in megatron_modules: + megatron_modules.append(module) + + return megatron_modules + + + +# def convert_target_modules_to_megatron(hf_modules: list[str]) -> list[str]: +# """Convert HuggingFace LoRA target module names to Megatron format. + +# HF: q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj +# Megatron: linear_qkv, linear_proj, linear_fc1, linear_fc2 + +# Special cases: +# - "all", "all-linear", "all_linear" -> returns all standard Megatron linear modules + +# If input is already in Megatron format, returns as-is without conversion. +# """ +# # Handle special cases for "all" variants +# if isinstance(hf_modules, str): +# if hf_modules in ["all", "all-linear", "all_linear"]: +# return ["linear_qkv", "linear_proj", "linear_fc1", "linear_fc2"] +# # Convert single string to list +# hf_modules = [hf_modules] +# elif isinstance(hf_modules, list) and len(hf_modules) == 1: +# if hf_modules[0] in ["all", "all-linear", "all_linear"]: +# return ["linear_qkv", "linear_proj", "linear_fc1", "linear_fc2"] + +# # Define Megatron and HF module name sets +# megatron_modules_set = {"linear_qkv", "linear_proj", "linear_fc1", "linear_fc2"} +# hf_modules_set = {"q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"} + +# # Check if all modules are already in Megatron format (or wildcards) +# all_megatron_format = True +# for module in hf_modules: +# # Skip wildcard patterns (e.g., "*.layers.0.*.linear_qkv") +# if "*" in module: +# continue +# # Check if it's a known HF module name +# if module in hf_modules_set: +# all_megatron_format = False +# break + +# # If already in Megatron format, return as-is +# if all_megatron_format: +# return hf_modules + +# # Otherwise, perform conversion +# hf_to_megatron = { +# "q_proj": "linear_qkv", +# "k_proj": "linear_qkv", +# "v_proj": "linear_qkv", +# "o_proj": "linear_proj", +# "gate_proj": "linear_fc1", +# "up_proj": "linear_fc1", +# "down_proj": "linear_fc2", +# } + +# megatron_modules = [] +# for module in hf_modules: +# if module in hf_to_megatron: +# megatron_name = hf_to_megatron[module] +# if megatron_name not in megatron_modules: +# megatron_modules.append(megatron_name) +# else: +# # Keep as-is if not in mapping (might already be Megatron format or wildcard) +# if module not in megatron_modules: +# megatron_modules.append(module) + +# return megatron_modules + + + + def freeze_base_model(model: Sequence[torch.nn.Module]) -> None: """Freeze base model parameters, only keep LoRA trainable.""" for model_chunk in model: diff --git a/miles/backends/megatron_utils/model.py b/miles/backends/megatron_utils/model.py index 14edfe011..7c2e3dd05 100644 --- a/miles/backends/megatron_utils/model.py +++ b/miles/backends/megatron_utils/model.py @@ -32,7 +32,8 @@ from .checkpoint import load_checkpoint, save_checkpoint, save_checkpoint_with_lora from .lora_utils import is_lora_model # from miles.backends.megatron_utils.lora_utils import is_lora_enabled -from miles.backends.megatron_utils.lora_utils import is_lora_enabled, apply_lora_to_megatron_model +# from miles.backends.megatron_utils.lora_utils import is_lora_enabled, apply_lora_to_megatron_model +from miles.backends.megatron_utils.lora_utils import is_lora_enabled, convert_target_modules_to_megatron ############################## ############################## ############################## @@ -161,16 +162,25 @@ def setup_model_and_optimizer( ########### + # This part can be moved to `lora_utils.py` def apply_lora_to_megatron_model if is_lora_enabled(args) and role == "actor" and args.megatron_to_hf_mode == "bridge": + # if is_lora_enabled(args) and args.megatron_to_hf_mode == "bridge": + # The below written as: get_model_provider_func() usage from megatron.core.distributed import DistributedDataParallelConfig from megatron.bridge.models.model_provider import get_model as bridge_get_model from megatron.bridge import AutoBridge from megatron.bridge.peft.lora import LoRA + from megatron.bridge.peft.canonical_lora import CanonicalLoRA import torch + + # This is register_canonical_lora_adapter usgae - more advnace and efficiency!!!! + # Compare lora, canonical_lora_adapter, .... # Build the provider from HF checkpoint bridge = AutoBridge.from_hf_pretrained(args.hf_checkpoint, trust_remote_code=True) - provider = bridge.to_megatron_provider(load_weights=False) + ##!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + # provider = bridge.to_megatron_provider(load_weights=False) + provider = bridge.to_megatron_provider(load_weights=True) # different from full model training - in the training script, I need to load tuned base model weight and initial lora weights. Need to carefully check and optimize - where to load the base model? (but why in `model_provider.py` using: provider = bridge.to_megatron_provider(load_weights=False)) # Set parallel configs on the provider provider.tensor_model_parallel_size = args.tensor_model_parallel_size @@ -194,19 +204,26 @@ def setup_model_and_optimizer( exclude_modules = [m.strip() for m in args.exclude_modules.split(",")] else: exclude_modules = list(args.exclude_modules) + # Convert HF module names to Megatron format + # exclude_modules = convert_target_modules_to_megatron(exclude_modules) + exclude_modules = convert_target_modules_to_megatron(exclude_modules, lora_type=CanonicalLoRA) # Create LoRA config - lora = LoRA( - target_modules=args.target_modules, + # lora = LoRA( + lora = CanonicalLoRA( + # target_modules=args.target_modules, + # target_modules=convert_target_modules_to_megatron(args.target_modules), + target_modules=convert_target_modules_to_megatron(args.target_modules, lora_type=CanonicalLoRA), exclude_modules=exclude_modules, dim=args.lora_rank, alpha=args.lora_alpha, dropout=args.lora_dropout, - dropout_position=getattr(args, 'lora_dropout_position', 'pre'), - lora_A_init_method=getattr(args, 'lora_A_init_method', 'xavier'), - lora_B_init_method=getattr(args, 'lora_B_init_method', 'zero'), - a2a_experimental=getattr(args, 'lora_a2a_experimental', False), - lora_dtype=lora_dtype, + ##Below for Lora + # dropout_position=getattr(args, 'lora_dropout_position', 'pre'), + # lora_A_init_method=getattr(args, 'lora_A_init_method', 'xavier'), + # lora_B_init_method=getattr(args, 'lora_B_init_method', 'zero'), + # a2a_experimental=getattr(args, 'lora_a2a_experimental', False), + # lora_dtype=lora_dtype, ) # Define pre_wrap_hook to apply LoRA before DDP wrapping @@ -240,7 +257,8 @@ def apply_lora_hook(model_chunks): fp16=getattr(args, 'fp16', False), pre_wrap_hook=provider.pre_wrap_hook, ) - + + # ???? the below can be remove ??? or it's tag to jude the model is (base) or (base + lora) # Store lora instance for later use (e.g., checkpoint saving) # You may want to attach this to the model or args for later access if hasattr(args, '_lora_instance'): @@ -250,7 +268,6 @@ def apply_lora_hook(model_chunks): # Original non-LoRA path or non-bridge mode model = get_model(get_model_provider_func(args, role), ModelType.encoder_or_decoder) - ############################## ############################## ############################## @@ -953,6 +970,15 @@ def save_hf_model(args, rollout_id: int, model: Sequence[DDP]) -> None: if should_log: logger.error(f"Failed to save HuggingFace format: {e}") + + ############################## + ###########lora############### + ############################## + # to-do: also need to impl lora saving + ############################## + ############################## + ############################## + def initialize_model_and_optimizer( args: Namespace, role: str = "actor" diff --git a/miles/backends/megatron_utils/update_weight/hf_weight_iterator_bridge.py b/miles/backends/megatron_utils/update_weight/hf_weight_iterator_bridge.py index 4927b2000..f4a0b752f 100644 --- a/miles/backends/megatron_utils/update_weight/hf_weight_iterator_bridge.py +++ b/miles/backends/megatron_utils/update_weight/hf_weight_iterator_bridge.py @@ -49,74 +49,74 @@ def get_hf_weight_chunks(self, megatron_local_weights): ############################## ###########lora############### ############################## - ## This is the origin way - weight sync will process - base model + lora weights - ## to-do (yusheng): Optimize: use the method in `self.is_lora` but need to deal with CUDA issue (weight not on the same device) - might need to be delt with in megatron-core + # ## This is the origin way - weight sync will process - base model + lora weights + # ## to-do (yusheng): Optimize: use the method in `self.is_lora` but need to deal with CUDA issue (weight not on the same device) - might need to be delt with in megatron-core - conversion_tasks = self._bridge.get_conversion_tasks(self.model) - conversion_tasks = _process_conversion_tasks(conversion_tasks, renamed_megatron_local_weights) + # conversion_tasks = self._bridge.get_conversion_tasks(self.model) + # conversion_tasks = _process_conversion_tasks(conversion_tasks, renamed_megatron_local_weights) - named_weights = self._bridge.export_hf_weights(self.model, cpu=False, conversion_tasks=conversion_tasks) + # named_weights = self._bridge.export_hf_weights(self.model, cpu=False, conversion_tasks=conversion_tasks) - # for hf_param_name, weight, megatron_param_name in named_weights: - # print(hf_param_name) - # exit() + # # for hf_param_name, weight, megatron_param_name in named_weights: + # # print(hf_param_name) + # # exit() - named_weights = ( - ( - hf_param_name, - postprocess_hf_param( - args=self.args, - megatron_param_name=megatron_param_name, - hf_param_name=hf_param_name, - param=weight, - ), - ) - for hf_param_name, weight, megatron_param_name in named_weights - ) + # named_weights = ( + # ( + # hf_param_name, + # postprocess_hf_param( + # args=self.args, + # megatron_param_name=megatron_param_name, + # hf_param_name=hf_param_name, + # param=weight, + # ), + # ) + # for hf_param_name, weight, megatron_param_name in named_weights + # ) - yield from chunk_named_params_by_size(named_weights, chunk_size=self.args.update_weight_buffer_size) + # yield from chunk_named_params_by_size(named_weights, chunk_size=self.args.update_weight_buffer_size) #### - # # Only sync base model on first call (or if not LoRA-only mode) - # # if not self.is_lora or not self._base_synced: - # if not self.is_lora: - # # only pass base model - # conversion_tasks = self._bridge.get_conversion_tasks(self.model) - # conversion_tasks = _process_conversion_tasks(conversion_tasks, renamed_megatron_local_weights) - # named_weights = self._bridge.export_hf_weights( - # self.model, - # cpu=False, - # conversion_tasks=conversion_tasks, - # merge_adapter_weights=not self.is_lora, # Do not return merged (base.weight + lora.weight). - # ) + # Only sync base model on first call (or if not LoRA-only mode) + # if not self.is_lora or not self._base_synced: + if not self.is_lora: + # only pass base model + conversion_tasks = self._bridge.get_conversion_tasks(self.model) + conversion_tasks = _process_conversion_tasks(conversion_tasks, renamed_megatron_local_weights) + named_weights = self._bridge.export_hf_weights( + self.model, + cpu=False, + conversion_tasks=conversion_tasks, + merge_adapter_weights=not self.is_lora, # Do not return merged (base.weight + lora.weight). + ) - # # for hf_param_name, weight, megatron_param_name in named_weights: - # # print(hf_param_name) + # for hf_param_name, weight, megatron_param_name in named_weights: + # print(hf_param_name) - # named_weights = ( - # ( - # ############################## - # ###########lora############### - # ############################## - # hf_param_name, - # # _normalize_base_weight_name(hf_param_name), - # ############################## - # ############################## - # ############################## - # postprocess_hf_param( - # args=self.args, - # megatron_param_name=megatron_param_name, - # hf_param_name=hf_param_name, - # param=weight, - # ), - # ) - # for hf_param_name, weight, megatron_param_name in named_weights - # ) - # yield from chunk_named_params_by_size(named_weights, chunk_size=self.args.update_weight_buffer_size) - # if self.is_lora: - # self._base_synced = True - # # torch.cuda.synchronize() + named_weights = ( + ( + ############################## + ###########lora############### + ############################## + hf_param_name, + # _normalize_base_weight_name(hf_param_name), + ############################## + ############################## + ############################## + postprocess_hf_param( + args=self.args, + megatron_param_name=megatron_param_name, + hf_param_name=hf_param_name, + param=weight, + ), + ) + for hf_param_name, weight, megatron_param_name in named_weights + ) + yield from chunk_named_params_by_size(named_weights, chunk_size=self.args.update_weight_buffer_size) + if self.is_lora: + self._base_synced = True + # torch.cuda.synchronize() ############################## ############################## ############################## @@ -126,27 +126,27 @@ def get_hf_weight_chunks(self, megatron_local_weights): ############################## ###########lora############### ############################## - # if self.is_lora: - # lora_weights = self._bridge.export_adapter_weights( - # self.model, - # cpu=False, - # show_progress=False - # ) + if self.is_lora: + lora_weights = self._bridge.export_adapter_weights( + self.model, + cpu=False, + show_progress=False + ) - # lora_weights = ( - # ( - # hf_param_name, - # postprocess_hf_param( - # args=self.args, - # megatron_param_name=megatron_param_name, - # hf_param_name=hf_param_name, - # param=weight, - # ), - # ) - # for hf_param_name, weight, megatron_param_name in lora_weights - # ) + lora_weights = ( + ( + hf_param_name, + postprocess_hf_param( + args=self.args, + megatron_param_name=megatron_param_name, + hf_param_name=hf_param_name, + param=weight, + ), + ) + for hf_param_name, weight, megatron_param_name in lora_weights + ) - # yield from chunk_named_params_by_size(lora_weights, chunk_size=self.args.update_weight_buffer_size) + yield from chunk_named_params_by_size(lora_weights, chunk_size=self.args.update_weight_buffer_size) ############################## ############################## ############################## diff --git a/miles/utils/arguments.py b/miles/utils/arguments.py index 78c6ac01f..43e147d58 100644 --- a/miles/utils/arguments.py +++ b/miles/utils/arguments.py @@ -1616,6 +1616,7 @@ def miles_validate_args(args): assert args.save is not None, "'--save' is required when LoRA is enabled." assert args.target_modules is not None, "'--target-modules' is required when LoRA is enabled." + # (to-do) yusheng: hf->mg; mg->hf if args.target_modules == "all-linear": modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] elif "," in args.target_modules: From 382e9d5893c84be99cc00e7b6a47a67987946180 Mon Sep 17 00:00:00 2001 From: Yusheng Su Date: Sat, 17 Jan 2026 02:31:28 +0000 Subject: [PATCH 07/12] done - but need to fix weightupdate problem --- .../run-qwen2.5-0.5B-gsm8k-lora.sh | 16 +- miles/backends/megatron_utils/actor.py | 162 ++++++---- miles/backends/megatron_utils/checkpoint.py | 25 +- miles/backends/megatron_utils/lora_utils.py | 20 +- miles/backends/megatron_utils/model.py | 303 ++++++++++++++---- .../backends/megatron_utils/model_provider.py | 39 ++- .../update_weight/hf_weight_iterator_base.py | 6 + .../hf_weight_iterator_bridge.py | 87 ++++- .../update_weight_from_tensor.py | 167 +++++----- miles/backends/sglang_utils/sglang_engine.py | 43 +-- miles/utils/arguments.py | 2 +- train.py | 38 ++- 12 files changed, 623 insertions(+), 285 deletions(-) diff --git a/examples/reproducibility/run-qwen2.5-0.5B-gsm8k-lora.sh b/examples/reproducibility/run-qwen2.5-0.5B-gsm8k-lora.sh index 52d1075ac..158e28cac 100644 --- a/examples/reproducibility/run-qwen2.5-0.5B-gsm8k-lora.sh +++ b/examples/reproducibility/run-qwen2.5-0.5B-gsm8k-lora.sh @@ -28,6 +28,10 @@ CKPT_ARGS=( ### ) +# target-module only support: linear (the proj_ need to be supported in future: in Megatron-Bridge/src/megatron/bridge/models/conversion/peft_bridge.py - build_adapter_conversion_tasks) +# example: if one module have two lora: (linear_proj): LoRALinear(), (linear_fc2): LoRALinear() + +# LORA_TARGET_MODULES=${LORA_TARGET_MODULES:-"['linear_qkv','linear_proj','linear_fc1','linear_fc2']"}. It will broken ############################## ###########lora############### @@ -87,13 +91,13 @@ ROLLOUT_ARGS=( # --num-rollout 100 --num-rollout 10 # onyl train 10 stesp # --rollout-batch-size 32 - --rollout-batch-size 16 + --rollout-batch-size 16 # for testing --n-samples-per-prompt 8 --rollout-max-response-len 1024 --rollout-temperature 1 # --global-batch-size 256 - --global-batch-size 32 + --global-batch-size 32 # for testing ) EVAL_ARGS=( @@ -106,7 +110,7 @@ EVAL_ARGS=( PERF_ARGS=( --tensor-model-parallel-size 1 - --sequence-parallel #becasue of lora training error: RuntimeError: Cannot access the main gradient of a frozen parameter. main_grad is None. (enable) + --sequence-parallel --pipeline-model-parallel-size 1 --context-parallel-size 1 --expert-model-parallel-size 1 @@ -170,9 +174,11 @@ MISC_ARGS=( ############################## ###########lora############### ############################## -# export GPUS_PER_NODE=1 +######## Note: Need to set export CUDA_VISIBLE_DEVICES= , or it will fail and have cuda error +export GPUS_PER_NODE=1 +# export GPUS_PER_NODE=2 # export GPUS_PER_NODE=4 -export GPUS_PER_NODE=8 +# export GPUS_PER_NODE=8 ############################## ############################## ############################## diff --git a/miles/backends/megatron_utils/actor.py b/miles/backends/megatron_utils/actor.py index 05753d3d9..4e00a78f9 100644 --- a/miles/backends/megatron_utils/actor.py +++ b/miles/backends/megatron_utils/actor.py @@ -86,8 +86,10 @@ def init( } dist.barrier(group=get_gloo_group()) + if args.offload_train: if (x := args.train_memory_margin_bytes) > 0: + # --train-memory-margin-bytes can tune this logger.info(f"Set torch_memory_saver.memory_margin_bytes to {x}") torch_memory_saver.memory_margin_bytes = x @@ -103,33 +105,6 @@ def init( (self.model, self.optimizer, self.opt_param_scheduler, loaded_rollout_id) = initialize_model_and_optimizer( args, role ) - - ### share ref model - ############################## - ###########lora############### - ############################## - # # For LoRA with share-ref-base-model: backup base model weights BEFORE applying LoRA - # if is_lora_enabled(args) and role == "actor" and with_ref and getattr(args, 'share_ref_base_model', False): - # # Create weights_backuper early to backup base weights as "ref" before LoRA - # self.weights_backuper = TensorBackuper.create( - # source_getter=lambda: named_params_and_buffers( - # self.args, - # self.model, - # convert_to_global_name=args.megatron_to_hf_mode == "raw", - # translate_gpu_to_cpu=not self.args.enable_weights_backuper, - # ), - # single_tag=None if args.enable_weights_backuper else "actor", - # ) - # self.weights_backuper.backup("ref") # Backup base weights as ref BEFORE LoRA - # logger.info("Backed up base model weights as 'ref' before applying LoRA (share-ref-base-model mode)") - - # if is_lora_enabled(args) and role == "actor": - # self.model = apply_lora_to_megatron_model(self.model, args) - # freeze_base_model(self.model) - ############################## - ############################## - ############################## - if role == "critic": if self.args.offload_train: @@ -147,46 +122,25 @@ def init( ), single_tag=None if args.enable_weights_backuper else "actor", ) - # Deal with actor model --> delt with in model.py - # ############################## - # ###########lora############### - # ############################## - # if is_lora_enabled(args): - # # self.weights_backuper.backup("ref") # Backup base weights as ref BEFORE LoRA (prevent load model weight again on later) - - # self.model = apply_lora_to_megatron_model(self.model, args) # model: base + lora including `requires_grad` process - # # freeze_base_model(self.model) # Set `requires_grad`: base + lora .. do not set here since self.weights_backuper.backup(...) does not process `requires_grad` - # ############################## - # ############################## - # ############################## self._active_model_tag: str | None = "actor" self.weights_backuper.backup("actor") - + #actor already include lora now - You can use self.weights_backuper.get("actor") to check + + # ############### + # print("=======") + # actor_weights = self.weights_backuper.get("actor") + + # for name, weight in actor_weights.items(): + # print(f"{name}: shape={weight.shape}, sum={weight.float().sum().item()}, requires_grad: {weight.requires_grad}") + # actor_weights = self.weights_backuper.get("actor") + # print("=======") + # exit() + # and then it will update_weight (sync) to sglang - so it does not need requires_grad. + # ############### + if with_ref: - ############################## - ###########lora############### - ############################## - # self.load_other_checkpoint("ref", args.ref_load) - - # if use lora: --ref-load /root/Qwen2.5-0.5B-Instruct_torch_dist/ (should be also lora weight) - if is_lora_enabled(args): - raise NotImplementedError( - "LoRA with reference model is not yet fully implemented. " - "Please remove reference model settings from your training script:\n" - " 0. Might need to ensure self.load_other_checkpoint can load loar module as well.\n" - " 1. Remove '--use-kl-loss' flag, OR\n" - " 2. Set '--kl-coef 0' without '--use-kl-loss', OR\n" - " 3. Remove '--ref-load' parameter\n" - "This will disable reference model loading (with_ref=False) and allow LoRA training to proceed." - ) - else: - self.load_other_checkpoint("ref", args.ref_load) - - ############################## - ############################## - ############################## - + self.load_other_checkpoint("ref", args.ref_load) if self.args.keep_old_actor: # Load old_actor checkpoint @@ -332,11 +286,11 @@ def _switch_model(self, target_tag: str) -> None: ############################## ###########lora############### ############################## - # Restore requires_grad after weight restoration - # For LoRA training: only adapter params should be trainable, base model frozen - if is_lora_enabled(self.args): - freeze_base_model(self.model) - # Note: ref model uses forward_only (@torch.no_grad), so requires_grad doesn't matter + # # Restore requires_grad after weight restoration + # # For LoRA training: only adapter params should be trainable, base model frozen + # if is_lora_enabled(self.args): + # freeze_base_model(self.model) + # # Note: ref model uses forward_only (@torch.no_grad), so requires_grad doesn't matter ############################## ############################## ############################## @@ -687,3 +641,75 @@ def connect_actor_critic( rank=0 if self.role == "actor" else 1, group_name=group_name, ) + + + + + ############################## + ###########lora############### + ############################## + ########### + def check_lora_status(self): + """check LoRA module""" + from megatron.bridge.peft.lora_layers import LoRALinear + from megatron.bridge.peft.adapter_wrapper import AdapterWrapper + + results = { + "lora_modules": [], + "trainable_params": 0, + "frozen_params": 0, + "total_params": 0, + } + + model = self.model[0] if isinstance(self.model, list) else self.model + + # travel all module + for name, module in model.named_modules(): + if isinstance(module, (LoRALinear, AdapterWrapper)): + results["lora_modules"].append(name) + + # check adapter weight + if hasattr(module, 'adapter'): + adapter = module.adapter + if hasattr(adapter, 'linear_in'): + lora_a_shape = tuple(adapter.linear_in.weight.shape) + lora_b_shape = tuple(adapter.linear_out.weight.shape) + print(adapter.linear_in.weight.shape) + # print(adapter.linear_in.weight) + # print(adapter.linear_in.weight.detach().cpu().clone()) + print(adapter.linear_out.weight.shape) + # print(adapter.linear_out.weight) + # print(adapter.linear_out.weight.detach().cpu().clone()) + print(f" {name}: lora_A={lora_a_shape}, lora_B={lora_b_shape}, requires_grad={adapter.linear_in.requires_grad}") + else: + for param_name, param in module.named_parameters(recurse=False): + # print(f" {name}.{param_name}: shape={tuple(param.shape)}", param) + # print(f" {name}.{param_name}: shape={tuple(param.shape)}", param.float().sum().item()) + print(f" {name}.{param_name}: shape={tuple(param.shape)}") + + + # account parms + for p in model.parameters(): + results["total_params"] += p.numel() + if p.requires_grad: + results["trainable_params"] += p.numel() + else: + results["frozen_params"] += p.numel() + + print(f"LoRA Check Results:") + print(f" - LoRA modules found: {len(results['lora_modules'])}") + print(f" - Trainable params: {results['trainable_params']:,}") + print(f" - Frozen params: {results['frozen_params']:,}") + print(f" - Trainable ratio: {results['trainable_params']/results['total_params']*100:.2f}%") + + return results + ############################## + ############################## + ############################## + + # │ print(tensor.shape) ✓ │ + # │ └── only read metadata(do not need GPU memory) │ + # │ │ + # │ print(tensor) ✗ │ + # │ └── PyTorch __repr__ will read the actual tensor values │ + # │ └── When GPU try to read the data, it has cudaErrorIllegalAddress \ No newline at end of file diff --git a/miles/backends/megatron_utils/checkpoint.py b/miles/backends/megatron_utils/checkpoint.py index 9681923f7..24fb2c6cc 100644 --- a/miles/backends/megatron_utils/checkpoint.py +++ b/miles/backends/megatron_utils/checkpoint.py @@ -43,17 +43,17 @@ def load_checkpoint(ddp_model, optimizer, opt_param_scheduler, checkpointing_con ############################## ###########lora############### ############################## - # Check for LoRA adapter first - ## Not correct - need to check the saving format and name - if is_lora_model(ddp_model): - lora_path = Path(load_path) / "adapter" - if lora_path.exists(): - logger.info(f"Loading LoRA checkpoint from {lora_path}") - iteration = load_lora_checkpoint(ddp_model, args, str(lora_path)) - num_floating_point_operations_so_far = 0 - return iteration, num_floating_point_operations_so_far - else: - logger.info(f"Not Found LoRA checkpoint from {lora_path}. Use the random initial weight.") + # # Check for LoRA adapter first + # ## Not correct - need to check the saving format and name + # if is_lora_model(ddp_model): + # lora_path = Path(load_path) / "adapter" + # if lora_path.exists(): + # logger.info(f"Loading LoRA checkpoint from {lora_path}") + # iteration = load_lora_checkpoint(ddp_model, args, str(lora_path)) + # num_floating_point_operations_so_far = 0 + # return iteration, num_floating_point_operations_so_far + # else: + # logger.info(f"Not Found LoRA checkpoint from {lora_path}. Use the random initial weight.") ############################## ############################## ############################## @@ -62,7 +62,8 @@ def load_checkpoint(ddp_model, optimizer, opt_param_scheduler, checkpointing_con ############################## ###########lora############### ############################## - # (to-do): yusheng- Also need to add megatron load lora function + # (to-do): yusheng- Also add lora weight loading in `_load_checkpoint_megatron` and `_load_checkpoint_hf` + # if no lora weight - random initalization ############################## ############################## ############################## diff --git a/miles/backends/megatron_utils/lora_utils.py b/miles/backends/megatron_utils/lora_utils.py index 0401131b4..eb130a50c 100644 --- a/miles/backends/megatron_utils/lora_utils.py +++ b/miles/backends/megatron_utils/lora_utils.py @@ -101,6 +101,19 @@ def is_lora_enabled(args: Namespace) -> bool: # return model + +# def print_adapter_info(model): +# """Print information about adapter parameters in the model.""" +# adapter_params, total_params, percentage = count_adapter_parameters(model) + +# print(f"\n{'=' * 60}") +# print("PEFT Adapter Information:") +# print(f" Total parameters: {total_params:,}") +# print(f" Adapter parameters: {adapter_params:,}") +# print(f" Trainable percentage: {percentage:.2f}%") +# print(f"{'=' * 60}\n") + + def _print_trainable_parameters(model: Sequence[torch.nn.Module]) -> None: """Print trainable parameters statistics.""" total_params = 0 @@ -423,7 +436,7 @@ def load_lora_checkpoint( - +## Check this functions - megatron-bridge might have the same function ####!!!! (to-do) yusheng: need to based on different Lora to provide the different mapping from typing import Union, Type, TYPE_CHECKING if TYPE_CHECKING: @@ -526,7 +539,8 @@ def convert_target_modules_to_megatron( "up_proj": "linear_fc1_up", "down_proj": "linear_fc2", } - else: # Standard LoRA + elif class_name == "LoRA": + # Standard LoRA # Standard LoRA: Merged Q/K/V and merged up/gate hf_to_megatron = { "q_proj": "linear_qkv", @@ -537,6 +551,8 @@ def convert_target_modules_to_megatron( "up_proj": "linear_fc1", "down_proj": "linear_fc2", } + else: + raise NotImplementedError(f"Unsupported LoRA class: {class_name}") megatron_modules = [] for module in hf_modules: diff --git a/miles/backends/megatron_utils/model.py b/miles/backends/megatron_utils/model.py index 7c2e3dd05..2ec58456b 100644 --- a/miles/backends/megatron_utils/model.py +++ b/miles/backends/megatron_utils/model.py @@ -45,6 +45,73 @@ logger = logging.getLogger(__name__) +############################## +###########lora############### +############################## +from dataclasses import dataclass +@dataclass +class McoreModuleWrapperConfig: + """Configuration for Mcore module wrapper.""" + is_value_model: bool = False + share_embeddings_and_output_weights: bool = False + wrap_with_ddp: bool = True + use_distributed_optimizer: bool = True + + +def _ensure_model_list(model): + return model if isinstance(model, list) else [model] + +# Get from: verl/verl/models/mcore/bridge.py +def make_value_model(hidden_size, sequence_parallel): + """Creates a pre-wrap hook that replace the output layer with a value head. + + Args: + hidden_size (int): The hidden size of the model's transformer layers. + sequence_parallel (bool): Whether sequence parallelism is enabled. + + Returns: + A hook function that can be used as a `pre_wrap_hook` in Megatron-Bridge. + The hook itself takes the model as input and prepares it for value head activation. + """ + + from megatron.core import parallel_state + from .model_provider import LinearForLastLayer + + def hook(model): + model_post_process = [] + if ( + parallel_state.get_pipeline_model_parallel_world_size() > 1 + and parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None + ): + for i in range(parallel_state.get_virtual_pipeline_model_parallel_world_size()): + model_post_process.append(parallel_state.is_pipeline_last_stage(ignore_virtual=False, vp_stage=i)) + else: + model_post_process.append(parallel_state.is_pipeline_last_stage()) + + model_list = _ensure_model_list(model) + assert len(model_post_process) == len(model_list), "Model list length and post process list length must match." + + for index, model_chunk in enumerate(model_list): + if not model_post_process[index]: + continue + + model_chunk.output_layer = LinearForLastLayer( + input_size=hidden_size, + output_size=1, + sequence_parallel=sequence_parallel, + ) + + return hook + + +from megatron.core.utils import get_attr_wrapped_model +def get_model_config(model): + return get_attr_wrapped_model(model, "config", allow_none=False) +############################## +############################## +############################## + + def get_optimizer_param_scheduler(args: Namespace, optimizer: MegatronOptimizer) -> OptimizerParamScheduler: """Create and configure the optimizer learning-rate/weight-decay scheduler. @@ -164,9 +231,15 @@ def setup_model_and_optimizer( # This part can be moved to `lora_utils.py` def apply_lora_to_megatron_model if is_lora_enabled(args) and role == "actor" and args.megatron_to_hf_mode == "bridge": - # if is_lora_enabled(args) and args.megatron_to_hf_mode == "bridge": + ###### + # refer to: verl/verl/workers/engine/megatron/transformer_impl.py + ###### + + + + # if is_lora_enabled(args) and args.megatron_to_hf_mode == "bridge": # The below written as: get_model_provider_func() usage - from megatron.core.distributed import DistributedDataParallelConfig + # from megatron.core.distributed import DistributedDataParallelConfig from megatron.bridge.models.model_provider import get_model as bridge_get_model from megatron.bridge import AutoBridge from megatron.bridge.peft.lora import LoRA @@ -177,10 +250,23 @@ def setup_model_and_optimizer( # This is register_canonical_lora_adapter usgae - more advnace and efficiency!!!! # Compare lora, canonical_lora_adapter, .... # Build the provider from HF checkpoint + + # model: start + # args.hf_checkpoint + # model: done + + # bridge : start + # hf config: + from transformers import AutoConfig + hf_config = AutoConfig.from_pretrained(args.hf_checkpoint, trust_remote_code=True) + + # bridge: hf --> mg config bridge = AutoBridge.from_hf_pretrained(args.hf_checkpoint, trust_remote_code=True) + # bridge : done ##!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! - # provider = bridge.to_megatron_provider(load_weights=False) - provider = bridge.to_megatron_provider(load_weights=True) # different from full model training - in the training script, I need to load tuned base model weight and initial lora weights. Need to carefully check and optimize - where to load the base model? (but why in `model_provider.py` using: provider = bridge.to_megatron_provider(load_weights=False)) + # provider: start + provider = bridge.to_megatron_provider(load_weights=False) # should be True??? + # provider = bridge.to_megatron_provider(load_weights=True) # different from full model training - in the training script, I need to load tuned base model weight and initial lora weights. Need to carefully check and optimize - where to load the base model? (but why in `model_provider.py` using: provider = bridge.to_megatron_provider(load_weights=False)) # Set parallel configs on the provider provider.tensor_model_parallel_size = args.tensor_model_parallel_size @@ -188,15 +274,34 @@ def setup_model_and_optimizer( provider.expert_model_parallel_size = args.expert_model_parallel_size provider.expert_tensor_parallel_size = args.expert_tensor_parallel_size provider.sequence_parallel = args.sequence_parallel + ##### + provider.virtual_pipeline_model_parallel_size = args.virtual_pipeline_model_parallel_size + provider.context_parallel_size = args.context_parallel_size + provider.variable_seq_lengths = True + provider.moe_token_dispatcher_type = "alltoall" + provider.moe_router_load_balancing_type = "none" + # provider: done + provider.finalize() + ##### + + ########### + ########### lora part + ########### + # peft_cls: start + # need: args.hf_checkpoint, bridge, provider, self.param_dtype: lora_dtype + # Determine lora_dtype if hasattr(args, 'bf16') and args.bf16: lora_dtype = torch.bfloat16 elif hasattr(args, 'fp16') and args.fp16: lora_dtype = torch.float16 else: - lora_dtype = None + # !!!!!!!!!!!!!!!!!!!!!! + # lora_dtype = None + lora_dtype = torch.float16 + # (to-do) yusheng - lora_type - can be input arg, the default is LoRA # Get exclude_modules as list exclude_modules = [] if hasattr(args, 'exclude_modules') and args.exclude_modules: @@ -206,14 +311,17 @@ def setup_model_and_optimizer( exclude_modules = list(args.exclude_modules) # Convert HF module names to Megatron format # exclude_modules = convert_target_modules_to_megatron(exclude_modules) - exclude_modules = convert_target_modules_to_megatron(exclude_modules, lora_type=CanonicalLoRA) - + exclude_modules = convert_target_modules_to_megatron(exclude_modules, lora_type=LoRA) + # exclude_modules = convert_target_modules_to_megatron(exclude_modules, lora_type=CanonicalLoRA) + + # print("========") # Create LoRA config - # lora = LoRA( - lora = CanonicalLoRA( - # target_modules=args.target_modules, - # target_modules=convert_target_modules_to_megatron(args.target_modules), - target_modules=convert_target_modules_to_megatron(args.target_modules, lora_type=CanonicalLoRA), + # (to-do) yusheng set - lora_type - the default is LoRA + lora = LoRA( + # lora = CanonicalLoRA( + target_modules=convert_target_modules_to_megatron(args.target_modules, lora_type=LoRA), + # target_modules=convert_target_modules_to_megatron(args.target_modules, lora_type=CanonicalLoRA), + # lora_dtype=lora_dtype, exclude_modules=exclude_modules, dim=args.lora_rank, alpha=args.lora_alpha, @@ -223,47 +331,145 @@ def setup_model_and_optimizer( # lora_A_init_method=getattr(args, 'lora_A_init_method', 'xavier'), # lora_B_init_method=getattr(args, 'lora_B_init_method', 'zero'), # a2a_experimental=getattr(args, 'lora_a2a_experimental', False), - # lora_dtype=lora_dtype, ) + + + + # peft_cls: done + ################ --------------------################# + # ----- Above: + # self.bridge: init hf model --> load hf weight + # self.provider: init megatron model --> load megatron weight + # self.peft_cls: init load module --> load lora weight - # Define pre_wrap_hook to apply LoRA before DDP wrapping + + # When using PEFT with Megatron-Bridge, we must apply PEFT transformation + # BEFORE wrapping the model in DDP. This is required because: + # 1. PEFT freezes base model parameters (requires_grad=False) + # 2. DDP must be aware of which parameters are trainable when building gradient buckets + # 3. The distributed optimizer must only track trainable (adapter) parameters + # See Megatron-Bridge docs: training/peft.md + + + # # Define pre_wrap_hook to apply LoRA before DDP wrapping def apply_lora_hook(model_chunks): transformed = lora(model_chunks, training=True) lora.set_params_to_save(transformed) + # Load adapter weights if adapter_path is specified + # adapter_path = getattr(peft_config, "adapter_path", None) + # if adapter_path is not None and adapter_path: + # print(f"Loading adapter weights from: {adapter_path}") + # load_adapter_checkpoint(transformed_model, adapter_path) + + # Print PEFT statistics + # if torch.distributed.get_rank() == 0: + # print_adapter_info(transformed) + return transformed - + + # Register the hook provider.register_pre_wrap_hook(apply_lora_hook) - provider.finalize() + # provider.finalize() + + + + + #### ------------------- ##### + #### ------------------- ##### + #### ------------------- ##### + + + # TODO: add more cases + is_value_model = ( + "ForTokenClassification" in hf_config.architectures[0] + or "ForSequenceClassification" in hf_config.architectures[0] + ) + - # Build DDP config - ddp_config = DistributedDataParallelConfig( - grad_reduce_in_fp32=getattr(args, 'grad_reduce_in_fp32', False), - check_for_nan_in_grad=getattr(args, 'check_for_nan_in_grad', False), - overlap_grad_reduce=getattr(args, 'overlap_grad_reduce', False), - overlap_param_gather=getattr(args, 'overlap_param_gather', False), - average_in_collective=getattr(args, 'average_in_collective', False), - use_distributed_optimizer=getattr(args, 'use_distributed_optimizer', False), + wrap_config = McoreModuleWrapperConfig( + is_value_model=is_value_model, # here is False # actor is not value model + # share_embeddings_and_output_weights=hf_config.share_embeddings_and_output_weights, + # wrap_with_ddp=wrap_with_ddp, #default is True + # use_distributed_optimizer=self.engine_config.use_distributed_optimizer, #default is True ) + + # Register post-creation callbacks (make_value_model, freeze_moe_router) as pre-wrap hooks + post_model_creation_callbacks = [] + if wrap_config.is_value_model: + hidden_size = ( + hf_config.text_config.hidden_size if hasattr(hf_config, "text_config") else hf_config.hidden_size + ) + value_model_hook = make_value_model(hidden_size, provider.sequence_parallel) + post_model_creation_callbacks.append(value_model_hook) + # if override_model_config.get("moe_config", {}).get("freeze_moe_router", False): + # post_model_creation_callbacks.append(freeze_moe_router) + + # Register post-creation callbacks (make_value_model, freeze_moe_router) as pre-wrap hooks + for callback in post_model_creation_callbacks: + provider.register_pre_wrap_hook(callback) - # Use Bridge's get_model with the provider (which now has LoRA hook registered) - model = bridge_get_model( - model_provider=provider, + # print("====") + # print(wrap_config.wrap_with_ddp) + # print("====") + # exit() + + if wrap_config.wrap_with_ddp: + from megatron.bridge.training.config import DistributedDataParallelConfig + ddp_config_dict = { + "use_distributed_optimizer": wrap_config.use_distributed_optimizer, + } + # Apply any DDP config overrides + # if override_ddp_config is not None: + # ddp_config_dict.update(override_ddp_config) + ddp_config = DistributedDataParallelConfig(**ddp_config_dict) + ddp_config.finalize() + + # Now call provide_distributed_model with all hooks registered + # Hooks will be applied automatically before DDP wrapping + model = provider.provide_distributed_model( + wrap_with_ddp=wrap_config.wrap_with_ddp, ddp_config=ddp_config, - model_type=ModelType.encoder_or_decoder, - wrap_with_ddp=True, - use_cpu_initialization=False, - bf16=getattr(args, 'bf16', False), - fp16=getattr(args, 'fp16', False), - pre_wrap_hook=provider.pre_wrap_hook, ) + ## Use: + # model = bridge_get_model(...) seems more simple + + + # Extract TransformerConfig from the created model + tf_config = get_model_config(model[0] if isinstance(model, list) else model) + + ############# + # # Build DDP config - config + # ddp_config = DistributedDataParallelConfig( + # grad_reduce_in_fp32=getattr(args, 'grad_reduce_in_fp32', False), + # check_for_nan_in_grad=getattr(args, 'check_for_nan_in_grad', False), + # overlap_grad_reduce=getattr(args, 'overlap_grad_reduce', False), + # overlap_param_gather=getattr(args, 'overlap_param_gather', False), + # average_in_collective=getattr(args, 'average_in_collective', False), + # use_distributed_optimizer=getattr(args, 'use_distributed_optimizer', False), + # ) + + # good method + # # Use Bridge's get_model with the provider (which now has LoRA hook registered) + # model = bridge_get_model( # this function eqaul get_model() in megatron-core for base model + # model_provider=provider, + # ddp_config=ddp_config, + # model_type=ModelType.encoder_or_decoder, + # wrap_with_ddp=True, + # use_cpu_initialization=False, + # bf16=getattr(args, 'bf16', False), + # fp16=getattr(args, 'fp16', False), + # pre_wrap_hook=provider.pre_wrap_hook, + # ) + # ???? the below can be remove ??? or it's tag to jude the model is (base) or (base + lora) # Store lora instance for later use (e.g., checkpoint saving) # You may want to attach this to the model or args for later access - if hasattr(args, '_lora_instance'): - args._lora_instance = lora - + # if hasattr(args, '_lora_instance'): + # args._lora_instance = lora + ############# + else: # Original non-LoRA path or non-bridge mode model = get_model(get_model_provider_func(args, role), ModelType.encoder_or_decoder) @@ -272,31 +478,6 @@ def apply_lora_hook(model_chunks): ############################## ############################## - - ############################## - ###########lora############### - ############################## - # from miles.backends.megatron_utils.lora_utils import is_lora_enabled, apply_lora_to_megatron_model - # if is_lora_enabled(args) and role == "actor": - # model = apply_lora_to_megatron_model(model, args) - - ######### - # if is_lora_enabled(args) and role == "actor": - # from megatron.bridge.peft.lora import LoRA - - # lora = LoRA( - # target_modules=args.target_modules, - # dim=args.lora_rank, - # alpha=args.lora_alpha, - # dropout=args.lora_dropout, - # ) - # # model is list[DDP],it require unwrap - # model = lora(model, training=True) - # lora.set_params_to_save(model) - ############################## - ############################## - ############################## - # Optimizer kwargs = {} for f in dataclasses.fields(OptimizerConfig): diff --git a/miles/backends/megatron_utils/model_provider.py b/miles/backends/megatron_utils/model_provider.py index c854bfb6f..cac560a6c 100644 --- a/miles/backends/megatron_utils/model_provider.py +++ b/miles/backends/megatron_utils/model_provider.py @@ -92,33 +92,32 @@ def wrapped_model_provider( bridge = AutoBridge.from_hf_pretrained(args.hf_checkpoint, trust_remote_code=True) provider = bridge.to_megatron_provider(load_weights=False) + # TODO: we should not manually set this... provider.tensor_model_parallel_size = args.tensor_model_parallel_size provider.pipeline_model_parallel_size = args.pipeline_model_parallel_size provider.expert_model_parallel_size = args.expert_model_parallel_size provider.expert_tensor_parallel_size = args.expert_tensor_parallel_size provider.sequence_parallel = args.sequence_parallel - ############################## - ###########lora############### - ############################## + + # ############################## + # ###########lora############### + # ############################## # if is_lora_enabled(args) and role == "actor": - # from megatron.bridge.peft.lora import LoRA - # lora = LoRA( - # target_modules=args.target_modules, - # dim=args.lora_rank, - # alpha=args.lora_alpha, - # dropout=args.lora_dropout, - # # lora_dtype=lora_dtype, - # ) - # # Apply LoRA and freeze base model - # def apply_lora(model_chunks): - # transformed = lora(model_chunks, training=True) - # lora.set_params_to_save(transformed) - # return transformed - # provider.register_pre_wrap_hook(apply_lora) - ############################## - ############################## - ############################## + # provider.virtual_pipeline_model_parallel_size = args.virtual_pipeline_model_parallel_size + # provider.context_parallel_size = args.context_parallel_size + # provider.variable_seq_lengths = True + # provider.moe_token_dispatcher_type = "alltoall" + # provider.moe_router_load_balancing_type = "none" + # provider.finalize() + # return provider.provide + # else: + # provider.finalize() + # return provider.provide + # ############################## + # ############################## + # ############################## + provider.finalize() return provider.provide diff --git a/miles/backends/megatron_utils/update_weight/hf_weight_iterator_base.py b/miles/backends/megatron_utils/update_weight/hf_weight_iterator_base.py index 3cedf28db..9acbbd4a1 100644 --- a/miles/backends/megatron_utils/update_weight/hf_weight_iterator_base.py +++ b/miles/backends/megatron_utils/update_weight/hf_weight_iterator_base.py @@ -14,8 +14,14 @@ def create(args, model, **kwargs): return c(args, model, **kwargs) + ############################## + ###########lora############### + ############################## # def __init__(self, args, model, model_name, quantization_config): def __init__(self, args, model, model_name, quantization_config, **kwargs): + ############################## + ############################## + ############################## self.args = args self.model = model self.model_name = model_name diff --git a/miles/backends/megatron_utils/update_weight/hf_weight_iterator_bridge.py b/miles/backends/megatron_utils/update_weight/hf_weight_iterator_bridge.py index f4a0b752f..474b127dc 100644 --- a/miles/backends/megatron_utils/update_weight/hf_weight_iterator_bridge.py +++ b/miles/backends/megatron_utils/update_weight/hf_weight_iterator_bridge.py @@ -45,6 +45,25 @@ def __init__(self, *args, **kwargs): def get_hf_weight_chunks(self, megatron_local_weights): # TODO support quantization (e.g. modify megatron-bridge to provide megatron param name) renamed_megatron_local_weights = {strip_param_name_prefix(k): v for k, v in megatron_local_weights.items()} + # print(111111) + # print(renamed_megatron_local_weights) + # print(type(renamed_megatron_local_weights)) + # print(renamed_megatron_local_weights["vp_stages.0.decoder.layers.23.mlp.linear_fc2.adapter.linear_out.weight"]) + # print(renamed_megatron_local_weights["vp_stages.0.decoder.layers.21.mlp.linear_fc2.to_wrap.weight"]) + # print(111111) + + # lora_weights = self._bridge.export_adapter_weights( + # self.model, + # # cpu=False, + # cpu=True, ### if False, it will have the problem - why? + # # conversion_tasks=conversion_tasks, #### + # show_progress=False + # ) + # # print("---") + # # for i in lora_weights: + # # print(i) + # exit() + with megatron_bridge_utils.patch_megatron_model(self.model): ############################## ###########lora############### @@ -78,8 +97,8 @@ def get_hf_weight_chunks(self, megatron_local_weights): #### - # Only sync base model on first call (or if not LoRA-only mode) - # if not self.is_lora or not self._base_synced: + # Only sync base model on first call - smile/miles need (or if not LoRA-only mode) + # if not self.is_lora or self._base_synced: if not self.is_lora: # only pass base model conversion_tasks = self._bridge.get_conversion_tasks(self.model) @@ -88,7 +107,7 @@ def get_hf_weight_chunks(self, megatron_local_weights): self.model, cpu=False, conversion_tasks=conversion_tasks, - merge_adapter_weights=not self.is_lora, # Do not return merged (base.weight + lora.weight). + # merge_adapter_weights=not self.is_lora, # Do not return merged (base.weight + lora.weight). ) # for hf_param_name, weight, megatron_param_name in named_weights: @@ -113,9 +132,11 @@ def get_hf_weight_chunks(self, megatron_local_weights): ) for hf_param_name, weight, megatron_param_name in named_weights ) + yield from chunk_named_params_by_size(named_weights, chunk_size=self.args.update_weight_buffer_size) + if self.is_lora: - self._base_synced = True + self._base_synced = False # torch.cuda.synchronize() ############################## ############################## @@ -126,13 +147,68 @@ def get_hf_weight_chunks(self, megatron_local_weights): ############################## ###########lora############### ############################## + # print(4444444) if self.is_lora: + # (to-do) yusheng: I might need to add the converting weights (mg --> hf) - refer above + # conversion_tasks = self._bridge.get_conversion_tasks(self.model) + # conversion_tasks = _process_conversion_tasks(conversion_tasks, renamed_megatron_local_weights) + + # print(333333) + # print(self.model) + # print(333333) + # exit() + # conversion_tasks = self._bridge.get_conversion_tasks(self.model) + # conversion_tasks = self._bridge.build_adapter_conversion_tasks(self.model) + # conversion_tasks = _process_conversion_tasks(conversion_tasks, renamed_megatron_local_weights) + # print(999999) + # print(conversion_tasks) + # print(999999) + # exit() + + + ### + # conversion_tasks = self._bridge.get_conversion_tasks(self.model) + # conversion_tasks = _process_conversion_tasks(conversion_tasks, renamed_megatron_local_weights) + # lora_weights = self._bridge.export_hf_weights( + # self.model, + # cpu=False, + # conversion_tasks=conversion_tasks, + # merge_adapter_weights=not self.is_lora, # Do not return merged (base.weight + lora.weight). + # ) + ### + + # self.model --> eval mode () ## + # problem in self._bridge.export_adapter_weights() # verl do the same thing + # print(self.model) --> self.model is a list + # for model_module in self.model: + # print(model_module, "training:", model_module.training) + # model_module.eval() + # print(model_module, "training:", model_module.training) + # print("0099") + # print(self.model) + + # print(self.model) + # print("---------") lora_weights = self._bridge.export_adapter_weights( self.model, - cpu=False, + # cpu=False, + cpu=True, ### if False, it will have the problem - why? + # conversion_tasks=conversion_tasks, #### show_progress=False ) + # print(self.model) + # exit() + # for model_module in self.model: + # model_module.train() + + # for item in lora_weights: + # # print(i) + # # print(f"param_name: {item.param_name}, shape: {item[1].shape}, dtype: {item[1].dtype}") + # hf_param_name, weight, megatron_param_name = item + + + # hf_param_name's might have big problem lora_weights = ( ( hf_param_name, @@ -141,6 +217,7 @@ def get_hf_weight_chunks(self, megatron_local_weights): megatron_param_name=megatron_param_name, hf_param_name=hf_param_name, param=weight, + # param=weight.clone(), # solutuon - need to have self._bridge.build_adapter_conversion_tasks in megatron-bridge ), ) for hf_param_name, weight, megatron_param_name in lora_weights diff --git a/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py b/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py index 80a42eba4..218c43072 100644 --- a/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py +++ b/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py @@ -65,8 +65,9 @@ def __init__( ###########lora############### ############################## self.is_lora = is_lora - self._lora_loaded = False - self._base_synced = False + # self._lora_loaded = False + self._base_synced = True + # self._base_synced = False # self._hf_weight_iterator = HfWeightIteratorBase.create( # args=args, model=model, model_name=model_name, quantization_config=quantization_config @@ -147,100 +148,102 @@ def update_weights(self) -> None: ray.get([engine.flush_cache.remote() for engine in self.rollout_engines]) dist.barrier(group=get_gloo_group()) - megatron_local_weights = self.weights_getter() + megatron_local_weights = self.weights_getter() - ############################## - ###########lora############### - ############################## - # lora_named_tensors = [] - ############################## - ############################## - ############################## + # error in self._hf_weight_iterator.get_hf_weight_chunks(megatron_local_weights) for hf_named_tensors in self._hf_weight_iterator.get_hf_weight_chunks(megatron_local_weights): - ############################## - ###########lora############### - ############################## + import logging + logger = logging.getLogger(__name__) refs, long_lived_tensors = self._send_hf_params(hf_named_tensors) ray.get(refs) del long_lived_tensors - - # above original way: will pass base+lora weight - - ########## (Optimize - lora pass only): refer - https://github.com/radixark/miles/blob/3b975df8e3e6af3453d36de491a764334f16b059/miles/backends/fsdp_utils/update_weight_utils.py - # # Only pass throguht lora weight - # # Check if this chunk contains LoRA weights - # if self.is_lora: - # # print() - # lora_weights = [(name, tensor) for name, tensor in hf_named_tensors - # if 'lora_' in name.lower() or 'adapter' in name.lower()] - - # base_weights = [(name, tensor) for name, tensor in hf_named_tensors - # if 'lora_' not in name.lower()] - - # # Sync base weights normally - # if base_weights: - # refs, long_lived_tensors = self._send_hf_params(base_weights) - # ray.get(refs) - # del long_lived_tensors - - # # Collect LoRA weights for later - # lora_named_tensors.extend(lora_weights) - # else: - # refs, long_lived_tensors = self._send_hf_params(hf_named_tensors) - # ray.get(refs) - # del long_lived_tensors - - - # # After syncing all weights, load LoRA adapter into SGLang - # if self.is_lora and lora_named_tensors: - # self._load_lora_adapter(lora_named_tensors) - ############################## - ############################## - ############################## dist.barrier(group=get_gloo_group()) - ############################## - ###########lora############### - ############################## - # def _load_lora_adapter(self, lora_named_tensors: list[tuple[str, torch.Tensor]]) -> None: - # """Load LoRA adapter into SGLang engine.""" - # from ..sglang import FlattenedTensorBucket, MultiprocessingSerializer - - # # Create config dict - # config_dict = { - # "peft_type": "LORA", - # "r": self.args.lora_rank, - # "lora_alpha": self.args.lora_alpha, - # "target_modules": list(self.args.target_modules) if self.args.target_modules else [], - # "bias": "none", - # } - - # # Serialize LoRA tensors - # flattened_tensor_bucket = FlattenedTensorBucket(named_tensors=lora_named_tensors) - # flattened_tensor_data = { - # "flattened_tensor": flattened_tensor_bucket.get_flattened_tensor(), - # "metadata": flattened_tensor_bucket.get_metadata(), - # } - # serialized_tensors = MultiprocessingSerializer.serialize(flattened_tensor_data, output_str=True) - - # # Load adapter on rank 0 - # rank = dist.get_rank() - # if rank == 0: + + # ############################## + # ###########lora############### + # ############################## + # def _update_lora_via_tensor(self) -> None: + # """Push LoRA weights to rollout engines using tensors.""" + # lora_weights, config_dict = get_lora_weights_and_config(self.model) + # dist.barrier() + + # if dist.get_rank() == 0: + # serialized_tensors = MultiprocessingSerializer.serialize(lora_weights, output_str=True) + + # if self._lora_loaded: + # refs = [engine.unload_lora_adapter.remote(LORA_ADAPTER_NAME) for engine in self.rollout_engines] + # ray.get(refs) + + # refs = [ + # engine.load_lora_adapter_from_tensors.remote(LORA_ADAPTER_NAME, serialized_tensors, config_dict) + # for engine in self.rollout_engines + # ] + # ray.get(refs) + + # refs = [engine.flush_cache.remote() for engine in self.rollout_engines] + # ray.get(refs) + + # self._lora_loaded = True + + # dist.barrier() + + + + # # _update_lora_via_file -- have not done/fix yet + # def _update_lora_via_file(self) -> None: + # """Push LoRA weights to rollout engines using disk files.""" + # self._lora_save_dir = os.path.join(self.args.save, LORA_SUBDIR) + # if dist.get_rank() == 0: + # if os.path.exists(self._lora_save_dir): + # delete_lora_from_disk(self._lora_save_dir) + + # dist.barrier() + + # save_lora_to_disk(self.model, self._lora_save_dir) + + # dist.barrier() + + # if dist.get_rank() == 0: + # if self._lora_loaded: + # refs = [engine.unload_lora_adapter.remote(LORA_ADAPTER_NAME) for engine in self.rollout_engines] + # ray.get(refs) + + # refs = [engine.flush_cache.remote() for engine in self.rollout_engines] + # ray.get(refs) + # refs = [ - # engine.load_lora_adapter_from_tensors.remote( - # lora_name=LORA_ADAPTER_NAME, - # serialized_tensors=serialized_tensors, - # config_dict=config_dict, - # ) + # engine.load_lora_adapter.remote(LORA_ADAPTER_NAME, self._lora_save_dir) # for engine in self.rollout_engines # ] # ray.get(refs) + + # refs = [engine.flush_cache.remote() for engine in self.rollout_engines] + # ray.get(refs) + # self._lora_loaded = True - ############################## - ############################## - ############################## + + # dist.barrier() + + # ############################## + # ############################## + # ############################## + def _send_hf_params(self, hf_named_tensors) -> tuple[list[ObjectRef], Any]: + + ############################## + ###########lora############### + ############################## + + # to-do (yusheng): need to deal with update_from_disk or tensor in this function + + ############################## + ############################## + ############################## + + + all_refs = [] refs_colocated, long_lived_tensors = _send_to_colocated_engine( diff --git a/miles/backends/sglang_utils/sglang_engine.py b/miles/backends/sglang_utils/sglang_engine.py index abfb1fb9f..3ebfd8558 100644 --- a/miles/backends/sglang_utils/sglang_engine.py +++ b/miles/backends/sglang_utils/sglang_engine.py @@ -395,26 +395,26 @@ def get_weight_version(self): ############################## ###########lora############### ############################## - def load_lora_adapter(self, lora_name: str, lora_path: str): - """Load LoRA adapter from disk.""" - return self._make_request( - "load_lora_adapter", - {"lora_name": lora_name, "lora_path": lora_path}, - ) - - def load_lora_adapter_from_tensors(self, lora_name: str, serialized_tensors: str, config_dict: dict): - """Load LoRA adapter from serialized tensors.""" - return self._make_request( - "load_lora_adapter_from_tensors", - {"lora_name": lora_name, "serialized_tensors": serialized_tensors, "config_dict": config_dict}, - ) - - def unload_lora_adapter(self, lora_name: str): - """Unload LoRA adapter.""" - return self._make_request( - "unload_lora_adapter", - {"lora_name": lora_name}, - ) + # def load_lora_adapter(self, lora_name: str, lora_path: str): + # """Load LoRA adapter from disk.""" + # return self._make_request( + # "load_lora_adapter", + # {"lora_name": lora_name, "lora_path": lora_path}, + # ) + + # def load_lora_adapter_from_tensors(self, lora_name: str, serialized_tensors: str, config_dict: dict): + # """Load LoRA adapter from serialized tensors.""" + # return self._make_request( + # "load_lora_adapter_from_tensors", + # {"lora_name": lora_name, "serialized_tensors": serialized_tensors, "config_dict": config_dict}, + # ) + + # def unload_lora_adapter(self, lora_name: str): + # """Unload LoRA adapter.""" + # return self._make_request( + # "unload_lora_adapter", + # {"lora_name": lora_name}, + # ) ############################## ############################## ############################## @@ -618,7 +618,8 @@ def _compute_server_args( # # print(2222222222) # # exit() - if args.lora_rank > 0 or args.lora_adapter_path is not None: + # if args.lora_rank > 0 or args.lora_adapter_path is not None: + if is_lora_enabled(args): kwargs["max_loras_per_batch"] = 1 #!!!!!!!! kwargs["enable_lora"] = True # kwargs["max_lora_rank"] = args.lora_rank diff --git a/miles/utils/arguments.py b/miles/utils/arguments.py index 43e147d58..76a042cab 100644 --- a/miles/utils/arguments.py +++ b/miles/utils/arguments.py @@ -970,7 +970,7 @@ def add_lora_arguments(parser): "--lora-adapter-path", type=str, default=None, - help="Path to pre-trained LoRA adapter to load", + help="Path to load pre-trained LoRA adapter weights (default: None)", ) parser.add_argument( "--lora-sync-from-tensor", diff --git a/train.py b/train.py index 2ed15716c..a86106e0d 100644 --- a/train.py +++ b/train.py @@ -36,11 +36,32 @@ def train(args): # create the actor and critic models actor_model, critic_model = create_training_models(args, pgs, rollout_manager) + + # ############################## + # ###########lora############### + # ############################## + # ========== LoRA check ========== + # # check LoRA status + # from miles.backends.megatron_utils.lora_utils import is_lora_enabled + # if is_lora_enabled(args) or True: + # lora_status = ray.get(actor_model._actor_handlers[0].check_lora_status.remote()) + # print(f"LoRA modules: {len(lora_status['lora_modules'])}") + # assert len(lora_status['lora_modules']) > 0, "LoRA modules not found!" + # # already cannot access weight here - before torch_memory_saver + # # ========== LoRA check end ========== + # ############################## + # ###########lora############### + # ############################## + if args.offload_rollout: ray.get(rollout_manager.onload.remote(tags=[GPU_MEMORY_TYPE_WEIGHTS])) + # always update weight first so that sglang has the loaded weights from training. + print(11111) actor_model.update_weights() + print(22222) + exit() if args.check_weight_update_equal: ray.get(rollout_manager.check_weights.remote(action="compare")) @@ -90,16 +111,17 @@ def onload_rollout(): ############################## ###########lora############### ############################## + # if args.offload_rollout: + # ray.get(rollout_manager.offload.remote()) + if args.offload_rollout: - ray.get(rollout_manager.offload.remote()) + offload_tags = [GPU_MEMORY_TYPE_CUDA_GRAPH] + if "kv_cache" in args.offload_rollout_level: + offload_tags.append(GPU_MEMORY_TYPE_KV_CACHE) + if "weight" in args.offload_rollout_level: + offload_tags.append(GPU_MEMORY_TYPE_WEIGHTS) + ray.get(rollout_manager.offload.remote(tags=offload_tags)) - # if args.offload_rollout: - # offload_tags = [GPU_MEMORY_TYPE_CUDA_GRAPH] - # if "kv_cache" in args.offload_rollout_level: - # offload_tags.append(GPU_MEMORY_TYPE_KV_CACHE) - # if "weight" in args.offload_rollout_level: - # offload_tags.append(GPU_MEMORY_TYPE_WEIGHTS) - # ray.get(rollout_manager.offload.remote(tags=offload_tags)) ############################## ############################## ############################## From a2494d32224f67a465c08c816b5a7feb09de14d1 Mon Sep 17 00:00:00 2001 From: Yusheng Su Date: Sat, 17 Jan 2026 02:40:15 +0000 Subject: [PATCH 08/12] need to fix weight update --- .../hf_weight_iterator_bridge.py | 70 +----------------- .../update_weight_from_tensor.py | 72 +------------------ 2 files changed, 4 insertions(+), 138 deletions(-) diff --git a/miles/backends/megatron_utils/update_weight/hf_weight_iterator_bridge.py b/miles/backends/megatron_utils/update_weight/hf_weight_iterator_bridge.py index 474b127dc..9e316ec4e 100644 --- a/miles/backends/megatron_utils/update_weight/hf_weight_iterator_bridge.py +++ b/miles/backends/megatron_utils/update_weight/hf_weight_iterator_bridge.py @@ -45,24 +45,7 @@ def __init__(self, *args, **kwargs): def get_hf_weight_chunks(self, megatron_local_weights): # TODO support quantization (e.g. modify megatron-bridge to provide megatron param name) renamed_megatron_local_weights = {strip_param_name_prefix(k): v for k, v in megatron_local_weights.items()} - # print(111111) - # print(renamed_megatron_local_weights) - # print(type(renamed_megatron_local_weights)) - # print(renamed_megatron_local_weights["vp_stages.0.decoder.layers.23.mlp.linear_fc2.adapter.linear_out.weight"]) - # print(renamed_megatron_local_weights["vp_stages.0.decoder.layers.21.mlp.linear_fc2.to_wrap.weight"]) - # print(111111) - - # lora_weights = self._bridge.export_adapter_weights( - # self.model, - # # cpu=False, - # cpu=True, ### if False, it will have the problem - why? - # # conversion_tasks=conversion_tasks, #### - # show_progress=False - # ) - # # print("---") - # # for i in lora_weights: - # # print(i) - # exit() + with megatron_bridge_utils.patch_megatron_model(self.model): ############################## @@ -153,61 +136,14 @@ def get_hf_weight_chunks(self, megatron_local_weights): # conversion_tasks = self._bridge.get_conversion_tasks(self.model) # conversion_tasks = _process_conversion_tasks(conversion_tasks, renamed_megatron_local_weights) - # print(333333) - # print(self.model) - # print(333333) - # exit() - # conversion_tasks = self._bridge.get_conversion_tasks(self.model) - # conversion_tasks = self._bridge.build_adapter_conversion_tasks(self.model) - # conversion_tasks = _process_conversion_tasks(conversion_tasks, renamed_megatron_local_weights) - # print(999999) - # print(conversion_tasks) - # print(999999) - # exit() - - - ### - # conversion_tasks = self._bridge.get_conversion_tasks(self.model) - # conversion_tasks = _process_conversion_tasks(conversion_tasks, renamed_megatron_local_weights) - # lora_weights = self._bridge.export_hf_weights( - # self.model, - # cpu=False, - # conversion_tasks=conversion_tasks, - # merge_adapter_weights=not self.is_lora, # Do not return merged (base.weight + lora.weight). - # ) - ### - - # self.model --> eval mode () ## - # problem in self._bridge.export_adapter_weights() # verl do the same thing - # print(self.model) --> self.model is a list - # for model_module in self.model: - # print(model_module, "training:", model_module.training) - # model_module.eval() - # print(model_module, "training:", model_module.training) - # print("0099") - # print(self.model) - - # print(self.model) - # print("---------") lora_weights = self._bridge.export_adapter_weights( self.model, - # cpu=False, - cpu=True, ### if False, it will have the problem - why? + cpu=False, + # cpu=True, ### if False, it will have the problem - why? # conversion_tasks=conversion_tasks, #### show_progress=False ) - # print(self.model) - # exit() - # for model_module in self.model: - # model_module.train() - - # for item in lora_weights: - # # print(i) - # # print(f"param_name: {item.param_name}, shape: {item[1].shape}, dtype: {item[1].dtype}") - # hf_param_name, weight, megatron_param_name = item - - # hf_param_name's might have big problem lora_weights = ( ( diff --git a/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py b/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py index 218c43072..9ffe08e62 100644 --- a/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py +++ b/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py @@ -160,75 +160,6 @@ def update_weights(self) -> None: dist.barrier(group=get_gloo_group()) - # ############################## - # ###########lora############### - # ############################## - # def _update_lora_via_tensor(self) -> None: - # """Push LoRA weights to rollout engines using tensors.""" - # lora_weights, config_dict = get_lora_weights_and_config(self.model) - # dist.barrier() - - # if dist.get_rank() == 0: - # serialized_tensors = MultiprocessingSerializer.serialize(lora_weights, output_str=True) - - # if self._lora_loaded: - # refs = [engine.unload_lora_adapter.remote(LORA_ADAPTER_NAME) for engine in self.rollout_engines] - # ray.get(refs) - - # refs = [ - # engine.load_lora_adapter_from_tensors.remote(LORA_ADAPTER_NAME, serialized_tensors, config_dict) - # for engine in self.rollout_engines - # ] - # ray.get(refs) - - # refs = [engine.flush_cache.remote() for engine in self.rollout_engines] - # ray.get(refs) - - # self._lora_loaded = True - - # dist.barrier() - - - - # # _update_lora_via_file -- have not done/fix yet - # def _update_lora_via_file(self) -> None: - # """Push LoRA weights to rollout engines using disk files.""" - # self._lora_save_dir = os.path.join(self.args.save, LORA_SUBDIR) - # if dist.get_rank() == 0: - # if os.path.exists(self._lora_save_dir): - # delete_lora_from_disk(self._lora_save_dir) - - # dist.barrier() - - # save_lora_to_disk(self.model, self._lora_save_dir) - - # dist.barrier() - - # if dist.get_rank() == 0: - # if self._lora_loaded: - # refs = [engine.unload_lora_adapter.remote(LORA_ADAPTER_NAME) for engine in self.rollout_engines] - # ray.get(refs) - - # refs = [engine.flush_cache.remote() for engine in self.rollout_engines] - # ray.get(refs) - - # refs = [ - # engine.load_lora_adapter.remote(LORA_ADAPTER_NAME, self._lora_save_dir) - # for engine in self.rollout_engines - # ] - # ray.get(refs) - - # refs = [engine.flush_cache.remote() for engine in self.rollout_engines] - # ray.get(refs) - - # self._lora_loaded = True - - # dist.barrier() - - # ############################## - # ############################## - # ############################## - def _send_hf_params(self, hf_named_tensors) -> tuple[list[ObjectRef], Any]: @@ -236,14 +167,13 @@ def _send_hf_params(self, hf_named_tensors) -> tuple[list[ObjectRef], Any]: ###########lora############### ############################## - # to-do (yusheng): need to deal with update_from_disk or tensor in this function + # to-do (yusheng): need to deal with lora update_from_tensor in sglang ############################## ############################## ############################## - all_refs = [] refs_colocated, long_lived_tensors = _send_to_colocated_engine( From 65948336953ebf99f6111f7894a2f2f15ec5d9d4 Mon Sep 17 00:00:00 2001 From: Yusheng Su Date: Sun, 18 Jan 2026 10:54:14 +0000 Subject: [PATCH 09/12] lora megatron backend - end2end training --- .../run-qwen2.5-0.5B-gsm8k-lora.sh | 13 +- miles/backends/megatron_utils/model.py | 66 +----- .../backends/megatron_utils/model_provider.py | 26 +-- .../hf_weight_iterator_bridge.py | 25 ++- .../update_weight_from_tensor.py | 208 +++++++++++++++--- miles/backends/sglang_utils/sglang_engine.py | 57 ++++- train.py | 5 +- 7 files changed, 265 insertions(+), 135 deletions(-) diff --git a/examples/reproducibility/run-qwen2.5-0.5B-gsm8k-lora.sh b/examples/reproducibility/run-qwen2.5-0.5B-gsm8k-lora.sh index 158e28cac..6fd84a731 100644 --- a/examples/reproducibility/run-qwen2.5-0.5B-gsm8k-lora.sh +++ b/examples/reproducibility/run-qwen2.5-0.5B-gsm8k-lora.sh @@ -1,5 +1,7 @@ #!/bin/bash +export FLASHINFER_DISABLE_VERSION_CHECK=1 + # for rerun the task pkill -9 sglang sleep 3 @@ -54,9 +56,9 @@ LORA_ARGS=( ############################## # # Debug #### inference - #--debug-rollout-only + # --debug-rollout-only ### --lora-adapter-path /root/checkpoints/qwen2.5-0.5B-lora-megatron/lora_adapter.pt - #--lora-adapter-path lewtun/Qwen2.5-0.5B-SFT-LoRA + --lora-adapter-path lewtun/Qwen2.5-0.5B-SFT-LoRA ## --lora-adapter-path /root/checkpoints/qwen2.5-0.5B-lora-megatron/ ### @@ -76,6 +78,10 @@ LORA_ARGS=( # Disable gradient accumulation fusion for LoRA training # --no-gradient-accumulation-fusion #Root cause: When training with LoRA, the base model’s parameters are frozen (requires_grad=False). However, Megatron-LM’s tensor-parallel layers use gradient-accumulation fusion during the backward pass, and that fusion path checks weight.main_grad.dtype. For frozen parameters, main_grad is never allocated (it remains None), which triggers the error. (enable) + + #### debug + --no-offload-train + # --no-offload-rollout ) ############################## ############################## @@ -151,7 +157,8 @@ OPTIMIZER_ARGS=( SGLANG_ARGS=( --rollout-num-gpus-per-engine 1 - --sglang-mem-fraction-static 0.7 + # --sglang-mem-fraction-static 0.7 + --sglang-mem-fraction-static 0.4 --sglang-enable-deterministic-inference --sglang-attention-backend flashinfer diff --git a/miles/backends/megatron_utils/model.py b/miles/backends/megatron_utils/model.py index 2ec58456b..6578eba29 100644 --- a/miles/backends/megatron_utils/model.py +++ b/miles/backends/megatron_utils/model.py @@ -187,55 +187,13 @@ def setup_model_and_optimizer( ############################## ###########lora############### ############################## - # model = get_model(get_model_provider_func(args, role), ModelType.encoder_or_decoder) - - # if is_lora_enabled(args): - - # from megatron.core.distributed import DistributedDataParallelConfig - # from megatron.bridge.models.model_provider import get_model - # provider = get_model_provider_func(args, role) - - # ddp_config = DistributedDataParallelConfig( - # grad_reduce_in_fp32=getattr(args, 'grad_reduce_in_fp32', False), - # check_for_nan_in_grad=getattr(args, 'check_for_nan_in_grad', False), - # overlap_grad_reduce=getattr(args, 'overlap_grad_reduce', False), - # overlap_param_gather=getattr(args, 'overlap_param_gather', False), - # average_in_collective=getattr(args, 'average_in_collective', False), - # use_distributed_optimizer=getattr(args, 'use_distributed_optimizer', False), - # ) - # # model = provider.provide_distributed_model( - # # ddp_config=ddp_config, - # # wrap_with_ddp=True, - # # bf16=getattr(args, 'bf16', False), - # # fp16=getattr(args, 'fp16', False), - # # ) - - # model = get_model( - # model_provider=provider, # must be ModelProviderMixin object - # ddp_config=ddp_config, - # model_type=ModelType.encoder_or_decoder, - # wrap_with_ddp=True, - # use_cpu_initialization=False, - # ) - - - # print(111111) - # print(model) - # print(111111) - # exit() - # else: - # model = get_model(get_model_provider_func(args, role), ModelType.encoder_or_decoder) - - - ########### + # This part can be moved to `lora_utils.py` def apply_lora_to_megatron_model if is_lora_enabled(args) and role == "actor" and args.megatron_to_hf_mode == "bridge": ###### # refer to: verl/verl/workers/engine/megatron/transformer_impl.py ###### - - # if is_lora_enabled(args) and args.megatron_to_hf_mode == "bridge": # The below written as: get_model_provider_func() usage @@ -245,15 +203,7 @@ def setup_model_and_optimizer( from megatron.bridge.peft.lora import LoRA from megatron.bridge.peft.canonical_lora import CanonicalLoRA import torch - - # This is register_canonical_lora_adapter usgae - more advnace and efficiency!!!! - # Compare lora, canonical_lora_adapter, .... - # Build the provider from HF checkpoint - - # model: start - # args.hf_checkpoint - # model: done # bridge : start # hf config: @@ -265,6 +215,7 @@ def setup_model_and_optimizer( # bridge : done ##!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! # provider: start + # I can also use bridge_get_model() method provider = bridge.to_megatron_provider(load_weights=False) # should be True??? # provider = bridge.to_megatron_provider(load_weights=True) # different from full model training - in the training script, I need to load tuned base model weight and initial lora weights. Need to carefully check and optimize - where to load the base model? (but why in `model_provider.py` using: provider = bridge.to_megatron_provider(load_weights=False)) @@ -316,7 +267,7 @@ def setup_model_and_optimizer( # print("========") # Create LoRA config - # (to-do) yusheng set - lora_type - the default is LoRA + # (to-do) yusheng set - lora_type - the default is LoRA: should support all lora lora = LoRA( # lora = CanonicalLoRA( target_modules=convert_target_modules_to_megatron(args.target_modules, lora_type=LoRA), @@ -371,10 +322,8 @@ def apply_lora_hook(model_chunks): # Register the hook provider.register_pre_wrap_hook(apply_lora_hook) # provider.finalize() - - - + #### ------------------- ##### #### ------------------- ##### #### ------------------- ##### @@ -408,11 +357,7 @@ def apply_lora_hook(model_chunks): # Register post-creation callbacks (make_value_model, freeze_moe_router) as pre-wrap hooks for callback in post_model_creation_callbacks: provider.register_pre_wrap_hook(callback) - - # print("====") - # print(wrap_config.wrap_with_ddp) - # print("====") - # exit() + if wrap_config.wrap_with_ddp: from megatron.bridge.training.config import DistributedDataParallelConfig @@ -469,6 +414,7 @@ def apply_lora_hook(model_chunks): # if hasattr(args, '_lora_instance'): # args._lora_instance = lora ############# + else: # Original non-LoRA path or non-bridge mode diff --git a/miles/backends/megatron_utils/model_provider.py b/miles/backends/megatron_utils/model_provider.py index cac560a6c..19aabfd42 100644 --- a/miles/backends/megatron_utils/model_provider.py +++ b/miles/backends/megatron_utils/model_provider.py @@ -80,15 +80,9 @@ def wrapped_model_provider( return wrapped_model_provider + if args.megatron_to_hf_mode == "bridge": from megatron.bridge import AutoBridge - ############################## - ###########lora############### - ############################## - # from miles.backends.megatron_utils.lora_utils import is_lora_enabled - ############################## - ############################## - ############################## bridge = AutoBridge.from_hf_pretrained(args.hf_checkpoint, trust_remote_code=True) provider = bridge.to_megatron_provider(load_weights=False) @@ -100,24 +94,6 @@ def wrapped_model_provider( provider.expert_tensor_parallel_size = args.expert_tensor_parallel_size provider.sequence_parallel = args.sequence_parallel - # ############################## - # ###########lora############### - # ############################## - # if is_lora_enabled(args) and role == "actor": - # provider.virtual_pipeline_model_parallel_size = args.virtual_pipeline_model_parallel_size - # provider.context_parallel_size = args.context_parallel_size - # provider.variable_seq_lengths = True - # provider.moe_token_dispatcher_type = "alltoall" - # provider.moe_router_load_balancing_type = "none" - # provider.finalize() - # return provider.provide - # else: - # provider.finalize() - # return provider.provide - # ############################## - # ############################## - # ############################## - provider.finalize() return provider.provide diff --git a/miles/backends/megatron_utils/update_weight/hf_weight_iterator_bridge.py b/miles/backends/megatron_utils/update_weight/hf_weight_iterator_bridge.py index 9e316ec4e..d82740381 100644 --- a/miles/backends/megatron_utils/update_weight/hf_weight_iterator_bridge.py +++ b/miles/backends/megatron_utils/update_weight/hf_weight_iterator_bridge.py @@ -16,6 +16,22 @@ def _normalize_base_weight_name(param_name: str) -> str: if param_name.endswith("base_layer.weight"): return param_name[: -len("base_layer.weight")] + "weight" return param_name + +# CanonicalLoRA - same as sglang +# Lora - need to below convert (or sglang also need to use Lora) +# def _convert_lora_name_for_sglang(hf_param_name: str) -> str: +# """Convert standard HF LoRA names to SGLang's merged projection format.""" +# # Handle attention projections - SGLang uses merged qkv_proj +# for proj in ['q_proj', 'k_proj', 'v_proj']: +# if f'.self_attn.{proj}.' in hf_param_name: +# return hf_param_name.replace(f'.self_attn.{proj}.', '.self_attn.qkv_proj.') + +# # Handle MLP projections - SGLang uses merged gate_up_proj +# for proj in ['gate_proj', 'up_proj']: +# if f'.mlp.{proj}.' in hf_param_name: +# return hf_param_name.replace(f'.mlp.{proj}.', '.mlp.gate_up_proj.') + +# return hf_param_name ############################## ############################## ############################## @@ -44,6 +60,7 @@ def __init__(self, *args, **kwargs): def get_hf_weight_chunks(self, megatron_local_weights): # TODO support quantization (e.g. modify megatron-bridge to provide megatron param name) + renamed_megatron_local_weights = {strip_param_name_prefix(k): v for k, v in megatron_local_weights.items()} @@ -130,25 +147,27 @@ def get_hf_weight_chunks(self, megatron_local_weights): ############################## ###########lora############### ############################## - # print(4444444) if self.is_lora: # (to-do) yusheng: I might need to add the converting weights (mg --> hf) - refer above # conversion_tasks = self._bridge.get_conversion_tasks(self.model) # conversion_tasks = _process_conversion_tasks(conversion_tasks, renamed_megatron_local_weights) + # print(self.model) # Identity() + lora_weights = self._bridge.export_adapter_weights( self.model, cpu=False, - # cpu=True, ### if False, it will have the problem - why? + # cpu=True, ### if False, it will have cudaaccess error # conversion_tasks=conversion_tasks, #### show_progress=False ) + # hf_param_name's might have big problem lora_weights = ( ( hf_param_name, - postprocess_hf_param( + postprocess_hf_param( # check if need postprocess_hf_param args=self.args, megatron_param_name=megatron_param_name, hf_param_name=hf_param_name, diff --git a/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py b/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py index 9ffe08e62..5cf724e3e 100644 --- a/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py +++ b/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py @@ -23,6 +23,10 @@ ###########lora############### ############################## from miles.backends.megatron_utils.lora_utils import LORA_ADAPTER_NAME, is_lora_enabled + +def _is_lora_weight(name: str) -> bool: + """Check if a weight name is a LoRA adapter weight.""" + return ".lora_A." in name or ".lora_B." in name ############################## ############################## ############################## @@ -79,6 +83,22 @@ def __init__( is_lora=self.is_lora, _base_synced=self._base_synced, ) + + + # Store LoRA config for weight sync (from tensor) - why this - let me think think + if self.is_lora: + from miles.backends.sglang_utils.sglang_engine import convert_target_modules_to_hf + self._lora_config = { + "peft_type": "LORA", + "r": args.lora_rank, + "lora_alpha": args.lora_alpha, + "target_modules": convert_target_modules_to_hf(list(args.target_modules)) if args.target_modules else ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], + "lora_dropout": args.lora_dropout, + "bias": "none", + "task_type": "CAUSAL_LM", + } + else: + self._lora_config = None ############################## ############################## ############################## @@ -150,53 +170,98 @@ def update_weights(self) -> None: megatron_local_weights = self.weights_getter() - # error in self._hf_weight_iterator.get_hf_weight_chunks(megatron_local_weights) + for hf_named_tensors in self._hf_weight_iterator.get_hf_weight_chunks(megatron_local_weights): - import logging - logger = logging.getLogger(__name__) refs, long_lived_tensors = self._send_hf_params(hf_named_tensors) ray.get(refs) del long_lived_tensors dist.barrier(group=get_gloo_group()) - - def _send_hf_params(self, hf_named_tensors) -> tuple[list[ObjectRef], Any]: + ############################## + ###########lora############### + ############################## + # def _send_hf_params(self, hf_named_tensors) -> tuple[list[ObjectRef], Any]: - ############################## - ###########lora############### - ############################## - - # to-do (yusheng): need to deal with lora update_from_tensor in sglang + # all_refs = [] + + # refs_colocated, long_lived_tensors = _send_to_colocated_engine( + # hf_named_tensors, + # ipc_engine=self._ipc_engine, + # ipc_gather_src=self._ipc_gather_src, + # ipc_gather_group=self._ipc_gather_group, + # weight_version=self.weight_version, + # ) + # all_refs.extend(refs_colocated) + + # if self.use_distribute and self._is_distributed_src_rank: + # refs_distributed = update_weights_from_distributed( + # self._group_name, + # self._model_update_groups, + # self.weight_version, + # self.distributed_rollout_engines, + # hf_named_tensors, + # ) + # if refs_distributed: + # all_refs.extend(refs_distributed) + + # return all_refs, long_lived_tensors + + ############## + + def _send_hf_params(self, hf_named_tensors) -> tuple[list[ObjectRef], Any]: - ############################## - ############################## - ############################## + all_refs = [] + # Separate LoRA weights from base weights + if self.is_lora: + base_tensors = [(n, t) for n, t in hf_named_tensors if not _is_lora_weight(n)] + lora_tensors = [(n, t) for n, t in hf_named_tensors if _is_lora_weight(n)] + else: + base_tensors = hf_named_tensors + lora_tensors = [] + + # Send base model weights via update_weights_from_tensor + long_lived_tensors = [] + if base_tensors: + refs_colocated, long_lived_tensors = _send_to_colocated_engine( + base_tensors, + ipc_engine=self._ipc_engine, + ipc_gather_src=self._ipc_gather_src, + ipc_gather_group=self._ipc_gather_group, + weight_version=self.weight_version, + ) + all_refs.extend(refs_colocated) + + if self.use_distribute and self._is_distributed_src_rank: + refs_distributed = update_weights_from_distributed( + self._group_name, + self._model_update_groups, + self.weight_version, + self.distributed_rollout_engines, + base_tensors, + ) + if refs_distributed: + all_refs.extend(refs_distributed) - all_refs = [] - - refs_colocated, long_lived_tensors = _send_to_colocated_engine( - hf_named_tensors, - ipc_engine=self._ipc_engine, - ipc_gather_src=self._ipc_gather_src, - ipc_gather_group=self._ipc_gather_group, - weight_version=self.weight_version, - ) - all_refs.extend(refs_colocated) - - if self.use_distribute and self._is_distributed_src_rank: - refs_distributed = update_weights_from_distributed( - self._group_name, - self._model_update_groups, - self.weight_version, - self.distributed_rollout_engines, - hf_named_tensors, + # Send LoRA weights via load_lora_adapter_from_tensors + if lora_tensors and self._lora_config is not None: + refs_lora, lora_long_lived = _send_lora_to_colocated_engine( + lora_tensors, + ipc_engine=self._ipc_engine, + ipc_gather_src=self._ipc_gather_src, + ipc_gather_group=self._ipc_gather_group, + lora_config=self._lora_config, + lora_name=LORA_ADAPTER_NAME, ) - if refs_distributed: - all_refs.extend(refs_distributed) + all_refs.extend(refs_lora) + long_lived_tensors.extend(lora_long_lived) return all_refs, long_lived_tensors + + ############################## + ############################## + ############################## def _send_to_colocated_engine( @@ -219,7 +284,7 @@ def _send_to_colocated_engine( if dtype not in converted_named_tensors_by_dtypes: converted_named_tensors_by_dtypes[dtype] = [] converted_named_tensors_by_dtypes[dtype].append((name, tensor)) - + serialized_tensors = [] for _dtype, named_tensors in converted_named_tensors_by_dtypes.items(): flattened_tensor_bucket = FlattenedTensorBucket(named_tensors=named_tensors) @@ -254,3 +319,78 @@ def _send_to_colocated_engine( refs.append(ipc_engine.update_weights_from_tensor.remote(**kwargs)) return refs, long_live_tensors + + + + + +############################## +###########lora############### +############################## +def _send_lora_to_colocated_engine( + lora_named_tensors: list[tuple[str, torch.Tensor]], + *, + ipc_engine, + ipc_gather_src, + ipc_gather_group, + lora_config: dict, + lora_name: str, +) -> tuple[list[ObjectRef], Any]: + """Send LoRA weights to colocated engine via load_lora_adapter_from_tensors. + + Uses FlattenedTensorBucket for cross-process serialization (same as base weights). + """ + + long_live_tensors = [] + + # Use FlattenedTensorBucket (same as base weights) for proper cross-process serialization + flattened_tensor_bucket = FlattenedTensorBucket(named_tensors=lora_named_tensors) + metadata = flattened_tensor_bucket.get_metadata() + flattened_tensor_data = { + "flattened_tensor": flattened_tensor_bucket.get_flattened_tensor(), + "metadata": metadata, + } + long_live_tensors.append(flattened_tensor_data) + serialized_lora = MultiprocessingSerializer.serialize(flattened_tensor_data, output_str=True) + + # Gather from all ranks in the group + serialized_lora_gathered = ( + [None] * dist.get_world_size(ipc_gather_group) if ipc_gather_src == dist.get_rank() else None + ) + dist.gather_object( + serialized_lora, + object_gather_list=serialized_lora_gathered, + dst=ipc_gather_src, + group=ipc_gather_group, + ) + + # refs = [] + # if dist.get_rank() == ipc_gather_src: + # # Send LoRA via the same mechanism as base weights, but use a special endpoint + # refs.append(ipc_engine.load_lora_adapter_from_tensors.remote( + # lora_name=lora_name, + # serialized_tensors=serialized_lora_gathered[0], # FlattenedTensorBucket format + # config_dict=lora_config, + # load_format="flattened_bucket", # Add this to indicate the format + # )) + + refs = [] + if dist.get_rank() == ipc_gather_src: + # First, unload the existing LoRA adapter (if any) + try: + ray.get(ipc_engine.unload_lora_adapter.remote(lora_name=lora_name)) + except Exception: + pass # Ignore error if adapter was not loaded + + # Then load the new LoRA weights + refs.append(ipc_engine.load_lora_adapter_from_tensors.remote( + lora_name=lora_name, + serialized_tensors=serialized_lora_gathered[0], # FlattenedTensorBucket format + config_dict=lora_config, + load_format="flattened_bucket", # Add this to indicate the format + )) + + return refs, long_live_tensors +############################## +############################## +############################## \ No newline at end of file diff --git a/miles/backends/sglang_utils/sglang_engine.py b/miles/backends/sglang_utils/sglang_engine.py index 3ebfd8558..814dba23d 100644 --- a/miles/backends/sglang_utils/sglang_engine.py +++ b/miles/backends/sglang_utils/sglang_engine.py @@ -56,6 +56,7 @@ def convert_target_modules_to_hf(megatron_modules: list[str]) -> list[str]: hf_modules.append(module) return hf_modules + ############################## ############################## ############################## @@ -330,6 +331,40 @@ def update_weights_from_tensor( "update_weights_from_tensor", payload, ) + + ############################## + ###########lora############### + ############################## + def load_lora_adapter_from_tensors( + self, + lora_name: str, + serialized_tensors: str, + config_dict: dict, + load_format: str | None = None, # Add this parameter + pinned: bool = False, + added_tokens_config: dict | None = None, + ): + """ + Load a LoRA adapter from serialized tensor data. + """ + payload = { + "lora_name": lora_name, + "serialized_tensors": serialized_tensors, + "config_dict": config_dict, + "pinned": pinned, + } + if load_format is not None: + payload["load_format"] = load_format + if added_tokens_config is not None: + payload["added_tokens_config"] = added_tokens_config + + return self._make_request( + "load_lora_adapter_from_tensors", + payload, + ) + ############################## + ############################## + ############################## def flush_cache(self): """Flush the cache of the server.""" @@ -409,12 +444,12 @@ def get_weight_version(self): # {"lora_name": lora_name, "serialized_tensors": serialized_tensors, "config_dict": config_dict}, # ) - # def unload_lora_adapter(self, lora_name: str): - # """Unload LoRA adapter.""" - # return self._make_request( - # "unload_lora_adapter", - # {"lora_name": lora_name}, - # ) + def unload_lora_adapter(self, lora_name: str): + """Unload LoRA adapter.""" + return self._make_request( + "unload_lora_adapter", + {"lora_name": lora_name}, + ) ############################## ############################## ############################## @@ -635,10 +670,18 @@ def _compute_server_args( kwargs["lora_target_modules"] = convert_target_modules_to_hf(args.target_modules) ##### For rollout debug mode to add: - if args.debug_rollout_only and args.lora_adapter_path: + if args.debug_rollout_only: + from miles.backends.megatron_utils.lora_utils import LORA_ADAPTER_NAME + # SGLang lora_paths Format: {"adapter_name": "path_to_adapter"} + kwargs["lora_paths"] = {LORA_ADAPTER_NAME: args.lora_adapter_path} + + if args.lora_adapter_path is None: + raise ValueError("lora_adapter_path must be provided") + else: from miles.backends.megatron_utils.lora_utils import LORA_ADAPTER_NAME # SGLang lora_paths Format: {"adapter_name": "path_to_adapter"} kwargs["lora_paths"] = {LORA_ADAPTER_NAME: args.lora_adapter_path} + ############################## ############################## ############################## diff --git a/train.py b/train.py index a86106e0d..c41687ba8 100644 --- a/train.py +++ b/train.py @@ -53,15 +53,14 @@ def train(args): # ###########lora############### # ############################## + if args.offload_rollout: ray.get(rollout_manager.onload.remote(tags=[GPU_MEMORY_TYPE_WEIGHTS])) # always update weight first so that sglang has the loaded weights from training. - print(11111) actor_model.update_weights() - print(22222) - exit() + if args.check_weight_update_equal: ray.get(rollout_manager.check_weights.remote(action="compare")) From ee92631f2df8c67f341023cc245ab23ec7e5b0b3 Mon Sep 17 00:00:00 2001 From: Yusheng Su Date: Wed, 21 Jan 2026 08:41:08 +0000 Subject: [PATCH 10/12] enable no --lora-adapter-path --- .../run-qwen2.5-0.5B-gsm8k-lora.sh | 36 ++++++++++++------- miles/backends/sglang_utils/sglang_engine.py | 11 +++--- miles/utils/arguments.py | 2 +- 3 files changed, 31 insertions(+), 18 deletions(-) diff --git a/examples/reproducibility/run-qwen2.5-0.5B-gsm8k-lora.sh b/examples/reproducibility/run-qwen2.5-0.5B-gsm8k-lora.sh index 6fd84a731..a19e49819 100644 --- a/examples/reproducibility/run-qwen2.5-0.5B-gsm8k-lora.sh +++ b/examples/reproducibility/run-qwen2.5-0.5B-gsm8k-lora.sh @@ -25,8 +25,8 @@ CKPT_ARGS=( # --ref-load /root/Qwen2.5-0.5B-Instruct_torch_dist/ # Uncomment to save checkpoints (required for LoRA) #### train - --save /root/checkpoints/qwen2.5-0.5B-lora-megatron/ - --save-interval 5 + # --save /root/checkpoints/qwen2.5-0.5B-lora-megatron/ + # --save-interval 100 ### ) @@ -39,7 +39,7 @@ CKPT_ARGS=( ###########lora############### ############################## LORA_ARGS=( - --lora-rank 16 # LoRA rank (typical values: 8, 16, 32, 64) + --lora-rank 32 # LoRA rank (typical values: 8, 16, 32, 64) --lora-alpha 32 # LoRA alpha (usually 2x rank) --lora-dropout 0.0 # LoRA dropout (0.0 for RL training) # Target modules - use Megatron naming or HF naming @@ -58,7 +58,7 @@ LORA_ARGS=( #### inference # --debug-rollout-only ### --lora-adapter-path /root/checkpoints/qwen2.5-0.5B-lora-megatron/lora_adapter.pt - --lora-adapter-path lewtun/Qwen2.5-0.5B-SFT-LoRA + # --lora-adapter-path lewtun/Qwen2.5-0.5B-SFT-LoRA ## --lora-adapter-path /root/checkpoints/qwen2.5-0.5B-lora-megatron/ ### @@ -94,20 +94,21 @@ ROLLOUT_ARGS=( --apply-chat-template --rollout-shuffle --rm-type math - # --num-rollout 100 - --num-rollout 10 # onyl train 10 stesp - # --rollout-batch-size 32 - --rollout-batch-size 16 # for testing + --num-rollout 100 + # --num-rollout 10 # onyl train 10 stesp + --rollout-batch-size 32 + # --rollout-batch-size 16 # for testing --n-samples-per-prompt 8 --rollout-max-response-len 1024 --rollout-temperature 1 - # --global-batch-size 256 - --global-batch-size 32 # for testing + --global-batch-size 256 + # --global-batch-size 32 # for testing ) EVAL_ARGS=( - --eval-interval 20 + # --eval-interval 20 + --eval-interval 10 --eval-prompt-data gsm8k /root/gsm8k/test.parquet --n-samples-per-eval-prompt 1 --eval-max-response-len 1024 @@ -155,6 +156,15 @@ OPTIMIZER_ARGS=( # --wandb-group qwen2.5-0.5B-gsm8k-deterministic # ) +WANDB_ARGS=( + --use-wandb + --wandb-host https://wandb.ai/ + --wandb-team miles-lora + --wandb-project miles-lora-megatron + --wandb-group qwen2.5-0.5B-gsm8k-test +) + + SGLANG_ARGS=( --rollout-num-gpus-per-engine 1 # --sglang-mem-fraction-static 0.7 @@ -182,9 +192,9 @@ MISC_ARGS=( ###########lora############### ############################## ######## Note: Need to set export CUDA_VISIBLE_DEVICES= , or it will fail and have cuda error -export GPUS_PER_NODE=1 +# export GPUS_PER_NODE=1 # export GPUS_PER_NODE=2 -# export GPUS_PER_NODE=4 +export GPUS_PER_NODE=4 # export GPUS_PER_NODE=8 ############################## ############################## diff --git a/miles/backends/sglang_utils/sglang_engine.py b/miles/backends/sglang_utils/sglang_engine.py index 814dba23d..c3abdd325 100644 --- a/miles/backends/sglang_utils/sglang_engine.py +++ b/miles/backends/sglang_utils/sglang_engine.py @@ -653,10 +653,10 @@ def _compute_server_args( # # print(2222222222) # # exit() - # if args.lora_rank > 0 or args.lora_adapter_path is not None: - if is_lora_enabled(args): - kwargs["max_loras_per_batch"] = 1 #!!!!!!!! + # if is_lora_enabled(args): + if args.lora_rank > 0 or args.lora_adapter_path is not None: kwargs["enable_lora"] = True + kwargs["max_loras_per_batch"] = 1 #!!!!!!!! # kwargs["max_lora_rank"] = args.lora_rank # Ensure a valid positive LoRA rank is passed to the SGLang engine. # If LoRA is enabled via adapter path but lora_rank is not set to a @@ -676,7 +676,10 @@ def _compute_server_args( kwargs["lora_paths"] = {LORA_ADAPTER_NAME: args.lora_adapter_path} if args.lora_adapter_path is None: - raise ValueError("lora_adapter_path must be provided") + # raise ValueError("lora_adapter_path must be provided") + # raise ValueError("lora_adapter_path must be provided") + # pass + logger.info("Did not provide pre-trained LoRA adapter_path, will use random initial weights") else: from miles.backends.megatron_utils.lora_utils import LORA_ADAPTER_NAME # SGLang lora_paths Format: {"adapter_name": "path_to_adapter"} diff --git a/miles/utils/arguments.py b/miles/utils/arguments.py index 76a042cab..2b1822304 100644 --- a/miles/utils/arguments.py +++ b/miles/utils/arguments.py @@ -1613,7 +1613,7 @@ def miles_validate_args(args): ###########lora############### ############################## if args.lora_rank > 0: - assert args.save is not None, "'--save' is required when LoRA is enabled." + # assert args.save is not None, "'--save' is required when LoRA is enabled." assert args.target_modules is not None, "'--target-modules' is required when LoRA is enabled." # (to-do) yusheng: hf->mg; mg->hf From 4953c113420288d1014a9c9ccf391d21eb8d516a Mon Sep 17 00:00:00 2001 From: Yusheng Su Date: Thu, 22 Jan 2026 06:01:01 +0000 Subject: [PATCH 11/12] update script --- .../run-qwen2.5-0.5B-gsm8k-lora.sh | 4 + examples/reproducibility/run-qwen3-4B-lora.sh | 206 ++++++++++++++++++ miles/backends/megatron_utils/model.py | 4 +- 3 files changed, 212 insertions(+), 2 deletions(-) create mode 100644 examples/reproducibility/run-qwen3-4B-lora.sh diff --git a/examples/reproducibility/run-qwen2.5-0.5B-gsm8k-lora.sh b/examples/reproducibility/run-qwen2.5-0.5B-gsm8k-lora.sh index a19e49819..4e54e2c3f 100644 --- a/examples/reproducibility/run-qwen2.5-0.5B-gsm8k-lora.sh +++ b/examples/reproducibility/run-qwen2.5-0.5B-gsm8k-lora.sh @@ -1,5 +1,8 @@ #!/bin/bash +# Debug this: +# "--offload-rollout-level kv_cache weight " + export FLASHINFER_DISABLE_VERSION_CHECK=1 # for rerun the task @@ -156,6 +159,7 @@ OPTIMIZER_ARGS=( # --wandb-group qwen2.5-0.5B-gsm8k-deterministic # ) + WANDB_ARGS=( --use-wandb --wandb-host https://wandb.ai/ diff --git a/examples/reproducibility/run-qwen3-4B-lora.sh b/examples/reproducibility/run-qwen3-4B-lora.sh new file mode 100644 index 000000000..c167424ae --- /dev/null +++ b/examples/reproducibility/run-qwen3-4B-lora.sh @@ -0,0 +1,206 @@ +#!/bin/bash + +# Example launcher that reuses the Qwen3-4B recipe but delegates evaluation to an +# external Nemo Skills server via the eval_delegate_rollout wrapper. + +# Clean up any stale processes from a previous run. +pkill -9 sglang +sleep 3 +ray stop --force +pkill -9 ray +pkill -9 python +sleep 3 +pkill -9 ray +pkill -9 python + +set -ex + +SKILLS_OPENAI_MODEL_NAME=${SKILLS_OPENAI_MODEL_NAME:-"miles-openai-model"} + + +export PYTHONBUFFERED=16 + +NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l) +if [ "$NVLINK_COUNT" -gt 0 ]; then + HAS_NVLINK=1 +else + HAS_NVLINK=0 +fi +echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" +REPO_ROOT="$(cd -- "${SCRIPT_DIR}/../../.." &>/dev/null && pwd)" +source "${REPO_ROOT}/miles/scripts/models/qwen3-4B.sh" + +# Store eval/delegate settings in a YAML config similar to examples/eval_multi_task. +# EVAL_CONFIG_PATH=${SKILLS_EVAL_CONFIG_PATH:-"${REPO_ROOT}/examples/eval/scripts/multi_tasks.yaml"} +EVAL_CONFIG_PATH=${SKILLS_EVAL_CONFIG_PATH:-"${REPO_ROOT}/miles/examples/eval/scripts/multi_tasks.yaml"} + +CKPT_ARGS=( + # --hf-checkpoint /root/Qwen3-4B + --hf-checkpoint /root/models/Qwen3-4B + # --ref-load /root/Qwen3-4B_torch_dist + # --load /root/Qwen3-4B_miles/ + # --save /root/Qwen3-4B_miles/ + # --save-interval 20 +) + + +LORA_ARGS=( + --lora-rank 32 # LoRA rank (typical values: 8, 16, 32, 64) + --lora-alpha 32 # LoRA alpha (usually 2x rank) + --lora-dropout 0.0 # LoRA dropout (0.0 for RL training) + --target-modules "all-linear" + --megatron-to-hf-mode bridge + #### debug + --no-offload-train + # --no-offload-rollout +) + + + +ROLLOUT_ARGS=( + # --prompt-data /root/dapo-math-17k/dapo-math-17k.jsonl + --prompt-data /root/datasets/dapo-math-17k/dapo-math-17k.jsonl + --input-key prompt + --label-key label + --apply-chat-template + --rollout-shuffle + --rm-type deepscaler + --num-rollout 3000 + # --rollout-batch-size 32 + --rollout-batch-size 16 + --n-samples-per-prompt 8 + # --rollout-max-response-len 8192 + --rollout-max-response-len 2048 + --rollout-temperature 1 + --over-sampling-batch-size 64 + + --dynamic-sampling-filter-path miles.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std + # --global-batch-size 256 + --global-batch-size 128 + --balance-data +) + +# EVAL_ARGS=( +# --eval-interval 5 +# --eval-config "${EVAL_CONFIG_PATH}" +# --eval-function-path examples.eval.eval_delegate_rollout.generate_rollout +# ) + +EVAL_ARGS=( + --eval-interval 5 + --eval-prompt-data aime24 /root/datasets/aime-2024/aime-2024.jsonl + --n-samples-per-eval-prompt 2 + --eval-max-response-len 16384 + --eval-top-k 1 +) + +PERF_ARGS=( + --tensor-model-parallel-size 2 + --sequence-parallel + --pipeline-model-parallel-size 1 + --context-parallel-size 1 + --expert-model-parallel-size 1 + --expert-tensor-parallel-size 1 + + --recompute-granularity full + --recompute-method uniform + --recompute-num-layers 1 + + --use-dynamic-batch-size + --max-tokens-per-gpu 9216 +) + +GRPO_ARGS=( + --advantage-estimator grpo + # --use-kl-loss + --kl-loss-coef 0.00 + --kl-loss-type low_var_kl + --entropy-coef 0.00 + --eps-clip 0.2 + --eps-clip-high 0.28 +) + +OPTIMIZER_ARGS=( + --optimizer adam + # --lr 1e-6 + --lr 1e-5 + --lr-decay-style constant + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.98 +) + +WANDB_ARGS=( + --use-wandb + --wandb-host https://wandb.ai/ + --wandb-team miles-lora + --wandb-project miles-lora-megatron + --wandb-group qwen3-4B-test +) + +SGLANG_ARGS=( + --rollout-num-gpus-per-engine 1 + # --sglang-mem-fraction-static 0.7 + --sglang-mem-fraction-static 0.4 + + --sglang-enable-deterministic-inference + --sglang-attention-backend flashinfer + + --deterministic-mode +) + +MISC_ARGS=( + --attention-dropout 0.0 + --hidden-dropout 0.0 + --accumulate-allreduce-grads-in-fp32 + --attention-softmax-in-fp32 + --attention-backend flash +) + +export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} +# export CUDA_VISIBLE_DEVICES=0,1 +# Set Up Your GPUs for Training + +# export GPUS_PER_NODE=2 #default +export GPUS_PER_NODE=4 + +ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus $GPUS_PER_NODE --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 + +RUNTIME_ENV_JSON="{ + \"env_vars\": { + \"PYTHONPATH\": \"/root/Megatron-LM/\", + \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\" + } +}" + + +# ray job submit --address="http://127.0.0.1:8265" \ + # --runtime-env-json="${RUNTIME_ENV_JSON}" \ + +ray job submit --address="http://127.0.0.1:8265" \ + --runtime-env-json='{ + "env_vars": { + "PYTHONPATH": "/root/Megatron-LM", + "CUDA_DEVICE_MAX_CONNECTIONS": "1", + "NCCL_ALGO": "Ring", + "NVTE_ALLOW_NONDETERMINISTIC_ALGO": "0", + "CUBLAS_WORKSPACE_CONFIG": ":4096:8" + } + }' \ + -- python3 train.py \ + --actor-num-nodes 1 \ + --actor-num-gpus-per-node $GPUS_PER_NODE \ + --colocate \ + ${MODEL_ARGS[@]} \ + ${CKPT_ARGS[@]} \ + ${ROLLOUT_ARGS[@]} \ + ${OPTIMIZER_ARGS[@]} \ + ${GRPO_ARGS[@]} \ + ${WANDB_ARGS[@]} \ + ${PERF_ARGS[@]} \ + ${EVAL_ARGS[@]} \ + ${SGLANG_ARGS[@]} \ + ${MISC_ARGS[@]} \ + ${LORA_ARGS[@]} diff --git a/miles/backends/megatron_utils/model.py b/miles/backends/megatron_utils/model.py index 6578eba29..910d6a617 100644 --- a/miles/backends/megatron_utils/model.py +++ b/miles/backends/megatron_utils/model.py @@ -279,8 +279,8 @@ def setup_model_and_optimizer( dropout=args.lora_dropout, ##Below for Lora # dropout_position=getattr(args, 'lora_dropout_position', 'pre'), - # lora_A_init_method=getattr(args, 'lora_A_init_method', 'xavier'), - # lora_B_init_method=getattr(args, 'lora_B_init_method', 'zero'), + lora_A_init_method=getattr(args, 'lora_A_init_method', 'xavier'), + lora_B_init_method=getattr(args, 'lora_B_init_method', 'zero'), # a2a_experimental=getattr(args, 'lora_a2a_experimental', False), ) From 8da0dfc677cb96dc7c466fa796fe44f2314440bf Mon Sep 17 00:00:00 2001 From: Yusheng Su Date: Thu, 22 Jan 2026 07:18:35 +0000 Subject: [PATCH 12/12] to-do: need to enable --no-offload-train and --no-offload-rollout --- examples/reproducibility/run-qwen2.5-0.5B-gsm8k-lora.sh | 2 +- examples/reproducibility/run-qwen3-4B-lora.sh | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/reproducibility/run-qwen2.5-0.5B-gsm8k-lora.sh b/examples/reproducibility/run-qwen2.5-0.5B-gsm8k-lora.sh index 4e54e2c3f..2df0390f9 100644 --- a/examples/reproducibility/run-qwen2.5-0.5B-gsm8k-lora.sh +++ b/examples/reproducibility/run-qwen2.5-0.5B-gsm8k-lora.sh @@ -84,7 +84,7 @@ LORA_ARGS=( #### debug --no-offload-train - # --no-offload-rollout + --no-offload-rollout ) ############################## ############################## diff --git a/examples/reproducibility/run-qwen3-4B-lora.sh b/examples/reproducibility/run-qwen3-4B-lora.sh index c167424ae..c099ef603 100644 --- a/examples/reproducibility/run-qwen3-4B-lora.sh +++ b/examples/reproducibility/run-qwen3-4B-lora.sh @@ -36,6 +36,7 @@ source "${REPO_ROOT}/miles/scripts/models/qwen3-4B.sh" # EVAL_CONFIG_PATH=${SKILLS_EVAL_CONFIG_PATH:-"${REPO_ROOT}/examples/eval/scripts/multi_tasks.yaml"} EVAL_CONFIG_PATH=${SKILLS_EVAL_CONFIG_PATH:-"${REPO_ROOT}/miles/examples/eval/scripts/multi_tasks.yaml"} + CKPT_ARGS=( # --hf-checkpoint /root/Qwen3-4B --hf-checkpoint /root/models/Qwen3-4B @@ -54,7 +55,7 @@ LORA_ARGS=( --megatron-to-hf-mode bridge #### debug --no-offload-train - # --no-offload-rollout + --no-offload-rollout )