diff --git a/examples/reproducibility/run-qwen2.5-0.5B-gsm8k-lora.sh b/examples/reproducibility/run-qwen2.5-0.5B-gsm8k-lora.sh new file mode 100644 index 000000000..2df0390f9 --- /dev/null +++ b/examples/reproducibility/run-qwen2.5-0.5B-gsm8k-lora.sh @@ -0,0 +1,242 @@ +#!/bin/bash + +# Debug this: +# "--offload-rollout-level kv_cache weight " + +export FLASHINFER_DISABLE_VERSION_CHECK=1 + +# for rerun the task +pkill -9 sglang +sleep 3 +ray stop --force +pkill -9 ray +pkill -9 python +sleep 3 +pkill -9 ray +pkill -9 python + +set -ex + +# will prevent ray from buffering stdout/stderr +export PYTHONBUFFERED=16 + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" +source "${SCRIPT_DIR}/../../scripts/models/qwen2.5-0.5B.sh" + +CKPT_ARGS=( + --hf-checkpoint /root/Qwen2.5-0.5B-Instruct/ + # --ref-load /root/Qwen2.5-0.5B-Instruct_torch_dist/ + # Uncomment to save checkpoints (required for LoRA) + #### train + # --save /root/checkpoints/qwen2.5-0.5B-lora-megatron/ + # --save-interval 100 + ### +) + +# target-module only support: linear (the proj_ need to be supported in future: in Megatron-Bridge/src/megatron/bridge/models/conversion/peft_bridge.py - build_adapter_conversion_tasks) +# example: if one module have two lora: (linear_proj): LoRALinear(), (linear_fc2): LoRALinear() + +# LORA_TARGET_MODULES=${LORA_TARGET_MODULES:-"['linear_qkv','linear_proj','linear_fc1','linear_fc2']"}. It will broken + +############################## +###########lora############### +############################## +LORA_ARGS=( + --lora-rank 32 # LoRA rank (typical values: 8, 16, 32, 64) + --lora-alpha 32 # LoRA alpha (usually 2x rank) + --lora-dropout 0.0 # LoRA dropout (0.0 for RL training) + # Target modules - use Megatron naming or HF naming + # Megatron: linear_qkv, linear_proj, linear_fc1, linear_fc2 + # Need this PR: Update LoRA Weights via Tensor sgl-project/sglang#16226 + # --lora-sync-from-tensor # Use tensor-based sync (more efficient) + # # Uncomment to share base model between actor and ref (saves memory) + # --share-ref-base-model + + --target-modules "all-linear" + # --target-modules "o_proj,down_proj,k_proj,gate_proj,q_proj,v_proj,up_proj" + # --target-modules "q_proj,k_proj,v_proj,o_proj" + ############################## + ############################## + # # Debug + #### inference + # --debug-rollout-only + ### --lora-adapter-path /root/checkpoints/qwen2.5-0.5B-lora-megatron/lora_adapter.pt + # --lora-adapter-path lewtun/Qwen2.5-0.5B-SFT-LoRA + ## --lora-adapter-path /root/checkpoints/qwen2.5-0.5B-lora-megatron/ + ### + + #### train + # --debug-train-only + # --load-debug-rollout-data /root/debug_data/rollout_data.pt + ## --save /root/checkpoints/qwen2.5-0.5B-lora-megatron/ + + # --save-debug-rollout-data /root/debug_data/rollout_data.pt + ### + ############################## + ############################## + # --no-use-distributed-optimizer # if open it will has error: /home/radixark/yushengsu/miles-pr/miles/miles/utils/arguments.py: + #def set_default_megatron_args(args): (error) # optimizer cannot distributed to other gpus (enable) + + --megatron-to-hf-mode bridge + # Disable gradient accumulation fusion for LoRA training + + # --no-gradient-accumulation-fusion #Root cause: When training with LoRA, the base model’s parameters are frozen (requires_grad=False). However, Megatron-LM’s tensor-parallel layers use gradient-accumulation fusion during the backward pass, and that fusion path checks weight.main_grad.dtype. For frozen parameters, main_grad is never allocated (it remains None), which triggers the error. (enable) + + #### debug + --no-offload-train + --no-offload-rollout +) +############################## +############################## +############################## + +ROLLOUT_ARGS=( + --prompt-data /root/gsm8k/train.parquet + --input-key messages + --label-key label + --apply-chat-template + --rollout-shuffle + --rm-type math + --num-rollout 100 + # --num-rollout 10 # onyl train 10 stesp + --rollout-batch-size 32 + # --rollout-batch-size 16 # for testing + --n-samples-per-prompt 8 + --rollout-max-response-len 1024 + --rollout-temperature 1 + + --global-batch-size 256 + # --global-batch-size 32 # for testing +) + +EVAL_ARGS=( + # --eval-interval 20 + --eval-interval 10 + --eval-prompt-data gsm8k /root/gsm8k/test.parquet + --n-samples-per-eval-prompt 1 + --eval-max-response-len 1024 + --eval-top-k 1 +) + +PERF_ARGS=( + --tensor-model-parallel-size 1 + --sequence-parallel + --pipeline-model-parallel-size 1 + --context-parallel-size 1 + --expert-model-parallel-size 1 + --expert-tensor-parallel-size 1 + + --use-dynamic-batch-size + --max-tokens-per-gpu 9216 +) + +GRPO_ARGS=( + --advantage-estimator grpo + # --use-kl-loss # if use kl loss, should use --ref-load + --kl-loss-coef 0.00 + --kl-loss-type low_var_kl + --kl-coef 0.00 + --entropy-coef 0.00 + --eps-clip 0.2 + --eps-clip-high 0.28 +) + +OPTIMIZER_ARGS=( + --optimizer adam + # --lr 1e-6 + --lr 1e-5 # Higher LR often works better for LoRA + --lr-decay-style constant + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.98 +) + +# WANDB_ARGS=( +# --use-wandb +# --wandb-host https://wandb.ai/ +# --wandb-team glm-zero +# --wandb-project miles-dev +# --wandb-group qwen2.5-0.5B-gsm8k-deterministic +# ) + + +WANDB_ARGS=( + --use-wandb + --wandb-host https://wandb.ai/ + --wandb-team miles-lora + --wandb-project miles-lora-megatron + --wandb-group qwen2.5-0.5B-gsm8k-test +) + + +SGLANG_ARGS=( + --rollout-num-gpus-per-engine 1 + # --sglang-mem-fraction-static 0.7 + --sglang-mem-fraction-static 0.4 + + --sglang-enable-deterministic-inference + --sglang-attention-backend flashinfer + + --deterministic-mode +) + +MISC_ARGS=( + # default dropout in megatron is 0.1 + --attention-dropout 0.0 + --hidden-dropout 0.0 + # should be good for model performance + --accumulate-allreduce-grads-in-fp32 + --attention-softmax-in-fp32 + # need to comment this when using model with MLA + --attention-backend flash +) + + +############################## +###########lora############### +############################## +######## Note: Need to set export CUDA_VISIBLE_DEVICES= , or it will fail and have cuda error +# export GPUS_PER_NODE=1 +# export GPUS_PER_NODE=2 +export GPUS_PER_NODE=4 +# export GPUS_PER_NODE=8 +############################## +############################## +############################## + +# launch the master node of ray in container +ray start --head --node-ip-address 127.0.0.1 --num-gpus $GPUS_PER_NODE --disable-usage-stats +# ray start --head --node-ip-address 127.0.0.1 --num-gpus 1 --disable-usage-stats + +ray job submit --address="http://127.0.0.1:8265" \ + --runtime-env-json='{ + "env_vars": { + "PYTHONPATH": "/root/Megatron-LM", + "CUDA_DEVICE_MAX_CONNECTIONS": "1", + "NCCL_ALGO": "Ring", + "NVTE_ALLOW_NONDETERMINISTIC_ALGO": "0", + "CUBLAS_WORKSPACE_CONFIG": ":4096:8" + } + }' \ + -- python3 train.py \ + --actor-num-nodes 1 \ + --actor-num-gpus-per-node $GPUS_PER_NODE \ + --colocate \ + --calculate-per-token-loss \ + --use-miles-router \ + ${MODEL_ARGS[@]} \ + ${CKPT_ARGS[@]} \ + ${LORA_ARGS[@]} \ + ${OPTIMIZER_ARGS[@]} \ + ${GRPO_ARGS[@]} \ + ${WANDB_ARGS[@]} \ + ${PERF_ARGS[@]} \ + ${EVAL_ARGS[@]} \ + ${SGLANG_ARGS[@]} \ + ${MISC_ARGS[@]} \ + ${ROLLOUT_ARGS[@]} \ + ${LORA_ARGS[@]} + + +# colocate : update from tesnor +# disaggrate : update from distributed \ No newline at end of file diff --git a/examples/reproducibility/run-qwen3-4B-lora.sh b/examples/reproducibility/run-qwen3-4B-lora.sh new file mode 100644 index 000000000..c099ef603 --- /dev/null +++ b/examples/reproducibility/run-qwen3-4B-lora.sh @@ -0,0 +1,207 @@ +#!/bin/bash + +# Example launcher that reuses the Qwen3-4B recipe but delegates evaluation to an +# external Nemo Skills server via the eval_delegate_rollout wrapper. + +# Clean up any stale processes from a previous run. +pkill -9 sglang +sleep 3 +ray stop --force +pkill -9 ray +pkill -9 python +sleep 3 +pkill -9 ray +pkill -9 python + +set -ex + +SKILLS_OPENAI_MODEL_NAME=${SKILLS_OPENAI_MODEL_NAME:-"miles-openai-model"} + + +export PYTHONBUFFERED=16 + +NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l) +if [ "$NVLINK_COUNT" -gt 0 ]; then + HAS_NVLINK=1 +else + HAS_NVLINK=0 +fi +echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" +REPO_ROOT="$(cd -- "${SCRIPT_DIR}/../../.." &>/dev/null && pwd)" +source "${REPO_ROOT}/miles/scripts/models/qwen3-4B.sh" + +# Store eval/delegate settings in a YAML config similar to examples/eval_multi_task. +# EVAL_CONFIG_PATH=${SKILLS_EVAL_CONFIG_PATH:-"${REPO_ROOT}/examples/eval/scripts/multi_tasks.yaml"} +EVAL_CONFIG_PATH=${SKILLS_EVAL_CONFIG_PATH:-"${REPO_ROOT}/miles/examples/eval/scripts/multi_tasks.yaml"} + + +CKPT_ARGS=( + # --hf-checkpoint /root/Qwen3-4B + --hf-checkpoint /root/models/Qwen3-4B + # --ref-load /root/Qwen3-4B_torch_dist + # --load /root/Qwen3-4B_miles/ + # --save /root/Qwen3-4B_miles/ + # --save-interval 20 +) + + +LORA_ARGS=( + --lora-rank 32 # LoRA rank (typical values: 8, 16, 32, 64) + --lora-alpha 32 # LoRA alpha (usually 2x rank) + --lora-dropout 0.0 # LoRA dropout (0.0 for RL training) + --target-modules "all-linear" + --megatron-to-hf-mode bridge + #### debug + --no-offload-train + --no-offload-rollout +) + + + +ROLLOUT_ARGS=( + # --prompt-data /root/dapo-math-17k/dapo-math-17k.jsonl + --prompt-data /root/datasets/dapo-math-17k/dapo-math-17k.jsonl + --input-key prompt + --label-key label + --apply-chat-template + --rollout-shuffle + --rm-type deepscaler + --num-rollout 3000 + # --rollout-batch-size 32 + --rollout-batch-size 16 + --n-samples-per-prompt 8 + # --rollout-max-response-len 8192 + --rollout-max-response-len 2048 + --rollout-temperature 1 + --over-sampling-batch-size 64 + + --dynamic-sampling-filter-path miles.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std + # --global-batch-size 256 + --global-batch-size 128 + --balance-data +) + +# EVAL_ARGS=( +# --eval-interval 5 +# --eval-config "${EVAL_CONFIG_PATH}" +# --eval-function-path examples.eval.eval_delegate_rollout.generate_rollout +# ) + +EVAL_ARGS=( + --eval-interval 5 + --eval-prompt-data aime24 /root/datasets/aime-2024/aime-2024.jsonl + --n-samples-per-eval-prompt 2 + --eval-max-response-len 16384 + --eval-top-k 1 +) + +PERF_ARGS=( + --tensor-model-parallel-size 2 + --sequence-parallel + --pipeline-model-parallel-size 1 + --context-parallel-size 1 + --expert-model-parallel-size 1 + --expert-tensor-parallel-size 1 + + --recompute-granularity full + --recompute-method uniform + --recompute-num-layers 1 + + --use-dynamic-batch-size + --max-tokens-per-gpu 9216 +) + +GRPO_ARGS=( + --advantage-estimator grpo + # --use-kl-loss + --kl-loss-coef 0.00 + --kl-loss-type low_var_kl + --entropy-coef 0.00 + --eps-clip 0.2 + --eps-clip-high 0.28 +) + +OPTIMIZER_ARGS=( + --optimizer adam + # --lr 1e-6 + --lr 1e-5 + --lr-decay-style constant + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.98 +) + +WANDB_ARGS=( + --use-wandb + --wandb-host https://wandb.ai/ + --wandb-team miles-lora + --wandb-project miles-lora-megatron + --wandb-group qwen3-4B-test +) + +SGLANG_ARGS=( + --rollout-num-gpus-per-engine 1 + # --sglang-mem-fraction-static 0.7 + --sglang-mem-fraction-static 0.4 + + --sglang-enable-deterministic-inference + --sglang-attention-backend flashinfer + + --deterministic-mode +) + +MISC_ARGS=( + --attention-dropout 0.0 + --hidden-dropout 0.0 + --accumulate-allreduce-grads-in-fp32 + --attention-softmax-in-fp32 + --attention-backend flash +) + +export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} +# export CUDA_VISIBLE_DEVICES=0,1 +# Set Up Your GPUs for Training + +# export GPUS_PER_NODE=2 #default +export GPUS_PER_NODE=4 + +ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus $GPUS_PER_NODE --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 + +RUNTIME_ENV_JSON="{ + \"env_vars\": { + \"PYTHONPATH\": \"/root/Megatron-LM/\", + \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\" + } +}" + + +# ray job submit --address="http://127.0.0.1:8265" \ + # --runtime-env-json="${RUNTIME_ENV_JSON}" \ + +ray job submit --address="http://127.0.0.1:8265" \ + --runtime-env-json='{ + "env_vars": { + "PYTHONPATH": "/root/Megatron-LM", + "CUDA_DEVICE_MAX_CONNECTIONS": "1", + "NCCL_ALGO": "Ring", + "NVTE_ALLOW_NONDETERMINISTIC_ALGO": "0", + "CUBLAS_WORKSPACE_CONFIG": ":4096:8" + } + }' \ + -- python3 train.py \ + --actor-num-nodes 1 \ + --actor-num-gpus-per-node $GPUS_PER_NODE \ + --colocate \ + ${MODEL_ARGS[@]} \ + ${CKPT_ARGS[@]} \ + ${ROLLOUT_ARGS[@]} \ + ${OPTIMIZER_ARGS[@]} \ + ${GRPO_ARGS[@]} \ + ${WANDB_ARGS[@]} \ + ${PERF_ARGS[@]} \ + ${EVAL_ARGS[@]} \ + ${SGLANG_ARGS[@]} \ + ${MISC_ARGS[@]} \ + ${LORA_ARGS[@]} diff --git a/miles/backends/megatron_utils/actor.py b/miles/backends/megatron_utils/actor.py index 7cc7f2619..4e00a78f9 100644 --- a/miles/backends/megatron_utils/actor.py +++ b/miles/backends/megatron_utils/actor.py @@ -37,7 +37,19 @@ from .update_weight.common import named_params_and_buffers from .update_weight.update_weight_from_distributed import UpdateWeightFromDistributed from .update_weight.update_weight_from_tensor import UpdateWeightFromTensor - +############################## +###########lora############### +############################## +from .lora_utils import ( + is_lora_enabled, + is_lora_model, + # apply_lora_to_megatron_model, + # get_lora_weights_and_config, + freeze_base_model, +) +############################## +############################## +############################## logging.getLogger("megatron").setLevel(logging.WARNING) logger = logging.getLogger(__name__) @@ -74,8 +86,10 @@ def init( } dist.barrier(group=get_gloo_group()) + if args.offload_train: if (x := args.train_memory_margin_bytes) > 0: + # --train-memory-margin-bytes can tune this logger.info(f"Set torch_memory_saver.memory_margin_bytes to {x}") torch_memory_saver.memory_margin_bytes = x @@ -91,7 +105,7 @@ def init( (self.model, self.optimizer, self.opt_param_scheduler, loaded_rollout_id) = initialize_model_and_optimizer( args, role ) - + if role == "critic": if self.args.offload_train: self.sleep() @@ -110,10 +124,24 @@ def init( ) self._active_model_tag: str | None = "actor" self.weights_backuper.backup("actor") + #actor already include lora now - You can use self.weights_backuper.get("actor") to check + + # ############### + # print("=======") + # actor_weights = self.weights_backuper.get("actor") + + # for name, weight in actor_weights.items(): + # print(f"{name}: shape={weight.shape}, sum={weight.float().sum().item()}, requires_grad: {weight.requires_grad}") + # actor_weights = self.weights_backuper.get("actor") + # print("=======") + # exit() + # and then it will update_weight (sync) to sglang - so it does not need requires_grad. + # ############### + if with_ref: self.load_other_checkpoint("ref", args.ref_load) - + if self.args.keep_old_actor: # Load old_actor checkpoint self.load_other_checkpoint("old_actor", args.load) @@ -131,6 +159,13 @@ def init( weights_getter=lambda: self.weights_backuper.get("actor"), model_name=type(self.hf_config).__name__.lower() if self.args.model_name is None else self.args.model_name, quantization_config=getattr(self.hf_config, "quantization_config", None), + ############################## + ###########lora############### + ############################## + is_lora=is_lora_enabled(args), + ############################## + ############################## + ############################## ) # empty cache after initialization @@ -248,6 +283,18 @@ def _switch_model(self, target_tag: str) -> None: self.weights_backuper.restore(target_tag) self._active_model_tag = target_tag + ############################## + ###########lora############### + ############################## + # # Restore requires_grad after weight restoration + # # For LoRA training: only adapter params should be trainable, base model frozen + # if is_lora_enabled(self.args): + # freeze_base_model(self.model) + # # Note: ref model uses forward_only (@torch.no_grad), so requires_grad doesn't matter + ############################## + ############################## + ############################## + def fill_routing_replay(self, data_iterator, num_microbatches, rollout_data): if "rollout_routed_experts" not in rollout_data: raise ValueError( @@ -545,7 +592,7 @@ def update_weights(self) -> None: if self.args.offload_train: destroy_process_groups() - + def load_other_checkpoint(self, model_tag: str, path: str) -> None: old_args = self.args.load, self.args.no_load_optim, self.args.no_load_rng, self.args.finetune self.args.load = path @@ -594,3 +641,75 @@ def connect_actor_critic( rank=0 if self.role == "actor" else 1, group_name=group_name, ) + + + + + ############################## + ###########lora############### + ############################## + ########### + def check_lora_status(self): + """check LoRA module""" + from megatron.bridge.peft.lora_layers import LoRALinear + from megatron.bridge.peft.adapter_wrapper import AdapterWrapper + + results = { + "lora_modules": [], + "trainable_params": 0, + "frozen_params": 0, + "total_params": 0, + } + + model = self.model[0] if isinstance(self.model, list) else self.model + + # travel all module + for name, module in model.named_modules(): + if isinstance(module, (LoRALinear, AdapterWrapper)): + results["lora_modules"].append(name) + + # check adapter weight + if hasattr(module, 'adapter'): + adapter = module.adapter + if hasattr(adapter, 'linear_in'): + lora_a_shape = tuple(adapter.linear_in.weight.shape) + lora_b_shape = tuple(adapter.linear_out.weight.shape) + print(adapter.linear_in.weight.shape) + # print(adapter.linear_in.weight) + # print(adapter.linear_in.weight.detach().cpu().clone()) + print(adapter.linear_out.weight.shape) + # print(adapter.linear_out.weight) + # print(adapter.linear_out.weight.detach().cpu().clone()) + print(f" {name}: lora_A={lora_a_shape}, lora_B={lora_b_shape}, requires_grad={adapter.linear_in.requires_grad}") + else: + for param_name, param in module.named_parameters(recurse=False): + # print(f" {name}.{param_name}: shape={tuple(param.shape)}", param) + # print(f" {name}.{param_name}: shape={tuple(param.shape)}", param.float().sum().item()) + print(f" {name}.{param_name}: shape={tuple(param.shape)}") + + + # account parms + for p in model.parameters(): + results["total_params"] += p.numel() + if p.requires_grad: + results["trainable_params"] += p.numel() + else: + results["frozen_params"] += p.numel() + + print(f"LoRA Check Results:") + print(f" - LoRA modules found: {len(results['lora_modules'])}") + print(f" - Trainable params: {results['trainable_params']:,}") + print(f" - Frozen params: {results['frozen_params']:,}") + print(f" - Trainable ratio: {results['trainable_params']/results['total_params']*100:.2f}%") + + return results + ############################## + ############################## + ############################## + + # │ print(tensor.shape) ✓ │ + # │ └── only read metadata(do not need GPU memory) │ + # │ │ + # │ print(tensor) ✗ │ + # │ └── PyTorch __repr__ will read the actual tensor values │ + # │ └── When GPU try to read the data, it has cudaErrorIllegalAddress \ No newline at end of file diff --git a/miles/backends/megatron_utils/arguments.py b/miles/backends/megatron_utils/arguments.py index aea72ceb8..8e00dc128 100644 --- a/miles/backends/megatron_utils/arguments.py +++ b/miles/backends/megatron_utils/arguments.py @@ -11,6 +11,7 @@ def set_default_megatron_args(args): # always use zero optimizer args.use_distributed_optimizer = True + # TODO: maybe change this after megatron has good fp8 support args.bf16 = not args.fp16 # placeholders diff --git a/miles/backends/megatron_utils/checkpoint.py b/miles/backends/megatron_utils/checkpoint.py index 35d712910..24fb2c6cc 100644 --- a/miles/backends/megatron_utils/checkpoint.py +++ b/miles/backends/megatron_utils/checkpoint.py @@ -10,9 +10,25 @@ from miles.utils import megatron_bridge_utils +############################## +###########lora############### +############################## +from miles.backends.megatron_utils.lora_utils import is_lora_model, save_lora_checkpoint, load_lora_checkpoint +############################## +############################## +############################## + logger = logging.getLogger(__name__) -__all__ = ["save_checkpoint"] +############################## +###########lora############### +############################## +# __all__ = ["save_checkpoint"] +__all__ = ["save_checkpoint", "save_checkpoint_with_lora", "load_checkpoint"] +############################## +############################## +############################## + def load_checkpoint(ddp_model, optimizer, opt_param_scheduler, checkpointing_context, skip_load_to_model_and_opt): @@ -24,6 +40,36 @@ def load_checkpoint(ddp_model, optimizer, opt_param_scheduler, checkpointing_con load_path ), f"{args.load=} does not exist or is an empty directory. Did you specify the wrong folder?" + ############################## + ###########lora############### + ############################## + # # Check for LoRA adapter first + # ## Not correct - need to check the saving format and name + # if is_lora_model(ddp_model): + # lora_path = Path(load_path) / "adapter" + # if lora_path.exists(): + # logger.info(f"Loading LoRA checkpoint from {lora_path}") + # iteration = load_lora_checkpoint(ddp_model, args, str(lora_path)) + # num_floating_point_operations_so_far = 0 + # return iteration, num_floating_point_operations_so_far + # else: + # logger.info(f"Not Found LoRA checkpoint from {lora_path}. Use the random initial weight.") + ############################## + ############################## + ############################## + + + ############################## + ###########lora############### + ############################## + # (to-do): yusheng- Also add lora weight loading in `_load_checkpoint_megatron` and `_load_checkpoint_hf` + # if no lora weight - random initalization + ############################## + ############################## + ############################## + ############################## + + if _is_megatron_checkpoint(load_path): return _load_checkpoint_megatron( ddp_model=ddp_model, @@ -40,6 +86,24 @@ def load_checkpoint(ddp_model, optimizer, opt_param_scheduler, checkpointing_con load_path=load_path, ) +############################## +###########lora############### +############################## +def save_checkpoint_with_lora(iteration, model, optimizer, opt_param_scheduler): + """Extended save that handles LoRA adapters separately.""" + args = get_args() + + if is_lora_model(model): + # Save only LoRA adapter weights + save_dir = Path(args.save) / f"iter_{iteration:07d}" / "adapter" + logger.info(f"Saving LoRA checkpoint to {save_dir}") + save_lora_checkpoint(model, args, str(save_dir)) + else: + # Use standard Megatron save + save_checkpoint(iteration, model, optimizer, opt_param_scheduler) +############################## +############################## +############################## def _is_megatron_checkpoint(path: str | Path) -> bool: return (Path(path) / "latest_checkpointed_iteration.txt").is_file() or bool( diff --git a/miles/backends/megatron_utils/lora_utils.py b/miles/backends/megatron_utils/lora_utils.py new file mode 100644 index 000000000..eb130a50c --- /dev/null +++ b/miles/backends/megatron_utils/lora_utils.py @@ -0,0 +1,659 @@ +############################## +###########lora############### +############################## +# to-do(yusheng): this should be moved to utils or split into hf_weight_iterator_bridge.py + +"""LoRA utilities for Megatron backend using Megatron-Bridge PEFT integration.""" + +import logging +import os +from argparse import Namespace +from collections.abc import Sequence +from pathlib import Path +from typing import Any + +import torch +import torch.distributed as dist +from megatron.core import mpu + +logger = logging.getLogger(__name__) + +LORA_ADAPTER_NAME = "miles_lora" +LORA_SUBDIR = "tmp_lora" + + +def is_lora_enabled(args: Namespace) -> bool: + """Check if LoRA is enabled.""" + return args.lora_rank > 0 or args.lora_adapter_path is not None + + +# def apply_lora_to_megatron_model( +# model: Sequence[torch.nn.Module], +# args: Namespace, +# ) -> Sequence[torch.nn.Module]: +# """Apply LoRA to Megatron model using Megatron-Bridge PEFT integration. + +# This uses the Megatron-Bridge's PEFT support from: +# https://github.com/NVIDIA-NeMo/Megatron-Bridge/tree/main/src/megatron/bridge/peft + +# Note: in this version implementation, we use this Megatron-Bridge branch: https://github.com/yushengsu-thu/Megatron-Bridge/tree/merged-megatron-0.16.0rc0 + +# Args: +# model: Megatron model (DDP wrapped) +# args: Training arguments with LoRA config + +# Returns: +# LoRA-wrapped model +# """ +# # from megatron.bridge.peft import apply_lora_adapter, LoraConfig +# from megatron.bridge.peft.lora import LoRA + +# if args.lora_adapter_path: +# # TODO: Loading existing LoRA adapter needs separate implementation +# # Megatron-Bridge may have different API for loading +# # Refer to this one: https://github.com/volcengine/verl/pull/4063/files#diff-10d5abfbdb508c9478018ad08f295686a960701639fc4e3f3c24a4bdc2f0b711 +# raise NotImplementedError("Loading existing LoRA adapter is not yet implemented") +# else: +# # Determine lora_dtype from args +# if hasattr(args, 'bf16') and args.bf16: +# lora_dtype = torch.bfloat16 +# elif hasattr(args, 'fp16') and args.fp16: +# lora_dtype = torch.float16 +# else: +# lora_dtype = None # Will use model's dtype + +# # Get exclude_modules as list +# exclude_modules = [] +# if hasattr(args, 'exclude_modules') and args.exclude_modules: +# if isinstance(args.exclude_modules, str): +# exclude_modules = [m.strip() for m in args.exclude_modules.split(",")] +# else: +# exclude_modules = list(args.exclude_modules) + +# # Create new LoRA adapter using Megatron-Bridge LoRA dataclass +# # There are different lora_type, I just use the classic one (speed and acc might not the optimal) +# # https://github.com/volcengine/verl/pull/4063/files#diff-10d5abfbdb508c9478018ad08f295686a960701639fc4e3f3c24a4bdc2f0b711 +# lora = LoRA( +# target_modules=args.target_modules, # e.g., ["linear_qkv", "linear_proj", "linear_fc1", "linear_fc2"] +# exclude_modules=exclude_modules, # Modules to exclude from LoRA +# dim=args.lora_rank, # LoRA rank (called 'dim' in Megatron-Bridge) +# alpha=args.lora_alpha, # LoRA alpha scaling factor +# dropout=args.lora_dropout, # LoRA dropout rate +# dropout_position=getattr(args, 'lora_dropout_position', 'pre'), # 'pre' or 'post' +# lora_A_init_method=getattr(args, 'lora_A_init_method', 'xavier'), # Initialization for LoRA A matrix +# lora_B_init_method=getattr(args, 'lora_B_init_method', 'zero'), # Initialization for LoRA B matrix +# a2a_experimental=getattr(args, 'lora_a2a_experimental', False), # Experimental All-to-All communication +# lora_dtype=lora_dtype, # Parameter data type for LoRA weights +# ) +# logger.info(f"Applying LoRA: rank={args.lora_rank}, alpha={args.lora_alpha}, " +# f"dropout={args.lora_dropout}, target_modules={args.target_modules}, " +# f"exclude_modules={exclude_modules}, lora_dtype={lora_dtype}") + +# # Apply LoRA to each model chunk +# # The LoRA class is callable - calling it applies the transformation +# for model_chunk in model: +# # lora(model_chunk.module, training=True) applies LoRA and freezes base model +# lora(model_chunk.module, training=True) + +# # Print trainable parameters info +# _print_trainable_parameters(model) + +# return model + + + +# def print_adapter_info(model): +# """Print information about adapter parameters in the model.""" +# adapter_params, total_params, percentage = count_adapter_parameters(model) + +# print(f"\n{'=' * 60}") +# print("PEFT Adapter Information:") +# print(f" Total parameters: {total_params:,}") +# print(f" Adapter parameters: {adapter_params:,}") +# print(f" Trainable percentage: {percentage:.2f}%") +# print(f"{'=' * 60}\n") + + +def _print_trainable_parameters(model: Sequence[torch.nn.Module]) -> None: + """Print trainable parameters statistics.""" + total_params = 0 + trainable_params = 0 + trainable_param_names = [] + + for model_chunk in model: + for name, param in model_chunk.named_parameters(): + total_params += param.numel() + if param.requires_grad: + trainable_params += param.numel() + trainable_param_names.append((name, param.numel())) + + if mpu.get_data_parallel_rank() == 0 and mpu.get_tensor_model_parallel_rank() == 0: + logger.info( + f"LoRA trainable params: {trainable_params:,} / {total_params:,} " + f"({100 * trainable_params / total_params:.2f}%)" + ) + # if trainable_param_names: + # logger.info(f"\nTrainable layers ({len(trainable_param_names)} parameters):") + # for name, num_params in trainable_param_names: + # logger.info(f" ✓ {name}: {num_params:,} params") + # else: + # logger.warning("⚠️ NO TRAINABLE PARAMETERS! LoRA may not be applied correctly.") + + +def is_lora_model(model: Sequence[torch.nn.Module]) -> bool: + """Check if model has LoRA layers applied.""" + + for model_chunk in model: + if hasattr(model_chunk.module, "peft_config"): + return True + # Check for LoRA layers in parameters + for name, _ in model_chunk.named_parameters(): + if "lora_" in name or "adapter" in name: + return True + return False + + +# def get_lora_state_dict( +# model: Sequence[torch.nn.Module], +# args: Namespace, +# ) -> dict[str, torch.Tensor]: +# """Extract LoRA weights from model. + +# Returns only the LoRA adapter weights, not the base model weights. +# """ +# from miles.backends.megatron_utils.update_weight.common import named_params_and_buffers + +# lora_state_dict = {} + +# for name, param in named_params_and_buffers(args, model, convert_to_global_name=True): +# if "lora_" in name or ".adapter." in name: +# lora_state_dict[name] = param + +# return lora_state_dict + + +# def get_lora_weights_and_config( +# model: Sequence[torch.nn.Module], +# args: Namespace, +# ) -> tuple[dict[str, torch.Tensor], dict[str, Any]]: +# """Extract LoRA weights and config for tensor-based sync. + +# This is used for efficient weight sync to SGLang engines. +# """ +# lora_state_dict = get_lora_state_dict(model, args) + +# # Convert Megatron names to HF-compatible names for SGLang +# hf_state_dict = {} +# for name, param in lora_state_dict.items(): +# # Convert megatron naming to HF naming +# hf_name = _convert_megatron_to_hf_lora_name(name) +# hf_state_dict[hf_name] = param + +# config_dict = { +# "peft_type": "LORA", +# "r": args.lora_rank, +# "lora_alpha": args.lora_alpha, +# "target_modules": list(args.target_modules), +# "bias": "none", +# } + +# if mpu.get_data_parallel_rank() == 0: +# logger.info(f"Extracted {len(hf_state_dict)} LoRA weight tensors for sync") + +# return hf_state_dict, config_dict + + +# def _convert_megatron_to_hf_lora_name(name: str) -> str: +# """Convert Megatron LoRA parameter name to HuggingFace format. + +# Megatron: module.module.decoder.layers.0.self_attention.linear_qkv.lora_A.weight +# HF: model.layers.0.self_attn.q_proj.lora_A.weight +# """ +# # This mapping should match your specific model architecture +# replacements = [ +# ("module.module.decoder.layers.", "model.layers."), +# (".self_attention.linear_qkv.lora_", ".self_attn.q_proj.lora_"), +# (".self_attention.linear_proj.lora_", ".self_attn.o_proj.lora_"), +# (".mlp.linear_fc1.lora_", ".mlp.gate_proj.lora_"), +# (".mlp.linear_fc2.lora_", ".mlp.down_proj.lora_"), +# ] + +# result = name +# for old, new in replacements: +# result = result.replace(old, new) + +# return result + + +# def save_lora_checkpoint( +# model: Sequence[torch.nn.Module], +# args: Namespace, +# save_dir: str, +# ) -> str: +# """Save LoRA adapter checkpoint to disk. + +# Args: +# model: Megatron model with LoRA +# args: Training arguments +# save_dir: Directory to save checkpoint + +# Returns: +# Path to saved checkpoint +# """ +# from megatron.bridge.peft import save_lora_adapter + +# save_path = Path(save_dir) +# save_path.mkdir(parents=True, exist_ok=True) + +# # Use Megatron-Bridge's save function +# if mpu.get_data_parallel_rank() == 0 and mpu.get_tensor_model_parallel_rank() == 0: +# for model_chunk in model: +# save_lora_adapter(model_chunk.module, str(save_path)) +# os.sync() +# logger.info(f"Saved LoRA adapter to {save_path}") + +# dist.barrier() +# return str(save_path) + + +## to-do (yusheng): need to confirm usage +def save_lora_checkpoint( + model: Sequence[torch.nn.Module], + args: Namespace, + save_dir: str, +) -> str: + """Save LoRA adapter checkpoint to disk in HuggingFace PEFT format. + + Since Megatron-Bridge doesn't have a save_lora_adapter function, + we manually extract adapter weights and convert to PEFT format. + """ + import json + from pathlib import Path + from megatron.bridge.peft.lora_layers import LoRALinear, LinearAdapter, TELinearAdapter + from megatron.bridge.peft.adapter_wrapper import AdapterWrapper + + save_path = Path(save_dir) + + # Only rank 0 saves (other ranks just return) + if not (mpu.get_data_parallel_rank() == 0 and mpu.get_tensor_model_parallel_rank() == 0): + return str(save_path) + + save_path.mkdir(parents=True, exist_ok=True) + + lora_state_dict = {} + + for model_chunk in model: + for name, module in model_chunk.named_modules(): + linear_in = None + linear_out = None + + # LoRALinear (wraps base layer with adapter) + if isinstance(module, AdapterWrapper) and hasattr(module, 'adapter'): + adapter = module.adapter + if hasattr(adapter, 'linear_in') and hasattr(adapter, 'linear_out'): + linear_in = adapter.linear_in + linear_out = adapter.linear_out + # LinearAdapter/TELinearAdapter (extends nn.Linear with lora) + elif isinstance(module, (LinearAdapter, TELinearAdapter)): + if hasattr(module, 'linear_in') and hasattr(module, 'linear_out'): + linear_in = module.linear_in + linear_out = module.linear_out + + if linear_in is not None and linear_out is not None: + # Convert Megatron naming to HF PEFT naming + base_name = name.replace("module.module.", "base_model.model.") + base_name = base_name.replace(".decoder.layers.", ".model.layers.") + base_name = base_name.replace(".self_attention.linear_qkv", ".self_attn.q_proj") + base_name = base_name.replace(".self_attention.linear_proj", ".self_attn.o_proj") + base_name = base_name.replace(".mlp.linear_fc1", ".mlp.gate_proj") + base_name = base_name.replace(".mlp.linear_fc2", ".mlp.down_proj") + + lora_state_dict[f"{base_name}.lora_A.weight"] = linear_in.weight.data.cpu() + lora_state_dict[f"{base_name}.lora_B.weight"] = linear_out.weight.data.cpu() + + # Save weights + torch.save(lora_state_dict, save_path / "adapter_model.bin") + + # Save PEFT config + config = { + "peft_type": "LORA", + "r": args.lora_rank, + "lora_alpha": args.lora_alpha, + "target_modules": list(args.target_modules) if args.target_modules else ["q_proj", "o_proj", "gate_proj", "down_proj"], + "bias": "none", + "task_type": "CAUSAL_LM", + } + with open(save_path / "adapter_config.json", "w") as f: + json.dump(config, f, indent=2) + + os.sync() + logger.info(f"Saved LoRA adapter to {save_path} with {len(lora_state_dict)} tensors") + + return str(save_path) + + + + +def load_lora_checkpoint( + model: Sequence[torch.nn.Module], + args: Namespace, + load_dir: str, +) -> None: + """Load LoRA adapter checkpoint from disk. + + Args: + model: Megatron model + args: Training arguments + load_dir: Directory containing checkpoint + """ + from megatron.bridge.peft import load_lora_adapter + + load_path = Path(load_dir) + if not load_path.exists(): + raise FileNotFoundError(f"LoRA checkpoint not found at {load_path}") + + logger.info(f"Loading LoRA adapter from {load_path}") + + for model_chunk in model: + load_lora_adapter(model_chunk.module, str(load_path)) + + dist.barrier() + + +# ## to-do (yusheng): need to confirm usage +# def load_lora_checkpoint( +# model: Sequence[torch.nn.Module], +# args: Namespace, +# load_dir: str, +# ) -> None: +# """Load LoRA adapter checkpoint from disk. + +# Note: This loads PEFT-format checkpoints into Megatron-Bridge LoRA layers. +# The checkpoint must be in HuggingFace PEFT format (adapter_model.bin + adapter_config.json). +# """ +# import json +# from pathlib import Path +# from megatron.bridge.peft.lora_layers import LoRALinear, LinearAdapter, TELinearAdapter +# from megatron.bridge.peft.adapter_wrapper import AdapterWrapper + +# load_path = Path(load_dir) +# if not load_path.exists(): +# raise FileNotFoundError(f"LoRA checkpoint not found at {load_path}") + +# # Load state dict +# state_dict_path = load_path / "adapter_model.bin" +# if not state_dict_path.exists(): +# raise FileNotFoundError(f"adapter_model.bin not found in {load_path}") + +# lora_state_dict = torch.load(state_dict_path, map_location="cpu") + +# logger.info(f"Loading LoRA adapter from {load_path} with {len(lora_state_dict)} tensors") + +# # Build reverse name mapping (HF -> Megatron) +# def hf_to_megatron_name(hf_name: str) -> str: +# name = hf_name.replace("base_model.model.", "module.module.") +# name = name.replace(".model.layers.", ".decoder.layers.") +# name = name.replace(".self_attn.q_proj", ".self_attention.linear_qkv") +# name = name.replace(".self_attn.o_proj", ".self_attention.linear_proj") +# name = name.replace(".mlp.gate_proj", ".mlp.linear_fc1") +# name = name.replace(".mlp.down_proj", ".mlp.linear_fc2") +# return name + +# # Load weights into model +# for model_chunk in model: +# for name, module in model_chunk.named_modules(): +# linear_in = None +# linear_out = None + +# if isinstance(module, AdapterWrapper) and hasattr(module, 'adapter'): +# adapter = module.adapter +# if hasattr(adapter, 'linear_in') and hasattr(adapter, 'linear_out'): +# linear_in = adapter.linear_in +# linear_out = adapter.linear_out +# elif isinstance(module, (LinearAdapter, TELinearAdapter)): +# if hasattr(module, 'linear_in') and hasattr(module, 'linear_out'): +# linear_in = module.linear_in +# linear_out = module.linear_out + +# if linear_in is not None and linear_out is not None: +# # Find corresponding HF name +# base_name = name.replace("module.module.", "base_model.model.") +# base_name = base_name.replace(".decoder.layers.", ".model.layers.") +# base_name = base_name.replace(".self_attention.linear_qkv", ".self_attn.q_proj") +# base_name = base_name.replace(".self_attention.linear_proj", ".self_attn.o_proj") +# base_name = base_name.replace(".mlp.linear_fc1", ".mlp.gate_proj") +# base_name = base_name.replace(".mlp.linear_fc2", ".mlp.down_proj") + +# lora_a_key = f"{base_name}.lora_A.weight" +# lora_b_key = f"{base_name}.lora_B.weight" + +# if lora_a_key in lora_state_dict and lora_b_key in lora_state_dict: +# linear_in.weight.data.copy_(lora_state_dict[lora_a_key].to(linear_in.weight.device)) +# linear_out.weight.data.copy_(lora_state_dict[lora_b_key].to(linear_out.weight.device)) + +# dist.barrier() +# logger.info(f"Successfully loaded LoRA adapter from {load_path}") + + + +## Check this functions - megatron-bridge might have the same function +####!!!! (to-do) yusheng: need to based on different Lora to provide the different mapping +from typing import Union, Type, TYPE_CHECKING +if TYPE_CHECKING: + from megatron.bridge.peft.lora import LoRA + from megatron.bridge.peft.canonical_lora import CanonicalLoRA + +def convert_target_modules_to_megatron( + hf_modules: list[str], + lora_type: Union[Type, object, None] = None +) -> list[str]: + """Convert HuggingFace LoRA target module names to Megatron format. + + HF: q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj + + Megatron (Standard LoRA): linear_qkv, linear_proj, linear_fc1, linear_fc2 + Megatron (CanonicalLoRA): linear_q, linear_k, linear_v, linear_proj, + linear_fc1_up, linear_fc1_gate, linear_fc2 + + Special cases: + - "all", "all-linear", "all_linear" -> returns all standard Megatron linear modules + + Args: + hf_modules: List of HuggingFace module names or single string + lora_type: LoRA class or instance (LoRA, CanonicalLoRA, etc.) + If None, defaults to CanonicalLoRA format + + Returns: + List of Megatron module names + + If input is already in Megatron format, returns as-is without conversion. + """ + # Get the class name whether lora_type is a class or an instance + if lora_type is not None: + if isinstance(lora_type, type): + # It's a class + class_name = lora_type.__name__ + else: + # It's an instance + class_name = type(lora_type).__name__ + + logger.info(f"Converting target modules for {class_name}") + else: + # Default to CanonicalLoRA if not specified + class_name = "CanonicalLoRA" + logger.info(f"Converting target modules (defaulting to CanonicalLoRA)") + + # Handle special cases for "all" variants + if isinstance(hf_modules, str): + if hf_modules in ["all", "all-linear", "all_linear"]: + if class_name == "CanonicalLoRA": + return ["linear_q", "linear_k", "linear_v", "linear_proj", + "linear_fc1_up", "linear_fc1_gate", "linear_fc2"] + else: # Standard LoRA + return ["linear_qkv", "linear_proj", "linear_fc1", "linear_fc2"] + # Convert single string to list + hf_modules = [hf_modules] + elif isinstance(hf_modules, list) and len(hf_modules) == 1: + if hf_modules[0] in ["all", "all-linear", "all_linear"]: + if class_name == "CanonicalLoRA": + return ["linear_q", "linear_k", "linear_v", "linear_proj", + "linear_fc1_up", "linear_fc1_gate", "linear_fc2"] + else: # Standard LoRA + return ["linear_qkv", "linear_proj", "linear_fc1", "linear_fc2"] + + # Define module name sets based on LoRA type + if class_name == "CanonicalLoRA": + megatron_modules_set = { + "linear_q", "linear_k", "linear_v", "linear_proj", + "linear_fc1_up", "linear_fc1_gate", "linear_fc2" + } + else: # Standard LoRA + megatron_modules_set = {"linear_qkv", "linear_proj", "linear_fc1", "linear_fc2"} + + hf_modules_set = {"q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"} + + # Check if all modules are already in Megatron format (or wildcards) + all_megatron_format = True + for module in hf_modules: + # Skip wildcard patterns (e.g., "*.layers.0.*.linear_qkv") + if "*" in module: + continue + # Check if it's a known HF module name + if module in hf_modules_set: + all_megatron_format = False + break + + # If already in Megatron format, return as-is + if all_megatron_format: + return hf_modules + + # Otherwise, perform conversion based on LoRA type + if class_name == "CanonicalLoRA": + # CanonicalLoRA: Split Q/K/V and up/gate + hf_to_megatron = { + "q_proj": "linear_q", + "k_proj": "linear_k", + "v_proj": "linear_v", + "o_proj": "linear_proj", + "gate_proj": "linear_fc1_gate", + "up_proj": "linear_fc1_up", + "down_proj": "linear_fc2", + } + elif class_name == "LoRA": + # Standard LoRA + # Standard LoRA: Merged Q/K/V and merged up/gate + hf_to_megatron = { + "q_proj": "linear_qkv", + "k_proj": "linear_qkv", + "v_proj": "linear_qkv", + "o_proj": "linear_proj", + "gate_proj": "linear_fc1", + "up_proj": "linear_fc1", + "down_proj": "linear_fc2", + } + else: + raise NotImplementedError(f"Unsupported LoRA class: {class_name}") + + megatron_modules = [] + for module in hf_modules: + if module in hf_to_megatron: + megatron_name = hf_to_megatron[module] + if megatron_name not in megatron_modules: + megatron_modules.append(megatron_name) + else: + # Keep as-is if not in mapping (might already be Megatron format or wildcard) + if module not in megatron_modules: + megatron_modules.append(module) + + return megatron_modules + + + +# def convert_target_modules_to_megatron(hf_modules: list[str]) -> list[str]: +# """Convert HuggingFace LoRA target module names to Megatron format. + +# HF: q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj +# Megatron: linear_qkv, linear_proj, linear_fc1, linear_fc2 + +# Special cases: +# - "all", "all-linear", "all_linear" -> returns all standard Megatron linear modules + +# If input is already in Megatron format, returns as-is without conversion. +# """ +# # Handle special cases for "all" variants +# if isinstance(hf_modules, str): +# if hf_modules in ["all", "all-linear", "all_linear"]: +# return ["linear_qkv", "linear_proj", "linear_fc1", "linear_fc2"] +# # Convert single string to list +# hf_modules = [hf_modules] +# elif isinstance(hf_modules, list) and len(hf_modules) == 1: +# if hf_modules[0] in ["all", "all-linear", "all_linear"]: +# return ["linear_qkv", "linear_proj", "linear_fc1", "linear_fc2"] + +# # Define Megatron and HF module name sets +# megatron_modules_set = {"linear_qkv", "linear_proj", "linear_fc1", "linear_fc2"} +# hf_modules_set = {"q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"} + +# # Check if all modules are already in Megatron format (or wildcards) +# all_megatron_format = True +# for module in hf_modules: +# # Skip wildcard patterns (e.g., "*.layers.0.*.linear_qkv") +# if "*" in module: +# continue +# # Check if it's a known HF module name +# if module in hf_modules_set: +# all_megatron_format = False +# break + +# # If already in Megatron format, return as-is +# if all_megatron_format: +# return hf_modules + +# # Otherwise, perform conversion +# hf_to_megatron = { +# "q_proj": "linear_qkv", +# "k_proj": "linear_qkv", +# "v_proj": "linear_qkv", +# "o_proj": "linear_proj", +# "gate_proj": "linear_fc1", +# "up_proj": "linear_fc1", +# "down_proj": "linear_fc2", +# } + +# megatron_modules = [] +# for module in hf_modules: +# if module in hf_to_megatron: +# megatron_name = hf_to_megatron[module] +# if megatron_name not in megatron_modules: +# megatron_modules.append(megatron_name) +# else: +# # Keep as-is if not in mapping (might already be Megatron format or wildcard) +# if module not in megatron_modules: +# megatron_modules.append(module) + +# return megatron_modules + + + + +def freeze_base_model(model: Sequence[torch.nn.Module]) -> None: + """Freeze base model parameters, only keep LoRA trainable.""" + for model_chunk in model: + for name, param in model_chunk.named_parameters(): + if "lora_" not in name and "adapter" not in name: + param.requires_grad = False + + +def get_trainable_params_for_optimizer( + model: Sequence[torch.nn.Module], +) -> list[torch.nn.Parameter]: + """Get only trainable parameters for optimizer (LoRA params only).""" + trainable_params = [] + for model_chunk in model: + for param in model_chunk.parameters(): + if param.requires_grad: + trainable_params.append(param) + return trainable_params +############################## +############################## +############################## \ No newline at end of file diff --git a/miles/backends/megatron_utils/megatron_to_hf/__init__.py b/miles/backends/megatron_utils/megatron_to_hf/__init__.py index ba5a286a3..bc1ba073b 100644 --- a/miles/backends/megatron_utils/megatron_to_hf/__init__.py +++ b/miles/backends/megatron_utils/megatron_to_hf/__init__.py @@ -83,3 +83,51 @@ def _convert_to_hf_core(args, model_name, name, param): else: converted_named_tensors.append((converted_name, converted_param)) return converted_named_tensors + +############################## +###########lora############### +############################## +### This might be model specific --> make it more general +def convert_lora_to_hf(args, model_name, name, param): + """ + Convert Megatron LoRA parameter to HuggingFace PEFT format. + + Megatron format: module.module.decoder.layers.0.self_attention.linear_qkv.adapter.linear_in.weight + HF PEFT format: base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight + """ + # Determine if this is lora_A (linear_in) or lora_B (linear_out) + if ".linear_in." in name or ".lora_A." in name: + lora_suffix = "lora_A.weight" + elif ".linear_out." in name or ".lora_B." in name: + lora_suffix = "lora_B.weight" + else: + # Fallback - return as is + return [(name, param)] + + # Convert Megatron naming to HF PEFT naming + hf_name = name + + # Remove Megatron wrapper prefixes + hf_name = hf_name.replace("module.module.", "base_model.model.") + + # Convert layer path + hf_name = hf_name.replace(".decoder.layers.", ".model.layers.") + + # Convert attention modules + hf_name = hf_name.replace(".self_attention.linear_qkv", ".self_attn.q_proj") + hf_name = hf_name.replace(".self_attention.linear_proj", ".self_attn.o_proj") + + # Convert MLP modules + hf_name = hf_name.replace(".mlp.linear_fc1", ".mlp.gate_proj") + hf_name = hf_name.replace(".mlp.linear_fc2", ".mlp.down_proj") + + # Replace adapter naming with lora naming + hf_name = hf_name.replace(".adapter.linear_in.weight", f".{lora_suffix}") + hf_name = hf_name.replace(".adapter.linear_out.weight", f".{lora_suffix}") + hf_name = hf_name.replace(".lora_A.weight", f".{lora_suffix}") + hf_name = hf_name.replace(".lora_B.weight", f".{lora_suffix}") + + return [(hf_name, param)] +############################## +############################## +############################## \ No newline at end of file diff --git a/miles/backends/megatron_utils/model.py b/miles/backends/megatron_utils/model.py index 780370453..910d6a617 100644 --- a/miles/backends/megatron_utils/model.py +++ b/miles/backends/megatron_utils/model.py @@ -25,7 +25,19 @@ from miles.utils import tracking_utils from miles.utils.memory_utils import clear_memory -from .checkpoint import load_checkpoint, save_checkpoint +############################## +###########lora############### +############################## +# from .checkpoint import load_checkpoint, save_checkpoint +from .checkpoint import load_checkpoint, save_checkpoint, save_checkpoint_with_lora +from .lora_utils import is_lora_model +# from miles.backends.megatron_utils.lora_utils import is_lora_enabled +# from miles.backends.megatron_utils.lora_utils import is_lora_enabled, apply_lora_to_megatron_model +from miles.backends.megatron_utils.lora_utils import is_lora_enabled, convert_target_modules_to_megatron +############################## +############################## +############################## + from .data import DataIterator, get_batch from .loss import loss_function from .model_provider import get_model_provider_func @@ -33,6 +45,73 @@ logger = logging.getLogger(__name__) +############################## +###########lora############### +############################## +from dataclasses import dataclass +@dataclass +class McoreModuleWrapperConfig: + """Configuration for Mcore module wrapper.""" + is_value_model: bool = False + share_embeddings_and_output_weights: bool = False + wrap_with_ddp: bool = True + use_distributed_optimizer: bool = True + + +def _ensure_model_list(model): + return model if isinstance(model, list) else [model] + +# Get from: verl/verl/models/mcore/bridge.py +def make_value_model(hidden_size, sequence_parallel): + """Creates a pre-wrap hook that replace the output layer with a value head. + + Args: + hidden_size (int): The hidden size of the model's transformer layers. + sequence_parallel (bool): Whether sequence parallelism is enabled. + + Returns: + A hook function that can be used as a `pre_wrap_hook` in Megatron-Bridge. + The hook itself takes the model as input and prepares it for value head activation. + """ + + from megatron.core import parallel_state + from .model_provider import LinearForLastLayer + + def hook(model): + model_post_process = [] + if ( + parallel_state.get_pipeline_model_parallel_world_size() > 1 + and parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None + ): + for i in range(parallel_state.get_virtual_pipeline_model_parallel_world_size()): + model_post_process.append(parallel_state.is_pipeline_last_stage(ignore_virtual=False, vp_stage=i)) + else: + model_post_process.append(parallel_state.is_pipeline_last_stage()) + + model_list = _ensure_model_list(model) + assert len(model_post_process) == len(model_list), "Model list length and post process list length must match." + + for index, model_chunk in enumerate(model_list): + if not model_post_process[index]: + continue + + model_chunk.output_layer = LinearForLastLayer( + input_size=hidden_size, + output_size=1, + sequence_parallel=sequence_parallel, + ) + + return hook + + +from megatron.core.utils import get_attr_wrapped_model +def get_model_config(model): + return get_attr_wrapped_model(model, "config", allow_none=False) +############################## +############################## +############################## + + def get_optimizer_param_scheduler(args: Namespace, optimizer: MegatronOptimizer) -> OptimizerParamScheduler: """Create and configure the optimizer learning-rate/weight-decay scheduler. @@ -105,7 +184,245 @@ def setup_model_and_optimizer( assert not args.moe_use_upcycling assert args.load is not None or args.pretrained_checkpoint is not None - model = get_model(get_model_provider_func(args, role), ModelType.encoder_or_decoder) + ############################## + ###########lora############### + ############################## + + + # This part can be moved to `lora_utils.py` def apply_lora_to_megatron_model + if is_lora_enabled(args) and role == "actor" and args.megatron_to_hf_mode == "bridge": + ###### + # refer to: verl/verl/workers/engine/megatron/transformer_impl.py + ###### + + # if is_lora_enabled(args) and args.megatron_to_hf_mode == "bridge": + # The below written as: get_model_provider_func() usage + # from megatron.core.distributed import DistributedDataParallelConfig + from megatron.bridge.models.model_provider import get_model as bridge_get_model + from megatron.bridge import AutoBridge + from megatron.bridge.peft.lora import LoRA + from megatron.bridge.peft.canonical_lora import CanonicalLoRA + import torch + + + # bridge : start + # hf config: + from transformers import AutoConfig + hf_config = AutoConfig.from_pretrained(args.hf_checkpoint, trust_remote_code=True) + + # bridge: hf --> mg config + bridge = AutoBridge.from_hf_pretrained(args.hf_checkpoint, trust_remote_code=True) + # bridge : done + ##!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + # provider: start + # I can also use bridge_get_model() method + provider = bridge.to_megatron_provider(load_weights=False) # should be True??? + # provider = bridge.to_megatron_provider(load_weights=True) # different from full model training - in the training script, I need to load tuned base model weight and initial lora weights. Need to carefully check and optimize - where to load the base model? (but why in `model_provider.py` using: provider = bridge.to_megatron_provider(load_weights=False)) + + # Set parallel configs on the provider + provider.tensor_model_parallel_size = args.tensor_model_parallel_size + provider.pipeline_model_parallel_size = args.pipeline_model_parallel_size + provider.expert_model_parallel_size = args.expert_model_parallel_size + provider.expert_tensor_parallel_size = args.expert_tensor_parallel_size + provider.sequence_parallel = args.sequence_parallel + ##### + provider.virtual_pipeline_model_parallel_size = args.virtual_pipeline_model_parallel_size + provider.context_parallel_size = args.context_parallel_size + provider.variable_seq_lengths = True + provider.moe_token_dispatcher_type = "alltoall" + provider.moe_router_load_balancing_type = "none" + # provider: done + provider.finalize() + ##### + + ########### + ########### lora part + ########### + + # peft_cls: start + # need: args.hf_checkpoint, bridge, provider, self.param_dtype: lora_dtype + + # Determine lora_dtype + if hasattr(args, 'bf16') and args.bf16: + lora_dtype = torch.bfloat16 + elif hasattr(args, 'fp16') and args.fp16: + lora_dtype = torch.float16 + else: + # !!!!!!!!!!!!!!!!!!!!!! + # lora_dtype = None + lora_dtype = torch.float16 + + # (to-do) yusheng - lora_type - can be input arg, the default is LoRA + # Get exclude_modules as list + exclude_modules = [] + if hasattr(args, 'exclude_modules') and args.exclude_modules: + if isinstance(args.exclude_modules, str): + exclude_modules = [m.strip() for m in args.exclude_modules.split(",")] + else: + exclude_modules = list(args.exclude_modules) + # Convert HF module names to Megatron format + # exclude_modules = convert_target_modules_to_megatron(exclude_modules) + exclude_modules = convert_target_modules_to_megatron(exclude_modules, lora_type=LoRA) + # exclude_modules = convert_target_modules_to_megatron(exclude_modules, lora_type=CanonicalLoRA) + + # print("========") + # Create LoRA config + # (to-do) yusheng set - lora_type - the default is LoRA: should support all lora + lora = LoRA( + # lora = CanonicalLoRA( + target_modules=convert_target_modules_to_megatron(args.target_modules, lora_type=LoRA), + # target_modules=convert_target_modules_to_megatron(args.target_modules, lora_type=CanonicalLoRA), + # lora_dtype=lora_dtype, + exclude_modules=exclude_modules, + dim=args.lora_rank, + alpha=args.lora_alpha, + dropout=args.lora_dropout, + ##Below for Lora + # dropout_position=getattr(args, 'lora_dropout_position', 'pre'), + lora_A_init_method=getattr(args, 'lora_A_init_method', 'xavier'), + lora_B_init_method=getattr(args, 'lora_B_init_method', 'zero'), + # a2a_experimental=getattr(args, 'lora_a2a_experimental', False), + ) + + + + # peft_cls: done + ################ --------------------################# + # ----- Above: + # self.bridge: init hf model --> load hf weight + # self.provider: init megatron model --> load megatron weight + # self.peft_cls: init load module --> load lora weight + + + # When using PEFT with Megatron-Bridge, we must apply PEFT transformation + # BEFORE wrapping the model in DDP. This is required because: + # 1. PEFT freezes base model parameters (requires_grad=False) + # 2. DDP must be aware of which parameters are trainable when building gradient buckets + # 3. The distributed optimizer must only track trainable (adapter) parameters + # See Megatron-Bridge docs: training/peft.md + + + # # Define pre_wrap_hook to apply LoRA before DDP wrapping + def apply_lora_hook(model_chunks): + transformed = lora(model_chunks, training=True) + lora.set_params_to_save(transformed) + # Load adapter weights if adapter_path is specified + # adapter_path = getattr(peft_config, "adapter_path", None) + # if adapter_path is not None and adapter_path: + # print(f"Loading adapter weights from: {adapter_path}") + # load_adapter_checkpoint(transformed_model, adapter_path) + + # Print PEFT statistics + # if torch.distributed.get_rank() == 0: + # print_adapter_info(transformed) + + return transformed + + + # Register the hook + provider.register_pre_wrap_hook(apply_lora_hook) + # provider.finalize() + + + #### ------------------- ##### + #### ------------------- ##### + #### ------------------- ##### + + + # TODO: add more cases + is_value_model = ( + "ForTokenClassification" in hf_config.architectures[0] + or "ForSequenceClassification" in hf_config.architectures[0] + ) + + + wrap_config = McoreModuleWrapperConfig( + is_value_model=is_value_model, # here is False # actor is not value model + # share_embeddings_and_output_weights=hf_config.share_embeddings_and_output_weights, + # wrap_with_ddp=wrap_with_ddp, #default is True + # use_distributed_optimizer=self.engine_config.use_distributed_optimizer, #default is True + ) + + # Register post-creation callbacks (make_value_model, freeze_moe_router) as pre-wrap hooks + post_model_creation_callbacks = [] + if wrap_config.is_value_model: + hidden_size = ( + hf_config.text_config.hidden_size if hasattr(hf_config, "text_config") else hf_config.hidden_size + ) + value_model_hook = make_value_model(hidden_size, provider.sequence_parallel) + post_model_creation_callbacks.append(value_model_hook) + # if override_model_config.get("moe_config", {}).get("freeze_moe_router", False): + # post_model_creation_callbacks.append(freeze_moe_router) + + # Register post-creation callbacks (make_value_model, freeze_moe_router) as pre-wrap hooks + for callback in post_model_creation_callbacks: + provider.register_pre_wrap_hook(callback) + + + if wrap_config.wrap_with_ddp: + from megatron.bridge.training.config import DistributedDataParallelConfig + ddp_config_dict = { + "use_distributed_optimizer": wrap_config.use_distributed_optimizer, + } + # Apply any DDP config overrides + # if override_ddp_config is not None: + # ddp_config_dict.update(override_ddp_config) + ddp_config = DistributedDataParallelConfig(**ddp_config_dict) + ddp_config.finalize() + + # Now call provide_distributed_model with all hooks registered + # Hooks will be applied automatically before DDP wrapping + model = provider.provide_distributed_model( + wrap_with_ddp=wrap_config.wrap_with_ddp, + ddp_config=ddp_config, + ) + ## Use: + # model = bridge_get_model(...) seems more simple + + + # Extract TransformerConfig from the created model + tf_config = get_model_config(model[0] if isinstance(model, list) else model) + + ############# + # # Build DDP config - config + # ddp_config = DistributedDataParallelConfig( + # grad_reduce_in_fp32=getattr(args, 'grad_reduce_in_fp32', False), + # check_for_nan_in_grad=getattr(args, 'check_for_nan_in_grad', False), + # overlap_grad_reduce=getattr(args, 'overlap_grad_reduce', False), + # overlap_param_gather=getattr(args, 'overlap_param_gather', False), + # average_in_collective=getattr(args, 'average_in_collective', False), + # use_distributed_optimizer=getattr(args, 'use_distributed_optimizer', False), + # ) + + # good method + # # Use Bridge's get_model with the provider (which now has LoRA hook registered) + # model = bridge_get_model( # this function eqaul get_model() in megatron-core for base model + # model_provider=provider, + # ddp_config=ddp_config, + # model_type=ModelType.encoder_or_decoder, + # wrap_with_ddp=True, + # use_cpu_initialization=False, + # bf16=getattr(args, 'bf16', False), + # fp16=getattr(args, 'fp16', False), + # pre_wrap_hook=provider.pre_wrap_hook, + # ) + + + # ???? the below can be remove ??? or it's tag to jude the model is (base) or (base + lora) + # Store lora instance for later use (e.g., checkpoint saving) + # You may want to attach this to the model or args for later access + # if hasattr(args, '_lora_instance'): + # args._lora_instance = lora + ############# + + + else: + # Original non-LoRA path or non-bridge mode + model = get_model(get_model_provider_func(args, role), ModelType.encoder_or_decoder) + + ############################## + ############################## + ############################## # Optimizer kwargs = {} @@ -114,7 +431,6 @@ def setup_model_and_optimizer( kwargs[f.name] = getattr(args, f.name) config = OptimizerConfig(**kwargs) config.timers = None - optimizer = get_megatron_optimizer( config=config, model_chunks=model, @@ -703,16 +1019,44 @@ def save( args = get_args() if should_disable_forward_pre_hook(args): disable_forward_pre_hook(model) - save_checkpoint( - iteration, - model, - optimizer, - opt_param_scheduler, - num_floating_point_operations_so_far=0, - checkpointing_context=None, - train_data_iterator=None, - preprocess_common_state_dict_fn=None, - ) + + ############################## + ###########lora############### + ############################## + # save_checkpoint( + # iteration, + # model, + # optimizer, + # opt_param_scheduler, + # num_floating_point_operations_so_far=0, + # checkpointing_context=None, + # train_data_iterator=None, + # preprocess_common_state_dict_fn=None, + # ) + + if is_lora_model(model): + save_checkpoint_with_lora( + iteration, + model, + optimizer, + opt_param_scheduler, + ) + else: + save_checkpoint( + iteration, + model, + optimizer, + opt_param_scheduler, + num_floating_point_operations_so_far=0, + checkpointing_context=None, + train_data_iterator=None, + preprocess_common_state_dict_fn=None, + ) + + ############################## + ############################## + ############################## + if should_disable_forward_pre_hook(args): enable_forward_pre_hook(model) @@ -753,6 +1097,15 @@ def save_hf_model(args, rollout_id: int, model: Sequence[DDP]) -> None: if should_log: logger.error(f"Failed to save HuggingFace format: {e}") + + ############################## + ###########lora############### + ############################## + # to-do: also need to impl lora saving + ############################## + ############################## + ############################## + def initialize_model_and_optimizer( args: Namespace, role: str = "actor" diff --git a/miles/backends/megatron_utils/model_provider.py b/miles/backends/megatron_utils/model_provider.py index 7834f1101..19aabfd42 100644 --- a/miles/backends/megatron_utils/model_provider.py +++ b/miles/backends/megatron_utils/model_provider.py @@ -19,6 +19,9 @@ from miles.utils.misc import load_function + + + # Adapt from https://github.com/volcengine/verl/blob/c3b20575d2bc815fcccd84bddb4c0401fc4b632b/verl/models/llama/megatron/layers/parallel_linear.py#L82 class LinearForLastLayer(torch.nn.Linear): def __init__( @@ -77,20 +80,24 @@ def wrapped_model_provider( return wrapped_model_provider + if args.megatron_to_hf_mode == "bridge": from megatron.bridge import AutoBridge bridge = AutoBridge.from_hf_pretrained(args.hf_checkpoint, trust_remote_code=True) provider = bridge.to_megatron_provider(load_weights=False) + # TODO: we should not manually set this... provider.tensor_model_parallel_size = args.tensor_model_parallel_size provider.pipeline_model_parallel_size = args.pipeline_model_parallel_size provider.expert_model_parallel_size = args.expert_model_parallel_size provider.expert_tensor_parallel_size = args.expert_tensor_parallel_size provider.sequence_parallel = args.sequence_parallel + provider.finalize() return provider.provide - + + def model_provider(pre_process: bool = True, post_process: bool = True, vp_stage: int | None = None) -> GPTModel: """Builds the model. diff --git a/miles/backends/megatron_utils/update_weight/hf_weight_iterator_base.py b/miles/backends/megatron_utils/update_weight/hf_weight_iterator_base.py index ef7d62e8a..9acbbd4a1 100644 --- a/miles/backends/megatron_utils/update_weight/hf_weight_iterator_base.py +++ b/miles/backends/megatron_utils/update_weight/hf_weight_iterator_base.py @@ -14,11 +14,26 @@ def create(args, model, **kwargs): return c(args, model, **kwargs) - def __init__(self, args, model, model_name, quantization_config): + ############################## + ###########lora############### + ############################## + # def __init__(self, args, model, model_name, quantization_config): + def __init__(self, args, model, model_name, quantization_config, **kwargs): + ############################## + ############################## + ############################## self.args = args self.model = model self.model_name = model_name self.quantization_config = quantization_config + ############################## + ###########lora############### + ############################## + self.is_lora = kwargs.pop('is_lora', False) + self._base_synced = kwargs.pop('_base_synced', False) + ############################## + ############################## + ############################## @abstractmethod def get_hf_weight_chunks(self, megatron_local_weights): diff --git a/miles/backends/megatron_utils/update_weight/hf_weight_iterator_bridge.py b/miles/backends/megatron_utils/update_weight/hf_weight_iterator_bridge.py index 7e0a4817e..d82740381 100644 --- a/miles/backends/megatron_utils/update_weight/hf_weight_iterator_bridge.py +++ b/miles/backends/megatron_utils/update_weight/hf_weight_iterator_bridge.py @@ -1,4 +1,5 @@ import dataclasses +import torch from miles.utils import megatron_bridge_utils from miles.utils.iter_utils import chunk_named_params_by_size @@ -7,11 +8,50 @@ from ..misc_utils import strip_param_name_prefix from .hf_weight_iterator_base import HfWeightIteratorBase +############################## +###########lora############### +############################## +def _normalize_base_weight_name(param_name: str) -> str: + """Remove the 'base_layer' suffix emitted when merge_adapter_weights=False.""" + if param_name.endswith("base_layer.weight"): + return param_name[: -len("base_layer.weight")] + "weight" + return param_name + +# CanonicalLoRA - same as sglang +# Lora - need to below convert (or sglang also need to use Lora) +# def _convert_lora_name_for_sglang(hf_param_name: str) -> str: +# """Convert standard HF LoRA names to SGLang's merged projection format.""" +# # Handle attention projections - SGLang uses merged qkv_proj +# for proj in ['q_proj', 'k_proj', 'v_proj']: +# if f'.self_attn.{proj}.' in hf_param_name: +# return hf_param_name.replace(f'.self_attn.{proj}.', '.self_attn.qkv_proj.') + +# # Handle MLP projections - SGLang uses merged gate_up_proj +# for proj in ['gate_proj', 'up_proj']: +# if f'.mlp.{proj}.' in hf_param_name: +# return hf_param_name.replace(f'.mlp.{proj}.', '.mlp.gate_up_proj.') + +# return hf_param_name +############################## +############################## +############################## + class HfWeightIteratorBridge(HfWeightIteratorBase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # ############################## + # ###########lora############### + # ############################## + # self.is_lora = is_lora # already get from HfWeightIteratorBase + # self._base_synced = _base_synced # already get from HfWeightIteratorBase + # ############################## + # ############################## + # ############################## + from megatron.bridge import AutoBridge import miles_plugins.megatron_bridge # noqa: F401 @@ -20,27 +60,128 @@ def __init__(self, *args, **kwargs): def get_hf_weight_chunks(self, megatron_local_weights): # TODO support quantization (e.g. modify megatron-bridge to provide megatron param name) + renamed_megatron_local_weights = {strip_param_name_prefix(k): v for k, v in megatron_local_weights.items()} + + with megatron_bridge_utils.patch_megatron_model(self.model): - conversion_tasks = self._bridge.get_conversion_tasks(self.model) - conversion_tasks = _process_conversion_tasks(conversion_tasks, renamed_megatron_local_weights) - - named_weights = self._bridge.export_hf_weights(self.model, cpu=False, conversion_tasks=conversion_tasks) - - named_weights = ( - ( - hf_param_name, - postprocess_hf_param( - args=self.args, - megatron_param_name=megatron_param_name, - hf_param_name=hf_param_name, - param=weight, - ), + ############################## + ###########lora############### + ############################## + # ## This is the origin way - weight sync will process - base model + lora weights + # ## to-do (yusheng): Optimize: use the method in `self.is_lora` but need to deal with CUDA issue (weight not on the same device) - might need to be delt with in megatron-core + + # conversion_tasks = self._bridge.get_conversion_tasks(self.model) + # conversion_tasks = _process_conversion_tasks(conversion_tasks, renamed_megatron_local_weights) + + # named_weights = self._bridge.export_hf_weights(self.model, cpu=False, conversion_tasks=conversion_tasks) + + # # for hf_param_name, weight, megatron_param_name in named_weights: + # # print(hf_param_name) + # # exit() + + # named_weights = ( + # ( + # hf_param_name, + # postprocess_hf_param( + # args=self.args, + # megatron_param_name=megatron_param_name, + # hf_param_name=hf_param_name, + # param=weight, + # ), + # ) + # for hf_param_name, weight, megatron_param_name in named_weights + # ) + + # yield from chunk_named_params_by_size(named_weights, chunk_size=self.args.update_weight_buffer_size) + + #### + + # Only sync base model on first call - smile/miles need (or if not LoRA-only mode) + # if not self.is_lora or self._base_synced: + if not self.is_lora: + # only pass base model + conversion_tasks = self._bridge.get_conversion_tasks(self.model) + conversion_tasks = _process_conversion_tasks(conversion_tasks, renamed_megatron_local_weights) + named_weights = self._bridge.export_hf_weights( + self.model, + cpu=False, + conversion_tasks=conversion_tasks, + # merge_adapter_weights=not self.is_lora, # Do not return merged (base.weight + lora.weight). + ) + + # for hf_param_name, weight, megatron_param_name in named_weights: + # print(hf_param_name) + + named_weights = ( + ( + ############################## + ###########lora############### + ############################## + hf_param_name, + # _normalize_base_weight_name(hf_param_name), + ############################## + ############################## + ############################## + postprocess_hf_param( + args=self.args, + megatron_param_name=megatron_param_name, + hf_param_name=hf_param_name, + param=weight, + ), + ) + for hf_param_name, weight, megatron_param_name in named_weights + ) + + yield from chunk_named_params_by_size(named_weights, chunk_size=self.args.update_weight_buffer_size) + + if self.is_lora: + self._base_synced = False + # torch.cuda.synchronize() + ############################## + ############################## + ############################## + + + + ############################## + ###########lora############### + ############################## + if self.is_lora: + # (to-do) yusheng: I might need to add the converting weights (mg --> hf) - refer above + # conversion_tasks = self._bridge.get_conversion_tasks(self.model) + # conversion_tasks = _process_conversion_tasks(conversion_tasks, renamed_megatron_local_weights) + + # print(self.model) # Identity() + + lora_weights = self._bridge.export_adapter_weights( + self.model, + cpu=False, + # cpu=True, ### if False, it will have cudaaccess error + # conversion_tasks=conversion_tasks, #### + show_progress=False + ) + + + # hf_param_name's might have big problem + lora_weights = ( + ( + hf_param_name, + postprocess_hf_param( # check if need postprocess_hf_param + args=self.args, + megatron_param_name=megatron_param_name, + hf_param_name=hf_param_name, + param=weight, + # param=weight.clone(), # solutuon - need to have self._bridge.build_adapter_conversion_tasks in megatron-bridge + ), + ) + for hf_param_name, weight, megatron_param_name in lora_weights ) - for hf_param_name, weight, megatron_param_name in named_weights - ) - yield from chunk_named_params_by_size(named_weights, chunk_size=self.args.update_weight_buffer_size) + yield from chunk_named_params_by_size(lora_weights, chunk_size=self.args.update_weight_buffer_size) + ############################## + ############################## + ############################## def _process_conversion_tasks(vanilla_conversion_tasks, new_weight_dict): diff --git a/miles/backends/megatron_utils/update_weight/hf_weight_iterator_direct.py b/miles/backends/megatron_utils/update_weight/hf_weight_iterator_direct.py index af2250dc1..1beded53f 100644 --- a/miles/backends/megatron_utils/update_weight/hf_weight_iterator_direct.py +++ b/miles/backends/megatron_utils/update_weight/hf_weight_iterator_direct.py @@ -209,4 +209,4 @@ def _get_megatron_local_param_infos(args: Namespace, model: Sequence[torch.nn.Mo infos[i].dtype == param_info.dtype ), f"Parameter dtype mismatch: {infos[i].dtype} != {param_info.dtype}" - return param_infos + return param_infos \ No newline at end of file diff --git a/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py b/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py index 527d3cfe9..5cf724e3e 100644 --- a/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py +++ b/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py @@ -19,6 +19,18 @@ update_weights_from_distributed, ) +############################## +###########lora############### +############################## +from miles.backends.megatron_utils.lora_utils import LORA_ADAPTER_NAME, is_lora_enabled + +def _is_lora_weight(name: str) -> bool: + """Check if a weight name is a LoRA adapter weight.""" + return ".lora_A." in name or ".lora_B." in name +############################## +############################## +############################## + class UpdateWeightFromTensor: """ @@ -36,6 +48,13 @@ def __init__( *, model_name: str, quantization_config: dict[str, int | str | list[str]] | None, + ############################## + ###########lora############### + ############################## + is_lora: bool = False, + ############################## + ############################## + ############################## ) -> None: """ Compute param buckets, create IPC Gloo groups (rollout_num_gpus_per_engine ranks/group). @@ -46,11 +65,45 @@ def __init__( self.model_name = model_name self.quantization_config = quantization_config self.weight_version = 0 + ############################## + ###########lora############### + ############################## + self.is_lora = is_lora + # self._lora_loaded = False + self._base_synced = True + # self._base_synced = False + + # self._hf_weight_iterator = HfWeightIteratorBase.create( + # args=args, model=model, model_name=model_name, quantization_config=quantization_config + # ) + self._hf_weight_iterator = HfWeightIteratorBase.create( - args=args, model=model, model_name=model_name, quantization_config=quantization_config + args=args, model=model, model_name=model_name, quantization_config=quantization_config, + is_lora=self.is_lora, + _base_synced=self._base_synced, ) + + # Store LoRA config for weight sync (from tensor) - why this - let me think think + if self.is_lora: + from miles.backends.sglang_utils.sglang_engine import convert_target_modules_to_hf + self._lora_config = { + "peft_type": "LORA", + "r": args.lora_rank, + "lora_alpha": args.lora_alpha, + "target_modules": convert_target_modules_to_hf(list(args.target_modules)) if args.target_modules else ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], + "lora_dropout": args.lora_dropout, + "bias": "none", + "task_type": "CAUSAL_LM", + } + else: + self._lora_config = None + ############################## + ############################## + ############################## + + # create the group within megatron. for start_rank in range(0, dist.get_world_size(), self.args.rollout_num_gpus_per_engine): end_rank = start_rank + self.args.rollout_num_gpus_per_engine @@ -62,6 +115,7 @@ def __init__( self._model_update_groups = None + def connect_rollout_engines( self, rollout_engines: Sequence[ActorHandle], rollout_engine_lock: ActorHandle ) -> None: @@ -114,39 +168,100 @@ def update_weights(self) -> None: ray.get([engine.flush_cache.remote() for engine in self.rollout_engines]) dist.barrier(group=get_gloo_group()) - megatron_local_weights = self.weights_getter() + megatron_local_weights = self.weights_getter() + for hf_named_tensors in self._hf_weight_iterator.get_hf_weight_chunks(megatron_local_weights): refs, long_lived_tensors = self._send_hf_params(hf_named_tensors) ray.get(refs) del long_lived_tensors - dist.barrier(group=get_gloo_group()) + + ############################## + ###########lora############### + ############################## + # def _send_hf_params(self, hf_named_tensors) -> tuple[list[ObjectRef], Any]: + + # all_refs = [] + + # refs_colocated, long_lived_tensors = _send_to_colocated_engine( + # hf_named_tensors, + # ipc_engine=self._ipc_engine, + # ipc_gather_src=self._ipc_gather_src, + # ipc_gather_group=self._ipc_gather_group, + # weight_version=self.weight_version, + # ) + # all_refs.extend(refs_colocated) + + # if self.use_distribute and self._is_distributed_src_rank: + # refs_distributed = update_weights_from_distributed( + # self._group_name, + # self._model_update_groups, + # self.weight_version, + # self.distributed_rollout_engines, + # hf_named_tensors, + # ) + # if refs_distributed: + # all_refs.extend(refs_distributed) + + # return all_refs, long_lived_tensors + + ############## + def _send_hf_params(self, hf_named_tensors) -> tuple[list[ObjectRef], Any]: + all_refs = [] + + # Separate LoRA weights from base weights + if self.is_lora: + base_tensors = [(n, t) for n, t in hf_named_tensors if not _is_lora_weight(n)] + lora_tensors = [(n, t) for n, t in hf_named_tensors if _is_lora_weight(n)] + else: + base_tensors = hf_named_tensors + lora_tensors = [] - refs_colocated, long_lived_tensors = _send_to_colocated_engine( - hf_named_tensors, - ipc_engine=self._ipc_engine, - ipc_gather_src=self._ipc_gather_src, - ipc_gather_group=self._ipc_gather_group, - weight_version=self.weight_version, - ) - all_refs.extend(refs_colocated) - - if self.use_distribute and self._is_distributed_src_rank: - refs_distributed = update_weights_from_distributed( - self._group_name, - self._model_update_groups, - self.weight_version, - self.distributed_rollout_engines, - hf_named_tensors, + # Send base model weights via update_weights_from_tensor + long_lived_tensors = [] + if base_tensors: + refs_colocated, long_lived_tensors = _send_to_colocated_engine( + base_tensors, + ipc_engine=self._ipc_engine, + ipc_gather_src=self._ipc_gather_src, + ipc_gather_group=self._ipc_gather_group, + weight_version=self.weight_version, + ) + all_refs.extend(refs_colocated) + + if self.use_distribute and self._is_distributed_src_rank: + refs_distributed = update_weights_from_distributed( + self._group_name, + self._model_update_groups, + self.weight_version, + self.distributed_rollout_engines, + base_tensors, + ) + if refs_distributed: + all_refs.extend(refs_distributed) + + # Send LoRA weights via load_lora_adapter_from_tensors + if lora_tensors and self._lora_config is not None: + refs_lora, lora_long_lived = _send_lora_to_colocated_engine( + lora_tensors, + ipc_engine=self._ipc_engine, + ipc_gather_src=self._ipc_gather_src, + ipc_gather_group=self._ipc_gather_group, + lora_config=self._lora_config, + lora_name=LORA_ADAPTER_NAME, ) - if refs_distributed: - all_refs.extend(refs_distributed) + all_refs.extend(refs_lora) + long_lived_tensors.extend(lora_long_lived) return all_refs, long_lived_tensors + + ############################## + ############################## + ############################## def _send_to_colocated_engine( @@ -169,7 +284,7 @@ def _send_to_colocated_engine( if dtype not in converted_named_tensors_by_dtypes: converted_named_tensors_by_dtypes[dtype] = [] converted_named_tensors_by_dtypes[dtype].append((name, tensor)) - + serialized_tensors = [] for _dtype, named_tensors in converted_named_tensors_by_dtypes.items(): flattened_tensor_bucket = FlattenedTensorBucket(named_tensors=named_tensors) @@ -204,3 +319,78 @@ def _send_to_colocated_engine( refs.append(ipc_engine.update_weights_from_tensor.remote(**kwargs)) return refs, long_live_tensors + + + + + +############################## +###########lora############### +############################## +def _send_lora_to_colocated_engine( + lora_named_tensors: list[tuple[str, torch.Tensor]], + *, + ipc_engine, + ipc_gather_src, + ipc_gather_group, + lora_config: dict, + lora_name: str, +) -> tuple[list[ObjectRef], Any]: + """Send LoRA weights to colocated engine via load_lora_adapter_from_tensors. + + Uses FlattenedTensorBucket for cross-process serialization (same as base weights). + """ + + long_live_tensors = [] + + # Use FlattenedTensorBucket (same as base weights) for proper cross-process serialization + flattened_tensor_bucket = FlattenedTensorBucket(named_tensors=lora_named_tensors) + metadata = flattened_tensor_bucket.get_metadata() + flattened_tensor_data = { + "flattened_tensor": flattened_tensor_bucket.get_flattened_tensor(), + "metadata": metadata, + } + long_live_tensors.append(flattened_tensor_data) + serialized_lora = MultiprocessingSerializer.serialize(flattened_tensor_data, output_str=True) + + # Gather from all ranks in the group + serialized_lora_gathered = ( + [None] * dist.get_world_size(ipc_gather_group) if ipc_gather_src == dist.get_rank() else None + ) + dist.gather_object( + serialized_lora, + object_gather_list=serialized_lora_gathered, + dst=ipc_gather_src, + group=ipc_gather_group, + ) + + # refs = [] + # if dist.get_rank() == ipc_gather_src: + # # Send LoRA via the same mechanism as base weights, but use a special endpoint + # refs.append(ipc_engine.load_lora_adapter_from_tensors.remote( + # lora_name=lora_name, + # serialized_tensors=serialized_lora_gathered[0], # FlattenedTensorBucket format + # config_dict=lora_config, + # load_format="flattened_bucket", # Add this to indicate the format + # )) + + refs = [] + if dist.get_rank() == ipc_gather_src: + # First, unload the existing LoRA adapter (if any) + try: + ray.get(ipc_engine.unload_lora_adapter.remote(lora_name=lora_name)) + except Exception: + pass # Ignore error if adapter was not loaded + + # Then load the new LoRA weights + refs.append(ipc_engine.load_lora_adapter_from_tensors.remote( + lora_name=lora_name, + serialized_tensors=serialized_lora_gathered[0], # FlattenedTensorBucket format + config_dict=lora_config, + load_format="flattened_bucket", # Add this to indicate the format + )) + + return refs, long_live_tensors +############################## +############################## +############################## \ No newline at end of file diff --git a/miles/backends/sglang_utils/sglang_engine.py b/miles/backends/sglang_utils/sglang_engine.py index 9bb9b1287..c3abdd325 100644 --- a/miles/backends/sglang_utils/sglang_engine.py +++ b/miles/backends/sglang_utils/sglang_engine.py @@ -16,6 +16,51 @@ from miles.ray.ray_actor import RayActor from miles.utils.http_utils import get_host_info +############################## +###########lora############### +############################## +from argparse import Namespace + +def is_lora_enabled(args: Namespace) -> bool: + """Check if LoRA is enabled.""" + return args.lora_rank > 0 or args.lora_adapter_path is not None + + +def convert_target_modules_to_hf(megatron_modules: list[str]) -> list[str]: + """Convert Megatron LoRA target module names to HuggingFace format. + + Megatron: linear_qkv, linear_proj, linear_fc1, linear_fc2 + HF: q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj + """ + + # If "all-linear" was converted to the standard Megatron list, just return "all" + # This allows SGLang to accept any LoRA adapter regardless of its target modules + # standard_all_linear = ["linear_qkv", "linear_proj", "linear_fc1", "linear_fc2"] + # if set(megatron_modules) == set(standard_all_linear): + # return "all" + + # This mapping should match your specific model architecture + replacements = { + "linear_qkv": ["q_proj", "k_proj", "v_proj"], + "linear_proj": ["o_proj"], + "linear_fc1": ["gate_proj", "up_proj"], + "linear_fc2": ["down_proj"], + } + + hf_modules = [] + for module in megatron_modules: + if module in replacements: + hf_modules.extend(replacements[module]) + else: + # Keep as-is if not in mapping (might already be HF format) + hf_modules.append(module) + + return hf_modules + +############################## +############################## +############################## + logger = logging.getLogger(__name__) @@ -53,8 +98,23 @@ def _to_local_gpu_id(physical_gpu_id: int) -> int: def launch_server_process(server_args: ServerArgs) -> multiprocessing.Process: from sglang.srt.entrypoints.http_server import launch_server + multiprocessing.set_start_method("spawn", force=True) server_args.host = server_args.host.strip("[]") + ############################## + ###########lora############### + ############################## + # for debugging - can be removed + # Add logging to see what args are being passed + logger.info(f"Launching SGLang server with args:") + logger.info(f" enable_lora={getattr(server_args, 'enable_lora', None)}") + logger.info(f" max_lora_rank={getattr(server_args, 'max_lora_rank', None)}") + logger.info(f" max_loras_per_batch={getattr(server_args, 'max_loras_per_batch', None)}") + logger.info(f" lora_target_modules={getattr(server_args, 'lora_target_modules', None)}") + logger.info(f" base_gpu_id={server_args.base_gpu_id}") + ############################## + ############################## + ############################## p = multiprocessing.Process(target=launch_server, args=(server_args,)) p.start() @@ -271,6 +331,40 @@ def update_weights_from_tensor( "update_weights_from_tensor", payload, ) + + ############################## + ###########lora############### + ############################## + def load_lora_adapter_from_tensors( + self, + lora_name: str, + serialized_tensors: str, + config_dict: dict, + load_format: str | None = None, # Add this parameter + pinned: bool = False, + added_tokens_config: dict | None = None, + ): + """ + Load a LoRA adapter from serialized tensor data. + """ + payload = { + "lora_name": lora_name, + "serialized_tensors": serialized_tensors, + "config_dict": config_dict, + "pinned": pinned, + } + if load_format is not None: + payload["load_format"] = load_format + if added_tokens_config is not None: + payload["added_tokens_config"] = added_tokens_config + + return self._make_request( + "load_lora_adapter_from_tensors", + payload, + ) + ############################## + ############################## + ############################## def flush_cache(self): """Flush the cache of the server.""" @@ -333,9 +427,56 @@ def get_weight_version(self): response.raise_for_status() return response.json()["weight_version"] - def release_memory_occupation(self): + ############################## + ###########lora############### + ############################## + # def load_lora_adapter(self, lora_name: str, lora_path: str): + # """Load LoRA adapter from disk.""" + # return self._make_request( + # "load_lora_adapter", + # {"lora_name": lora_name, "lora_path": lora_path}, + # ) + + # def load_lora_adapter_from_tensors(self, lora_name: str, serialized_tensors: str, config_dict: dict): + # """Load LoRA adapter from serialized tensors.""" + # return self._make_request( + # "load_lora_adapter_from_tensors", + # {"lora_name": lora_name, "serialized_tensors": serialized_tensors, "config_dict": config_dict}, + # ) + + def unload_lora_adapter(self, lora_name: str): + """Unload LoRA adapter.""" + return self._make_request( + "unload_lora_adapter", + {"lora_name": lora_name}, + ) + ############################## + ############################## + ############################## + + + ############################## + ###########lora############### + ############################## + # def release_memory_occupation(self): + # self.flush_cache() + # return self._make_request("release_memory_occupation") + + def release_memory_occupation(self, tags: list[str] = None): + """ + Available tags for multi-stage release: weights, kv_cache + """ self.flush_cache() - return self._make_request("release_memory_occupation") + return self._make_request( + "release_memory_occupation", + {"tags": tags}, + ) + ############################## + ############################## + ############################## + + + def resume_memory_occupation(self, tags: list[str] = None): """ @@ -494,6 +635,60 @@ def _compute_server_args( kwargs["dtype"] = "float16" external_engine_need_check_fields = [k for k in kwargs.keys() if k not in _EXTERNAL_ENGINE_SKIP_CHECK_FIELDS] + ############################## + ###########lora############### + ############################## + # if is_lora_enabled(args): + # kwargs["enable_lora"] = True + # kwargs["max_lora_rank"] = args.lora_rank + # kwargs["max_loras_per_batch"] = 1 + # # NOTE: lora_target_modules might not be supported by your SGLang version + # # Comment out this line if SGLang doesn't support it: + # # kwargs["lora_target_modules"] = convert_target_modules_to_hf(args.target_modules) + # # Log for debugging + # kwargs["lora_target_modules"] = convert_target_modules_to_hf(args.target_modules) + # # kwargs["lora_target_modules_list"] = convert_target_modules_to_hf(args.target_modules) + # # print(1111111111) + # # print(kwargs["lora_target_modules"]) + # # print(2222222222) + # # exit() + + # if is_lora_enabled(args): + if args.lora_rank > 0 or args.lora_adapter_path is not None: + kwargs["enable_lora"] = True + kwargs["max_loras_per_batch"] = 1 #!!!!!!!! + # kwargs["max_lora_rank"] = args.lora_rank + # Ensure a valid positive LoRA rank is passed to the SGLang engine. + # If LoRA is enabled via adapter path but lora_rank is not set to a + # positive value, default to rank 1 to avoid an invalid configuration. + if getattr(args, "lora_rank", None) and args.lora_rank > 0: + max_lora_rank = args.lora_rank + else: + max_lora_rank = 1 + kwargs["max_lora_rank"] = max_lora_rank + # kwargs["lora_target_modules"] = args.target_modules + kwargs["lora_target_modules"] = convert_target_modules_to_hf(args.target_modules) + + ##### For rollout debug mode to add: + if args.debug_rollout_only: + from miles.backends.megatron_utils.lora_utils import LORA_ADAPTER_NAME + # SGLang lora_paths Format: {"adapter_name": "path_to_adapter"} + kwargs["lora_paths"] = {LORA_ADAPTER_NAME: args.lora_adapter_path} + + if args.lora_adapter_path is None: + # raise ValueError("lora_adapter_path must be provided") + # raise ValueError("lora_adapter_path must be provided") + # pass + logger.info("Did not provide pre-trained LoRA adapter_path, will use random initial weights") + else: + from miles.backends.megatron_utils.lora_utils import LORA_ADAPTER_NAME + # SGLang lora_paths Format: {"adapter_name": "path_to_adapter"} + kwargs["lora_paths"] = {LORA_ADAPTER_NAME: args.lora_adapter_path} + + ############################## + ############################## + ############################## + unused_keys = set(kwargs.keys()) for attr in dataclasses.fields(ServerArgs): if worker_type == "decode" and attr.name == "enable_hierarchical_cache": @@ -502,6 +697,7 @@ def _compute_server_args( kwargs[attr.name] = getattr(args, f"sglang_{attr.name}") unused_keys.discard(attr.name) + # for compatibility with old args if len(unused_keys) > 0: logger.info(f"Warning: The following arguments is not supported in the current sglang: {unused_keys}.") diff --git a/miles/ray/rollout.py b/miles/ray/rollout.py index 4b22c5ddc..08c886516 100644 --- a/miles/ray/rollout.py +++ b/miles/ray/rollout.py @@ -130,8 +130,33 @@ def save(self, rollout_id): def load(self, rollout_id=None): self.data_source.load(rollout_id) - def offload(self): - return ray.get([engine.release_memory_occupation.remote() for engine in self.rollout_engines]) + ############################## + ###########lora############### + ############################## + # def offload(self): + # return ray.get([engine.release_memory_occupation.remote() for engine in self.rollout_engines]) + + def offload(self, tags: list[str] | None = None): + self.health_monitoring_pause() + return ray.get( + [ + engine.release_memory_occupation.remote(tags=tags) + for engine in self.rollout_engines + if engine is not None + ] + ) + + def health_monitoring_pause(self): + if self.args.use_fault_tolerance and hasattr(self, '_health_monitor'): + self._health_monitor.stop() + + def health_monitoring_resume(self): + if self.args.use_fault_tolerance and hasattr(self, '_health_monitor'): + self._health_monitor.start() + ############################## + ############################## + ############################## + def onload(self, tags: list[str] = None): return ray.get([engine.resume_memory_occupation.remote(tags=tags) for engine in self.rollout_engines]) @@ -459,6 +484,23 @@ def init_rollout_engines(args, pg, all_rollout_engines): init_handles = [engine.init.remote(**(addr_and_ports[rank])) for rank, engine in rollout_engines] ray.get(init_handles) + # ############################## + # ###########lora############### + # ############################## + # # Load LoRA adapter from disk in debug-rollout-only mode + # if args.debug_rollout_only and args.lora_adapter_path: + # from miles.backends.megatron_utils.lora_utils import LORA_ADAPTER_NAME + # logger.info(f"Loading LoRA adapter from {args.lora_adapter_path} for debug-rollout-only mode") + # for i, engine in rollout_engines: + # ray.get(engine.load_lora_adapter.remote( + # lora_name=LORA_ADAPTER_NAME, + # lora_path=args.lora_adapter_path + # )) + # logger.info("LoRA adapter loaded successfully") + # ############################## + # ############################## + # ############################## + return num_new_engines diff --git a/miles/rollout/sglang_rollout.py b/miles/rollout/sglang_rollout.py index 77c540d60..74c4bfd1d 100644 --- a/miles/rollout/sglang_rollout.py +++ b/miles/rollout/sglang_rollout.py @@ -25,6 +25,14 @@ from .rm_hub import async_rm, batched_async_rm +############################## +###########lora############### +############################## +from miles.backends.sglang_utils.sglang_engine import is_lora_enabled +############################## +############################## +############################## + __all__ = ["generate_rollout"] logger = logging.getLogger(__name__) @@ -136,6 +144,16 @@ async def generate(args: Namespace, sample: Sample, sampling_params: dict[str, A "return_logprob": True, } + ############################## + ###########lora############### + ############################## + if is_lora_enabled(args): + from miles.backends.megatron_utils.lora_utils import LORA_ADAPTER_NAME + payload["lora_path"] = LORA_ADAPTER_NAME + ############################## + ############################## + ############################## + if args.use_rollout_routing_replay: payload["return_routed_experts"] = True diff --git a/miles/utils/arguments.py b/miles/utils/arguments.py index 595542c58..2b1822304 100644 --- a/miles/utils/arguments.py +++ b/miles/utils/arguments.py @@ -104,6 +104,24 @@ def add_cluster_arguments(parser): ), ) + ############################## + ###########lora############### + ############################## + parser.add_argument( + "--offload-rollout-level", + type=str, + nargs="+", + default=["kv_cache", "weight"], + help=( + "Specifies what to offload during rollout when offload-rollout is set. " + "Possible values: 'kv_cache', 'weight'. Default: both 'kv_cache' and 'weight'. " + "Example: --offload-rollout-level kv_cache weight" + ), + ) + ############################## + ############################## + ############################## + reset_arg(parser, "--distributed-backend", type=str, default="nccl") reset_arg(parser, "--distributed-timeout-minutes", type=int, default=10) @@ -911,6 +929,74 @@ def add_algo_arguments(parser): help="The threshold for Off-Policy Sequence Masking (OPSM).", ) return parser + + ############################## + ###########lora############### + ############################## + def add_lora_arguments(parser): + """Add LoRA-related arguments for Megatron backend.""" + parser.add_argument( + "--lora-rank", + type=int, + default=0, + help="LoRA rank. Set to 0 to disable LoRA (default: 0)", + ) + parser.add_argument( + "--lora-alpha", + type=int, + default=16, + help="LoRA alpha for scaling (default: 16)", + ) + parser.add_argument( + "--lora-dropout", + type=float, + default=0.0, + help="LoRA dropout rate (default: 0.0)", + ) + parser.add_argument( + "--target-modules", + type=str, + default=None, + help="Target modules for LoRA. Use 'all-linear' or comma-separated module names " + "(e.g., 'q_proj,k_proj,v_proj,o_proj' for HF naming or 'linear_qkv,linear_proj' for Megatron naming)", + ) + parser.add_argument( + "--exclude-modules", + type=str, + default=None, + help="Modules to exclude from LoRA (comma-separated)", + ) + parser.add_argument( + "--lora-adapter-path", + type=str, + default=None, + help="Path to load pre-trained LoRA adapter weights (default: None)", + ) + parser.add_argument( + "--lora-sync-from-tensor", + action="store_true", + default=False, + help="Sync LoRA weights via tensor instead of file (more efficient)", + ) + # parser.add_argument( + # "--share-ref-base-model", + # action="store_true", + # default=False, + # help="Share base model between actor and reference model (saves memory for LoRA)", + # ) + + # parser.add_argument( + # "--no-use-distributed-optimizer", + # action="store_false", + # default=True, + # dest="Use distributed optimizer (ZeRO)", + # help="Use distributed optimizer (ZeRO). Disable for LoRA training. (default: True)", + # ) + + return parser + ############################## + ############################## + ############################## def add_router_arguments(parser): parser.add_argument( @@ -1352,6 +1438,13 @@ def add_sglang_tp_size(): parser = add_data_arguments(parser) parser = add_eval_arguments(parser) parser = add_algo_arguments(parser) + ############################## + ###########lora############### + ############################## + parser = add_lora_arguments(parser) + ############################## + ############################## + ############################## parser = add_wandb_arguments(parser) parser = add_tensorboard_arguments(parser) parser = add_router_arguments(parser) @@ -1516,6 +1609,34 @@ def miles_validate_args(args): if args.save_interval is not None: assert args.save is not None, "'--save' is required when save_interval is set." + ############################## + ###########lora############### + ############################## + if args.lora_rank > 0: + # assert args.save is not None, "'--save' is required when LoRA is enabled." + assert args.target_modules is not None, "'--target-modules' is required when LoRA is enabled." + + # (to-do) yusheng: hf->mg; mg->hf + if args.target_modules == "all-linear": + modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] + elif "," in args.target_modules: + modules = [m.strip() for m in args.target_modules.split(",")] + else: + modules = [args.target_modules] + + if args.exclude_modules: + exclude_set = ( + set(m.strip() for m in args.exclude_modules.split(",")) + if "," in args.exclude_modules + else {args.exclude_modules} + ) + modules = [m for m in modules if m not in exclude_set] + + args.target_modules = modules + ############################## + ############################## + ############################## + assert not (args.kl_coef != 0 and args.kl_loss_coef != 0), "Only one of kl_coef and kl_loss_coef can be set" if args.advantage_estimator in ["reinforce_plus_plus", "reinforce_plus_plus_baseline"]: diff --git a/train.py b/train.py index a4f6824cc..c41687ba8 100644 --- a/train.py +++ b/train.py @@ -12,6 +12,16 @@ from miles.utils.misc import should_run_periodic_action from miles.utils.tracking_utils import init_tracking +############################## +###########lora############### +############################## +from sglang.srt.constants import GPU_MEMORY_TYPE_CUDA_GRAPH, GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS +############################## +############################## +############################## + + + def train(args): configure_logger() @@ -26,12 +36,32 @@ def train(args): # create the actor and critic models actor_model, critic_model = create_training_models(args, pgs, rollout_manager) + + # ############################## + # ###########lora############### + # ############################## + # ========== LoRA check ========== + # # check LoRA status + # from miles.backends.megatron_utils.lora_utils import is_lora_enabled + # if is_lora_enabled(args) or True: + # lora_status = ray.get(actor_model._actor_handlers[0].check_lora_status.remote()) + # print(f"LoRA modules: {len(lora_status['lora_modules'])}") + # assert len(lora_status['lora_modules']) > 0, "LoRA modules not found!" + # # already cannot access weight here - before torch_memory_saver + # # ========== LoRA check end ========== + # ############################## + # ###########lora############### + # ############################## + + if args.offload_rollout: ray.get(rollout_manager.onload.remote(tags=[GPU_MEMORY_TYPE_WEIGHTS])) + # always update weight first so that sglang has the loaded weights from training. actor_model.update_weights() + if args.check_weight_update_equal: ray.get(rollout_manager.check_weights.remote(action="compare")) @@ -55,10 +85,20 @@ def offload_train(): else: actor_model.clear_memory() + ############################## + ###########lora############### + ############################## def onload_rollout(): if args.offload_rollout: ray.get(rollout_manager.onload.remote(tags=[GPU_MEMORY_TYPE_WEIGHTS])) + # def onload_rollout(): + # if args.offload_rollout and "weight" in args.offload_rollout_level: + # ray.get(rollout_manager.onload.remote(tags=[GPU_MEMORY_TYPE_WEIGHTS])) + ############################## + ############################## + ############################## + # train loop. # note that for async training, one can change the position of the sync operation(ray.get). for rollout_id in range(args.start_rollout_id, args.num_rollout): @@ -67,8 +107,23 @@ def onload_rollout(): rollout_data_ref = ray.get(rollout_manager.generate.remote(rollout_id)) + ############################## + ###########lora############### + ############################## + # if args.offload_rollout: + # ray.get(rollout_manager.offload.remote()) + if args.offload_rollout: - ray.get(rollout_manager.offload.remote()) + offload_tags = [GPU_MEMORY_TYPE_CUDA_GRAPH] + if "kv_cache" in args.offload_rollout_level: + offload_tags.append(GPU_MEMORY_TYPE_KV_CACHE) + if "weight" in args.offload_rollout_level: + offload_tags.append(GPU_MEMORY_TYPE_WEIGHTS) + ray.get(rollout_manager.offload.remote(tags=offload_tags)) + + ############################## + ############################## + ############################## if args.use_critic: critic_train_handle = critic_model.async_train(rollout_id, rollout_data_ref)