diff --git a/docs/models/vlm/index.md b/docs/models/vlm/index.md index 8c030258f2..b7407793b1 100644 --- a/docs/models/vlm/index.md +++ b/docs/models/vlm/index.md @@ -11,4 +11,5 @@ ministral3.md nemotron-nano-v2-vl.md qwen2.5-vl.md qwen3-vl.md +qwen35-vl.md ``` diff --git a/docs/models/vlm/qwen35-vl.md b/docs/models/vlm/qwen35-vl.md new file mode 100644 index 0000000000..ba01d3384b --- /dev/null +++ b/docs/models/vlm/qwen35-vl.md @@ -0,0 +1,62 @@ +# Qwen 3.5 + +[Qwen3.5](https://huggingface.co/collections/Qwen/qwen35) is a family of vision-language models supporting multimodal understanding across text, images, and videos. Qwen3.5-VL includes both dense models and Mixture-of-Experts (MoE) variants for improved efficiency at scale. + +Qwen 3.5 models feature a hybrid architecture combining GDN (Gated DeltaNet) layers with standard attention layers, SwiGLU activations, and RMSNorm. MoE variants use top-k routing with shared experts for better quality. + +Qwen 3.5 models are supported via Megatron Bridge with auto-detected configuration and weight mapping. + +```{important} +Please upgrade to `transformers` >= 5.2.0 in order to use the Qwen 3.5 models. +``` + +## Available Models + +### Dense Models +- **Qwen3.5 0.8B** (`Qwen/Qwen3.5-0.8B`): 0.8B parameter vision-language model + - Recommended: 1 node, 8 GPUs + +- **Qwen3.5 2B** (`Qwen/Qwen3.5-2B`): 2B parameter vision-language model + - Recommended: 1 node, 8 GPUs + +- **Qwen3.5 4B** (`Qwen/Qwen3.5-4B`): 4B parameter vision-language model + - Recommended: 1 node, 8 GPUs + +- **Qwen3.5 9B** (`Qwen/Qwen3.5-9B`): 9B parameter vision-language model + - Recommended: 1 node, 8 GPUs + +- **Qwen3.5 27B** (`Qwen/Qwen3.5-27B`): 27B parameter vision-language model + - Recommended: 2 nodes, 16 GPUs + +### Mixture-of-Experts (MoE) Models +- **Qwen3.5 35B-A3B** (`Qwen/Qwen3.5-35B-A3B`): 35B total parameters, 3B activated per token + - Recommended: 2 nodes, 16 GPUs + +- **Qwen3.5 122B-A10B** (`Qwen/Qwen3.5-122B-A10B`): 122B total parameters, 10B activated per token + - Recommended: 4 nodes, 32 GPUs + +- **Qwen3.5 397B-A17B** (`Qwen/Qwen3.5-397B-A17B`): 397B total parameters, 17B activated per token + - 512 experts with top-10 routing and shared experts + - Recommended: 16 nodes, 128 GPUs + +## Examples + +For checkpoint conversion, inference, finetuning recipes, and step-by-step training guides, see the [Qwen 3.5 Examples](https://github.com/NVIDIA-NeMo/Megatron-Bridge/blob/main/examples/models/vlm/qwen35_vl/README.md). + +## Hugging Face Model Cards + +- Qwen3.5 0.8B: https://huggingface.co/Qwen/Qwen3.5-0.8B +- Qwen3.5 2B: https://huggingface.co/Qwen/Qwen3.5-2B +- Qwen3.5 4B: https://huggingface.co/Qwen/Qwen3.5-4B +- Qwen3.5 9B: https://huggingface.co/Qwen/Qwen3.5-9B +- Qwen3.5 27B: https://huggingface.co/Qwen/Qwen3.5-27B +- Qwen3.5 35B-A3B (MoE): https://huggingface.co/Qwen/Qwen3.5-35B-A3B +- Qwen3.5 122B-A10B (MoE): https://huggingface.co/Qwen/Qwen3.5-122B-A10B +- Qwen3.5 397B-A17B (MoE): https://huggingface.co/Qwen/Qwen3.5-397B-A17B + +## Related Docs +- Related VLM: [Qwen3-VL](qwen3-vl.md) +- Related LLM: [Qwen](../llm/qwen.md) +- Recipe usage: [Recipe usage](../../recipe-usage.md) +- Customizing the training recipe configuration: [Configuration overview](../../training/config-container-overview.md) +- Training entry points: [Entry points](../../training/entry-points.md) diff --git a/examples/models/vlm/qwen35_vl/README.md b/examples/models/vlm/qwen35_vl/README.md new file mode 100644 index 0000000000..23f5929aab --- /dev/null +++ b/examples/models/vlm/qwen35_vl/README.md @@ -0,0 +1,119 @@ +# Qwen3.5-VL Examples + +This directory contains example scripts for Qwen3.5-VL vision-language models. + +For model introduction and architecture details, see the [Qwen3.5-VL documentation](../../../../docs/models/vlm/qwen35-vl.md). + +## Workspace Configuration + +All scripts use a `WORKSPACE` environment variable to define the base directory for checkpoints and results. By default, this is set to `/workspace`. You can override it: + +```bash +export WORKSPACE=/your/custom/path +``` + +Directory structure: +- `${WORKSPACE}/models/` - Converted checkpoints +- `${WORKSPACE}/results/` - Training outputs and experiment results + +## Checkpoint Conversion + +### Import HF → Megatron +To import the HF VL model to your desired Megatron path: +```bash +python examples/conversion/convert_checkpoints.py import \ + --hf-model Qwen/Qwen3.5-35B-A3B \ + --megatron-path ${WORKSPACE}/models/Qwen/Qwen3.5-35B-A3B +``` + +### Export Megatron → HF +```bash +python examples/conversion/convert_checkpoints.py export \ + --hf-model Qwen/Qwen3.5-35B-A3B \ + --megatron-path ${WORKSPACE}/models/Qwen/Qwen3.5-35B-A3B/iter_0000000 \ + --hf-path ${WORKSPACE}/models/Qwen/Qwen3.5-35B-A3B-hf-export +``` + +See the [conversion.sh](conversion.sh) script for more examples including multi-GPU round-trip validation. + +## Inference + +### Run Inference on Converted Checkpoint + +```bash +python -m torch.distributed.run --nproc_per_node=8 examples/conversion/hf_to_megatron_generate_vlm.py \ + --hf_model_path Qwen/Qwen3.5-35B-A3B \ + --megatron_model_path ${WORKSPACE}/models/Qwen/Qwen3.5-35B-A3B/iter_0000000 \ + --image_path "https://huggingface.co/nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-BF16/resolve/main/images/table.png" \ + --prompt "Describe this image." \ + --max_new_tokens 100 \ + --tp 2 --pp 2 --ep 4 +``` + +Note: +- `--megatron_model_path` is optional. If not specified, the script will convert the model and then run forward. +- You can also use image URLs: `--image_path="https://example.com/image.jpg"` +- For MoE models, set `--ep` to the desired expert parallelism degree. + +See the [inference.sh](inference.sh) script for commands to: +- Run inference with Hugging Face checkpoints +- Run inference with imported Megatron checkpoints +- Run inference with exported Hugging Face checkpoints + +For multi-node distributed inference—required for the largest 397B model—see the [slurm_inference.sh](slurm_inference.sh) script. + +## Finetune Recipes + +- Available recipes: + - `qwen35_vl_800m_sft_config` / `qwen35_vl_800m_peft_config`: 0.8B dense model + - `qwen35_vl_2b_sft_config` / `qwen35_vl_2b_peft_config`: 2B dense model + - `qwen35_vl_4b_sft_config` / `qwen35_vl_4b_peft_config`: 4B dense model + - `qwen35_vl_9b_sft_config` / `qwen35_vl_9b_peft_config`: 9B dense model + - `qwen35_vl_27b_sft_config` / `qwen35_vl_27b_peft_config`: 27B dense model + - `qwen35_vl_35b_a3b_sft_config` / `qwen35_vl_35b_a3b_peft_config`: 35B-A3B MoE model + - `qwen35_vl_122b_a10b_sft_config` / `qwen35_vl_122b_a10b_peft_config`: 122B-A10B MoE model + - `qwen35_vl_397b_a17b_sft_config` / `qwen35_vl_397b_a17b_peft_config`: 397B-A17B MoE model + +Before training, ensure the following environment variables are set: +1. `SAVE_DIR`: checkpoint and log saving directory +2. `HF_TOKEN`: to download models from HF Hub (if required) +3. `HF_HOME`: (optional) to avoid re-downloading models and datasets +4. `WANDB_API_KEY`: (optional) to enable WandB logging + +### Pretrain + +Pretraining is not verified for this model. + +### Supervised Fine-Tuning (SFT) + +See the [slurm_sft.sh](slurm_sft.sh) script for full parameter fine-tuning with configurable model sizes. + +### Parameter-Efficient Fine-Tuning (PEFT) with LoRA + +See the [slurm_peft.sh](slurm_peft.sh) script for LoRA fine-tuning with configurable model sizes. + +### Multi-Token Prediction (MTP) + +All Qwen3.5 models are trained with Multi-Token Prediction (`mtp_num_hidden_layers=1` in the HuggingFace config). MTP adds an auxiliary loss that predicts the next-next token alongside the standard next-token prediction, improving training quality. + +MTP is **enabled by default** in all recipes. The MTP layer uses standard attention (not GDN) and the same MLP architecture as the main decoder (dense MLP for dense models, MoE for MoE models). The MTP loss is scaled by `mtp_loss_scaling_factor=0.1` relative to the main LM loss. + +**Finetune with MTP** (default): +```python +cfg.model.mtp_num_layers = 1 +cfg.model.mtp_loss_scaling_factor = 0.1 +``` + +**Finetune without MTP** (discard MTP weights, standard LM loss only): +```python +cfg.model.mtp_num_layers = None +``` + +When converting checkpoints, MTP weights are included by default. Setting `mtp_num_layers = None` skips MTP weight conversion and removes the MTP auxiliary loss during training. + +### Expected Training Dynamics +We provide a [Weights & Biases report](https://api.wandb.ai/links/nvidia-nemo-fw-public/rt6uzrvf) for the expected loss curves and grad norms. + +## Evaluation + +Coming soon. diff --git a/examples/models/vlm/qwen35_vl/conversion.sh b/examples/models/vlm/qwen35_vl/conversion.sh index b7bcd54ad3..e24d6b1c3c 100755 --- a/examples/models/vlm/qwen35_vl/conversion.sh +++ b/examples/models/vlm/qwen35_vl/conversion.sh @@ -12,15 +12,27 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +set -e # Workspace directory for checkpoints and results WORKSPACE=${WORKSPACE:-/workspace} -MODEL_NAME=Qwen3.5-35B-A3B # Qwen3.5-35B-A3B, Qwen3.5-122B-A10B, Qwen3.5-397B-A17B, Qwen3.5-27B +# Supported model variants are: +# Qwen3.5-0.8B, Qwen3.5-2B, Qwen3.5-4B, Qwen3.5-9B, Qwen3.5-27B, Qwen3.5-35B-A3B, Qwen3.5-122B-A10B, Qwen3.5-397B-A17B +MODEL_NAME=Qwen3.5-35B-A3B -if [ "${MODEL_NAME}" = "Qwen3.5-27B" ]; then +if [ "${MODEL_NAME}" = "Qwen3.5-0.8B" ] || [ "${MODEL_NAME}" = "Qwen3.5-2B" ] || [ "${MODEL_NAME}" = "Qwen3.5-4B" ] || [ "${MODEL_NAME}" = "Qwen3.5-9B" ] || [ "${MODEL_NAME}" = "Qwen3.5-27B" ]; then HF_MODEL_CLASS="Qwen3_5ForConditionalGeneration" -else + EP=1 + PP=8 + TP=1 +elif [ "${MODEL_NAME}" = "Qwen3.5-35B-A3B" ] || [ "${MODEL_NAME}" = "Qwen3.5-122B-A10B" ] || [ "${MODEL_NAME}" = "Qwen3.5-397B-A17B" ]; then HF_MODEL_CLASS="Qwen3_5MoeForConditionalGeneration" + EP=8 + PP=1 + TP=1 +else + echo "Unsupported model variant: ${MODEL_NAME}" + exit 1 fi # Make sure to upgrade to transformers >= 5.2.0 @@ -39,7 +51,7 @@ uv run python -m torch.distributed.run --nproc_per_node=8 examples/conversion/co --model_class "${HF_MODEL_CLASS}" \ --image_path "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg" \ --prompt "Describe this image." \ - --tp 1 --pp 1 --ep 8 + --tp ${TP} --pp ${PP} --ep ${EP} # Export Megatron → HF uv run python examples/conversion/convert_checkpoints.py export \ @@ -49,4 +61,4 @@ uv run python examples/conversion/convert_checkpoints.py export \ # Round-trip validation uv run python -m torch.distributed.run --nproc_per_node=8 examples/conversion/hf_megatron_roundtrip_multi_gpu.py \ - --hf-model-id Qwen/${MODEL_NAME} --tp 1 --pp 2 --ep 4 --trust-remote-code + --hf-model-id Qwen/${MODEL_NAME} --tp ${TP} --pp ${PP} --ep ${EP} diff --git a/examples/models/vlm/qwen35_vl/inference.sh b/examples/models/vlm/qwen35_vl/inference.sh index 17cfbec635..41e80720d7 100755 --- a/examples/models/vlm/qwen35_vl/inference.sh +++ b/examples/models/vlm/qwen35_vl/inference.sh @@ -13,9 +13,29 @@ # See the License for the specific language governing permissions and # limitations under the License. +set -e + # Workspace directory for checkpoints and results WORKSPACE=${WORKSPACE:-/workspace} -MODEL_NAME=Qwen3.5-35B-A3B # Qwen3.5-35B-A3B, Qwen3.5-122B-A10B, Qwen3.5-27B +# Set the model name to any of the supported dense or MoE Qwen3.5-VL models: +# Dense: Qwen3.5-0.8B, Qwen3.5-2B, Qwen3.5-4B, Qwen3.5-9B, Qwen3.5-27B +# MoE: Qwen3.5-35B-A3B, Qwen3.5-122B-A10B, Qwen3.5-397B-A17B +# For Qwen3.5-397B-A17B, please use the slurm_inference.sh script for multinode inference. +MODEL_NAME=Qwen3.5-35B-A3B + +# Set EP (Expert Parallelism) to 1 for dense models, 4 for MoE models +case "$MODEL_NAME" in + Qwen3.5-0.8B|Qwen3.5-2B|Qwen3.5-4B|Qwen3.5-9B|Qwen3.5-27B) + EP=1 + ;; + Qwen3.5-35B-A3B|Qwen3.5-122B-A10B|Qwen3.5-397B-A17B) + EP=4 + ;; + *) + echo "ERROR: Unknown model type for \$MODEL_NAME: $MODEL_NAME" + exit 1 + ;; +esac # Inference with Hugging Face checkpoints uv run python -m torch.distributed.run --nproc_per_node=8 examples/conversion/hf_to_megatron_generate_vlm.py \ @@ -23,7 +43,7 @@ uv run python -m torch.distributed.run --nproc_per_node=8 examples/conversion/hf --image_path "https://huggingface.co/nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-BF16/resolve/main/images/table.png" \ --prompt "Describe this image." \ --max_new_tokens 50 \ - --tp 2 --pp 2 --ep 4 + --tp 2 --pp 2 --ep ${EP} # Inference with imported Megatron checkpoints uv run python -m torch.distributed.run --nproc_per_node=8 examples/conversion/hf_to_megatron_generate_vlm.py \ @@ -32,7 +52,7 @@ uv run python -m torch.distributed.run --nproc_per_node=8 examples/conversion/hf --image_path "https://huggingface.co/nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-BF16/resolve/main/images/table.png" \ --prompt "Describe this image." \ --max_new_tokens 50 \ - --tp 2 --pp 2 --ep 4 + --tp 2 --pp 2 --ep ${EP} # Inference with exported HF checkpoints uv run python -m torch.distributed.run --nproc_per_node=8 examples/conversion/hf_to_megatron_generate_vlm.py \ @@ -40,4 +60,4 @@ uv run python -m torch.distributed.run --nproc_per_node=8 examples/conversion/hf --image_path "https://huggingface.co/nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-BF16/resolve/main/images/table.png" \ --prompt "Describe this image." \ --max_new_tokens 50 \ - --tp 2 --pp 2 --ep 4 + --tp 2 --pp 2 --ep ${EP} diff --git a/examples/models/vlm/qwen35_vl/slurm_peft.sh b/examples/models/vlm/qwen35_vl/slurm_peft.sh new file mode 100755 index 0000000000..9de0b3641b --- /dev/null +++ b/examples/models/vlm/qwen35_vl/slurm_peft.sh @@ -0,0 +1,195 @@ +#!/bin/bash +set -euo pipefail +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ============================================================================== +# Qwen3.5 VL Parameter-Efficient Fine-Tuning (PEFT) with LoRA +# +# Supports all Qwen3.5 VL models (dense and MoE). +# LoRA/DoRA significantly reduces memory requirements. +# +# Usage: +# sbatch slurm_peft.sh +# +# model: 0.8B | 2B | 4B | 9B | 27B | 35B-A3B | 122B-A10B | 397B-A17B +# +# Recommended parallelism (recipe defaults for LoRA): +# 0.8B (dense): TP=1, PP=1 (1 node) +# 2B (dense): TP=1, PP=1 (1 node) +# 4B (dense): TP=1, PP=1 (1 node) +# 9B (dense): TP=2, PP=1 (1 node) +# 27B (dense): TP=2, PP=1 (1 node) +# 35B-A3B (MoE): TP=2, PP=1, EP=4 (1 node) +# 122B-A10B (MoE): TP=2, PP=1, EP=8 (1 node) +# 397B-A17B (MoE): TP=2, PP=1, EP=32 (4 nodes) +# +# Examples: +# sbatch slurm_peft.sh 4B +# sbatch --nodes=4 slurm_peft.sh 397B-A17B +# ============================================================================== + +#SBATCH --job-name=qwen35vl-lora +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=8 +#SBATCH --gpus-per-node=8 +#SBATCH --time=08:00:00 +#SBATCH --partition=gpu +#SBATCH --account=my_account +#SBATCH --output=qwen35vl_lora_%j.out +#SBATCH --error=qwen35vl_lora_%j.err +#SBATCH --exclusive + +# ============================================================================== +# Parse arguments +# ============================================================================== + +MODEL_SIZE="${1:?Usage: sbatch $0 (model: 0.8B|2B|4B|9B|27B|35B-A3B|122B-A10B|397B-A17B)}" + +# Map model size to HF name and recipe +case "$MODEL_SIZE" in + 0.8B) + HF_MODEL_NAME="Qwen3.5-0.8B" + RECIPE="qwen35_vl_800m_peft_config" + ;; + 2B) + HF_MODEL_NAME="Qwen3.5-2B" + RECIPE="qwen35_vl_2b_peft_config" + ;; + 4B) + HF_MODEL_NAME="Qwen3.5-4B" + RECIPE="qwen35_vl_4b_peft_config" + ;; + 9B) + HF_MODEL_NAME="Qwen3.5-9B" + RECIPE="qwen35_vl_9b_peft_config" + ;; + 27B) + HF_MODEL_NAME="Qwen3.5-27B" + RECIPE="qwen35_vl_27b_peft_config" + ;; + 35B-A3B) + HF_MODEL_NAME="Qwen3.5-35B-A3B" + RECIPE="qwen35_vl_35b_a3b_peft_config" + ;; + 122B-A10B) + HF_MODEL_NAME="Qwen3.5-122B-A10B" + RECIPE="qwen35_vl_122b_a10b_peft_config" + ;; + 397B-A17B) + HF_MODEL_NAME="Qwen3.5-397B-A17B" + RECIPE="qwen35_vl_397b_a17b_peft_config" + ;; + *) + echo "ERROR: Unknown model '$MODEL_SIZE'. Must be one of: 0.8B, 2B, 4B, 9B, 27B, 35B-A3B, 122B-A10B, 397B-A17B" + exit 1 + ;; +esac + +# ============================================================================== +# CONFIGURATION +# ============================================================================== + +WORKSPACE=${WORKSPACE:-/workspace} + +PRETRAINED_CHECKPOINT=${WORKSPACE}/models/Qwen/${HF_MODEL_NAME} +DATASET_NAME=cord_v2 +SEQ_LENGTH=4096 +TRAIN_ITERS=500 +GLOBAL_BATCH_SIZE=32 +MICRO_BATCH_SIZE=1 +EVAL_ITERS=10 +LOG_INTERVAL=1 +WANDB_PROJECT=megatron-bridge-${DATASET_NAME} + +# Container image (required) +CONTAINER_IMAGE="" +# CONTAINER_IMAGE="/path/to/container.sqsh" + +# Container mounts (optional, space-separated) +CONTAINER_MOUNTS="" +# CONTAINER_MOUNTS="/data:/data /workspace:/workspace" + +# ============================================================================== +# Environment Setup +# ============================================================================== + +export TORCH_NCCL_AVOID_RECORD_STREAMS=1 +export NCCL_NVLS_ENABLE=0 + +# export UV_CACHE_DIR="/path/to/shared/uv_cache" +# export HF_HOME="/path/to/shared/HF_HOME" +# export HF_TOKEN="hf_your_token_here" +# export WANDB_API_KEY="your_wandb_key_here" +# export WANDB_MODE=disabled + +# ============================================================================== +# Job Execution +# ============================================================================== + +echo "======================================" +echo "Qwen3.5-VL LoRA Fine-Tuning Job" +echo "======================================" +echo "Job ID: $SLURM_JOB_ID" +echo "Nodes: $SLURM_JOB_NUM_NODES" +echo "GPUs per node: $SLURM_GPUS_PER_NODE" +echo "Model: $HF_MODEL_NAME" +echo "Recipe: $RECIPE" +echo "PEFT: LoRA" +echo "Checkpoint: $PRETRAINED_CHECKPOINT" +echo "======================================" + +CLI_OVERRIDES="\ + checkpoint.pretrained_checkpoint=$PRETRAINED_CHECKPOINT \ + model.seq_length=$SEQ_LENGTH \ + train.train_iters=$TRAIN_ITERS \ + train.global_batch_size=$GLOBAL_BATCH_SIZE \ + train.micro_batch_size=$MICRO_BATCH_SIZE \ + checkpoint.save=${WORKSPACE}/results/${RECIPE}_lora \ + logger.log_interval=$LOG_INTERVAL \ + logger.wandb_project=$WANDB_PROJECT \ + logger.wandb_exp_name=${RECIPE}_${DATASET_NAME}_lora \ + dataset.maker_name=make_${DATASET_NAME}_dataset \ + dataset.seq_length=$SEQ_LENGTH" + +# For multinode runs, the recipe's online HF path can be unstable. Pass --hf_path +# with a local model directory for more reliable config loading, e.g.: +# --hf_path ${WORKSPACE}/models/Qwen/${HF_MODEL_NAME} +CMD="uv run --no-sync python scripts/training/run_recipe.py \ + --recipe $RECIPE \ + --step_func vlm_step \ + --peft_scheme lora \ + $CLI_OVERRIDES" + +echo "Executing command..." +echo "======================================" + +if [ -z "$CONTAINER_IMAGE" ]; then + echo "ERROR: CONTAINER_IMAGE must be set. Please specify a valid container image." + exit 1 +fi + +SRUN_CMD="srun --mpi=pmix --container-image=$CONTAINER_IMAGE" + +if [ -n "$CONTAINER_MOUNTS" ]; then + for mount in $CONTAINER_MOUNTS; do + SRUN_CMD="$SRUN_CMD --container-mounts=$mount" + done +fi + +$SRUN_CMD bash -c "$CMD" + +echo "======================================" +echo "Job completed" +echo "======================================" diff --git a/examples/models/vlm/qwen35_vl/slurm_sft.sh b/examples/models/vlm/qwen35_vl/slurm_sft.sh new file mode 100644 index 0000000000..0ea7a67609 --- /dev/null +++ b/examples/models/vlm/qwen35_vl/slurm_sft.sh @@ -0,0 +1,195 @@ +#!/bin/bash +set -euo pipefail +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ============================================================================== +# Qwen3.5 VL Full Supervised Fine-Tuning (SFT) +# +# Supports all Qwen3.5 VL models (dense and MoE). +# For smaller setups, use LoRA/DoRA instead (see slurm_peft.sh). +# +# Usage: +# sbatch slurm_sft.sh +# +# model: 0.8B | 2B | 4B | 9B | 27B | 35B-A3B | 122B-A10B | 397B-A17B +# +# Recommended parallelism (recipe defaults for full SFT): +# 0.8B (dense): TP=1, PP=1 (1 node) +# 2B (dense): TP=1, PP=1 (1 node) +# 4B (dense): TP=2, PP=1 (1 node) +# 9B (dense): TP=4, PP=1 (1 node) +# 27B (dense): TP=4, PP=4 (2 nodes) +# 35B-A3B (MoE): TP=2, PP=1, EP=16 (2 nodes) +# 122B-A10B (MoE): TP=2, PP=6, EP=8 (4 nodes) +# 397B-A17B (MoE): TP=2, PP=4, EP=32 (16 nodes) +# +# Examples: +# sbatch slurm_sft.sh 4B +# sbatch --nodes=2 slurm_sft.sh 27B +# sbatch --nodes=16 slurm_sft.sh 397B-A17B +# ============================================================================== + +#SBATCH --job-name=qwen35vl-sft +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=8 +#SBATCH --gpus-per-node=8 +#SBATCH --time=24:00:00 +#SBATCH --partition=gpu +#SBATCH --account=my_account +#SBATCH --output=qwen35vl_sft_%j.out +#SBATCH --error=qwen35vl_sft_%j.err +#SBATCH --exclusive + +# ============================================================================== +# Parse arguments +# ============================================================================== + +MODEL_SIZE="${1:?Usage: sbatch $0 (model: 0.8B|2B|4B|9B|27B|35B-A3B|122B-A10B|397B-A17B)}" + +# Map model size to HF name and recipe +case "$MODEL_SIZE" in + 0.8B) + HF_MODEL_NAME="Qwen3.5-0.8B" + RECIPE="qwen35_vl_800m_sft_config" + ;; + 2B) + HF_MODEL_NAME="Qwen3.5-2B" + RECIPE="qwen35_vl_2b_sft_config" + ;; + 4B) + HF_MODEL_NAME="Qwen3.5-4B" + RECIPE="qwen35_vl_4b_sft_config" + ;; + 9B) + HF_MODEL_NAME="Qwen3.5-9B" + RECIPE="qwen35_vl_9b_sft_config" + ;; + 27B) + HF_MODEL_NAME="Qwen3.5-27B" + RECIPE="qwen35_vl_27b_sft_config" + ;; + 35B-A3B) + HF_MODEL_NAME="Qwen3.5-35B-A3B" + RECIPE="qwen35_vl_35b_a3b_sft_config" + ;; + 122B-A10B) + HF_MODEL_NAME="Qwen3.5-122B-A10B" + RECIPE="qwen35_vl_122b_a10b_sft_config" + ;; + 397B-A17B) + HF_MODEL_NAME="Qwen3.5-397B-A17B" + RECIPE="qwen35_vl_397b_a17b_sft_config" + ;; + *) + echo "ERROR: Unknown model '$MODEL_SIZE'. Must be one of: 0.8B, 2B, 4B, 9B, 27B, 35B-A3B, 122B-A10B, 397B-A17B" + exit 1 + ;; +esac + +# ============================================================================== +# CONFIGURATION +# ============================================================================== + +WORKSPACE=${WORKSPACE:-/workspace} + +PRETRAINED_CHECKPOINT=${WORKSPACE}/models/Qwen/${HF_MODEL_NAME} +DATASET_NAME=cord_v2 +SEQ_LENGTH=4096 +TRAIN_ITERS=500 +GLOBAL_BATCH_SIZE=32 +MICRO_BATCH_SIZE=1 +EVAL_ITERS=10 +LOG_INTERVAL=1 +WANDB_PROJECT=megatron-bridge-${DATASET_NAME} + +# Container image (required) +CONTAINER_IMAGE="" +# CONTAINER_IMAGE="/path/to/container.sqsh" + +# Container mounts (optional, space-separated) +CONTAINER_MOUNTS="" +# CONTAINER_MOUNTS="/data:/data /workspace:/workspace" + +# ============================================================================== +# Environment Setup +# ============================================================================== + +export TORCH_NCCL_AVOID_RECORD_STREAMS=1 +export NCCL_NVLS_ENABLE=0 + +# export UV_CACHE_DIR="/path/to/shared/uv_cache" +# export HF_HOME="/path/to/shared/HF_HOME" +# export HF_TOKEN="hf_your_token_here" +# export WANDB_API_KEY="your_wandb_key_here" +# export WANDB_MODE=disabled + +# ============================================================================== +# Job Execution +# ============================================================================== + +echo "======================================" +echo "Qwen3.5-VL Full SFT Training Job" +echo "======================================" +echo "Job ID: $SLURM_JOB_ID" +echo "Nodes: $SLURM_JOB_NUM_NODES" +echo "GPUs per node: $SLURM_GPUS_PER_NODE" +echo "Total GPUs: $((SLURM_JOB_NUM_NODES * SLURM_GPUS_PER_NODE))" +echo "Model: $HF_MODEL_NAME" +echo "Recipe: $RECIPE" +echo "Checkpoint: $PRETRAINED_CHECKPOINT" +echo "======================================" + +CLI_OVERRIDES="\ + checkpoint.pretrained_checkpoint=$PRETRAINED_CHECKPOINT \ + model.seq_length=$SEQ_LENGTH \ + train.train_iters=$TRAIN_ITERS \ + train.global_batch_size=$GLOBAL_BATCH_SIZE \ + train.micro_batch_size=$MICRO_BATCH_SIZE \ + checkpoint.save=${WORKSPACE}/results/${RECIPE}_sft \ + logger.log_interval=$LOG_INTERVAL \ + logger.wandb_project=$WANDB_PROJECT \ + logger.wandb_exp_name=${RECIPE}_${DATASET_NAME}_sft \ + dataset.maker_name=make_${DATASET_NAME}_dataset \ + dataset.seq_length=$SEQ_LENGTH" + +# For multinode runs, the recipe's online HF path can be unstable. Pass --hf_path +# with a local model directory for more reliable config loading, e.g.: +# --hf_path ${WORKSPACE}/models/Qwen/${HF_MODEL_NAME} +CMD="uv run --no-sync python scripts/training/run_recipe.py \ + --recipe $RECIPE \ + --step_func vlm_step \ + $CLI_OVERRIDES" + +echo "Executing command..." +echo "======================================" + +if [ -z "$CONTAINER_IMAGE" ]; then + echo "ERROR: CONTAINER_IMAGE must be set. Please specify a valid container image." + exit 1 +fi + +SRUN_CMD="srun --mpi=pmix --container-image=$CONTAINER_IMAGE" + +if [ -n "$CONTAINER_MOUNTS" ]; then + for mount in $CONTAINER_MOUNTS; do + SRUN_CMD="$SRUN_CMD --container-mounts=$mount" + done +fi + +$SRUN_CMD bash -c "$CMD" + +echo "======================================" +echo "Job completed" +echo "======================================" diff --git a/scripts/training/run_recipe.py b/scripts/training/run_recipe.py index 4b923f5907..e927143f35 100755 --- a/scripts/training/run_recipe.py +++ b/scripts/training/run_recipe.py @@ -139,6 +139,13 @@ def parse_args() -> tuple[argparse.Namespace, list[str]]: default=None, help="Dataset type for VLM recipes (e.g., 'energon', 'mock', 'hf', 'preloaded').", ) + parser.add_argument( + "--hf_path", + type=str, + default=None, + help="HuggingFace model ID or local path to model directory. " + "Use a local path for more stable multinode training.", + ) args, cli_overrides = parser.parse_known_args() return args, cli_overrides @@ -149,6 +156,7 @@ def load_recipe( packed_sequence: bool = False, seq_length: int | None = None, dataset_type: str | None = None, + hf_path: str | None = None, ) -> ConfigContainer: """ Load recipe by name from megatron.bridge.recipes. @@ -159,6 +167,7 @@ def load_recipe( packed_sequence: Enable packed sequence training (default: False) seq_length: Sequence length for training (optional) dataset_type: Dataset type for VLM recipes (e.g., 'energon', 'mock', 'hf', 'preloaded') + hf_path: HuggingFace model ID or local path to model directory (optional) Returns: ConfigContainer from calling the recipe @@ -185,12 +194,14 @@ def load_recipe( accepts_packed_sequence = "packed_sequence" in params or has_var_keyword accepts_seq_length = "seq_length" in params or has_var_keyword accepts_dataset_type = "dataset_type" in params or has_var_keyword + accepts_hf_path = "hf_path" in params or has_var_keyword except (ValueError, TypeError): # If signature inspection fails, fallback conservatively accepts_peft = True # peft is widely supported, try passing it accepts_packed_sequence = False # new parameter, don't pass if unsure accepts_seq_length = False # new parameter, don't pass if unsure accepts_dataset_type = False # VLM-specific, don't pass if unsure + accepts_hf_path = False # model-specific, don't pass if unsure # Build kwargs dynamically based on what the recipe accepts kwargs = {} @@ -202,6 +213,8 @@ def load_recipe( kwargs["seq_length"] = seq_length if accepts_dataset_type and dataset_type is not None: kwargs["dataset_type"] = dataset_type + if accepts_hf_path and hf_path is not None: + kwargs["hf_path"] = hf_path try: return config_builder(**kwargs) @@ -238,6 +251,7 @@ def main() -> None: args.packed_sequence, args.seq_length, args.dataset_type, + args.hf_path, ) config = process_config_with_overrides( diff --git a/src/megatron/bridge/models/qwen_vl/qwen35_vl_bridge.py b/src/megatron/bridge/models/qwen_vl/qwen35_vl_bridge.py index fd26808bc5..a2c9feb136 100644 --- a/src/megatron/bridge/models/qwen_vl/qwen35_vl_bridge.py +++ b/src/megatron/bridge/models/qwen_vl/qwen35_vl_bridge.py @@ -118,6 +118,9 @@ def provider_bridge(self, hf_pretrained: PreTrainedVLM) -> Qwen35VLMoEModelProvi provider = Qwen35VLMoEModelProvider(**provider_kwargs) + # For VLMs, tie_word_embeddings lives on the top-level config, not text_config. + provider.share_embeddings_and_output_weights = getattr(hf_config, "tie_word_embeddings", False) + # --- Common Qwen3 LLM settings --- provider.normalization = "RMSNorm" provider.gated_linear_unit = True @@ -438,6 +441,10 @@ def provider_bridge(self, hf_pretrained: PreTrainedVLM) -> Qwen35VLModelProvider provider = Qwen35VLModelProvider(**provider_kwargs) + # For VLMs, tie_word_embeddings lives on the top-level config, not text_config. + # text_config inherits PretrainedConfig's default of True which is wrong for 9B/27B. + provider.share_embeddings_and_output_weights = getattr(hf_config, "tie_word_embeddings", False) + # --- Common Qwen3 LLM settings --- provider.normalization = "RMSNorm" provider.gated_linear_unit = True diff --git a/src/megatron/bridge/models/qwen_vl/qwen35_vl_provider.py b/src/megatron/bridge/models/qwen_vl/qwen35_vl_provider.py index cac4042a24..458a4e81cb 100644 --- a/src/megatron/bridge/models/qwen_vl/qwen35_vl_provider.py +++ b/src/megatron/bridge/models/qwen_vl/qwen35_vl_provider.py @@ -179,11 +179,20 @@ def __post_init__(self): _check_qwen3_5_available() if self.vision_config is None: self.vision_config = Qwen3_5VisionConfig() + self.validate_parallelism() + super().__post_init__() + + def validate_parallelism(self): + """Validate that parallelism settings are compatible with this model's architecture. + + Call this after mutating parallelism attributes (e.g. tensor_model_parallel_size) + on an already-constructed provider, since __post_init__ only runs at construction time. + """ if self.num_query_groups < self.tensor_model_parallel_size: raise ValueError( - f"TP size {self.tensor_model_parallel_size} should be less than or equal to num_query_groups {self.num_query_groups}. Please use a smaller TP size." + f"TP size {self.tensor_model_parallel_size} should be less than or equal to " + f"num_query_groups {self.num_query_groups}. Please use a smaller TP size." ) - super().__post_init__() def provide(self, pre_process=None, post_process=None, vp_stage=None) -> Qwen3VLModel: """Provide a Qwen3.5 VL dense model instance with vision and language components.""" @@ -346,11 +355,20 @@ def __post_init__(self): _check_qwen3_5_moe_available() if self.vision_config is None: self.vision_config = Qwen3_5MoeVisionConfig() + self.validate_parallelism() + super().__post_init__() + + def validate_parallelism(self): + """Validate that parallelism settings are compatible with this model's architecture. + + Call this after mutating parallelism attributes (e.g. tensor_model_parallel_size) + on an already-constructed provider, since __post_init__ only runs at construction time. + """ if self.num_query_groups < self.tensor_model_parallel_size: raise ValueError( - f"TP size {self.tensor_model_parallel_size} should be less than or equal to num_query_groups {self.num_query_groups}. Please use a smaller TP size." + f"TP size {self.tensor_model_parallel_size} should be less than or equal to " + f"num_query_groups {self.num_query_groups}. Please use a smaller TP size." ) - super().__post_init__() def provide(self, pre_process=None, post_process=None, vp_stage=None) -> Qwen3VLModel: """Provide a Qwen3.5 VL model instance with vision and language components. diff --git a/src/megatron/bridge/recipes/qwen_vl/__init__.py b/src/megatron/bridge/recipes/qwen_vl/__init__.py index 4d89c691c3..35ef162de1 100644 --- a/src/megatron/bridge/recipes/qwen_vl/__init__.py +++ b/src/megatron/bridge/recipes/qwen_vl/__init__.py @@ -33,8 +33,48 @@ qwen25_vl_72b_sft_config, ) +# Qwen3.5 models +from .qwen35_vl import ( + qwen35_vl_2b_peft_config, + qwen35_vl_2b_sft_config, + qwen35_vl_4b_peft_config, + qwen35_vl_4b_sft_config, + qwen35_vl_9b_peft_config, + qwen35_vl_9b_sft_config, + qwen35_vl_27b_peft_config, + qwen35_vl_27b_sft_config, + qwen35_vl_35b_a3b_peft_config, + qwen35_vl_35b_a3b_sft_config, + qwen35_vl_122b_a10b_peft_config, + qwen35_vl_122b_a10b_sft_config, + qwen35_vl_397b_a17b_peft_config, + qwen35_vl_397b_a17b_sft_config, + qwen35_vl_800m_peft_config, + qwen35_vl_800m_sft_config, +) + __all__ = [ + # Qwen3.5-VL SFT configs — dense + "qwen35_vl_800m_sft_config", + "qwen35_vl_2b_sft_config", + "qwen35_vl_4b_sft_config", + "qwen35_vl_9b_sft_config", + "qwen35_vl_27b_sft_config", + # Qwen3.5-VL SFT configs — MoE + "qwen35_vl_35b_a3b_sft_config", + "qwen35_vl_122b_a10b_sft_config", + "qwen35_vl_397b_a17b_sft_config", + # Qwen3.5-VL PEFT configs — dense + "qwen35_vl_800m_peft_config", + "qwen35_vl_2b_peft_config", + "qwen35_vl_4b_peft_config", + "qwen35_vl_9b_peft_config", + "qwen35_vl_27b_peft_config", + # Qwen3.5-VL PEFT configs — MoE + "qwen35_vl_35b_a3b_peft_config", + "qwen35_vl_122b_a10b_peft_config", + "qwen35_vl_397b_a17b_peft_config", # Qwen2.5-VL SFT configs "qwen25_vl_3b_sft_config", "qwen25_vl_7b_sft_config", diff --git a/src/megatron/bridge/recipes/qwen_vl/qwen35_vl.py b/src/megatron/bridge/recipes/qwen_vl/qwen35_vl.py new file mode 100644 index 0000000000..c0e9b2b89e --- /dev/null +++ b/src/megatron/bridge/recipes/qwen_vl/qwen35_vl.py @@ -0,0 +1,1771 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Qwen3.5-VL finetuning recipes. + +This module provides SFT and PEFT configurations for Qwen3.5-VL models: + +- **Dense**: 800M, 2B, 4B, 9B, 27B +- **MoE**: 35B-A3B, 122B-A10B, 397B-A17B +""" + +import torch + +from megatron.bridge import AutoBridge +from megatron.bridge.peft.base import PEFT +from megatron.bridge.recipes.common import _peft_common_vlm, _sft_common_vlm +from megatron.bridge.recipes.utils.finetune_utils import default_peft_config +from megatron.bridge.recipes.utils.optimizer_utils import distributed_fused_adam_with_cosine_annealing +from megatron.bridge.training.config import ConfigContainer + + +# ============================================================================= +# Qwen3.5-VL 800M SFT Configuration (Dense) +# ============================================================================= +def qwen35_vl_800m_sft_config(hf_path: str = "Qwen/Qwen3.5-0.8B") -> ConfigContainer: + """Return a full SFT config for Qwen3.5-VL 800M (dense). + + Default configuration: 1 node, 8 GPUs + - TP=1, PP=1 + - LR=5e-6 (full SFT) + - Sequence length: 4096 + + Note: num_kv_heads=2, so max TP=2. + + Args: + hf_path: HuggingFace model ID or local path to model directory. + """ + cfg = _sft_common_vlm() + + # Model configuration + cfg.model = AutoBridge.from_hf_pretrained(hf_path).to_megatron_provider(load_weights=False) + cfg.model.seq_length = 4096 + + # Parallel settings + cfg.model.tensor_model_parallel_size = 1 + cfg.model.pipeline_model_parallel_size = 1 + cfg.model.pipeline_dtype = None + cfg.model.virtual_pipeline_model_parallel_size = None + cfg.model.context_parallel_size = 1 + cfg.model.sequence_parallel = False + + # VLM-specific settings + cfg.model.freeze_language_model = False + cfg.model.freeze_vision_model = False + cfg.model.freeze_vision_projection = False + + # TE / Transformer implementation + cfg.model.transformer_impl = "transformer_engine" + + # CUDA Graph settings + cfg.model.cuda_graph_impl = "none" + cfg.model.cuda_graph_scope = "full" + cfg.model.cuda_graph_warmup_steps = 3 + + # Kernel selections + cfg.model.attention_backend = "auto" + cfg.model.cross_entropy_loss_fusion = True + cfg.model.cross_entropy_fusion_impl = "native" + + # Memory saving (disabled by default) + cfg.model.recompute_granularity = None + cfg.model.recompute_modules = None + cfg.model.fine_grained_activation_offloading = False + cfg.model.offload_modules = None + + # Training config + cfg.train.train_iters = 300000 + cfg.train.global_batch_size = 32 + cfg.train.micro_batch_size = 1 + cfg.train.manual_gc = True + cfg.train.manual_gc_interval = 100 + cfg.train.manual_gc_eval = 100 + + # Optimizer - lower LR for full SFT + opt_cfg, scheduler_cfg = distributed_fused_adam_with_cosine_annealing( + lr_warmup_iters=200, + lr_decay_iters=300000, + max_lr=5e-6, + min_lr=5e-7, + ) + cfg.optimizer = opt_cfg + cfg.scheduler = scheduler_cfg + + # Optimizer precision settings (disabled by default for full precision) + cfg.optimizer.use_precision_aware_optimizer = False + cfg.optimizer.main_grads_dtype = torch.float32 + cfg.optimizer.main_params_dtype = torch.float32 + cfg.optimizer.exp_avg_dtype = torch.float32 + cfg.optimizer.exp_avg_sq_dtype = torch.float32 + + # Dataset configuration + cfg.dataset.seq_length = 4096 + cfg.dataset.hf_processor_path = hf_path + cfg.dataset.pack_sequences_in_batch = False + + # DDP settings + cfg.ddp.overlap_grad_reduce = False + cfg.ddp.overlap_param_gather = False + cfg.ddp.check_for_nan_in_grad = True + cfg.ddp.use_distributed_optimizer = True + cfg.ddp.grad_reduce_in_fp32 = True + cfg.ddp.average_in_collective = True + cfg.ddp.data_parallel_sharding_strategy = "optim_grads_params" + + # FP8 and MXFP8 settings (disabled by default) + cfg.mixed_precision = "bf16_mixed" + + return cfg + + +# ============================================================================= +# Qwen3.5-VL 2B SFT Configuration (Dense) +# ============================================================================= +def qwen35_vl_2b_sft_config(hf_path: str = "Qwen/Qwen3.5-2B") -> ConfigContainer: + """Return a full SFT config for Qwen3.5-VL 2B (dense). + + Default configuration: 1 node, 8 GPUs + - TP=1, PP=1 + - LR=5e-6 (full SFT) + - Sequence length: 4096 + + Note: num_kv_heads=2, so max TP=2. + + Args: + hf_path: HuggingFace model ID or local path to model directory. + """ + cfg = _sft_common_vlm() + + # Model configuration + cfg.model = AutoBridge.from_hf_pretrained(hf_path).to_megatron_provider(load_weights=False) + cfg.model.seq_length = 4096 + + # Parallel settings + cfg.model.tensor_model_parallel_size = 1 + cfg.model.pipeline_model_parallel_size = 1 + cfg.model.pipeline_dtype = None + cfg.model.virtual_pipeline_model_parallel_size = None + cfg.model.context_parallel_size = 1 + cfg.model.sequence_parallel = False + + # VLM-specific settings + cfg.model.freeze_language_model = False + cfg.model.freeze_vision_model = False + cfg.model.freeze_vision_projection = False + + # TE / Transformer implementation + cfg.model.transformer_impl = "transformer_engine" + + # CUDA Graph settings + cfg.model.cuda_graph_impl = "none" + cfg.model.cuda_graph_scope = "full" + cfg.model.cuda_graph_warmup_steps = 3 + + # Kernel selections + cfg.model.attention_backend = "auto" + cfg.model.cross_entropy_loss_fusion = True + cfg.model.cross_entropy_fusion_impl = "native" + + # Memory saving (disabled by default) + cfg.model.recompute_granularity = None + cfg.model.recompute_modules = None + cfg.model.fine_grained_activation_offloading = False + cfg.model.offload_modules = None + + # Training config + cfg.train.train_iters = 300000 + cfg.train.global_batch_size = 32 + cfg.train.micro_batch_size = 1 + cfg.train.manual_gc = True + cfg.train.manual_gc_interval = 100 + cfg.train.manual_gc_eval = 100 + + # Optimizer - lower LR for full SFT + opt_cfg, scheduler_cfg = distributed_fused_adam_with_cosine_annealing( + lr_warmup_iters=200, + lr_decay_iters=300000, + max_lr=5e-6, + min_lr=5e-7, + ) + cfg.optimizer = opt_cfg + cfg.scheduler = scheduler_cfg + + # Optimizer precision settings (disabled by default for full precision) + cfg.optimizer.use_precision_aware_optimizer = False + cfg.optimizer.main_grads_dtype = torch.float32 + cfg.optimizer.main_params_dtype = torch.float32 + cfg.optimizer.exp_avg_dtype = torch.float32 + cfg.optimizer.exp_avg_sq_dtype = torch.float32 + + # Dataset configuration + cfg.dataset.seq_length = 4096 + cfg.dataset.hf_processor_path = hf_path + cfg.dataset.pack_sequences_in_batch = False + + # DDP settings + cfg.ddp.overlap_grad_reduce = False + cfg.ddp.overlap_param_gather = False + cfg.ddp.check_for_nan_in_grad = True + cfg.ddp.use_distributed_optimizer = True + cfg.ddp.grad_reduce_in_fp32 = True + cfg.ddp.average_in_collective = True + cfg.ddp.data_parallel_sharding_strategy = "optim_grads_params" + + # FP8 and MXFP8 settings (disabled by default) + cfg.mixed_precision = "bf16_mixed" + + return cfg + + +# ============================================================================= +# Qwen3.5-VL 4B SFT Configuration (Dense) +# ============================================================================= +def qwen35_vl_4b_sft_config(hf_path: str = "Qwen/Qwen3.5-4B") -> ConfigContainer: + """Return a full SFT config for Qwen3.5-VL 4B (dense). + + Default configuration: 1 node, 8 GPUs + - TP=2, PP=1 + - LR=5e-6 (full SFT) + - Sequence length: 4096 + + Note: num_kv_heads=4, so max TP=4. + + Args: + hf_path: HuggingFace model ID or local path to model directory. + """ + cfg = _sft_common_vlm() + + # Model configuration + cfg.model = AutoBridge.from_hf_pretrained(hf_path).to_megatron_provider(load_weights=False) + cfg.model.seq_length = 4096 + + # Parallel settings + cfg.model.tensor_model_parallel_size = 2 + cfg.model.pipeline_model_parallel_size = 1 + cfg.model.pipeline_dtype = None + cfg.model.virtual_pipeline_model_parallel_size = None + cfg.model.context_parallel_size = 1 + cfg.model.sequence_parallel = False + + # VLM-specific settings + cfg.model.freeze_language_model = False + cfg.model.freeze_vision_model = False + cfg.model.freeze_vision_projection = False + + # TE / Transformer implementation + cfg.model.transformer_impl = "transformer_engine" + + # CUDA Graph settings + cfg.model.cuda_graph_impl = "none" + cfg.model.cuda_graph_scope = "full" + cfg.model.cuda_graph_warmup_steps = 3 + + # Kernel selections + cfg.model.attention_backend = "auto" + cfg.model.cross_entropy_loss_fusion = True + cfg.model.cross_entropy_fusion_impl = "native" + + # Memory saving (disabled by default) + cfg.model.recompute_granularity = None + cfg.model.recompute_modules = None + cfg.model.fine_grained_activation_offloading = False + cfg.model.offload_modules = None + + # Training config + cfg.train.train_iters = 300000 + cfg.train.global_batch_size = 32 + cfg.train.micro_batch_size = 1 + cfg.train.manual_gc = True + cfg.train.manual_gc_interval = 100 + cfg.train.manual_gc_eval = 100 + + # Optimizer - lower LR for full SFT + opt_cfg, scheduler_cfg = distributed_fused_adam_with_cosine_annealing( + lr_warmup_iters=200, + lr_decay_iters=300000, + max_lr=5e-6, + min_lr=5e-7, + ) + cfg.optimizer = opt_cfg + cfg.scheduler = scheduler_cfg + + # Optimizer precision settings (disabled by default for full precision) + cfg.optimizer.use_precision_aware_optimizer = False + cfg.optimizer.main_grads_dtype = torch.float32 + cfg.optimizer.main_params_dtype = torch.float32 + cfg.optimizer.exp_avg_dtype = torch.float32 + cfg.optimizer.exp_avg_sq_dtype = torch.float32 + + # Dataset configuration + cfg.dataset.seq_length = 4096 + cfg.dataset.hf_processor_path = hf_path + cfg.dataset.pack_sequences_in_batch = False + + # DDP settings + cfg.ddp.overlap_grad_reduce = False + cfg.ddp.overlap_param_gather = False + cfg.ddp.check_for_nan_in_grad = True + cfg.ddp.use_distributed_optimizer = True + cfg.ddp.grad_reduce_in_fp32 = True + cfg.ddp.average_in_collective = True + cfg.ddp.data_parallel_sharding_strategy = "optim_grads_params" + + # FP8 and MXFP8 settings (disabled by default) + cfg.mixed_precision = "bf16_mixed" + + return cfg + + +# ============================================================================= +# Qwen3.5-VL 9B SFT Configuration (Dense) +# ============================================================================= +def qwen35_vl_9b_sft_config(hf_path: str = "Qwen/Qwen3.5-9B") -> ConfigContainer: + """Return a full SFT config for Qwen3.5-VL 9B (dense). + + Default configuration: 1 node, 8 GPUs + - TP=4, PP=1 + - LR=5e-6 (full SFT) + - Sequence length: 4096 + + Note: num_kv_heads=4, so max TP=4. + + Args: + hf_path: HuggingFace model ID or local path to model directory. + """ + cfg = _sft_common_vlm() + + # Model configuration + cfg.model = AutoBridge.from_hf_pretrained(hf_path).to_megatron_provider(load_weights=False) + cfg.model.seq_length = 4096 + + # Parallel settings + cfg.model.tensor_model_parallel_size = 4 + cfg.model.pipeline_model_parallel_size = 1 + cfg.model.pipeline_dtype = None + cfg.model.virtual_pipeline_model_parallel_size = None + cfg.model.context_parallel_size = 1 + cfg.model.sequence_parallel = False + + # VLM-specific settings + cfg.model.freeze_language_model = False + cfg.model.freeze_vision_model = False + cfg.model.freeze_vision_projection = False + + # TE / Transformer implementation + cfg.model.transformer_impl = "transformer_engine" + + # CUDA Graph settings + cfg.model.cuda_graph_impl = "none" + cfg.model.cuda_graph_scope = "full" + cfg.model.cuda_graph_warmup_steps = 3 + + # Kernel selections + cfg.model.attention_backend = "auto" + cfg.model.cross_entropy_loss_fusion = True + cfg.model.cross_entropy_fusion_impl = "native" + + # Memory saving (disabled by default) + cfg.model.recompute_granularity = None + cfg.model.recompute_modules = None + cfg.model.fine_grained_activation_offloading = False + cfg.model.offload_modules = None + + # Training config + cfg.train.train_iters = 300000 + cfg.train.global_batch_size = 32 + cfg.train.micro_batch_size = 1 + cfg.train.manual_gc = True + cfg.train.manual_gc_interval = 100 + cfg.train.manual_gc_eval = 100 + + # Optimizer - lower LR for full SFT + opt_cfg, scheduler_cfg = distributed_fused_adam_with_cosine_annealing( + lr_warmup_iters=200, + lr_decay_iters=300000, + max_lr=5e-6, + min_lr=5e-7, + ) + cfg.optimizer = opt_cfg + cfg.scheduler = scheduler_cfg + + # Optimizer precision settings (disabled by default for full precision) + cfg.optimizer.use_precision_aware_optimizer = False + cfg.optimizer.main_grads_dtype = torch.float32 + cfg.optimizer.main_params_dtype = torch.float32 + cfg.optimizer.exp_avg_dtype = torch.float32 + cfg.optimizer.exp_avg_sq_dtype = torch.float32 + + # Dataset configuration + cfg.dataset.seq_length = 4096 + cfg.dataset.hf_processor_path = hf_path + cfg.dataset.pack_sequences_in_batch = False + + # DDP settings + cfg.ddp.overlap_grad_reduce = False + cfg.ddp.overlap_param_gather = False + cfg.ddp.check_for_nan_in_grad = True + cfg.ddp.use_distributed_optimizer = True + cfg.ddp.grad_reduce_in_fp32 = True + cfg.ddp.average_in_collective = True + cfg.ddp.data_parallel_sharding_strategy = "optim_grads_params" + + # FP8 and MXFP8 settings (disabled by default) + cfg.mixed_precision = "bf16_mixed" + + return cfg + + +# ============================================================================= +# Qwen3.5-VL 27B SFT Configuration (Dense) +# ============================================================================= +def qwen35_vl_27b_sft_config(hf_path: str = "Qwen/Qwen3.5-27B") -> ConfigContainer: + """Return a full SFT config for Qwen3.5-VL 27B (dense). + + Default configuration: 2 nodes, 16 GPUs total + - TP=4, PP=4 + - LR=5e-6 (full SFT) + - Sequence length: 4096 + + Args: + hf_path: HuggingFace model ID or local path to model directory. + """ + cfg = _sft_common_vlm() + + # Model configuration + cfg.model = AutoBridge.from_hf_pretrained(hf_path).to_megatron_provider(load_weights=False) + cfg.model.seq_length = 4096 + + # Parallel settings + cfg.model.tensor_model_parallel_size = 4 + cfg.model.pipeline_model_parallel_size = 4 + cfg.model.pipeline_dtype = torch.bfloat16 + cfg.model.virtual_pipeline_model_parallel_size = None + cfg.model.context_parallel_size = 1 + cfg.model.sequence_parallel = False + + # VLM-specific settings + cfg.model.freeze_language_model = False + cfg.model.freeze_vision_model = False + cfg.model.freeze_vision_projection = False + + # TE / Transformer implementation + cfg.model.transformer_impl = "transformer_engine" + + # CUDA Graph settings + cfg.model.cuda_graph_impl = "none" + cfg.model.cuda_graph_scope = "full" + cfg.model.cuda_graph_warmup_steps = 3 + + # Kernel selections + cfg.model.attention_backend = "auto" + cfg.model.cross_entropy_loss_fusion = True + cfg.model.cross_entropy_fusion_impl = "native" + + # Memory saving (disabled by default) + cfg.model.recompute_granularity = None + cfg.model.recompute_modules = None + cfg.model.fine_grained_activation_offloading = False + cfg.model.offload_modules = None + + # Training config + cfg.train.train_iters = 300000 + cfg.train.global_batch_size = 32 + cfg.train.micro_batch_size = 1 + cfg.train.manual_gc = True + cfg.train.manual_gc_interval = 100 + cfg.train.manual_gc_eval = 100 + + # Optimizer - lower LR for full SFT + opt_cfg, scheduler_cfg = distributed_fused_adam_with_cosine_annealing( + lr_warmup_iters=200, + lr_decay_iters=300000, + max_lr=5e-6, + min_lr=5e-7, + ) + cfg.optimizer = opt_cfg + cfg.scheduler = scheduler_cfg + + # Optimizer precision settings (disabled by default for full precision) + cfg.optimizer.use_precision_aware_optimizer = False + cfg.optimizer.main_grads_dtype = torch.float32 + cfg.optimizer.main_params_dtype = torch.float32 + cfg.optimizer.exp_avg_dtype = torch.float32 + cfg.optimizer.exp_avg_sq_dtype = torch.float32 + + # Dataset configuration + cfg.dataset.seq_length = 4096 + cfg.dataset.hf_processor_path = hf_path + cfg.dataset.pack_sequences_in_batch = False + + # DDP settings + cfg.ddp.overlap_grad_reduce = False + cfg.ddp.overlap_param_gather = False + cfg.ddp.check_for_nan_in_grad = True + cfg.ddp.use_distributed_optimizer = True + cfg.ddp.grad_reduce_in_fp32 = True + cfg.ddp.average_in_collective = True + cfg.ddp.data_parallel_sharding_strategy = "optim_grads_params" + + # FP8 and MXFP8 settings (disabled by default) + cfg.mixed_precision = "bf16_mixed" + + return cfg + + +# ============================================================================= +# Qwen3.5-VL 35B-A3B SFT Configuration (MoE) +# ============================================================================= +def qwen35_vl_35b_a3b_sft_config(hf_path: str = "Qwen/Qwen3.5-35B-A3B") -> ConfigContainer: + """Return a full SFT config for Qwen3.5-VL 35B-A3B (MoE). + + Default configuration: 2 nodes, 16 GPUs + - TP=2, PP=1, EP=16 + - LR=2e-5 (full SFT) + - Sequence length: 4096 + + Args: + hf_path: HuggingFace model ID or local path to model directory. + """ + cfg = _sft_common_vlm() + + # Model configuration + cfg.model = AutoBridge.from_hf_pretrained(hf_path).to_megatron_provider(load_weights=False) + cfg.model.seq_length = 4096 + + # Parallel settings + cfg.model.tensor_model_parallel_size = 2 + cfg.model.pipeline_model_parallel_size = 1 + cfg.model.pipeline_dtype = torch.bfloat16 + cfg.model.virtual_pipeline_model_parallel_size = None + cfg.model.expert_model_parallel_size = 16 + cfg.model.expert_tensor_parallel_size = 1 + cfg.model.context_parallel_size = 1 + cfg.model.sequence_parallel = True + + # VLM-specific settings + cfg.model.freeze_language_model = False + cfg.model.freeze_vision_model = False + cfg.model.freeze_vision_projection = False + + # TE / Transformer implementation + cfg.model.transformer_impl = "transformer_engine" + + # CUDA Graph settings + cfg.model.cuda_graph_impl = "none" + cfg.model.cuda_graph_scope = "full" + cfg.model.cuda_graph_warmup_steps = 3 + + # Kernel selections + cfg.model.attention_backend = "auto" + cfg.model.cross_entropy_loss_fusion = True + cfg.model.cross_entropy_fusion_impl = "native" + + # MoE kernel selections + cfg.model.moe_router_fusion = False + cfg.model.moe_permute_fusion = True + cfg.model.moe_grouped_gemm = True + + # Memory saving (disabled by default) + cfg.model.recompute_granularity = None + cfg.model.recompute_modules = None + cfg.model.fine_grained_activation_offloading = False + cfg.model.offload_modules = None + + # MoE overlap + cfg.model.moe_shared_expert_overlap = False + + # MoE force balance + cfg.model.moe_router_force_load_balancing = False + + # MoE FP8 padding + cfg.model.moe_router_padding_for_fp8 = False + + # Training config + cfg.train.train_iters = 300000 + cfg.train.global_batch_size = 32 + cfg.train.micro_batch_size = 1 + cfg.train.manual_gc = True + cfg.train.manual_gc_interval = 100 + cfg.train.manual_gc_eval = 100 + + # Optimizer - lower LR for full SFT + opt_cfg, scheduler_cfg = distributed_fused_adam_with_cosine_annealing( + lr_warmup_iters=200, + lr_decay_iters=300000, + max_lr=2e-5, + min_lr=2e-6, + ) + cfg.optimizer = opt_cfg + cfg.scheduler = scheduler_cfg + + # Optimizer precision settings (disabled by default for full precision) + cfg.optimizer.use_precision_aware_optimizer = False + cfg.optimizer.main_grads_dtype = torch.float32 + cfg.optimizer.main_params_dtype = torch.float32 + cfg.optimizer.exp_avg_dtype = torch.float32 + cfg.optimizer.exp_avg_sq_dtype = torch.float32 + + # Dataset configuration + cfg.dataset.seq_length = 4096 + cfg.dataset.hf_processor_path = hf_path + cfg.dataset.pack_sequences_in_batch = False + + # DDP settings + cfg.ddp.overlap_grad_reduce = False + cfg.ddp.overlap_param_gather = False + cfg.ddp.check_for_nan_in_grad = True + cfg.ddp.use_distributed_optimizer = True + cfg.ddp.grad_reduce_in_fp32 = True + cfg.ddp.average_in_collective = True + cfg.ddp.data_parallel_sharding_strategy = "optim_grads_params" + + # Comm overlap settings (MoE) + cfg.comm_overlap = None + + # FP8 and MXFP8 settings (disabled by default) + cfg.mixed_precision = "bf16_mixed" + + return cfg + + +# ============================================================================= +# Qwen3.5-VL 122B-A10B SFT Configuration (MoE) +# ============================================================================= +def qwen35_vl_122b_a10b_sft_config(hf_path: str = "Qwen/Qwen3.5-122B-A10B") -> ConfigContainer: + """Return a full SFT config for Qwen3.5-VL 122B-A10B (MoE). + + Default configuration: 4 nodes, 32 GPUs + - TP=2, PP=6, EP=8 + - LR=2e-5 (full SFT) + - Sequence length: 4096 + + Args: + hf_path: HuggingFace model ID or local path to model directory. + """ + cfg = _sft_common_vlm() + + # Model configuration + cfg.model = AutoBridge.from_hf_pretrained(hf_path).to_megatron_provider(load_weights=False) + cfg.model.seq_length = 4096 + + # Parallel settings + cfg.model.tensor_model_parallel_size = 2 + cfg.model.pipeline_model_parallel_size = 6 + cfg.model.pipeline_dtype = torch.bfloat16 + cfg.model.virtual_pipeline_model_parallel_size = None + cfg.model.expert_model_parallel_size = 8 + cfg.model.expert_tensor_parallel_size = 1 + cfg.model.context_parallel_size = 1 + cfg.model.sequence_parallel = True + + # VLM-specific settings + cfg.model.freeze_language_model = False + cfg.model.freeze_vision_model = False + cfg.model.freeze_vision_projection = False + + # TE / Transformer implementation + cfg.model.transformer_impl = "transformer_engine" + + # CUDA Graph settings + cfg.model.cuda_graph_impl = "none" + cfg.model.cuda_graph_scope = "full" + cfg.model.cuda_graph_warmup_steps = 3 + + # Kernel selections + cfg.model.attention_backend = "auto" + cfg.model.cross_entropy_loss_fusion = True + cfg.model.cross_entropy_fusion_impl = "native" + + # MoE kernel selections + cfg.model.moe_router_fusion = False + cfg.model.moe_permute_fusion = True + cfg.model.moe_grouped_gemm = True + + # Memory saving — activation recomputation enabled for this large model + cfg.model.recompute_granularity = "full" + cfg.model.recompute_method = "uniform" + cfg.model.recompute_num_layers = 1 + cfg.model.recompute_modules = None + cfg.model.fine_grained_activation_offloading = False + cfg.model.offload_modules = None + + # MoE overlap + cfg.model.moe_shared_expert_overlap = False + + # MoE force balance + cfg.model.moe_router_force_load_balancing = False + + # MoE FP8 padding + cfg.model.moe_router_padding_for_fp8 = False + + # Training config + cfg.train.train_iters = 300000 + cfg.train.global_batch_size = 36 + cfg.train.micro_batch_size = 1 + cfg.train.manual_gc = True + cfg.train.manual_gc_interval = 100 + cfg.train.manual_gc_eval = 100 + + # Optimizer - lower LR for full SFT + opt_cfg, scheduler_cfg = distributed_fused_adam_with_cosine_annealing( + lr_warmup_iters=200, + lr_decay_iters=300000, + max_lr=2e-5, + min_lr=2e-6, + ) + cfg.optimizer = opt_cfg + cfg.scheduler = scheduler_cfg + + # Optimizer precision settings (disabled by default for full precision) + cfg.optimizer.use_precision_aware_optimizer = False + cfg.optimizer.main_grads_dtype = torch.float32 + cfg.optimizer.main_params_dtype = torch.float32 + cfg.optimizer.exp_avg_dtype = torch.float32 + cfg.optimizer.exp_avg_sq_dtype = torch.float32 + + # Dataset configuration + cfg.dataset.seq_length = 4096 + cfg.dataset.hf_processor_path = hf_path + cfg.dataset.pack_sequences_in_batch = False + + # DDP settings + cfg.ddp.overlap_grad_reduce = False + cfg.ddp.overlap_param_gather = False + cfg.ddp.check_for_nan_in_grad = True + cfg.ddp.use_distributed_optimizer = True + cfg.ddp.grad_reduce_in_fp32 = True + cfg.ddp.average_in_collective = True + cfg.ddp.data_parallel_sharding_strategy = "optim_grads_params" + + # Comm overlap settings (MoE) + cfg.comm_overlap = None + + # FP8 and MXFP8 settings (disabled by default) + cfg.mixed_precision = "bf16_mixed" + + return cfg + + +# ============================================================================= +# Qwen3.5-VL 397B-A17B SFT Configuration (MoE) +# ============================================================================= +def qwen35_vl_397b_a17b_sft_config(hf_path: str = "Qwen/Qwen3.5-397B-A17B") -> ConfigContainer: + """Return a full SFT config for Qwen3.5-VL 397B-A17B (MoE). + + Default configuration: 16 nodes, 128 GPUs + - TP=2, PP=4, EP=32 + - LR=2e-5 (full SFT) + - Sequence length: 4096 + + Args: + hf_path: HuggingFace model ID or local path to model directory. + """ + cfg = _sft_common_vlm() + + # Model configuration + cfg.model = AutoBridge.from_hf_pretrained(hf_path).to_megatron_provider(load_weights=False) + cfg.model.seq_length = 4096 + + # Parallel settings + cfg.model.tensor_model_parallel_size = 2 + cfg.model.pipeline_model_parallel_size = 4 + cfg.model.pipeline_dtype = torch.bfloat16 + cfg.model.virtual_pipeline_model_parallel_size = None + cfg.model.expert_model_parallel_size = 32 + cfg.model.expert_tensor_parallel_size = 1 + cfg.model.context_parallel_size = 1 + cfg.model.sequence_parallel = True + + # VLM-specific settings + cfg.model.freeze_language_model = False + cfg.model.freeze_vision_model = False + cfg.model.freeze_vision_projection = False + + # TE / Transformer implementation + cfg.model.transformer_impl = "transformer_engine" + + # CUDA Graph settings + cfg.model.cuda_graph_impl = "none" + cfg.model.cuda_graph_scope = "full" + cfg.model.cuda_graph_warmup_steps = 3 + + # Kernel selections + cfg.model.attention_backend = "auto" + cfg.model.cross_entropy_loss_fusion = True + cfg.model.cross_entropy_fusion_impl = "native" + + # MoE kernel selections + cfg.model.moe_router_fusion = False + cfg.model.moe_permute_fusion = True + cfg.model.moe_grouped_gemm = True + + # Memory saving — activation recomputation enabled for this large model + cfg.model.recompute_granularity = "full" + cfg.model.recompute_method = "uniform" + cfg.model.recompute_num_layers = 1 + cfg.model.recompute_modules = None + cfg.model.fine_grained_activation_offloading = False + cfg.model.offload_modules = None + + # MoE overlap + cfg.model.moe_shared_expert_overlap = False + + # MoE force balance + cfg.model.moe_router_force_load_balancing = False + + # MoE FP8 padding + cfg.model.moe_router_padding_for_fp8 = False + + # Training config + cfg.train.train_iters = 300000 + cfg.train.global_batch_size = 32 + cfg.train.micro_batch_size = 1 + cfg.train.manual_gc = True + cfg.train.manual_gc_interval = 100 + cfg.train.manual_gc_eval = 100 + + # Optimizer - lower LR for full SFT + opt_cfg, scheduler_cfg = distributed_fused_adam_with_cosine_annealing( + lr_warmup_iters=200, + lr_decay_iters=300000, + max_lr=2e-5, + min_lr=2e-6, + ) + cfg.optimizer = opt_cfg + cfg.scheduler = scheduler_cfg + + # Optimizer precision settings (disabled by default for full precision) + cfg.optimizer.use_precision_aware_optimizer = False + cfg.optimizer.main_grads_dtype = torch.float32 + cfg.optimizer.main_params_dtype = torch.float32 + cfg.optimizer.exp_avg_dtype = torch.float32 + cfg.optimizer.exp_avg_sq_dtype = torch.float32 + + # Dataset configuration + cfg.dataset.seq_length = 4096 + cfg.dataset.hf_processor_path = hf_path + cfg.dataset.pack_sequences_in_batch = False + + # DDP settings + cfg.ddp.overlap_grad_reduce = False + cfg.ddp.overlap_param_gather = False + cfg.ddp.check_for_nan_in_grad = True + cfg.ddp.use_distributed_optimizer = True + cfg.ddp.grad_reduce_in_fp32 = True + cfg.ddp.average_in_collective = True + cfg.ddp.data_parallel_sharding_strategy = "optim_grads_params" + + # Comm overlap settings (MoE) + cfg.comm_overlap = None + + # FP8 and MXFP8 settings (disabled by default) + cfg.mixed_precision = "bf16_mixed" + + return cfg + + +# ============================================================================= +# Qwen3.5-VL 800M PEFT Configuration (Dense) +# ============================================================================= +def qwen35_vl_800m_peft_config( + peft_scheme: str | PEFT = "lora", hf_path: str = "Qwen/Qwen3.5-0.8B" +) -> ConfigContainer: + """Return a PEFT config for Qwen3.5-VL 800M (dense). + + Default configuration: 1 node, 8 GPUs + - TP=1, PP=1 + - LR=1e-4 (PEFT) + - Sequence length: 4096 + + Args: + peft_scheme: PEFT scheme - "lora", "dora", or a custom PEFT instance. + hf_path: HuggingFace model ID or local path to model directory. + """ + cfg = _peft_common_vlm() + + # PEFT scheme + if isinstance(peft_scheme, str) and peft_scheme.lower() in ["lora", "dora"]: + cfg.peft = default_peft_config(peft_scheme) + else: + cfg.peft = peft_scheme + + # Model configuration + cfg.model = AutoBridge.from_hf_pretrained(hf_path).to_megatron_provider(load_weights=False) + cfg.model.seq_length = 4096 + + # Parallel settings + cfg.model.tensor_model_parallel_size = 1 + cfg.model.pipeline_model_parallel_size = 1 + cfg.model.pipeline_dtype = None + cfg.model.virtual_pipeline_model_parallel_size = None + cfg.model.context_parallel_size = 1 + cfg.model.sequence_parallel = False + + # VLM-specific settings + cfg.model.freeze_language_model = False + cfg.model.freeze_vision_model = False + cfg.model.freeze_vision_projection = False + + # TE / Transformer implementation + cfg.model.transformer_impl = "transformer_engine" + + # CUDA Graph settings + cfg.model.cuda_graph_impl = "none" + cfg.model.cuda_graph_scope = "full" + cfg.model.cuda_graph_warmup_steps = 3 + + # Kernel selections + cfg.model.attention_backend = "auto" + cfg.model.cross_entropy_loss_fusion = True + cfg.model.cross_entropy_fusion_impl = "native" + + # Memory saving (disabled by default) + cfg.model.recompute_granularity = None + cfg.model.recompute_modules = None + cfg.model.fine_grained_activation_offloading = False + cfg.model.offload_modules = None + + # Training config + cfg.train.train_iters = 300000 + cfg.train.global_batch_size = 32 + cfg.train.micro_batch_size = 1 + cfg.train.manual_gc = True + cfg.train.manual_gc_interval = 100 + cfg.train.manual_gc_eval = 100 + + # Optimizer - higher LR for PEFT + opt_cfg, scheduler_cfg = distributed_fused_adam_with_cosine_annealing( + lr_warmup_iters=200, + lr_decay_iters=300000, + max_lr=1e-4, + min_lr=3e-5, + ) + cfg.optimizer = opt_cfg + cfg.scheduler = scheduler_cfg + + # Optimizer precision settings (disabled by default for full precision) + cfg.optimizer.use_precision_aware_optimizer = False + cfg.optimizer.main_grads_dtype = torch.float32 + cfg.optimizer.main_params_dtype = torch.float32 + cfg.optimizer.exp_avg_dtype = torch.float32 + cfg.optimizer.exp_avg_sq_dtype = torch.float32 + + # Dataset configuration + cfg.dataset.seq_length = 4096 + cfg.dataset.hf_processor_path = hf_path + cfg.dataset.pack_sequences_in_batch = False + + # DDP settings + cfg.ddp.overlap_grad_reduce = False + cfg.ddp.overlap_param_gather = False + cfg.ddp.check_for_nan_in_grad = True + cfg.ddp.use_distributed_optimizer = True + cfg.ddp.grad_reduce_in_fp32 = True + cfg.ddp.average_in_collective = True + cfg.ddp.data_parallel_sharding_strategy = "optim_grads_params" + + # FP8 and MXFP8 settings (disabled by default) + cfg.mixed_precision = "bf16_mixed" + + return cfg + + +# ============================================================================= +# Qwen3.5-VL 2B PEFT Configuration (Dense) +# ============================================================================= +def qwen35_vl_2b_peft_config(peft_scheme: str | PEFT = "lora", hf_path: str = "Qwen/Qwen3.5-2B") -> ConfigContainer: + """Return a PEFT config for Qwen3.5-VL 2B (dense). + + Default configuration: 1 node, 8 GPUs + - TP=1, PP=1 + - LR=1e-4 (PEFT) + - Sequence length: 4096 + + Args: + peft_scheme: PEFT scheme - "lora", "dora", or a custom PEFT instance. + hf_path: HuggingFace model ID or local path to model directory. + """ + cfg = _peft_common_vlm() + + # PEFT scheme + if isinstance(peft_scheme, str) and peft_scheme.lower() in ["lora", "dora"]: + cfg.peft = default_peft_config(peft_scheme) + else: + cfg.peft = peft_scheme + + # Model configuration + cfg.model = AutoBridge.from_hf_pretrained(hf_path).to_megatron_provider(load_weights=False) + cfg.model.seq_length = 4096 + + # Parallel settings + cfg.model.tensor_model_parallel_size = 1 + cfg.model.pipeline_model_parallel_size = 1 + cfg.model.pipeline_dtype = None + cfg.model.virtual_pipeline_model_parallel_size = None + cfg.model.context_parallel_size = 1 + cfg.model.sequence_parallel = False + + # VLM-specific settings + cfg.model.freeze_language_model = False + cfg.model.freeze_vision_model = False + cfg.model.freeze_vision_projection = False + + # TE / Transformer implementation + cfg.model.transformer_impl = "transformer_engine" + + # CUDA Graph settings + cfg.model.cuda_graph_impl = "none" + cfg.model.cuda_graph_scope = "full" + cfg.model.cuda_graph_warmup_steps = 3 + + # Kernel selections + cfg.model.attention_backend = "auto" + cfg.model.cross_entropy_loss_fusion = True + cfg.model.cross_entropy_fusion_impl = "native" + + # Memory saving (disabled by default) + cfg.model.recompute_granularity = None + cfg.model.recompute_modules = None + cfg.model.fine_grained_activation_offloading = False + cfg.model.offload_modules = None + + # Training config + cfg.train.train_iters = 300000 + cfg.train.global_batch_size = 32 + cfg.train.micro_batch_size = 1 + cfg.train.manual_gc = True + cfg.train.manual_gc_interval = 100 + cfg.train.manual_gc_eval = 100 + + # Optimizer - higher LR for PEFT + opt_cfg, scheduler_cfg = distributed_fused_adam_with_cosine_annealing( + lr_warmup_iters=200, + lr_decay_iters=300000, + max_lr=1e-4, + min_lr=3e-5, + ) + cfg.optimizer = opt_cfg + cfg.scheduler = scheduler_cfg + + # Optimizer precision settings (disabled by default for full precision) + cfg.optimizer.use_precision_aware_optimizer = False + cfg.optimizer.main_grads_dtype = torch.float32 + cfg.optimizer.main_params_dtype = torch.float32 + cfg.optimizer.exp_avg_dtype = torch.float32 + cfg.optimizer.exp_avg_sq_dtype = torch.float32 + + # Dataset configuration + cfg.dataset.seq_length = 4096 + cfg.dataset.hf_processor_path = hf_path + cfg.dataset.pack_sequences_in_batch = False + + # DDP settings + cfg.ddp.overlap_grad_reduce = False + cfg.ddp.overlap_param_gather = False + cfg.ddp.check_for_nan_in_grad = True + cfg.ddp.use_distributed_optimizer = True + cfg.ddp.grad_reduce_in_fp32 = True + cfg.ddp.average_in_collective = True + cfg.ddp.data_parallel_sharding_strategy = "optim_grads_params" + + # FP8 and MXFP8 settings (disabled by default) + cfg.mixed_precision = "bf16_mixed" + + return cfg + + +# ============================================================================= +# Qwen3.5-VL 4B PEFT Configuration (Dense) +# ============================================================================= +def qwen35_vl_4b_peft_config(peft_scheme: str | PEFT = "lora", hf_path: str = "Qwen/Qwen3.5-4B") -> ConfigContainer: + """Return a PEFT config for Qwen3.5-VL 4B (dense). + + Default configuration: 1 node, 8 GPUs + - TP=1, PP=1 + - LR=1e-4 (PEFT) + - Sequence length: 4096 + + Args: + peft_scheme: PEFT scheme - "lora", "dora", or a custom PEFT instance. + hf_path: HuggingFace model ID or local path to model directory. + """ + cfg = _peft_common_vlm() + + # PEFT scheme + if isinstance(peft_scheme, str) and peft_scheme.lower() in ["lora", "dora"]: + cfg.peft = default_peft_config(peft_scheme) + else: + cfg.peft = peft_scheme + + # Model configuration + cfg.model = AutoBridge.from_hf_pretrained(hf_path).to_megatron_provider(load_weights=False) + cfg.model.seq_length = 4096 + + # Parallel settings - lower TP for PEFT + cfg.model.tensor_model_parallel_size = 1 + cfg.model.pipeline_model_parallel_size = 1 + cfg.model.pipeline_dtype = None + cfg.model.virtual_pipeline_model_parallel_size = None + cfg.model.context_parallel_size = 1 + cfg.model.sequence_parallel = False + + # VLM-specific settings + cfg.model.freeze_language_model = False + cfg.model.freeze_vision_model = False + cfg.model.freeze_vision_projection = False + + # TE / Transformer implementation + cfg.model.transformer_impl = "transformer_engine" + + # CUDA Graph settings + cfg.model.cuda_graph_impl = "none" + cfg.model.cuda_graph_scope = "full" + cfg.model.cuda_graph_warmup_steps = 3 + + # Kernel selections + cfg.model.attention_backend = "auto" + cfg.model.cross_entropy_loss_fusion = True + cfg.model.cross_entropy_fusion_impl = "native" + + # Memory saving (disabled by default) + cfg.model.recompute_granularity = None + cfg.model.recompute_modules = None + cfg.model.fine_grained_activation_offloading = False + cfg.model.offload_modules = None + + # Training config + cfg.train.train_iters = 300000 + cfg.train.global_batch_size = 32 + cfg.train.micro_batch_size = 1 + cfg.train.manual_gc = True + cfg.train.manual_gc_interval = 100 + cfg.train.manual_gc_eval = 100 + + # Optimizer - higher LR for PEFT + opt_cfg, scheduler_cfg = distributed_fused_adam_with_cosine_annealing( + lr_warmup_iters=200, + lr_decay_iters=300000, + max_lr=1e-4, + min_lr=3e-5, + ) + cfg.optimizer = opt_cfg + cfg.scheduler = scheduler_cfg + + # Optimizer precision settings (disabled by default for full precision) + cfg.optimizer.use_precision_aware_optimizer = False + cfg.optimizer.main_grads_dtype = torch.float32 + cfg.optimizer.main_params_dtype = torch.float32 + cfg.optimizer.exp_avg_dtype = torch.float32 + cfg.optimizer.exp_avg_sq_dtype = torch.float32 + + # Dataset configuration + cfg.dataset.seq_length = 4096 + cfg.dataset.hf_processor_path = hf_path + cfg.dataset.pack_sequences_in_batch = False + + # DDP settings + cfg.ddp.overlap_grad_reduce = False + cfg.ddp.overlap_param_gather = False + cfg.ddp.check_for_nan_in_grad = True + cfg.ddp.use_distributed_optimizer = True + cfg.ddp.grad_reduce_in_fp32 = True + cfg.ddp.average_in_collective = True + cfg.ddp.data_parallel_sharding_strategy = "optim_grads_params" + + # FP8 and MXFP8 settings (disabled by default) + cfg.mixed_precision = "bf16_mixed" + + return cfg + + +# ============================================================================= +# Qwen3.5-VL 9B PEFT Configuration (Dense) +# ============================================================================= +def qwen35_vl_9b_peft_config(peft_scheme: str | PEFT = "lora", hf_path: str = "Qwen/Qwen3.5-9B") -> ConfigContainer: + """Return a PEFT config for Qwen3.5-VL 9B (dense). + + Default configuration: 1 node, 8 GPUs + - TP=1, PP=1 + - LR=1e-4 (PEFT) + - Sequence length: 4096 + + Args: + peft_scheme: PEFT scheme - "lora", "dora", or a custom PEFT instance. + hf_path: HuggingFace model ID or local path to model directory. + """ + cfg = _peft_common_vlm() + + # PEFT scheme + if isinstance(peft_scheme, str) and peft_scheme.lower() in ["lora", "dora"]: + cfg.peft = default_peft_config(peft_scheme) + else: + cfg.peft = peft_scheme + + # Model configuration + cfg.model = AutoBridge.from_hf_pretrained(hf_path).to_megatron_provider(load_weights=False) + cfg.model.seq_length = 4096 + + # Parallel settings - lower TP for PEFT + cfg.model.tensor_model_parallel_size = 1 + cfg.model.pipeline_model_parallel_size = 1 + cfg.model.pipeline_dtype = None + cfg.model.virtual_pipeline_model_parallel_size = None + cfg.model.context_parallel_size = 1 + cfg.model.sequence_parallel = False + + # VLM-specific settings + cfg.model.freeze_language_model = False + cfg.model.freeze_vision_model = False + cfg.model.freeze_vision_projection = False + + # TE / Transformer implementation + cfg.model.transformer_impl = "transformer_engine" + + # CUDA Graph settings + cfg.model.cuda_graph_impl = "none" + cfg.model.cuda_graph_scope = "full" + cfg.model.cuda_graph_warmup_steps = 3 + + # Kernel selections + cfg.model.attention_backend = "auto" + cfg.model.cross_entropy_loss_fusion = True + cfg.model.cross_entropy_fusion_impl = "native" + + # Memory saving (disabled by default) + cfg.model.recompute_granularity = None + cfg.model.recompute_modules = None + cfg.model.fine_grained_activation_offloading = False + cfg.model.offload_modules = None + + # Training config + cfg.train.train_iters = 300000 + cfg.train.global_batch_size = 32 + cfg.train.micro_batch_size = 1 + cfg.train.manual_gc = True + cfg.train.manual_gc_interval = 100 + cfg.train.manual_gc_eval = 100 + + # Optimizer - higher LR for PEFT + opt_cfg, scheduler_cfg = distributed_fused_adam_with_cosine_annealing( + lr_warmup_iters=200, + lr_decay_iters=300000, + max_lr=1e-4, + min_lr=3e-5, + ) + cfg.optimizer = opt_cfg + cfg.scheduler = scheduler_cfg + + # Optimizer precision settings (disabled by default for full precision) + cfg.optimizer.use_precision_aware_optimizer = False + cfg.optimizer.main_grads_dtype = torch.float32 + cfg.optimizer.main_params_dtype = torch.float32 + cfg.optimizer.exp_avg_dtype = torch.float32 + cfg.optimizer.exp_avg_sq_dtype = torch.float32 + + # Dataset configuration + cfg.dataset.seq_length = 4096 + cfg.dataset.hf_processor_path = hf_path + cfg.dataset.pack_sequences_in_batch = False + + # DDP settings + cfg.ddp.overlap_grad_reduce = False + cfg.ddp.overlap_param_gather = False + cfg.ddp.check_for_nan_in_grad = True + cfg.ddp.use_distributed_optimizer = True + cfg.ddp.grad_reduce_in_fp32 = True + cfg.ddp.average_in_collective = True + cfg.ddp.data_parallel_sharding_strategy = "optim_grads_params" + + # FP8 and MXFP8 settings (disabled by default) + cfg.mixed_precision = "bf16_mixed" + + return cfg + + +# ============================================================================= +# Qwen3.5-VL 27B PEFT Configuration (Dense) +# ============================================================================= +def qwen35_vl_27b_peft_config(peft_scheme: str | PEFT = "lora", hf_path: str = "Qwen/Qwen3.5-27B") -> ConfigContainer: + """Return a PEFT config for Qwen3.5-VL 27B (dense). + + Default configuration: 1 node, 8 GPUs + - TP=2, PP=1 + - LR=1e-4 (PEFT) + - Sequence length: 4096 + + Args: + peft_scheme: PEFT scheme - "lora", "dora", or a custom PEFT instance. + hf_path: HuggingFace model ID or local path to model directory. + """ + cfg = _peft_common_vlm() + + # PEFT scheme + if isinstance(peft_scheme, str) and peft_scheme.lower() in ["lora", "dora"]: + cfg.peft = default_peft_config(peft_scheme) + else: + cfg.peft = peft_scheme + + # Model configuration + cfg.model = AutoBridge.from_hf_pretrained(hf_path).to_megatron_provider(load_weights=False) + cfg.model.seq_length = 4096 + + # Parallel settings - lower TP/PP for PEFT + cfg.model.tensor_model_parallel_size = 2 + cfg.model.pipeline_model_parallel_size = 1 + cfg.model.pipeline_dtype = None + cfg.model.virtual_pipeline_model_parallel_size = None + cfg.model.context_parallel_size = 1 + cfg.model.sequence_parallel = False + + # VLM-specific settings + cfg.model.freeze_language_model = False + cfg.model.freeze_vision_model = False + cfg.model.freeze_vision_projection = False + + # TE / Transformer implementation + cfg.model.transformer_impl = "transformer_engine" + + # CUDA Graph settings + cfg.model.cuda_graph_impl = "none" + cfg.model.cuda_graph_scope = "full" + cfg.model.cuda_graph_warmup_steps = 3 + + # Kernel selections + cfg.model.attention_backend = "auto" + cfg.model.cross_entropy_loss_fusion = True + cfg.model.cross_entropy_fusion_impl = "native" + + # Memory saving (disabled by default) + cfg.model.recompute_granularity = None + cfg.model.recompute_modules = None + cfg.model.fine_grained_activation_offloading = False + cfg.model.offload_modules = None + + # Training config + cfg.train.train_iters = 300000 + cfg.train.global_batch_size = 32 + cfg.train.micro_batch_size = 1 + cfg.train.manual_gc = True + cfg.train.manual_gc_interval = 100 + cfg.train.manual_gc_eval = 100 + # Optimizer - higher LR for PEFT + opt_cfg, scheduler_cfg = distributed_fused_adam_with_cosine_annealing( + lr_warmup_iters=200, + lr_decay_iters=300000, + max_lr=1e-4, + min_lr=3e-5, + ) + cfg.optimizer = opt_cfg + cfg.scheduler = scheduler_cfg + + # Optimizer precision settings (disabled by default for full precision) + cfg.optimizer.use_precision_aware_optimizer = False + cfg.optimizer.main_grads_dtype = torch.float32 + cfg.optimizer.main_params_dtype = torch.float32 + cfg.optimizer.exp_avg_dtype = torch.float32 + cfg.optimizer.exp_avg_sq_dtype = torch.float32 + + # Dataset configuration + cfg.dataset.seq_length = 4096 + cfg.dataset.hf_processor_path = hf_path + cfg.dataset.pack_sequences_in_batch = False + + # DDP settings + cfg.ddp.overlap_grad_reduce = False + cfg.ddp.overlap_param_gather = False + cfg.ddp.check_for_nan_in_grad = True + cfg.ddp.use_distributed_optimizer = True + cfg.ddp.grad_reduce_in_fp32 = True + cfg.ddp.average_in_collective = True + cfg.ddp.data_parallel_sharding_strategy = "optim_grads_params" + + # FP8 and MXFP8 settings (disabled by default) + cfg.mixed_precision = "bf16_mixed" + + return cfg + + +# ============================================================================= +# Qwen3.5-VL 35B-A3B PEFT Configuration (MoE) +# ============================================================================= +def qwen35_vl_35b_a3b_peft_config( + peft_scheme: str | PEFT = "lora", hf_path: str = "Qwen/Qwen3.5-35B-A3B" +) -> ConfigContainer: + """Return a PEFT config for Qwen3.5-VL 35B-A3B (MoE). + + Default configuration: 1 node, 8 GPUs + - TP=2, PP=1, EP=4 + - LR=2e-4 (PEFT) + - Sequence length: 4096 + + Args: + peft_scheme: PEFT scheme - "lora", "dora", or a custom PEFT instance. + hf_path: HuggingFace model ID or local path to model directory. + """ + cfg = _peft_common_vlm() + + # PEFT scheme + if isinstance(peft_scheme, str) and peft_scheme.lower() in ["lora", "dora"]: + cfg.peft = default_peft_config(peft_scheme) + else: + cfg.peft = peft_scheme + + # Model configuration + cfg.model = AutoBridge.from_hf_pretrained(hf_path).to_megatron_provider(load_weights=False) + cfg.model.seq_length = 4096 + + # Parallel settings - lower EP for PEFT + cfg.model.tensor_model_parallel_size = 2 + cfg.model.pipeline_model_parallel_size = 1 + cfg.model.pipeline_dtype = torch.bfloat16 + cfg.model.virtual_pipeline_model_parallel_size = None + cfg.model.expert_model_parallel_size = 4 + cfg.model.expert_tensor_parallel_size = 1 + cfg.model.context_parallel_size = 1 + cfg.model.sequence_parallel = True + + # VLM-specific settings + cfg.model.freeze_language_model = False + cfg.model.freeze_vision_model = False + cfg.model.freeze_vision_projection = False + + # TE / Transformer implementation + cfg.model.transformer_impl = "transformer_engine" + + # CUDA Graph settings + cfg.model.cuda_graph_impl = "none" + cfg.model.cuda_graph_scope = "full" + cfg.model.cuda_graph_warmup_steps = 3 + + # Kernel selections + cfg.model.attention_backend = "auto" + cfg.model.cross_entropy_loss_fusion = True + cfg.model.cross_entropy_fusion_impl = "native" + + # MoE kernel selections + cfg.model.moe_router_fusion = False + cfg.model.moe_permute_fusion = True + cfg.model.moe_grouped_gemm = True + + # Memory saving (disabled by default) + cfg.model.recompute_granularity = None + cfg.model.recompute_modules = None + cfg.model.fine_grained_activation_offloading = False + cfg.model.offload_modules = None + + # MoE overlap + cfg.model.moe_shared_expert_overlap = False + + # MoE force balance + cfg.model.moe_router_force_load_balancing = False + + # MoE FP8 padding + cfg.model.moe_router_padding_for_fp8 = False + + # Training config + cfg.train.train_iters = 300000 + cfg.train.global_batch_size = 32 + cfg.train.micro_batch_size = 1 + cfg.train.manual_gc = True + cfg.train.manual_gc_interval = 100 + cfg.train.manual_gc_eval = 100 + + # Optimizer - higher LR for PEFT + opt_cfg, scheduler_cfg = distributed_fused_adam_with_cosine_annealing( + lr_warmup_iters=200, + lr_decay_iters=300000, + max_lr=2e-4, + min_lr=1e-4, + ) + cfg.optimizer = opt_cfg + cfg.scheduler = scheduler_cfg + + # Optimizer precision settings (disabled by default for full precision) + cfg.optimizer.use_precision_aware_optimizer = False + cfg.optimizer.main_grads_dtype = torch.float32 + cfg.optimizer.main_params_dtype = torch.float32 + cfg.optimizer.exp_avg_dtype = torch.float32 + cfg.optimizer.exp_avg_sq_dtype = torch.float32 + + # Dataset configuration + cfg.dataset.seq_length = 4096 + cfg.dataset.hf_processor_path = hf_path + cfg.dataset.pack_sequences_in_batch = False + + # DDP settings + cfg.ddp.overlap_grad_reduce = False + cfg.ddp.overlap_param_gather = False + cfg.ddp.check_for_nan_in_grad = True + cfg.ddp.use_distributed_optimizer = True + cfg.ddp.grad_reduce_in_fp32 = True + cfg.ddp.average_in_collective = True + cfg.ddp.data_parallel_sharding_strategy = "optim_grads_params" + + # Comm overlap settings (MoE) + cfg.comm_overlap = None + + # FP8 and MXFP8 settings (disabled by default) + cfg.mixed_precision = "bf16_mixed" + + return cfg + + +# ============================================================================= +# Qwen3.5-VL 122B-A10B PEFT Configuration (MoE) +# ============================================================================= +def qwen35_vl_122b_a10b_peft_config( + peft_scheme: str | PEFT = "lora", hf_path: str = "Qwen/Qwen3.5-122B-A10B" +) -> ConfigContainer: + """Return a PEFT config for Qwen3.5-VL 122B-A10B (MoE). + + Default configuration: 2 nodes, 16 GPUs + - TP=2, PP=1, EP=8 + - LR=2e-4 (PEFT) + - Sequence length: 4096 + + Args: + peft_scheme: PEFT scheme - "lora", "dora", or a custom PEFT instance. + hf_path: HuggingFace model ID or local path to model directory. + """ + cfg = _peft_common_vlm() + + # PEFT scheme + if isinstance(peft_scheme, str) and peft_scheme.lower() in ["lora", "dora"]: + cfg.peft = default_peft_config(peft_scheme) + else: + cfg.peft = peft_scheme + + # Model configuration + cfg.model = AutoBridge.from_hf_pretrained(hf_path).to_megatron_provider(load_weights=False) + cfg.model.seq_length = 4096 + + # Parallel settings - lower PP for PEFT + cfg.model.tensor_model_parallel_size = 2 + cfg.model.pipeline_model_parallel_size = 1 + cfg.model.pipeline_dtype = torch.bfloat16 + cfg.model.virtual_pipeline_model_parallel_size = None + cfg.model.expert_model_parallel_size = 8 + cfg.model.expert_tensor_parallel_size = 1 + cfg.model.context_parallel_size = 1 + cfg.model.sequence_parallel = True + + # VLM-specific settings + cfg.model.freeze_language_model = False + cfg.model.freeze_vision_model = False + cfg.model.freeze_vision_projection = False + + # TE / Transformer implementation + cfg.model.transformer_impl = "transformer_engine" + + # CUDA Graph settings + cfg.model.cuda_graph_impl = "none" + cfg.model.cuda_graph_scope = "full" + cfg.model.cuda_graph_warmup_steps = 3 + + # Kernel selections + cfg.model.attention_backend = "auto" + cfg.model.cross_entropy_loss_fusion = True + cfg.model.cross_entropy_fusion_impl = "native" + + # MoE kernel selections + cfg.model.moe_router_fusion = False + cfg.model.moe_permute_fusion = True + cfg.model.moe_grouped_gemm = True + + # Memory saving (disabled by default) + cfg.model.recompute_granularity = None + cfg.model.recompute_modules = None + cfg.model.fine_grained_activation_offloading = False + cfg.model.offload_modules = None + + # MoE overlap + cfg.model.moe_shared_expert_overlap = False + + # MoE force balance + cfg.model.moe_router_force_load_balancing = False + + # MoE FP8 padding + cfg.model.moe_router_padding_for_fp8 = False + + # Training config + cfg.train.train_iters = 300000 + cfg.train.global_batch_size = 36 + cfg.train.micro_batch_size = 1 + cfg.train.manual_gc = True + cfg.train.manual_gc_interval = 100 + cfg.train.manual_gc_eval = 100 + + # Optimizer - higher LR for PEFT + opt_cfg, scheduler_cfg = distributed_fused_adam_with_cosine_annealing( + lr_warmup_iters=200, + lr_decay_iters=300000, + max_lr=2e-4, + min_lr=3e-5, + ) + cfg.optimizer = opt_cfg + cfg.scheduler = scheduler_cfg + + # Optimizer precision settings (disabled by default for full precision) + cfg.optimizer.use_precision_aware_optimizer = False + cfg.optimizer.main_grads_dtype = torch.float32 + cfg.optimizer.main_params_dtype = torch.float32 + cfg.optimizer.exp_avg_dtype = torch.float32 + cfg.optimizer.exp_avg_sq_dtype = torch.float32 + + # Dataset configuration + cfg.dataset.seq_length = 4096 + cfg.dataset.hf_processor_path = hf_path + cfg.dataset.pack_sequences_in_batch = False + + # DDP settings + cfg.ddp.overlap_grad_reduce = False + cfg.ddp.overlap_param_gather = False + cfg.ddp.check_for_nan_in_grad = True + cfg.ddp.use_distributed_optimizer = True + cfg.ddp.grad_reduce_in_fp32 = True + cfg.ddp.average_in_collective = True + cfg.ddp.data_parallel_sharding_strategy = "optim_grads_params" + + # Comm overlap settings (MoE) + cfg.comm_overlap = None + + # FP8 and MXFP8 settings (disabled by default) + cfg.mixed_precision = "bf16_mixed" + + return cfg + + +# ============================================================================= +# Qwen3.5-VL 397B-A17B PEFT Configuration (MoE) +# ============================================================================= +def qwen35_vl_397b_a17b_peft_config( + peft_scheme: str | PEFT = "lora", hf_path: str = "Qwen/Qwen3.5-397B-A17B" +) -> ConfigContainer: + """Return a PEFT config for Qwen3.5-VL 397B-A17B (MoE). + + Default configuration: 4 nodes, 32 GPUs + - TP=2, PP=1, EP=32 + - LR=2e-4 (PEFT) + - Sequence length: 4096 + + Args: + peft_scheme: PEFT scheme - "lora", "dora", or a custom PEFT instance. + hf_path: HuggingFace model ID or local path to model directory. + """ + cfg = _peft_common_vlm() + + # PEFT scheme + if isinstance(peft_scheme, str) and peft_scheme.lower() in ["lora", "dora"]: + cfg.peft = default_peft_config(peft_scheme) + else: + cfg.peft = peft_scheme + + # Model configuration + cfg.model = AutoBridge.from_hf_pretrained(hf_path).to_megatron_provider(load_weights=False) + cfg.model.seq_length = 4096 + + # Parallel settings - lower PP for PEFT + cfg.model.tensor_model_parallel_size = 2 + cfg.model.pipeline_model_parallel_size = 1 + cfg.model.pipeline_dtype = torch.bfloat16 + cfg.model.virtual_pipeline_model_parallel_size = None + cfg.model.expert_model_parallel_size = 32 + cfg.model.expert_tensor_parallel_size = 1 + cfg.model.context_parallel_size = 1 + cfg.model.sequence_parallel = True + + # VLM-specific settings + cfg.model.freeze_language_model = False + cfg.model.freeze_vision_model = False + cfg.model.freeze_vision_projection = False + + # TE / Transformer implementation + cfg.model.transformer_impl = "transformer_engine" + + # CUDA Graph settings + cfg.model.cuda_graph_impl = "none" + cfg.model.cuda_graph_scope = "full" + cfg.model.cuda_graph_warmup_steps = 3 + + # Kernel selections + cfg.model.attention_backend = "auto" + cfg.model.cross_entropy_loss_fusion = True + cfg.model.cross_entropy_fusion_impl = "native" + + # MoE kernel selections + cfg.model.moe_router_fusion = False + cfg.model.moe_permute_fusion = True + cfg.model.moe_grouped_gemm = True + + # Memory saving (disabled by default) + cfg.model.recompute_granularity = None + cfg.model.recompute_modules = None + cfg.model.fine_grained_activation_offloading = False + cfg.model.offload_modules = None + + # MoE overlap + cfg.model.moe_shared_expert_overlap = False + + # MoE force balance + cfg.model.moe_router_force_load_balancing = False + + # MoE FP8 padding + cfg.model.moe_router_padding_for_fp8 = False + + # Training config + cfg.train.train_iters = 300000 + cfg.train.global_batch_size = 32 + cfg.train.micro_batch_size = 1 + cfg.train.manual_gc = True + cfg.train.manual_gc_interval = 100 + cfg.train.manual_gc_eval = 100 + + # Optimizer - higher LR for PEFT + opt_cfg, scheduler_cfg = distributed_fused_adam_with_cosine_annealing( + lr_warmup_iters=200, + lr_decay_iters=300000, + max_lr=2e-4, + min_lr=3e-5, + ) + cfg.optimizer = opt_cfg + cfg.scheduler = scheduler_cfg + + # Optimizer precision settings (disabled by default for full precision) + cfg.optimizer.use_precision_aware_optimizer = False + cfg.optimizer.main_grads_dtype = torch.float32 + cfg.optimizer.main_params_dtype = torch.float32 + cfg.optimizer.exp_avg_dtype = torch.float32 + cfg.optimizer.exp_avg_sq_dtype = torch.float32 + + # Dataset configuration + cfg.dataset.seq_length = 4096 + cfg.dataset.hf_processor_path = hf_path + cfg.dataset.pack_sequences_in_batch = False + + # DDP settings + cfg.ddp.overlap_grad_reduce = False + cfg.ddp.overlap_param_gather = False + cfg.ddp.check_for_nan_in_grad = True + cfg.ddp.use_distributed_optimizer = True + cfg.ddp.grad_reduce_in_fp32 = True + cfg.ddp.average_in_collective = True + cfg.ddp.data_parallel_sharding_strategy = "optim_grads_params" + + # Comm overlap settings (MoE) + cfg.comm_overlap = None + + # FP8 and MXFP8 settings (disabled by default) + cfg.mixed_precision = "bf16_mixed" + + return cfg diff --git a/tests/functional_tests/recipes/test_qwen35_vl_recipes_finetune.py b/tests/functional_tests/recipes/test_qwen35_vl_recipes_finetune.py new file mode 100644 index 0000000000..b320c3966d --- /dev/null +++ b/tests/functional_tests/recipes/test_qwen35_vl_recipes_finetune.py @@ -0,0 +1,206 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Functional smoke tests for Qwen3.5-VL finetuning recipes. + +Covers three training scenarios: +1. SFT with nothing frozen (all modules trainable) +2. SFT with language model frozen (train vision + projection) +3. SFT with vision + language frozen (train projection only) +4. SFT with activation recomputation + +Run with: + torchrun --nproc_per_node=2 -m pytest tests/functional_tests/recipes/test_qwen35_vl_recipes_finetune.py -v +""" + +import pytest + +from megatron.bridge.recipes.qwen_vl.qwen35_vl import qwen35_vl_27b_sft_config +from tests.functional_tests.recipes.utils import run_pretrain_vl_recipe_test + + +pytestmark = pytest.mark.integration + + +_TP2_PP1 = {"tensor_model_parallel_size": 2, "pipeline_model_parallel_size": 1} +_TINY_MODEL = {"num_layers": 4} + + +# --------------------------------------------------------------------------- +# Scenario 1: SFT — nothing frozen +# --------------------------------------------------------------------------- + +QWEN35_VL_SFT_NONE_FROZEN = [ + ( + qwen35_vl_27b_sft_config, + "qwen35_vl_27b_sft_none_frozen", + _TP2_PP1, + { + **_TINY_MODEL, + "freeze_language_model": False, + "freeze_vision_model": False, + "freeze_vision_projection": False, + }, + ), +] + +# --------------------------------------------------------------------------- +# Scenario 2: SFT — language model frozen +# --------------------------------------------------------------------------- + +QWEN35_VL_SFT_LM_FROZEN = [ + ( + qwen35_vl_27b_sft_config, + "qwen35_vl_27b_sft_lm_frozen", + _TP2_PP1, + { + **_TINY_MODEL, + "freeze_language_model": True, + "freeze_vision_model": False, + "freeze_vision_projection": False, + }, + ), +] + +# --------------------------------------------------------------------------- +# Scenario 3: SFT — vision + language frozen (train projection only) +# --------------------------------------------------------------------------- + +QWEN35_VL_SFT_PROJ_ONLY = [ + ( + qwen35_vl_27b_sft_config, + "qwen35_vl_27b_sft_projection_only", + _TP2_PP1, + { + **_TINY_MODEL, + "freeze_language_model": True, + "freeze_vision_model": True, + "freeze_vision_projection": False, + }, + ), +] + +# --------------------------------------------------------------------------- +# Scenario 4: SFT — activation recomputation +# --------------------------------------------------------------------------- + +QWEN35_VL_SFT_RECOMPUTE = [ + ( + qwen35_vl_27b_sft_config, + "qwen35_vl_27b_sft_recompute", + _TP2_PP1, + { + **_TINY_MODEL, + "recompute_granularity": "full", + "recompute_method": "uniform", + "recompute_num_layers": 1, + }, + ), +] + + +class TestQwen35VLFinetuneRecipes: + """Functional tests covering SFT freeze combos and recompute.""" + + @pytest.fixture(autouse=True) + def _reset_microbatch_calculator(self): + """Ensure the global microbatch calculator is cleared between tests. + + If a previous test fails mid-pretrain, destroy_global_state() never + runs and the calculator leaks into the next test. + """ + from megatron.core.num_microbatches_calculator import ( + _GLOBAL_NUM_MICROBATCHES_CALCULATOR, + destroy_num_microbatches_calculator, + ) + + if _GLOBAL_NUM_MICROBATCHES_CALCULATOR is not None: + destroy_num_microbatches_calculator() + + yield + + if _GLOBAL_NUM_MICROBATCHES_CALCULATOR is not None: + destroy_num_microbatches_calculator() + + # ----------------------------------------------------------------------- + # SFT scenarios + # ----------------------------------------------------------------------- + + @pytest.mark.run_only_on("GPU") + @pytest.mark.parametrize( + "config_func,recipe_name,parallelism_overrides,model_overrides", + QWEN35_VL_SFT_NONE_FROZEN, + ) + def test_sft_nothing_frozen(self, config_func, recipe_name, parallelism_overrides, model_overrides, tmp_path): + """Scenario 1: all modules trainable.""" + run_pretrain_vl_recipe_test( + config_func, + recipe_name, + tmp_path, + model_overrides=model_overrides, + **parallelism_overrides, + ) + + @pytest.mark.run_only_on("GPU") + @pytest.mark.parametrize( + "config_func,recipe_name,parallelism_overrides,model_overrides", + QWEN35_VL_SFT_LM_FROZEN, + ) + def test_sft_language_model_frozen( + self, config_func, recipe_name, parallelism_overrides, model_overrides, tmp_path + ): + """Scenario 2: language model frozen, train vision + projection.""" + run_pretrain_vl_recipe_test( + config_func, + recipe_name, + tmp_path, + model_overrides=model_overrides, + **parallelism_overrides, + ) + + @pytest.mark.run_only_on("GPU") + @pytest.mark.parametrize( + "config_func,recipe_name,parallelism_overrides,model_overrides", + QWEN35_VL_SFT_PROJ_ONLY, + ) + def test_sft_vision_and_language_frozen( + self, config_func, recipe_name, parallelism_overrides, model_overrides, tmp_path + ): + """Scenario 3: vision + language frozen, train projection only.""" + run_pretrain_vl_recipe_test( + config_func, + recipe_name, + tmp_path, + model_overrides=model_overrides, + **parallelism_overrides, + ) + + # ----------------------------------------------------------------------- + # Recompute + # ----------------------------------------------------------------------- + + @pytest.mark.run_only_on("GPU") + @pytest.mark.parametrize( + "config_func,recipe_name,parallelism_overrides,model_overrides", + QWEN35_VL_SFT_RECOMPUTE, + ) + def test_recompute(self, config_func, recipe_name, parallelism_overrides, model_overrides, tmp_path): + """SFT with activation recomputation.""" + run_pretrain_vl_recipe_test( + config_func, + recipe_name, + tmp_path, + model_overrides=model_overrides, + **parallelism_overrides, + ) diff --git a/tests/unit_tests/models/qwen_vl/test_qwen35_vl_bridge.py b/tests/unit_tests/models/qwen_vl/test_qwen35_vl_bridge.py index 605e6627e1..c7aa09207c 100644 --- a/tests/unit_tests/models/qwen_vl/test_qwen35_vl_bridge.py +++ b/tests/unit_tests/models/qwen_vl/test_qwen35_vl_bridge.py @@ -209,7 +209,7 @@ def test_provider_bridge_dtype_handling(self, mock_dtype, bridge, mock_pretraine def test_provider_bridge_tied_embeddings(self, bridge): text_config = _make_dense_text_config() text_config.tie_word_embeddings = True - pretrained = _make_mock_pretrained(text_config, _make_vision_config()) + pretrained = _make_mock_pretrained(text_config, _make_vision_config(), tie_word_embeddings=True) provider = bridge.provider_bridge(pretrained) assert provider.share_embeddings_and_output_weights is True diff --git a/tests/unit_tests/recipes/qwen_vl/test_qwen35_vl_recipes.py b/tests/unit_tests/recipes/qwen_vl/test_qwen35_vl_recipes.py new file mode 100644 index 0000000000..1b14d24de0 --- /dev/null +++ b/tests/unit_tests/recipes/qwen_vl/test_qwen35_vl_recipes.py @@ -0,0 +1,629 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# +# Test purpose: +# - Parametrize over all exported Qwen3.5-VL recipe functions. +# - Monkeypatch AutoBridge and the provider to avoid I/O and heavy model init. +# - Build a config and assert it forms a valid ConfigContainer. +# - Verify dataset provider selection, parallelism fields, freeze options, and PEFT defaults. +# + +import importlib +from typing import Callable + +import pytest +import torch + + +_qwen35_vl_module = importlib.import_module("megatron.bridge.recipes.qwen_vl.qwen35_vl") + +# SFT configs (parameterless) +_QWEN35_VL_SFT_FUNCS = [ + _qwen35_vl_module.qwen35_vl_800m_sft_config, + _qwen35_vl_module.qwen35_vl_2b_sft_config, + _qwen35_vl_module.qwen35_vl_4b_sft_config, + _qwen35_vl_module.qwen35_vl_9b_sft_config, + _qwen35_vl_module.qwen35_vl_27b_sft_config, + _qwen35_vl_module.qwen35_vl_35b_a3b_sft_config, + _qwen35_vl_module.qwen35_vl_122b_a10b_sft_config, + _qwen35_vl_module.qwen35_vl_397b_a17b_sft_config, +] + +# PEFT configs (take peft_scheme parameter) +_QWEN35_VL_PEFT_FUNCS = [ + _qwen35_vl_module.qwen35_vl_800m_peft_config, + _qwen35_vl_module.qwen35_vl_2b_peft_config, + _qwen35_vl_module.qwen35_vl_4b_peft_config, + _qwen35_vl_module.qwen35_vl_9b_peft_config, + _qwen35_vl_module.qwen35_vl_27b_peft_config, + _qwen35_vl_module.qwen35_vl_35b_a3b_peft_config, + _qwen35_vl_module.qwen35_vl_122b_a10b_peft_config, + _qwen35_vl_module.qwen35_vl_397b_a17b_peft_config, +] + + +class _FakeModelCfg: + """Fake model configuration for testing.""" + + def __init__(self): + self.tensor_model_parallel_size = 1 + self.pipeline_model_parallel_size = 1 + self.pipeline_dtype = None + self.virtual_pipeline_model_parallel_size = None + self.context_parallel_size = 1 + self.expert_model_parallel_size = 1 + self.expert_tensor_parallel_size = 1 + self.sequence_parallel = False + self.seq_length = 64 + self.freeze_language_model = False + self.freeze_vision_model = False + self.freeze_vision_projection = False + + def finalize(self): + return None + + +class _FakeAutoBridge: + """Fake AutoBridge for testing.""" + + @staticmethod + def from_hf_pretrained(hf_path: str): + return _FakeAutoBridge() + + def to_megatron_provider(self, load_weights: bool = False): + return _FakeModelCfg() + + +def _assert_basic_config(cfg): + """Assert that a config has all required components.""" + from megatron.bridge.training.config import ConfigContainer + + assert isinstance(cfg, ConfigContainer) + assert cfg.model is not None + assert cfg.train is not None + assert cfg.optimizer is not None + assert cfg.scheduler is not None + assert cfg.dataset is not None + assert cfg.logger is not None + assert cfg.tokenizer is not None + assert cfg.checkpoint is not None + assert cfg.rng is not None + + assert cfg.train.global_batch_size >= 1 + assert cfg.train.micro_batch_size >= 1 + assert cfg.dataset.seq_length >= 1 + + +# --------------------------------------------------------------------------- +# Basic SFT recipe building tests +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("recipe_func", _QWEN35_VL_SFT_FUNCS) +def test_each_qwen35_vl_sft_recipe_builds_config(recipe_func: Callable, monkeypatch: pytest.MonkeyPatch): + """Test that each Qwen3.5-VL SFT recipe function builds a valid configuration.""" + monkeypatch.setattr(_qwen35_vl_module, "AutoBridge", _FakeAutoBridge) + + cfg = recipe_func() + + _assert_basic_config(cfg) + + if hasattr(cfg, "tokenizer") and hasattr(cfg.tokenizer, "tokenizer_type"): + assert cfg.tokenizer.tokenizer_type == "NullTokenizer" + + assert getattr(cfg.model, "tensor_model_parallel_size", 1) >= 1 + assert getattr(cfg.model, "pipeline_model_parallel_size", 1) >= 1 + + assert hasattr(cfg.model, "freeze_language_model") + assert hasattr(cfg.model, "freeze_vision_model") + assert hasattr(cfg.model, "freeze_vision_projection") + + assert cfg.peft is None + + +# --------------------------------------------------------------------------- +# Basic PEFT recipe building tests +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("recipe_func", _QWEN35_VL_PEFT_FUNCS) +def test_each_qwen35_vl_peft_recipe_builds_config(recipe_func: Callable, monkeypatch: pytest.MonkeyPatch): + """Test that each Qwen3.5-VL PEFT recipe function builds a valid configuration.""" + monkeypatch.setattr(_qwen35_vl_module, "AutoBridge", _FakeAutoBridge) + + cfg = recipe_func() + + _assert_basic_config(cfg) + + if hasattr(cfg, "tokenizer") and hasattr(cfg.tokenizer, "tokenizer_type"): + assert cfg.tokenizer.tokenizer_type == "NullTokenizer" + + assert getattr(cfg.model, "tensor_model_parallel_size", 1) >= 1 + assert getattr(cfg.model, "pipeline_model_parallel_size", 1) >= 1 + + assert hasattr(cfg.model, "freeze_language_model") + assert hasattr(cfg.model, "freeze_vision_model") + assert hasattr(cfg.model, "freeze_vision_projection") + + assert cfg.peft is not None + assert hasattr(cfg.peft, "dim") + assert hasattr(cfg.peft, "alpha") + + +# --------------------------------------------------------------------------- +# PEFT schemes +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("recipe_func", _QWEN35_VL_PEFT_FUNCS) +@pytest.mark.parametrize("peft_scheme", ["lora", "dora"]) +def test_qwen35_vl_peft_schemes(recipe_func: Callable, peft_scheme: str, monkeypatch: pytest.MonkeyPatch): + """Test that different PEFT schemes are correctly applied for Qwen3.5-VL models.""" + monkeypatch.setattr(_qwen35_vl_module, "AutoBridge", _FakeAutoBridge) + + cfg = recipe_func(peft_scheme=peft_scheme) + + _assert_basic_config(cfg) + + assert cfg.peft is not None + assert hasattr(cfg.peft, "dim") + assert hasattr(cfg.peft, "alpha") + + +# --------------------------------------------------------------------------- +# 800M dense defaults +# --------------------------------------------------------------------------- + + +def test_qwen35_vl_800m_sft_defaults(monkeypatch: pytest.MonkeyPatch): + """800M SFT should have correct default parallelism and learning rate.""" + monkeypatch.setattr(_qwen35_vl_module, "AutoBridge", _FakeAutoBridge) + + cfg = _qwen35_vl_module.qwen35_vl_800m_sft_config() + + _assert_basic_config(cfg) + assert cfg.model.tensor_model_parallel_size == 1 + assert cfg.model.pipeline_model_parallel_size == 1 + assert cfg.peft is None + assert cfg.optimizer.lr == 5e-6 + + +def test_qwen35_vl_800m_peft_defaults(monkeypatch: pytest.MonkeyPatch): + """800M PEFT should have correct default parallelism and learning rate.""" + monkeypatch.setattr(_qwen35_vl_module, "AutoBridge", _FakeAutoBridge) + + cfg = _qwen35_vl_module.qwen35_vl_800m_peft_config() + + _assert_basic_config(cfg) + assert cfg.model.tensor_model_parallel_size == 1 + assert cfg.model.pipeline_model_parallel_size == 1 + assert cfg.peft is not None + assert cfg.optimizer.lr == 1e-4 + + +# --------------------------------------------------------------------------- +# 2B dense defaults +# --------------------------------------------------------------------------- + + +def test_qwen35_vl_2b_sft_defaults(monkeypatch: pytest.MonkeyPatch): + """2B SFT should have correct default parallelism and learning rate.""" + monkeypatch.setattr(_qwen35_vl_module, "AutoBridge", _FakeAutoBridge) + + cfg = _qwen35_vl_module.qwen35_vl_2b_sft_config() + + _assert_basic_config(cfg) + assert cfg.model.tensor_model_parallel_size == 1 + assert cfg.model.pipeline_model_parallel_size == 1 + assert cfg.peft is None + assert cfg.optimizer.lr == 5e-6 + + +def test_qwen35_vl_2b_peft_defaults(monkeypatch: pytest.MonkeyPatch): + """2B PEFT should have correct default parallelism and learning rate.""" + monkeypatch.setattr(_qwen35_vl_module, "AutoBridge", _FakeAutoBridge) + + cfg = _qwen35_vl_module.qwen35_vl_2b_peft_config() + + _assert_basic_config(cfg) + assert cfg.model.tensor_model_parallel_size == 1 + assert cfg.model.pipeline_model_parallel_size == 1 + assert cfg.peft is not None + assert cfg.optimizer.lr == 1e-4 + + +# --------------------------------------------------------------------------- +# 4B dense defaults +# --------------------------------------------------------------------------- + + +def test_qwen35_vl_4b_sft_defaults(monkeypatch: pytest.MonkeyPatch): + """4B SFT should have correct default parallelism and learning rate.""" + monkeypatch.setattr(_qwen35_vl_module, "AutoBridge", _FakeAutoBridge) + + cfg = _qwen35_vl_module.qwen35_vl_4b_sft_config() + + _assert_basic_config(cfg) + assert cfg.model.tensor_model_parallel_size == 2 + assert cfg.model.pipeline_model_parallel_size == 1 + assert cfg.peft is None + assert cfg.optimizer.lr == 5e-6 + + +def test_qwen35_vl_4b_peft_defaults(monkeypatch: pytest.MonkeyPatch): + """4B PEFT should have correct default parallelism and learning rate.""" + monkeypatch.setattr(_qwen35_vl_module, "AutoBridge", _FakeAutoBridge) + + cfg = _qwen35_vl_module.qwen35_vl_4b_peft_config() + + _assert_basic_config(cfg) + assert cfg.model.tensor_model_parallel_size == 1 + assert cfg.model.pipeline_model_parallel_size == 1 + assert cfg.peft is not None + assert cfg.optimizer.lr == 1e-4 + + +# --------------------------------------------------------------------------- +# 9B dense defaults +# --------------------------------------------------------------------------- + + +def test_qwen35_vl_9b_sft_defaults(monkeypatch: pytest.MonkeyPatch): + """9B SFT should have correct default parallelism and learning rate.""" + monkeypatch.setattr(_qwen35_vl_module, "AutoBridge", _FakeAutoBridge) + + cfg = _qwen35_vl_module.qwen35_vl_9b_sft_config() + + _assert_basic_config(cfg) + assert cfg.model.tensor_model_parallel_size == 4 + assert cfg.model.pipeline_model_parallel_size == 1 + assert cfg.peft is None + assert cfg.optimizer.lr == 5e-6 + + +def test_qwen35_vl_9b_peft_defaults(monkeypatch: pytest.MonkeyPatch): + """9B PEFT should have correct default parallelism and learning rate.""" + monkeypatch.setattr(_qwen35_vl_module, "AutoBridge", _FakeAutoBridge) + + cfg = _qwen35_vl_module.qwen35_vl_9b_peft_config() + + _assert_basic_config(cfg) + assert cfg.model.tensor_model_parallel_size == 1 + assert cfg.model.pipeline_model_parallel_size == 1 + assert cfg.peft is not None + assert cfg.optimizer.lr == 1e-4 + + +# --------------------------------------------------------------------------- +# 27B dense defaults +# --------------------------------------------------------------------------- + + +def test_qwen35_vl_27b_sft_defaults(monkeypatch: pytest.MonkeyPatch): + """27B SFT should have correct default parallelism and learning rate.""" + monkeypatch.setattr(_qwen35_vl_module, "AutoBridge", _FakeAutoBridge) + + cfg = _qwen35_vl_module.qwen35_vl_27b_sft_config() + + _assert_basic_config(cfg) + assert cfg.model.tensor_model_parallel_size == 4 + assert cfg.model.pipeline_model_parallel_size == 4 + assert cfg.model.pipeline_dtype == torch.bfloat16 + assert cfg.peft is None + assert cfg.optimizer.lr == 5e-6 + + +def test_qwen35_vl_27b_peft_lora_defaults(monkeypatch: pytest.MonkeyPatch): + """27B LoRA should have correct default parallelism and learning rate.""" + monkeypatch.setattr(_qwen35_vl_module, "AutoBridge", _FakeAutoBridge) + + cfg = _qwen35_vl_module.qwen35_vl_27b_peft_config(peft_scheme="lora") + + _assert_basic_config(cfg) + assert cfg.model.tensor_model_parallel_size == 2 + assert cfg.model.pipeline_model_parallel_size == 1 + assert cfg.model.pipeline_dtype is None + assert cfg.peft is not None + assert cfg.peft.dim == 32 + assert cfg.peft.alpha == 32 + assert cfg.optimizer.lr == 1e-4 + + +def test_qwen35_vl_27b_peft_dora_defaults(monkeypatch: pytest.MonkeyPatch): + """27B DoRA should have correct default parallelism and learning rate.""" + monkeypatch.setattr(_qwen35_vl_module, "AutoBridge", _FakeAutoBridge) + + cfg = _qwen35_vl_module.qwen35_vl_27b_peft_config(peft_scheme="dora") + + _assert_basic_config(cfg) + assert cfg.model.tensor_model_parallel_size == 2 + assert cfg.model.pipeline_model_parallel_size == 1 + assert cfg.peft is not None + assert cfg.peft.dim == 32 + assert cfg.peft.alpha == 64 + + +# --------------------------------------------------------------------------- +# 35B-A3B MoE defaults +# --------------------------------------------------------------------------- + + +def test_qwen35_vl_35b_a3b_sft_defaults(monkeypatch: pytest.MonkeyPatch): + """35B-A3B SFT should have correct default parallelism and learning rate.""" + monkeypatch.setattr(_qwen35_vl_module, "AutoBridge", _FakeAutoBridge) + + cfg = _qwen35_vl_module.qwen35_vl_35b_a3b_sft_config() + + _assert_basic_config(cfg) + assert cfg.model.tensor_model_parallel_size == 2 + assert cfg.model.pipeline_model_parallel_size == 1 + assert cfg.model.expert_model_parallel_size == 16 + assert cfg.model.pipeline_dtype == torch.bfloat16 + assert cfg.peft is None + assert cfg.optimizer.lr == 2e-5 + + +def test_qwen35_vl_35b_a3b_peft_defaults(monkeypatch: pytest.MonkeyPatch): + """35B-A3B PEFT should have correct default parallelism and learning rate.""" + monkeypatch.setattr(_qwen35_vl_module, "AutoBridge", _FakeAutoBridge) + + cfg = _qwen35_vl_module.qwen35_vl_35b_a3b_peft_config() + + _assert_basic_config(cfg) + assert cfg.model.tensor_model_parallel_size == 2 + assert cfg.model.pipeline_model_parallel_size == 1 + assert cfg.model.expert_model_parallel_size == 4 + assert cfg.peft is not None + assert cfg.optimizer.lr == 2e-4 + + +# --------------------------------------------------------------------------- +# 122B-A10B MoE defaults +# --------------------------------------------------------------------------- + + +def test_qwen35_vl_122b_a10b_sft_defaults(monkeypatch: pytest.MonkeyPatch): + """122B-A10B SFT should have correct default parallelism and learning rate.""" + monkeypatch.setattr(_qwen35_vl_module, "AutoBridge", _FakeAutoBridge) + + cfg = _qwen35_vl_module.qwen35_vl_122b_a10b_sft_config() + + _assert_basic_config(cfg) + assert cfg.model.tensor_model_parallel_size == 2 + assert cfg.model.pipeline_model_parallel_size == 6 + assert cfg.model.expert_model_parallel_size == 8 + assert cfg.model.pipeline_dtype == torch.bfloat16 + assert cfg.peft is None + assert cfg.optimizer.lr == 2e-5 + assert cfg.model.recompute_granularity == "full" + + +def test_qwen35_vl_122b_a10b_peft_defaults(monkeypatch: pytest.MonkeyPatch): + """122B-A10B PEFT should have correct default parallelism and learning rate.""" + monkeypatch.setattr(_qwen35_vl_module, "AutoBridge", _FakeAutoBridge) + + cfg = _qwen35_vl_module.qwen35_vl_122b_a10b_peft_config() + + _assert_basic_config(cfg) + assert cfg.model.tensor_model_parallel_size == 2 + assert cfg.model.pipeline_model_parallel_size == 1 + assert cfg.model.expert_model_parallel_size == 8 + assert cfg.model.pipeline_dtype == torch.bfloat16 + assert cfg.peft is not None + assert cfg.optimizer.lr == 2e-4 + + +# --------------------------------------------------------------------------- +# 397B-A17B MoE defaults +# --------------------------------------------------------------------------- + + +def test_qwen35_vl_397b_a17b_sft_defaults(monkeypatch: pytest.MonkeyPatch): + """397B-A17B SFT should have correct default parallelism and learning rate.""" + monkeypatch.setattr(_qwen35_vl_module, "AutoBridge", _FakeAutoBridge) + + cfg = _qwen35_vl_module.qwen35_vl_397b_a17b_sft_config() + + _assert_basic_config(cfg) + assert cfg.model.tensor_model_parallel_size == 2 + assert cfg.model.pipeline_model_parallel_size == 4 + assert cfg.model.expert_model_parallel_size == 32 + assert cfg.model.pipeline_dtype == torch.bfloat16 + assert cfg.peft is None + assert cfg.optimizer.lr == 2e-5 + assert cfg.model.recompute_granularity == "full" + + +def test_qwen35_vl_397b_a17b_peft_defaults(monkeypatch: pytest.MonkeyPatch): + """397B-A17B PEFT should have correct default parallelism and learning rate.""" + monkeypatch.setattr(_qwen35_vl_module, "AutoBridge", _FakeAutoBridge) + + cfg = _qwen35_vl_module.qwen35_vl_397b_a17b_peft_config() + + _assert_basic_config(cfg) + assert cfg.model.tensor_model_parallel_size == 2 + assert cfg.model.pipeline_model_parallel_size == 1 + assert cfg.model.expert_model_parallel_size == 32 + assert cfg.peft is not None + assert cfg.optimizer.lr == 2e-4 + assert cfg.model.pipeline_dtype == torch.bfloat16 + + +# --------------------------------------------------------------------------- +# Common config properties +# --------------------------------------------------------------------------- + + +def test_qwen35_vl_sft_has_hf_dataset_provider(monkeypatch: pytest.MonkeyPatch): + """Test that SFT configs use HFDatasetConversationProvider by default.""" + monkeypatch.setattr(_qwen35_vl_module, "AutoBridge", _FakeAutoBridge) + + cfg = _qwen35_vl_module.qwen35_vl_800m_sft_config() + + from megatron.bridge.data.vlm_datasets.hf_provider import HFDatasetConversationProvider + + assert isinstance(cfg.dataset, HFDatasetConversationProvider) + + +def test_qwen35_vl_peft_has_hf_dataset_provider(monkeypatch: pytest.MonkeyPatch): + """Test that PEFT configs use HFDatasetConversationProvider by default.""" + monkeypatch.setattr(_qwen35_vl_module, "AutoBridge", _FakeAutoBridge) + + cfg = _qwen35_vl_module.qwen35_vl_800m_peft_config() + + from megatron.bridge.data.vlm_datasets.hf_provider import HFDatasetConversationProvider + + assert isinstance(cfg.dataset, HFDatasetConversationProvider) + + +def test_qwen35_vl_sft_freeze_defaults(monkeypatch: pytest.MonkeyPatch): + """Test that SFT configs have freeze options set to False by default.""" + monkeypatch.setattr(_qwen35_vl_module, "AutoBridge", _FakeAutoBridge) + + cfg = _qwen35_vl_module.qwen35_vl_800m_sft_config() + + assert cfg.model.freeze_language_model is False + assert cfg.model.freeze_vision_model is False + assert cfg.model.freeze_vision_projection is False + + +def test_qwen35_vl_peft_freeze_defaults(monkeypatch: pytest.MonkeyPatch): + """Test that PEFT configs have freeze options set to False by default.""" + monkeypatch.setattr(_qwen35_vl_module, "AutoBridge", _FakeAutoBridge) + + cfg = _qwen35_vl_module.qwen35_vl_800m_peft_config() + + assert cfg.model.freeze_language_model is False + assert cfg.model.freeze_vision_model is False + assert cfg.model.freeze_vision_projection is False + + +def test_qwen35_vl_precision_config(monkeypatch: pytest.MonkeyPatch): + """Test that precision config is correctly set.""" + monkeypatch.setattr(_qwen35_vl_module, "AutoBridge", _FakeAutoBridge) + + cfg = _qwen35_vl_module.qwen35_vl_800m_sft_config() + + _assert_basic_config(cfg) + assert cfg.mixed_precision == "bf16_mixed" + + +def test_qwen35_vl_ddp_config(monkeypatch: pytest.MonkeyPatch): + """Test that DDP config is correctly set for VLMs.""" + monkeypatch.setattr(_qwen35_vl_module, "AutoBridge", _FakeAutoBridge) + + cfg = _qwen35_vl_module.qwen35_vl_800m_sft_config() + + _assert_basic_config(cfg) + assert cfg.ddp.overlap_grad_reduce is False + assert cfg.ddp.overlap_param_gather is False + assert cfg.ddp.check_for_nan_in_grad is True + assert cfg.ddp.use_distributed_optimizer is True + + +def test_qwen35_vl_optimizer_precision_defaults(monkeypatch: pytest.MonkeyPatch): + """Test that optimizer precision settings are correctly configured.""" + monkeypatch.setattr(_qwen35_vl_module, "AutoBridge", _FakeAutoBridge) + + cfg = _qwen35_vl_module.qwen35_vl_800m_sft_config() + + _assert_basic_config(cfg) + assert cfg.optimizer.use_precision_aware_optimizer is False + assert cfg.optimizer.main_grads_dtype == torch.float32 + assert cfg.optimizer.main_params_dtype == torch.float32 + assert cfg.optimizer.exp_avg_dtype == torch.float32 + assert cfg.optimizer.exp_avg_sq_dtype == torch.float32 + + +def test_qwen35_vl_training_config(monkeypatch: pytest.MonkeyPatch): + """Test that training configuration is correctly set.""" + monkeypatch.setattr(_qwen35_vl_module, "AutoBridge", _FakeAutoBridge) + + cfg = _qwen35_vl_module.qwen35_vl_800m_sft_config() + + _assert_basic_config(cfg) + assert cfg.train.train_iters == 300000 + assert cfg.train.global_batch_size == 32 + assert cfg.train.micro_batch_size == 1 + assert cfg.train.manual_gc is True + assert cfg.train.manual_gc_interval == 100 + + +def test_qwen35_vl_validation_config(monkeypatch: pytest.MonkeyPatch): + """Test that validation configuration is correctly set.""" + monkeypatch.setattr(_qwen35_vl_module, "AutoBridge", _FakeAutoBridge) + + cfg = _qwen35_vl_module.qwen35_vl_800m_sft_config() + + _assert_basic_config(cfg) + assert cfg.validation.eval_interval == 500 + assert cfg.validation.eval_iters == 32 + + +def test_qwen35_vl_sft_learning_rate(monkeypatch: pytest.MonkeyPatch): + """Test that SFT has lower learning rate than PEFT.""" + monkeypatch.setattr(_qwen35_vl_module, "AutoBridge", _FakeAutoBridge) + + sft_cfg = _qwen35_vl_module.qwen35_vl_800m_sft_config() + peft_cfg = _qwen35_vl_module.qwen35_vl_800m_peft_config() + + assert sft_cfg.optimizer.lr < peft_cfg.optimizer.lr + + +def test_qwen35_vl_kernel_settings(monkeypatch: pytest.MonkeyPatch): + """Test that kernel settings are correctly configured.""" + monkeypatch.setattr(_qwen35_vl_module, "AutoBridge", _FakeAutoBridge) + + cfg = _qwen35_vl_module.qwen35_vl_800m_sft_config() + + _assert_basic_config(cfg) + assert cfg.model.attention_backend == "auto" + assert cfg.model.cross_entropy_loss_fusion is True + assert cfg.model.cross_entropy_fusion_impl == "native" + + +def test_qwen35_vl_cuda_graph_settings(monkeypatch: pytest.MonkeyPatch): + """Test that CUDA graph settings are correctly configured.""" + monkeypatch.setattr(_qwen35_vl_module, "AutoBridge", _FakeAutoBridge) + + cfg = _qwen35_vl_module.qwen35_vl_800m_sft_config() + + _assert_basic_config(cfg) + assert cfg.model.cuda_graph_impl == "none" + assert cfg.model.cuda_graph_scope == "full" + assert cfg.model.cuda_graph_warmup_steps == 3 + + +def test_qwen35_vl_transformer_impl(monkeypatch: pytest.MonkeyPatch): + """Test that transformer implementation is set correctly.""" + monkeypatch.setattr(_qwen35_vl_module, "AutoBridge", _FakeAutoBridge) + + cfg = _qwen35_vl_module.qwen35_vl_800m_sft_config() + + _assert_basic_config(cfg) + assert cfg.model.transformer_impl == "transformer_engine" + + +def test_qwen35_vl_memory_saving_defaults(monkeypatch: pytest.MonkeyPatch): + """Test that memory saving settings are disabled by default.""" + monkeypatch.setattr(_qwen35_vl_module, "AutoBridge", _FakeAutoBridge) + + cfg = _qwen35_vl_module.qwen35_vl_800m_sft_config() + + _assert_basic_config(cfg) + assert cfg.model.recompute_granularity is None + assert cfg.model.recompute_modules is None + assert cfg.model.fine_grained_activation_offloading is False + assert cfg.model.offload_modules is None