Skip to content

Conversation

@yushengsu-thu
Copy link
Collaborator

@yushengsu-thu yushengsu-thu commented Jan 7, 2026

Description

  1. megatron backend
  2. disk sync weight
  3. Update LoRA weights via tensor
    (Update LoRA weights to the SGLang rollout engine via tensor, which is faster than the previous disk sync approach)
    Waiting for this sglang PR to be merged: Update LoRA Weights via Tensor sgl-project/sglang#16226
  4. [to-do] Need to refactor and fix bugs

Changes Made

Megatron-bridge: https://github.com/yushengsu-thu/Megatron-Bridge/tree/merged-megatron-0.16.0rc0
Codebase in miles

Pre-request

docker run --rm -it \
  --gpus all \
  -p 8264:8264 \
  --cap-add SYS_PTRACE \
  --security-opt seccomp=unconfined \
  --privileged \
  -v /.ssh/:/.ssh/ \
  -v /data:/data \
  --shm-size 128G \
  --name miles_yusheng \
  --ulimit memlock=-1 \
  --ulimit stack=67108864 \
  -w $PWD \
  radixark/miles:latest 
  • Megatron-Bridge
git clone --branch merged-megatron-0.16.0rc0 --single-branch https://github.com/yushengsu-thu/Megatron-Bridge.git
cd Megatron-Bridge
pip install -e . --no-deps --no-build-isolation
pip install megatron-energon --no-deps
pip install multi-storage-client --no-deps

Testing

# Model and model Download
huggingface-cli download --repo-type dataset zhuzilin/gsm8k --local-dir /root/gsm8k
huggingface-cli download Qwen/Qwen2.5-0.5B-Instruct --local-dir /root/Qwen2.5-0.5B-Instruct

# Codebase
git clone --branch miles-lora-megatron --single-branch https://github.com/yushengsu-thu/miles.git 
cd miles
source scripts/models/qwen2.5-0.5B.sh
PYTHONPATH=/root/Megatron-LM/ python \
   tools/convert_hf_to_torch_dist.py \
   ${MODEL_ARGS[@]} \
   --hf-checkpoint /root/Qwen2.5-0.5B-Instruct \
   --save /root/Qwen2.5-0.5B-Instruct_torch_dist/

# Run script:
bash examples/reproducibility/run-qwen2.5-0.5B-gsm8k-lora.sh

Related Issues, PRs:

Lora FSDP backend PR: #377
SGLang sync from tensor: sgl-project/sglang#16226

Code Style Compliance

  • [ ] Performance: Minimized synchronization calls (.item(), .cpu(), .tolist()) in inference paths
  • [ ] Architecture: No duplicate code > 5 lines; files < 2,000 lines
  • [ ] Function Purity: Avoided in-place modification of input arguments (unless explicitly documented for memory optimization)
  • [ ] Pythonic: Lean constructors, minimal dynamic attributes, proper type hints on public APIs
  • [ ] Testing: Provided a test script that reviewers can copy & paste to run immediately

@yushengsu-thu yushengsu-thu requested a review from fzyzcjy as a code owner January 7, 2026 00:41
Copilot AI review requested due to automatic review settings January 7, 2026 00:41
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @yushengsu-thu, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the miles framework by integrating LoRA capabilities into its Megatron backend. The primary goal is to enable more memory-efficient and faster fine-tuning of large language models by only updating a small set of LoRA parameters, rather than the entire model. This includes new utilities for applying LoRA, managing its checkpoints, and ensuring seamless weight synchronization with SGLang rollout engines, ultimately making the training process more accessible and resource-friendly.

Highlights

  • LoRA Integration: Introduced comprehensive LoRA (Low-Rank Adaptation) support for the Megatron backend, enabling efficient fine-tuning of large language models.
  • Memory Optimization: Implemented a mechanism to share base model weights between actor and reference models when using LoRA, significantly reducing memory footprint during training.
  • LoRA-Specific Checkpointing: Added functionality to save and load only the LoRA adapter weights, allowing for smaller checkpoints and faster iteration.
  • SGLang Integration: Updated SGLang utilities to support dynamic loading, unloading, and synchronization of LoRA adapters, facilitating their use in rollout engines.
  • New Example Script: Provided a new example script (run-qwen2.5-0.5B-gsm8k-lora.sh) demonstrating how to train a Qwen2.5-0.5B model with LoRA on the GSM8K dataset.
Ignored Files
  • Ignored by pattern: .github/workflows/** (1)
    • .github/workflows/pr-test.yml
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces LoRA (Low-Rank Adaptation) support for the Megatron backend, which is a significant feature enhancement. The changes cover argument parsing, model initialization, weight update mechanisms, and checkpointing, along with a new example script for running LoRA experiments. While the implementation is comprehensive, I've identified several issues that need attention. There's a critical bug in the new run script where LoRA arguments are not being passed, preventing LoRA training. Additionally, there are high-severity concerns regarding inconsistent and potentially incorrect LoRA parameter name conversions, which could lead to runtime errors and maintainability problems. I've also noted some medium-severity issues related to code style and fragile implementation patterns. Addressing these points will improve the correctness and robustness of this new feature.

Comment on lines 147 to 156
${MODEL_ARGS[@]} \
${CKPT_ARGS[@]} \
${ROLLOUT_ARGS[@]} \
${OPTIMIZER_ARGS[@]} \
${GRPO_ARGS[@]} \
${WANDB_ARGS[@]} \
${PERF_ARGS[@]} \
${EVAL_ARGS[@]} \
${SGLANG_ARGS[@]} \
${MISC_ARGS[@]}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The LORA_ARGS array is defined but not passed to the train.py script in the ray job submit command. This is a critical issue, as the script is intended for LoRA training, but the LoRA configuration will not be applied. You should add ${LORA_ARGS[@]} to the list of arguments passed to train.py.

Suggested change
${MODEL_ARGS[@]} \
${CKPT_ARGS[@]} \
${ROLLOUT_ARGS[@]} \
${OPTIMIZER_ARGS[@]} \
${GRPO_ARGS[@]} \
${WANDB_ARGS[@]} \
${PERF_ARGS[@]} \
${EVAL_ARGS[@]} \
${SGLANG_ARGS[@]} \
${MISC_ARGS[@]}
${MODEL_ARGS[@]} \
${CKPT_ARGS[@]} \
${LORA_ARGS[@]} \
${ROLLOUT_ARGS[@]} \
${OPTIMIZER_ARGS[@]} \
${GRPO_ARGS[@]} \
${WANDB_ARGS[@]} \
${PERF_ARGS[@]} \
${EVAL_ARGS[@]} \
${SGLANG_ARGS[@]} \
${MISC_ARGS[@]}

for m in unwrapped:
for name, param in m.named_parameters():
if ".adapter." in name or "lora_" in name:
clean_name = name.replace(".to_wrap.", "")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The get_lora_state_dict function in lora_utils.py cleans parameter names using name.replace(".to_wrap.", "."). However, here in load_lora_checkpoint, the name is cleaned with name.replace(".to_wrap.", ""), which is missing a dot. This inconsistency will cause a mismatch in parameter names and prevent LoRA checkpoints from being loaded correctly.

Suggested change
clean_name = name.replace(".to_wrap.", "")
clean_name = name.replace(".to_wrap.", ".")

Comment on lines 193 to 211
# Map Megatron module names to HF names
megatron_to_hf = {
"self_attention.linear_qkv": ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"],
"self_attention.linear_proj": ["self_attn.o_proj"],
"mlp.linear_fc1": ["mlp.gate_proj", "mlp.up_proj"],
"mlp.linear_fc2": ["mlp.down_proj"],
}

lora_suffix = "lora_A.weight" if lora_type == "in" else "lora_B.weight"

hf_modules = megatron_to_hf.get(module_path)
if hf_modules is None:
return None

# For QKV, we typically use the first mapping (q_proj)
# In practice, Megatron-Bridge handles this internally
hf_module = hf_modules[0]

return f"base_model.model.model.layers.{layer_idx}.{hf_module}.{lora_suffix}" No newline at end of file
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The function _megatron_to_hf_lora_name maps LoRA weights for a fused linear_qkv layer in Megatron to only the q_proj layer in the HuggingFace format. This is likely incorrect if the HF model has separate Q, K, and V projection layers, as the LoRA weights for a fused layer cannot typically be reused for just one of the component projections. This could lead to incorrect behavior or errors in SGLang. The logic should be revisited to correctly handle the mapping for fused layers, possibly by splitting the LoRA weights or creating mappings for k_proj and v_proj as well.

Comment on lines 238 to 283
def named_params_and_buffers_lora_only(args, model):
"""Only yield LoRA adapter parameters for weight sync."""
from miles.backends.megatron_utils.lora_utils import is_lora_model

if not is_lora_model(model):
yield from named_params_and_buffers(args, model)
return

for name, param in model.named_parameters():
if ".adapter." in name or "lora_" in name:
# Convert to HF-compatible name
hf_name = _convert_lora_name_to_hf(name)
yield (hf_name, param)


def _convert_lora_name_to_hf(megatron_name: str) -> str:
"""
Convert Megatron LoRA param name to HF-compatible format.
Examples:
module.module.decoder.layers.0.self_attention.linear_qkv.adapter.linear_in.weight
-> base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight
"""
import re

# Clean up name
name = megatron_name.replace("module.module.", "").replace("module.", "")

# Handle LoRA adapter names
if ".adapter.linear_in.weight" in name:
base = name.replace(".adapter.linear_in.weight", "")
lora_suffix = ".lora_A.weight"
elif ".adapter.linear_out.weight" in name:
base = name.replace(".adapter.linear_out.weight", "")
lora_suffix = ".lora_B.weight"
else:
return megatron_name # Not a LoRA param

# Map layer structure
base = base.replace("decoder.layers", "model.layers")
base = base.replace("self_attention.linear_qkv", "self_attn.q_proj") # Simplified
base = base.replace("self_attention.linear_proj", "self_attn.o_proj")
base = base.replace("mlp.linear_fc1", "mlp.gate_proj")
base = base.replace("mlp.linear_fc2", "mlp.down_proj")

return f"base_model.model.{base}{lora_suffix}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This file introduces _convert_lora_name_to_hf, which duplicates the functionality of _megatron_to_hf_lora_name in lora_utils.py. The implementations are different (string replacement vs. regex) and produce inconsistent output prefixes (base_model.model. vs. base_model.model.model.). This duplication is a significant maintainability risk and a likely source of bugs. This logic should be consolidated into a single, robust function in lora_utils.py and reused here.

Comment on lines 10 to 11
pkill -9 ray
pkill -9 python
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The pkill commands on lines 10 and 11 are duplicates of the commands on lines 7 and 8. The sleep on line 9 and the duplicated pkill commands should be removed to make the script cleaner and avoid redundant operations.

Comment on lines 626 to 635
old_args = (self.args.load, self.args.no_load_optim,
self.args.no_load_rng, self.args.finetune)
self.args.load = args.ref_load
self.args.no_load_optim = True
self.args.no_load_rng = True
self.args.finetune = True

load_checkpoint(self.model, None, None, {}, skip_load_to_model_and_opt=False)
(self.args.load, self.args.no_load_optim,
self.args.no_load_rng, self.args.finetune) = old_args
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Temporarily modifying self.args to call load_checkpoint is a fragile pattern. If an exception occurs between the modification and restoration, self.args could be left in an inconsistent state. While it's handled correctly here with a tuple for restoration, this pattern can be hard to maintain. Consider wrapping this logic in a try...finally block to guarantee restoration, or if possible, refactor load_checkpoint to accept a temporary configuration object.

Comment on lines 83 to 84
import os
from pathlib import Path
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The os and pathlib.Path modules are already imported at the top of the file. These redundant local imports should be removed to improve code clarity.

"peft_type": "LORA",
"r": args.lora_rank,
"lora_alpha": args.lora_alpha,
"target_modules": args.target_modules,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The target_modules are saved as a single string, which might be comma-separated. For compatibility with the Hugging Face PEFT format, it's better to save this as a list of strings. Please consider parsing the string into a list before saving it to the JSON file, similar to how it's handled in apply_lora_to_model.

Comment on lines 116 to 117
import os
from pathlib import Path
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The os and pathlib.Path modules are already imported at the top of the file. These redundant local imports should be removed.

Comment on lines 158 to 171
# def _megatron_to_hf_lora_name(megatron_name: str) -> Optional[str]:
# """Convert Megatron LoRA parameter name to HF format."""
# # Example mapping (adjust based on actual model)
# # decoder.layers.0.self_attention.linear_qkv.adapter.linear_in.weight
# # -> model.layers.0.self_attn.q_proj.lora_A.weight

# if ".adapter.linear_in.weight" in megatron_name:
# base_name = megatron_name.replace(".adapter.linear_in.weight", "")
# return base_name.replace("linear_qkv", "q_proj") + ".lora_A.weight"
# elif ".adapter.linear_out.weight" in megatron_name:
# base_name = megatron_name.replace(".adapter.linear_out.weight", "")
# return base_name.replace("linear_qkv", "q_proj") + ".lora_B.weight"

# return None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block of commented-out code appears to be a previous implementation of _megatron_to_hf_lora_name. It should be removed to improve code clarity and reduce clutter.

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces LoRA (Low-Rank Adaptation) support for the Megatron backend in the MILES training framework, enabling parameter-efficient fine-tuning of large language models.

Key Changes:

  • Adds LoRA configuration arguments and utilities for applying LoRA adapters to Megatron models
  • Implements LoRA-specific weight synchronization between training and rollout engines using tensor-based updates
  • Provides checkpoint management for LoRA adapters with support for shared base model weights

Reviewed changes

Copilot reviewed 11 out of 11 changed files in this pull request and generated 21 comments.

Show a summary per file
File Description
miles/utils/arguments.py Adds command-line arguments for LoRA configuration (rank, alpha, target modules, dropout, adapter path)
miles/rollout/sglang_rollout.py Includes LoRA adapter name in rollout payload when LoRA is enabled
miles/backends/sglang_utils/sglang_engine.py Adds LoRA adapter loading/unloading methods and configures SGLang server with LoRA parameters
miles/backends/megatron_utils/update_weight/update_weight_from_tensor_lora.py Implements LoRA-specific weight updater that syncs only LoRA parameters while optionally sharing base model
miles/backends/megatron_utils/update_weight/common.py Adds utilities for extracting and converting LoRA parameter names between Megatron and HuggingFace formats
miles/backends/megatron_utils/model.py Applies LoRA adapters to model during initialization when LoRA rank is specified
miles/backends/megatron_utils/lora_utils.py Core LoRA utilities including adapter application, weight extraction, and name conversion between Megatron and HF formats
miles/backends/megatron_utils/checkpoint.py Implements LoRA adapter checkpoint save/load functionality separate from base model checkpoints
miles/backends/megatron_utils/actor.py Integrates LoRA weight updater and adds support for shared reference model with different LoRA adapters
examples/reproducibility/run-qwen2.5-0.5B-gsm8k-lora.sh Example training script demonstrating LoRA configuration for GSM8K fine-tuning
.github/workflows/pr-test.yml Updates CI workflow installation step (contains duplicate run command)

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 44 to 49
if args.target_modules == "all-linear":
target_modules = ["linear_qkv", "linear_proj", "linear_fc1", "linear_fc2"]
elif "," in args.target_modules:
target_modules = [m.strip() for m in args.target_modules.split(",")]
else:
target_modules = [args.target_modules]
Copy link

Copilot AI Jan 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function checks args.target_modules but this could be None based on the argument definition. The code doesn't handle the None case before attempting string operations. This should either provide a default value or add proper None checking before using the value.

Suggested change
if args.target_modules == "all-linear":
target_modules = ["linear_qkv", "linear_proj", "linear_fc1", "linear_fc2"]
elif "," in args.target_modules:
target_modules = [m.strip() for m in args.target_modules.split(",")]
else:
target_modules = [args.target_modules]
target_modules_arg = getattr(args, "target_modules", None)
# Default to standard Megatron linear modules if not specified
if not target_modules_arg:
target_modules = ["linear_qkv", "linear_proj", "linear_fc1", "linear_fc2"]
elif target_modules_arg == "all-linear":
target_modules = ["linear_qkv", "linear_proj", "linear_fc1", "linear_fc2"]
elif "," in target_modules_arg:
target_modules = [m.strip() for m in target_modules_arg.split(",")]
else:
target_modules = [target_modules_arg]

Copilot uses AI. Check for mistakes.
Comment on lines 519 to 523
if args.lora_rank > 0 or args.lora_adapter_path is not None:
kwargs["enable_lora"] = True
kwargs["max_lora_rank"] = args.lora_rank
kwargs["max_loras_per_batch"] = 1
kwargs["lora_target_modules"] = args.target_modules
Copy link

Copilot AI Jan 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The condition checks args.lora_rank > 0 or args.lora_adapter_path is not None, but lora_adapter_path is not validated or used in the configuration. This could lead to enabling LoRA without proper configuration when only a path is provided but no rank is set. Consider whether both conditions are necessary or if the logic needs adjustment.

Copilot uses AI. Check for mistakes.
return state_dict


def get_lora_weights_for_sglang(model: nn.Module, args) -> tuple[dict, dict]:
Copy link

Copilot AI Jan 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function get_lora_weights_for_sglang has a return type hint tuple[dict, dict] but uses Python 3.9+ syntax. For better compatibility, consider using Tuple[dict, dict] from typing module which is already imported at the top of the file.

Copilot uses AI. Check for mistakes.
Comment on lines 147 to 156
${MODEL_ARGS[@]} \
${CKPT_ARGS[@]} \
${ROLLOUT_ARGS[@]} \
${OPTIMIZER_ARGS[@]} \
${GRPO_ARGS[@]} \
${WANDB_ARGS[@]} \
${PERF_ARGS[@]} \
${EVAL_ARGS[@]} \
${SGLANG_ARGS[@]} \
${MISC_ARGS[@]}
Copy link

Copilot AI Jan 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The LORA_ARGS array is defined but never used in the ray job submission command. The script defines LORA_ARGS but doesn't include ${LORA_ARGS[@]} in the command invocation at the bottom, which means the LoRA configuration won't be passed to the training script.

Copilot uses AI. Check for mistakes.
Comment on lines 191 to 201
layer_idx, module_path, lora_type = match.groups()

# Map Megatron module names to HF names
megatron_to_hf = {
"self_attention.linear_qkv": ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"],
"self_attention.linear_proj": ["self_attn.o_proj"],
"mlp.linear_fc1": ["mlp.gate_proj", "mlp.up_proj"],
"mlp.linear_fc2": ["mlp.down_proj"],
}

lora_suffix = "lora_A.weight" if lora_type == "in" else "lora_B.weight"
Copy link

Copilot AI Jan 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inconsistent naming: the variable is named lora_type but contains values "in" and "out" which represent direction, not type. Consider renaming to lora_direction or adapter_direction for clarity.

Copilot uses AI. Check for mistakes.
# Import from Megatron-Bridge
from megatron.bridge.peft.lora import LoRA
from megatron.bridge.peft.adapter_wrapper import AdapterWrapper
from megatron.bridge.peft.utils import ParallelLinearAdapter
Copy link

Copilot AI Jan 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import of 'ParallelLinearAdapter' is not used.

Suggested change
from megatron.bridge.peft.utils import ParallelLinearAdapter

Copilot uses AI. Check for mistakes.
from megatron.bridge.peft.lora import LoRA
from megatron.bridge.peft.adapter_wrapper import AdapterWrapper
from megatron.bridge.peft.utils import ParallelLinearAdapter
from megatron.bridge.models.conversion.peft_bridge import MegatronPeftBridge
Copy link

Copilot AI Jan 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import of 'MegatronPeftBridge' is not used.

Suggested change
from megatron.bridge.models.conversion.peft_bridge import MegatronPeftBridge

Copilot uses AI. Check for mistakes.
@@ -0,0 +1,75 @@
from argparse import Namespace
Copy link

Copilot AI Jan 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import of 'Namespace' is not used.

Suggested change
from argparse import Namespace

Copilot uses AI. Check for mistakes.
@@ -0,0 +1,75 @@
from argparse import Namespace
from collections.abc import Sequence
Copy link

Copilot AI Jan 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import of 'Sequence' is not used.

Suggested change
from collections.abc import Sequence

Copilot uses AI. Check for mistakes.
module.module.decoder.layers.0.self_attention.linear_qkv.adapter.linear_in.weight
-> base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight
"""
import re
Copy link

Copilot AI Jan 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This import of module re is redundant, as it was previously imported on line 2.

Suggested change
import re

Copilot uses AI. Check for mistakes.
@yushengsu-thu yushengsu-thu changed the title [feat] miles lora megatron backend [WIP] [feat] miles lora megatron backend Jan 7, 2026
@yushengsu-thu yushengsu-thu marked this pull request as draft January 7, 2026 00:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant