diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 000000000..dc0cc7cbc --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1,3 @@ +.github/CODEOWNERS @fzyzcjy @Ying1123 +.github/workflows/ @yushengsu-thu +/miles/ @fzyzcjy @yueming-yuan diff --git a/.github/workflows/conda-ci.yml b/.github/workflows/conda-ci.yml new file mode 100644 index 000000000..332ced7f2 --- /dev/null +++ b/.github/workflows/conda-ci.yml @@ -0,0 +1,90 @@ +name: conda CI + +on: + pull_request: + branches: [main] + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +jobs: + build-conda: + if: contains(github.event.pull_request.title, '[release]') + runs-on: self-hosted + container: + image: lmsysorg/sglang:v0.5.0rc0-cu126 + options: --gpus all --ipc=host --shm-size=16g --ulimit memlock=-1 --ulimit stack=67108864 --memory=0 --memory-swap=0 -v /mnt/nvme0n1/models:/root/models -v /mnt/nvme0n1/datasets:/root/datasets + + defaults: + run: + working-directory: ${{ github.workspace }} + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Construct Conda + run: | + echo "๐Ÿ“ฆ Installing miles..." + cd $GITHUB_WORKSPACE + echo "Current directory: $(pwd)" + + mkdir -p /root/ + BASE_DIR=/root bash build_conda.sh + shell: bash + + - name: Download model and dataset + run: | + echo "๐Ÿ”— Downloading up model and dataset..." + + # Create cache directories if they don't exist + mkdir -p /root/models /root/datasets + + echo "Downloading Qwen3-30B-A3B..." + hf download Qwen/Qwen3-30B-A3B --local-dir /root/models/Qwen3-30B-A3B + hf download Qwen/Qwen3-30B-A3B-FP8 --local-dir /root/models/Qwen3-30B-A3B-FP8 + + hf download --repo-type dataset zhuzilin/dapo-math-17k --local-dir /root/datasets/dapo-math-17k + + hf download --repo-type dataset zhuzilin/aime-2024 --local-dir /root/datasets/aime-2024 + shell: bash + + - name: Convert checkpoint + run: | + echo "๐Ÿ”„ Converting model checkpoint..." + cd $GITHUB_WORKSPACE + echo "Current directory: $(pwd)" + + source ~/.bashrc + micromamba activate miles + export CUDA_HOME="$CONDA_PREFIX" + + source scripts/models/qwen3-30B-A3B.sh + PYTHONPATH=/root/Megatron-LM torchrun --nproc-per-node 8 tools/convert_hf_to_torch_dist.py \ + ${MODEL_ARGS[@]} \ + --hf-checkpoint /root/models/Qwen3-30B-A3B \ + --save /root/Qwen3-30B-A3B_torch_dist + shell: bash + + - name: Run tests + run: | + echo "๐Ÿงช Running tests..." + cd $GITHUB_WORKSPACE + echo "Current directory: $(pwd)" + + source ~/.bashrc + micromamba activate miles + export CUDA_HOME="$CONDA_PREFIX" + + MILES_TEST_USE_DEEPEP=0 MILES_TEST_USE_FP8_ROLLOUT=0 python tests/test_qwen3_30B_A3B.py + shell: bash + + - name: Cleanup + if: always() + run: | + echo "๐Ÿงน Cleaning up..." + pkill -9 ray || true + ray stop --force || true + pkill -9 python || true + shell: bash diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index e649da717..4b8b5dc82 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -25,6 +25,46 @@ concurrency: jobs: + fast: + if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request) + runs-on: self-hosted + container: + image: radixark/miles:latest + options: > + --gpus all + --ipc=host + --shm-size=16g + --ulimit memlock=-1 + --ulimit stack=67108864 + --memory=0 + --memory-swap=0 + -v /mnt/nvme0n1/miles_ci:/data/miles_ci + -v /mnt/nvme0n1/miles_ci/models:/root/models + -v /mnt/nvme0n1/miles_ci/datasets:/root/datasets + strategy: + fail-fast: false + matrix: + info: [{"num_gpus": 0, "test_file": "fast"}] + defaults: + run: + working-directory: ${{ github.workspace }} + env: + GITHUB_COMMIT_NAME: ${{ github.sha }}_${{ github.event.pull_request.number || 'non-pr' }} + WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }} + MILES_TEST_ENABLE_INFINITE_RUN: ${{ (github.event_name == 'workflow_dispatch' && github.event.inputs.infinite_run) || 'false' }} + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Install + shell: bash + run: cd $GITHUB_WORKSPACE && pip install -e . --no-deps --break-system-packages + + - name: Execute + shell: bash + run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- pytest tests/${{ matrix.info.test_file }} + e2e-test-short: if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request && contains(github.event.pull_request.labels.*.name, 'run-ci-short')) runs-on: self-hosted @@ -33,18 +73,21 @@ jobs: options: > --gpus all --ipc=host - --shm-size=16g + --shm-size=32g --ulimit memlock=-1 --ulimit stack=67108864 --memory=0 --memory-swap=0 - -v /data/miles_ci:/data/miles_ci - -v /data/miles_ci/models:/root/models - -v /data/miles_ci/datasets:/root/datasets + -v /mnt/nvme0n1/miles_ci:/data/miles_ci + -v /mnt/nvme0n1/miles_ci/models:/root/models + -v /mnt/nvme0n1/miles_ci/datasets:/root/datasets + --privileged + --ulimit nofile=65535:65535 + -v /tmp:/tmp strategy: fail-fast: false matrix: - info: [{"num_gpus": 8, "test_file": "test_quick_start_glm4_9B.py"}, {"num_gpus": 8, "test_file": "test_qwen3_30B_A3B.py"}] + info: [{"num_gpus": 4, "test_file": "test_qwen2.5_0.5B_gsm8k_async_short.py"}, {"num_gpus": 4, "test_file": "test_qwen2.5_0.5B_gsm8k_short.py"}, {"num_gpus": 2, "test_file": "test_qwen3_0.6B_fsdp_colocated_2xGPU.py"}] defaults: run: working-directory: ${{ github.workspace }} @@ -52,19 +95,311 @@ jobs: GITHUB_COMMIT_NAME: ${{ github.sha }}_${{ github.event.pull_request.number || 'non-pr' }} WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }} MILES_TEST_ENABLE_INFINITE_RUN: ${{ (github.event_name == 'workflow_dispatch' && github.event.inputs.infinite_run) || 'false' }} + MILES_TEST_USE_DEEPEP: ${{ matrix.info.use_deepep || '0' }} + MILES_TEST_USE_FP8_ROLLOUT: ${{ matrix.info.use_fp8_rollout || '0' }} + MILES_TEST_ENABLE_EVAL: ${{ matrix.info.enable_eval || '1' }} steps: - name: Checkout repository uses: actions/checkout@v4 + - name: Cleanup Ray processes + shell: bash + run: | + pkill -9 -f 'ray::' 2>/dev/null || true + pkill -9 -f raylet 2>/dev/null || true + pkill -9 -f gcs_server 2>/dev/null || true + pkill -9 -f 'ray-dashboard' 2>/dev/null || true + pkill -9 sglang 2>/dev/null || true + ray stop --force 2>/dev/null || true + rm -rf /tmp/ray/* 2>/dev/null || true + sleep 3 + - name: Install shell: bash - run: cd $GITHUB_WORKSPACE && pip install -e . + run: cd $GITHUB_WORKSPACE && pip install -e . --no-deps --break-system-packages - name: Execute shell: bash run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }} + - name: Post-test cleanup + if: always() + shell: bash + run: | + pkill -9 -f 'ray::' 2>/dev/null || true + pkill -9 -f raylet 2>/dev/null || true + ray stop --force 2>/dev/null || true + rm -rf /tmp/ray/* 2>/dev/null || true + + e2e-test-fsdp: + if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request && contains(github.event.pull_request.labels.*.name, 'run-ci-fsdp')) + runs-on: self-hosted + container: + image: radixark/miles:latest + options: > + --gpus all + --ipc=host + --shm-size=32g + --ulimit memlock=-1 + --ulimit stack=67108864 + --memory=0 + --memory-swap=0 + -v /mnt/nvme0n1/miles_ci:/data/miles_ci + -v /mnt/nvme0n1/miles_ci/models:/root/models + -v /mnt/nvme0n1/miles_ci/datasets:/root/datasets + --privileged + --ulimit nofile=65535:65535 + -v /tmp:/tmp + strategy: + fail-fast: false + matrix: + info: [{"num_gpus": 2, "test_file": "test_qwen3_4B_fsdp_true_on_policy.py"}, {"num_gpus": 8, "test_file": "test_qwen3_vl_4B_fsdp.py"}, {"num_gpus": 2, "test_file": "test_qwen3_0.6B_fsdp_distributed.py"}, {"num_gpus": 4, "test_file": "test_qwen3_0.6B_megatron_fsdp_align.py"}] + defaults: + run: + working-directory: ${{ github.workspace }} + env: + GITHUB_COMMIT_NAME: ${{ github.sha }}_${{ github.event.pull_request.number || 'non-pr' }} + WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }} + MILES_TEST_ENABLE_INFINITE_RUN: ${{ (github.event_name == 'workflow_dispatch' && github.event.inputs.infinite_run) || 'false' }} + MILES_TEST_USE_DEEPEP: ${{ matrix.info.use_deepep || '0' }} + MILES_TEST_USE_FP8_ROLLOUT: ${{ matrix.info.use_fp8_rollout || '0' }} + MILES_TEST_ENABLE_EVAL: ${{ matrix.info.enable_eval || '1' }} + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Cleanup Ray processes + shell: bash + run: | + pkill -9 -f 'ray::' 2>/dev/null || true + pkill -9 -f raylet 2>/dev/null || true + pkill -9 -f gcs_server 2>/dev/null || true + pkill -9 -f 'ray-dashboard' 2>/dev/null || true + pkill -9 sglang 2>/dev/null || true + ray stop --force 2>/dev/null || true + rm -rf /tmp/ray/* 2>/dev/null || true + sleep 3 + + - name: Install + shell: bash + run: cd $GITHUB_WORKSPACE && pip install -e . --no-deps --break-system-packages + + - name: Execute + shell: bash + run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }} + + - name: Post-test cleanup + if: always() + shell: bash + run: | + pkill -9 -f 'ray::' 2>/dev/null || true + pkill -9 -f raylet 2>/dev/null || true + ray stop --force 2>/dev/null || true + rm -rf /tmp/ray/* 2>/dev/null || true + + e2e-test-megatron: + if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request && contains(github.event.pull_request.labels.*.name, 'run-ci-megatron')) + runs-on: self-hosted + container: + image: radixark/miles:latest + options: > + --gpus all + --ipc=host + --shm-size=32g + --ulimit memlock=-1 + --ulimit stack=67108864 + --memory=0 + --memory-swap=0 + -v /mnt/nvme0n1/miles_ci:/data/miles_ci + -v /mnt/nvme0n1/miles_ci/models:/root/models + -v /mnt/nvme0n1/miles_ci/datasets:/root/datasets + --privileged + --ulimit nofile=65535:65535 + -v /tmp:/tmp + strategy: + fail-fast: false + matrix: + info: [{"num_gpus": 8, "test_file": "test_quick_start_glm4_9B.py"}, {"num_gpus": 8, "test_file": "test_qwen3_30B_A3B.py", "use_deepep": "1", "use_fp8_rollout": "1"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_qwen3_30B_A3B_r3.py", "use_deepep": "1", "use_fp8_rollout": "1"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_qwen3_30B_A3B_r3.py"}, {"num_gpus": 8, "test_file": "test_qwen3_4B_ppo.py"}, {"num_gpus": 8, "test_file": "test_moonlight_16B_A3B.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_moonlight_16B_A3B_r3.py"}, {"num_gpus": 8, "test_file": "test_mimo_7B_mtp_only_grad.py"}] + defaults: + run: + working-directory: ${{ github.workspace }} + env: + GITHUB_COMMIT_NAME: ${{ github.sha }}_${{ github.event.pull_request.number || 'non-pr' }} + WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }} + MILES_TEST_ENABLE_INFINITE_RUN: ${{ (github.event_name == 'workflow_dispatch' && github.event.inputs.infinite_run) || 'false' }} + MILES_TEST_USE_DEEPEP: ${{ matrix.info.use_deepep || '0' }} + MILES_TEST_USE_FP8_ROLLOUT: ${{ matrix.info.use_fp8_rollout || '0' }} + MILES_TEST_ENABLE_EVAL: ${{ matrix.info.enable_eval || '1' }} + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Cleanup Ray processes + shell: bash + run: | + pkill -9 -f 'ray::' 2>/dev/null || true + pkill -9 -f raylet 2>/dev/null || true + pkill -9 -f gcs_server 2>/dev/null || true + pkill -9 -f 'ray-dashboard' 2>/dev/null || true + pkill -9 sglang 2>/dev/null || true + ray stop --force 2>/dev/null || true + rm -rf /tmp/ray/* 2>/dev/null || true + sleep 3 + + - name: Install + shell: bash + run: cd $GITHUB_WORKSPACE && pip install -e . --no-deps --break-system-packages + + - name: Execute + shell: bash + run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }} + + - name: Post-test cleanup + if: always() + shell: bash + run: | + pkill -9 -f 'ray::' 2>/dev/null || true + pkill -9 -f raylet 2>/dev/null || true + ray stop --force 2>/dev/null || true + rm -rf /tmp/ray/* 2>/dev/null || true + + e2e-test-precision: + if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request && contains(github.event.pull_request.labels.*.name, 'run-ci-precision')) + runs-on: self-hosted + container: + image: radixark/miles:latest + options: > + --gpus all + --ipc=host + --shm-size=32g + --ulimit memlock=-1 + --ulimit stack=67108864 + --memory=0 + --memory-swap=0 + -v /mnt/nvme0n1/miles_ci:/data/miles_ci + -v /mnt/nvme0n1/miles_ci/models:/root/models + -v /mnt/nvme0n1/miles_ci/datasets:/root/datasets + --privileged + --ulimit nofile=65535:65535 + -v /tmp:/tmp + strategy: + fail-fast: false + matrix: + info: [{"num_gpus": 8, "test_file": "test_qwen3_0.6B_parallel_check.py"}, {"num_gpus": 4, "test_file": "test_qwen3_0.6B_megatron_fsdp_align.py"}] + defaults: + run: + working-directory: ${{ github.workspace }} + env: + GITHUB_COMMIT_NAME: ${{ github.sha }}_${{ github.event.pull_request.number || 'non-pr' }} + WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }} + MILES_TEST_ENABLE_INFINITE_RUN: ${{ (github.event_name == 'workflow_dispatch' && github.event.inputs.infinite_run) || 'false' }} + MILES_TEST_USE_DEEPEP: ${{ matrix.info.use_deepep || '0' }} + MILES_TEST_USE_FP8_ROLLOUT: ${{ matrix.info.use_fp8_rollout || '0' }} + MILES_TEST_ENABLE_EVAL: ${{ matrix.info.enable_eval || '1' }} + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Cleanup Ray processes + shell: bash + run: | + pkill -9 -f 'ray::' 2>/dev/null || true + pkill -9 -f raylet 2>/dev/null || true + pkill -9 -f gcs_server 2>/dev/null || true + pkill -9 -f 'ray-dashboard' 2>/dev/null || true + pkill -9 sglang 2>/dev/null || true + ray stop --force 2>/dev/null || true + rm -rf /tmp/ray/* 2>/dev/null || true + sleep 3 + + - name: Install + shell: bash + run: cd $GITHUB_WORKSPACE && pip install -e . --no-deps --break-system-packages + + - name: Execute + shell: bash + run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }} + + - name: Post-test cleanup + if: always() + shell: bash + run: | + pkill -9 -f 'ray::' 2>/dev/null || true + pkill -9 -f raylet 2>/dev/null || true + ray stop --force 2>/dev/null || true + rm -rf /tmp/ray/* 2>/dev/null || true + + e2e-test-ckpt: + if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request && contains(github.event.pull_request.labels.*.name, 'run-ci-ckpt')) + runs-on: self-hosted + container: + image: radixark/miles:latest + options: > + --gpus all + --ipc=host + --shm-size=32g + --ulimit memlock=-1 + --ulimit stack=67108864 + --memory=0 + --memory-swap=0 + -v /mnt/nvme0n1/miles_ci:/data/miles_ci + -v /mnt/nvme0n1/miles_ci/models:/root/models + -v /mnt/nvme0n1/miles_ci/datasets:/root/datasets + --privileged + --ulimit nofile=65535:65535 + -v /tmp:/tmp + strategy: + fail-fast: false + matrix: + info: [{"num_gpus": 8, "test_file": "test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_file": "test_qwen3_4B_ckpt.py --async-save"}] + defaults: + run: + working-directory: ${{ github.workspace }} + env: + GITHUB_COMMIT_NAME: ${{ github.sha }}_${{ github.event.pull_request.number || 'non-pr' }} + WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }} + MILES_TEST_ENABLE_INFINITE_RUN: ${{ (github.event_name == 'workflow_dispatch' && github.event.inputs.infinite_run) || 'false' }} + MILES_TEST_USE_DEEPEP: ${{ matrix.info.use_deepep || '0' }} + MILES_TEST_USE_FP8_ROLLOUT: ${{ matrix.info.use_fp8_rollout || '0' }} + MILES_TEST_ENABLE_EVAL: ${{ matrix.info.enable_eval || '1' }} + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Cleanup Ray processes + shell: bash + run: | + pkill -9 -f 'ray::' 2>/dev/null || true + pkill -9 -f raylet 2>/dev/null || true + pkill -9 -f gcs_server 2>/dev/null || true + pkill -9 -f 'ray-dashboard' 2>/dev/null || true + pkill -9 sglang 2>/dev/null || true + ray stop --force 2>/dev/null || true + rm -rf /tmp/ray/* 2>/dev/null || true + sleep 3 + + - name: Install + shell: bash + run: cd $GITHUB_WORKSPACE && pip install -e . --no-deps --break-system-packages + + - name: Execute + shell: bash + run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }} + + - name: Post-test cleanup + if: always() + shell: bash + run: | + pkill -9 -f 'ray::' 2>/dev/null || true + pkill -9 -f raylet 2>/dev/null || true + ray stop --force 2>/dev/null || true + rm -rf /tmp/ray/* 2>/dev/null || true + e2e-test-long: if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request && contains(github.event.pull_request.labels.*.name, 'run-ci-long')) runs-on: self-hosted @@ -73,18 +408,21 @@ jobs: options: > --gpus all --ipc=host - --shm-size=16g + --shm-size=32g --ulimit memlock=-1 --ulimit stack=67108864 --memory=0 --memory-swap=0 - -v /data/miles_ci:/data/miles_ci - -v /data/miles_ci/models:/root/models - -v /data/miles_ci/datasets:/root/datasets + -v /mnt/nvme0n1/miles_ci:/data/miles_ci + -v /mnt/nvme0n1/miles_ci/models:/root/models + -v /mnt/nvme0n1/miles_ci/datasets:/root/datasets + --privileged + --ulimit nofile=65535:65535 + -v /tmp:/tmp strategy: fail-fast: false matrix: - info: [{"num_gpus": 2, "test_file": "test_qwen2.5_0.5B_gsm8k.py"}, {"num_gpus": 2, "test_file": "test_qwen2.5_0.5B_gsm8k_async.py"}, {"num_gpus": 2, "test_file": "test_qwen3_0.6B_fsdp_colocated_2xGPU.py"}, {"num_gpus": 2, "test_file": "test_qwen3_0.6B_fsdp_distributed.py"}] + info: [{"num_gpus": 2, "test_file": "test_qwen2.5_0.5B_gsm8k.py"}, {"num_gpus": 2, "test_file": "test_qwen2.5_0.5B_gsm8k_async.py"}] defaults: run: working-directory: ${{ github.workspace }} @@ -92,15 +430,106 @@ jobs: GITHUB_COMMIT_NAME: ${{ github.sha }}_${{ github.event.pull_request.number || 'non-pr' }} WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }} MILES_TEST_ENABLE_INFINITE_RUN: ${{ (github.event_name == 'workflow_dispatch' && github.event.inputs.infinite_run) || 'false' }} + MILES_TEST_USE_DEEPEP: ${{ matrix.info.use_deepep || '0' }} + MILES_TEST_USE_FP8_ROLLOUT: ${{ matrix.info.use_fp8_rollout || '0' }} + MILES_TEST_ENABLE_EVAL: ${{ matrix.info.enable_eval || '1' }} steps: - name: Checkout repository uses: actions/checkout@v4 + - name: Cleanup Ray processes + shell: bash + run: | + pkill -9 -f 'ray::' 2>/dev/null || true + pkill -9 -f raylet 2>/dev/null || true + pkill -9 -f gcs_server 2>/dev/null || true + pkill -9 -f 'ray-dashboard' 2>/dev/null || true + pkill -9 sglang 2>/dev/null || true + ray stop --force 2>/dev/null || true + rm -rf /tmp/ray/* 2>/dev/null || true + sleep 3 + - name: Install shell: bash - run: cd $GITHUB_WORKSPACE && pip install -e . + run: cd $GITHUB_WORKSPACE && pip install -e . --no-deps --break-system-packages - name: Execute shell: bash run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }} + + - name: Post-test cleanup + if: always() + shell: bash + run: | + pkill -9 -f 'ray::' 2>/dev/null || true + pkill -9 -f raylet 2>/dev/null || true + ray stop --force 2>/dev/null || true + rm -rf /tmp/ray/* 2>/dev/null || true + + e2e-test-image: + if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request && contains(github.event.pull_request.labels.*.name, 'run-ci-image')) + runs-on: self-hosted + container: + image: radixark/miles-test:latest + options: > + --gpus all + --ipc=host + --shm-size=32g + --ulimit memlock=-1 + --ulimit stack=67108864 + --memory=0 + --memory-swap=0 + -v /mnt/nvme0n1/miles_ci:/data/miles_ci + -v /mnt/nvme0n1/miles_ci/models:/root/models + -v /mnt/nvme0n1/miles_ci/datasets:/root/datasets + --privileged + --ulimit nofile=65535:65535 + -v /tmp:/tmp + strategy: + fail-fast: false + matrix: + info: [{"num_gpus": 4, "test_file": "test_qwen2.5_0.5B_gsm8k_async_short.py"}, {"num_gpus": 4, "test_file": "test_qwen2.5_0.5B_gsm8k_short.py"}, {"num_gpus": 2, "test_file": "test_qwen3_0.6B_fsdp_colocated_2xGPU.py"}, {"num_gpus": 2, "test_file": "test_qwen3_4B_fsdp_true_on_policy.py"}, {"num_gpus": 8, "test_file": "test_qwen3_vl_4B_fsdp.py"}, {"num_gpus": 2, "test_file": "test_qwen3_0.6B_fsdp_distributed.py"}, {"num_gpus": 8, "test_file": "test_quick_start_glm4_9B.py"}, {"num_gpus": 8, "test_file": "test_qwen3_30B_A3B.py"}, {"num_gpus": 8, "test_file": "test_qwen3_4B_ppo.py"}, {"num_gpus": 8, "test_file": "test_moonlight_16B_A3B.py"}, {"num_gpus": 8, "test_file": "test_mimo_7B_mtp_only_grad.py"}, {"num_gpus": 8, "test_file": "test_qwen3_0.6B_parallel_check.py"}, {"num_gpus": 4, "test_file": "test_qwen3_0.6B_megatron_fsdp_align.py"}, {"num_gpus": 8, "test_file": "test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_file": "test_qwen3_4B_ckpt.py --async-save"}, {"num_gpus": 2, "test_file": "test_qwen2.5_0.5B_gsm8k.py"}, {"num_gpus": 2, "test_file": "test_qwen2.5_0.5B_gsm8k_async.py"}] + defaults: + run: + working-directory: ${{ github.workspace }} + env: + GITHUB_COMMIT_NAME: ${{ github.sha }}_${{ github.event.pull_request.number || 'non-pr' }} + WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }} + MILES_TEST_ENABLE_INFINITE_RUN: ${{ (github.event_name == 'workflow_dispatch' && github.event.inputs.infinite_run) || 'false' }} + MILES_TEST_USE_DEEPEP: ${{ matrix.info.use_deepep || '0' }} + MILES_TEST_USE_FP8_ROLLOUT: ${{ matrix.info.use_fp8_rollout || '0' }} + MILES_TEST_ENABLE_EVAL: ${{ matrix.info.enable_eval || '1' }} + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Cleanup Ray processes + shell: bash + run: | + pkill -9 -f 'ray::' 2>/dev/null || true + pkill -9 -f raylet 2>/dev/null || true + pkill -9 -f gcs_server 2>/dev/null || true + pkill -9 -f 'ray-dashboard' 2>/dev/null || true + pkill -9 sglang 2>/dev/null || true + ray stop --force 2>/dev/null || true + rm -rf /tmp/ray/* 2>/dev/null || true + sleep 3 + + - name: Install + shell: bash + run: cd $GITHUB_WORKSPACE && pip install -e . --no-deps --break-system-packages + + - name: Execute + shell: bash + run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }} + + - name: Post-test cleanup + if: always() + shell: bash + run: | + pkill -9 -f 'ray::' 2>/dev/null || true + pkill -9 -f raylet 2>/dev/null || true + ray stop --force 2>/dev/null || true + rm -rf /tmp/ray/* 2>/dev/null || true diff --git a/.github/workflows/pr-test.yml.j2 b/.github/workflows/pr-test.yml.j2 index 06d6ed570..c052b8494 100644 --- a/.github/workflows/pr-test.yml.j2 +++ b/.github/workflows/pr-test.yml.j2 @@ -1,9 +1,52 @@ <% set jobs = { + 'fast': { + 'test_executor': 'pytest', + 'tests': [ + {'test_file': 'fast', 'num_gpus': 0}, + ], + }, 'e2e-test-short': { 'label': 'run-ci-short', + 'tests': [ + {'test_file': 'test_qwen2.5_0.5B_gsm8k_async_short.py', 'num_gpus': 4}, + {'test_file': 'test_qwen2.5_0.5B_gsm8k_short.py', 'num_gpus': 4}, + {'test_file': 'test_qwen3_0.6B_fsdp_colocated_2xGPU.py', 'num_gpus': 2}, + ], + }, + 'e2e-test-fsdp': { + 'label': 'run-ci-fsdp', + 'tests': [ + {'test_file': 'test_qwen3_4B_fsdp_true_on_policy.py', 'num_gpus': 2}, + {'test_file': 'test_qwen3_vl_4B_fsdp.py', 'num_gpus': 8}, + {'test_file': 'test_qwen3_0.6B_fsdp_distributed.py', 'num_gpus': 2}, + {'test_file': 'test_qwen3_0.6B_megatron_fsdp_align.py', 'num_gpus': 4}, + ], + }, + 'e2e-test-megatron': { + 'label': 'run-ci-megatron', 'tests': [ {'test_file': 'test_quick_start_glm4_9B.py', 'num_gpus': 8}, - {'test_file': 'test_qwen3_30B_A3B.py', 'num_gpus': 8}, + {'test_file': 'test_qwen3_30B_A3B.py', 'num_gpus': 8, 'use_deepep': '1', 'use_fp8_rollout': '1'}, + {'test_file': 'test_qwen3_30B_A3B_r3.py', 'num_gpus': 8, 'use_deepep': '1', 'use_fp8_rollout': '1', 'enable_eval': '0'}, + {'test_file': 'test_qwen3_30B_A3B_r3.py', 'num_gpus': 8, 'enable_eval': '0'}, + {'test_file': 'test_qwen3_4B_ppo.py', 'num_gpus': 8}, + {'test_file': 'test_moonlight_16B_A3B.py', 'num_gpus': 8}, + {'test_file': 'test_moonlight_16B_A3B_r3.py', 'num_gpus': 8, 'enable_eval': '0'}, + {'test_file': 'test_mimo_7B_mtp_only_grad.py', 'num_gpus': 8}, + ], + }, + 'e2e-test-precision': { + 'label': 'run-ci-precision', + 'tests': [ + {'test_file': 'test_qwen3_0.6B_parallel_check.py', 'num_gpus': 8}, + {'test_file': 'test_qwen3_0.6B_megatron_fsdp_align.py', 'num_gpus': 4}, + ], + }, + 'e2e-test-ckpt': { + 'label': 'run-ci-ckpt', + 'tests': [ + {'test_file': 'test_qwen3_4B_ckpt.py', 'num_gpus': 8}, + {'test_file': 'test_qwen3_4B_ckpt.py --async-save', 'num_gpus': 8}, ], }, 'e2e-test-long': { @@ -11,8 +54,29 @@ 'tests': [ {'test_file': 'test_qwen2.5_0.5B_gsm8k.py', 'num_gpus': 2}, {'test_file': 'test_qwen2.5_0.5B_gsm8k_async.py', 'num_gpus': 2}, + ], + }, + 'e2e-test-image': { + 'label': 'run-ci-image', + 'image': 'radixark/miles-test:latest', + 'tests': [ + {'test_file': 'test_qwen2.5_0.5B_gsm8k_async_short.py', 'num_gpus': 4}, + {'test_file': 'test_qwen2.5_0.5B_gsm8k_short.py', 'num_gpus': 4}, {'test_file': 'test_qwen3_0.6B_fsdp_colocated_2xGPU.py', 'num_gpus': 2}, + {'test_file': 'test_qwen3_4B_fsdp_true_on_policy.py', 'num_gpus': 2}, + {'test_file': 'test_qwen3_vl_4B_fsdp.py', 'num_gpus': 8}, {'test_file': 'test_qwen3_0.6B_fsdp_distributed.py', 'num_gpus': 2}, + {'test_file': 'test_quick_start_glm4_9B.py', 'num_gpus': 8}, + {'test_file': 'test_qwen3_30B_A3B.py', 'num_gpus': 8}, + {'test_file': 'test_qwen3_4B_ppo.py', 'num_gpus': 8}, + {'test_file': 'test_moonlight_16B_A3B.py', 'num_gpus': 8}, + {'test_file': 'test_mimo_7B_mtp_only_grad.py', 'num_gpus': 8}, + {'test_file': 'test_qwen3_0.6B_parallel_check.py', 'num_gpus': 8}, + {'test_file': 'test_qwen3_0.6B_megatron_fsdp_align.py', 'num_gpus': 4}, + {'test_file': 'test_qwen3_4B_ckpt.py', 'num_gpus': 8}, + {'test_file': 'test_qwen3_4B_ckpt.py --async-save', 'num_gpus': 8}, + {'test_file': 'test_qwen2.5_0.5B_gsm8k.py', 'num_gpus': 2}, + {'test_file': 'test_qwen2.5_0.5B_gsm8k_async.py', 'num_gpus': 2}, ], }, } %> @@ -40,21 +104,24 @@ concurrency: jobs: <% for job_name, config in jobs.items() %> << job_name >>: - if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request && contains(github.event.pull_request.labels.*.name, '<< config.label >>')) + if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request<% if config.label %> && contains(github.event.pull_request.labels.*.name, '<< config.label >>')<% endif %>) runs-on: self-hosted container: - image: radixark/miles:latest + image: << config.image if config.image else 'radixark/miles:latest' >> options: > --gpus all --ipc=host - --shm-size=16g + --shm-size=32g --ulimit memlock=-1 --ulimit stack=67108864 --memory=0 --memory-swap=0 - -v /data/miles_ci:/data/miles_ci - -v /data/miles_ci/models:/root/models - -v /data/miles_ci/datasets:/root/datasets + -v /mnt/nvme0n1/miles_ci:/data/miles_ci + -v /mnt/nvme0n1/miles_ci/models:/root/models + -v /mnt/nvme0n1/miles_ci/datasets:/root/datasets + --privileged + --ulimit nofile=65535:65535 + -v /tmp:/tmp strategy: fail-fast: false matrix: @@ -66,16 +133,31 @@ jobs: GITHUB_COMMIT_NAME: ${{ github.sha }}_${{ github.event.pull_request.number || 'non-pr' }} WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }} MILES_TEST_ENABLE_INFINITE_RUN: ${{ (github.event_name == 'workflow_dispatch' && github.event.inputs.infinite_run) || 'false' }} + MILES_TEST_USE_DEEPEP: ${{ matrix.info.use_deepep || '0' }} + MILES_TEST_USE_FP8_ROLLOUT: ${{ matrix.info.use_fp8_rollout || '0' }} + MILES_TEST_ENABLE_EVAL: ${{ matrix.info.enable_eval || '1' }} steps: - name: Checkout repository uses: actions/checkout@v4 + - name: Cleanup Ray processes + shell: bash + run: | + pkill -9 -f 'ray::' 2>/dev/null || true + pkill -9 -f raylet 2>/dev/null || true + pkill -9 -f gcs_server 2>/dev/null || true + pkill -9 -f 'ray-dashboard' 2>/dev/null || true + pkill -9 sglang 2>/dev/null || true + ray stop --force 2>/dev/null || true + rm -rf /tmp/ray/* 2>/dev/null || true + sleep 3 + - name: Install shell: bash - run: cd $GITHUB_WORKSPACE && pip install -e . + run: cd $GITHUB_WORKSPACE && pip install -e . --no-deps --break-system-packages - name: Execute shell: bash - run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }} + run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- << config.test_executor | default('python') >> tests/${{ matrix.info.test_file }} <% endfor %> \ No newline at end of file diff --git a/.github/workflows/release-docs.yaml b/.github/workflows/release-docs.yaml new file mode 100644 index 000000000..1da468e00 --- /dev/null +++ b/.github/workflows/release-docs.yaml @@ -0,0 +1,53 @@ +name: Release Documentation + +on: + push: + branches: + - main + paths: + - "docs/**" + - "examples/**" + - "version.txt" + workflow_dispatch: + +concurrency: + group: release-docs-${{ github.ref }} + cancel-in-progress: true + +jobs: + deploy: + runs-on: ubuntu-latest + if: github.repository == 'radixark/miles' + permissions: + contents: write + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: '3.13' + + + - name: Install dependencies + run: | + apt-get update && apt-get install -y pandoc parallel retry + pip install -r docs/requirements.txt + + - name: Build documentation + run: | + cd docs + bash ./build.sh en + bash ./build.sh zh + mv ./build/zh ./build/en/ + env: + LC_ALL: "en_US.UTF-8" + LC_CTYPE: "en_US.UTF-8" + + + - name: Deploy + uses: peaceiris/actions-gh-pages@v4 + with: + github_token: ${{ secrets.GITHUB_TOKEN }} + publish_dir: ./docs/build/en \ No newline at end of file diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 000000000..21ea17aba --- /dev/null +++ b/.gitmodules @@ -0,0 +1,8 @@ +[submodule "examples/swe-agent/nemo-gym"] + path = examples/experimental/swe-agent/nemo-gym + url = https://github.com/yueming-yuan/Gym + branch = miles-swe-agent +[submodule "examples/swe-agent/mini-swe-agent"] + path = examples/experimental/swe-agent/mini-swe-agent + url = https://github.com/yueming-yuan/nv-mini-swe-agent + branch = miles-swe-agent diff --git a/build_conda.sh b/build_conda.sh index 46fc12df6..b9787aef4 100644 --- a/build_conda.sh +++ b/build_conda.sh @@ -12,6 +12,8 @@ source ~/.bashrc micromamba create -n miles python=3.12 pip -c conda-forge -y micromamba activate miles export CUDA_HOME="$CONDA_PREFIX" +export SGLANG_COMMIT="24c91001cf99ba642be791e099d358f4dfe955f5" +export MEGATRON_COMMIT="3714d81d418c9f1bca4594fc35f9e8289f652862" export BASE_DIR=${BASE_DIR:-"/root"} cd $BASE_DIR @@ -27,7 +29,7 @@ pip install torch==2.9.1 torchvision==0.24.1 torchaudio==2.9.1 --index-url https # install sglang git clone https://github.com/sgl-project/sglang.git cd sglang -git checkout 5e2cda6158e670e64b926a9985d65826c537ac82 +git checkout ${SGLANG_COMMIT} # Install the python packages pip install -e "python[all]" @@ -46,10 +48,6 @@ NVCC_APPEND_FLAGS="--threads 4" \ --no-build-isolation \ --config-settings "--build-option=--cpp_ext --cuda_ext --parallel 8" git+https://github.com/NVIDIA/apex.git@10417aceddd7d5d05d7cbf7b0fc2daad1105f8b4 -git clone https://github.com/NVIDIA/Megatron-LM.git --recursive && \ - cd Megatron-LM && git checkout ${MEGATRON_COMMIT} && \ - pip install -e . - pip install git+https://github.com/fzyzcjy/torch_memory_saver.git@dc6876905830430b5054325fa4211ff302169c6b --no-cache-dir --force-reinstall pip install git+https://github.com/fzyzcjy/Megatron-Bridge.git@dev_rl --no-build-isolation pip install nvidia-modelopt[torch]>=0.37.0 --no-build-isolation @@ -57,12 +55,9 @@ pip install nvidia-modelopt[torch]>=0.37.0 --no-build-isolation # megatron cd $BASE_DIR git clone https://github.com/NVIDIA/Megatron-LM.git --recursive && \ - cd Megatron-LM/ && git checkout core_v0.14.0 && \ + cd Megatron-LM/ && git checkout ${MEGATRON_COMMIT} && \ pip install -e . -# https://github.com/pytorch/pytorch/issues/168167 -pip install nvidia-cudnn-cu12==9.16.0.29 - # install miles and apply patches # if miles does not exist locally, clone it @@ -77,8 +72,11 @@ else pip install -e . fi +# https://github.com/pytorch/pytorch/issues/168167 +pip install nvidia-cudnn-cu12==9.16.0.29 + # apply patch cd $BASE_DIR/sglang -git apply $MILES_DIR/docker/patch/v0.5.6/sglang.patch +git apply $MILES_DIR/docker/patch/v0.5.7/sglang.patch cd $BASE_DIR/Megatron-LM -git apply $MILES_DIR/docker/patch/v0.5.6/megatron.patch \ No newline at end of file +git apply $MILES_DIR/docker/patch/v0.5.7/megatron.patch \ No newline at end of file diff --git a/docker/Dockerfile b/docker/Dockerfile index cfffa422f..a48dc5987 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -1,10 +1,10 @@ -ARG SGLANG_IMAGE_TAG=nightly-dev-20251208-5e2cda61 +ARG SGLANG_IMAGE_TAG=nightly-dev-20260103-24c91001 FROM lmsysorg/sglang:${SGLANG_IMAGE_TAG} AS sglang # ======================================== Arguments ============================================= ARG PATCH_VERSION=latest -ARG MEGATRON_COMMIT=core_v0.14.0 +ARG MEGATRON_COMMIT=3714d81d418c9f1bca4594fc35f9e8289f652862 ARG ENABLE_CUDA_13=0 @@ -35,6 +35,7 @@ RUN git clone https://github.com/Dao-AILab/flash-attention.git && \ RUN pip install git+https://github.com/ISEEKYAN/mbridge.git@89eb10887887bc74853f89a4de258c0702932a1c --no-deps RUN pip install flash-linear-attention==0.4.0 +RUN pip install tilelang -f https://tile-ai.github.io/whl/nightly/cu128/ # TE does not have wheel on cuda 13 yet, thus need to install from source RUN if [ "${ENABLE_CUDA_13}" = "1" ]; then \ @@ -71,25 +72,14 @@ RUN if [ "$ENABLE_CUDA_13" = "1" ]; then \ python3 -m pip install https://github.com/sgl-project/whl/releases/download/v${SGL_KERNEL_VERSION}/sgl_kernel-${SGL_KERNEL_VERSION}+cu130-cp310-abi3-manylinux2014_$(uname -m).whl --force-reinstall --no-deps; \ fi -# AMEM -# we need to create a fake libcuda.so.1 to make the linker happy when building AMEM -ENV CUDA_DIR=/usr/local/cuda -ENV CUDA_STUBS=${CUDA_DIR}/lib64/stubs -RUN ln -s ${CUDA_STUBS}/libcuda.so ${CUDA_STUBS}/libcuda.so.1 && \ - echo "${CUDA_STUBS}" > /etc/ld.so.conf.d/z-cuda-stubs.conf && \ - ldconfig -RUN git clone https://github.com/inclusionAI/asystem-amem.git && \ - cd asystem-amem && git checkout 6483bb17c9a98b51c3a94b7048467d5b50fbad4b && \ - git submodule init && git submodule update && \ - MPI_HOME=/usr/lib/x86_64-linux-gnu/openmpi/ ./build.sh && \ - mv /usr/local/lib/python3.12/dist-packages/nvidia/nccl/lib/libnccl.so.2 /usr/local/lib/python3.12/dist-packages/nvidia/nccl/lib/libnccl.so.2.bak && \ - cp -r third_party/nccl/build/lib/* /usr/local/lib/python3.12/dist-packages/nvidia/nccl/lib/ - # https://github.com/pytorch/pytorch/issues/168167 RUN pip install nvidia-cudnn-cu12==9.16.0.29 +# reinstall numpy 1.x for megatron +RUN pip install "numpy<2" + RUN rm /root/.tmux.conf -RUN rm -rf /root/.cache/pip /root/asystem-amem /root/flash-attention +RUN rm -rf /root/.cache/pip /root/flash-attention # ====================================== Patches ============================================ diff --git a/docker/Dockerfile.rocm_MI350-5 b/docker/Dockerfile.rocm_MI350-5 index 6dc1353f0..dd32f32c5 100644 --- a/docker/Dockerfile.rocm_MI350-5 +++ b/docker/Dockerfile.rocm_MI350-5 @@ -1,252 +1,252 @@ -#### Use the base image for ROCm 7 / gfx950 (MI355) - -# The Docker image built with this Dockerfile: -# Base image: ROCm 7 with vllm pre-built for gfx950 -# Target GPU: MI355 (gfx950) - - -FROM rocm/sgl-dev:rocm7-vllm-20250904 - -SHELL ["/bin/bash", "-ceuxo", "pipefail"] - -ARG MAX_JOBS=128 -ENV MAX_JOBS=${MAX_JOBS} - -# Set environment variables for gfx950 -ENV GPU_ARCH=gfx950 -ENV PYTORCH_ROCM_ARCH=gfx950 -ENV GPU_ARCH_LIST=gfx950 -ENV AMDGPU_TARGET=gfx950 - - -########################################### -##############1. Install AITER############# -########################################### -WORKDIR /app - -RUN pip uninstall -y aiter || true -RUN rm -rf aiter -RUN git clone https://github.com/ROCm/aiter.git \ - && cd aiter \ - && git checkout v0.1.7.post2 \ - && git submodule update --init --recursive \ - && GPU_ARCHS=gfx950 python setup.py develop -########################################### -########################################### -########################################### - - -########################################### -####2. Install TransformerEngine for gfx950 -########################################### -WORKDIR /app - -RUN rm -rf TransformerEngine -RUN git clone https://github.com/ROCm/TransformerEngine.git \ - && cd TransformerEngine \ - && git checkout 90c04bcdc3c109505b318f40a39680263af55edf \ - && git submodule update --init --recursive - -ENV NVTE_FRAMEWORK=pytorch -ENV NVTE_ROCM_ARCH=gfx950 -ENV NVTE_USE_HIPBLASLT=1 -ENV NVTE_USE_ROCM=1 -ENV CMAKE_PREFIX_PATH="/opt/rocm:/opt/rocm/hip:/usr/local:/usr" - -RUN cd TransformerEngine && pip install . -v -########################################### -########################################### -########################################### - - -######################################### -####3. Install Megatron-LM (NVIDIA version) -######################################### -WORKDIR /app - -RUN pip install "numpy>=1.21.0,<2.0" --force-reinstall - -RUN pip uninstall -y megatron-core || true -RUN rm -rf Megatron-LM -RUN git clone https://github.com/NVIDIA/Megatron-LM \ - && cd Megatron-LM \ - && git checkout 48406695c4efcf1026a7ed70bb390793918dd97b \ - && pip install -e . -######################################### -######################################### -######################################### - - -######################################## -############ 4. Install mbridge######### -######################################## -RUN pip install git+https://github.com/ISEEKYAN/mbridge.git --no-deps -######################################## -######################################## -######################################## - - -######################################## -######5. Install Ray#################### -######################################## -RUN pip uninstall ray -y || true -RUN pip install "ray[data,train,tune,serve]==2.47.1" -######################################## -######################################## -######################################## - - -######################################### -###6. Install torch_memory_saver######### -######################################### -RUN pip install torch_memory_saver -######################################### -######################################### - - -####################################### -####7. Install Apex for ROCm########### -####################################### -WORKDIR /app - -RUN pip uninstall -y apex || true -RUN rm -rf apex -RUN git clone https://github.com/ROCm/apex.git \ - && cd apex \ - && python setup.py install -####################################### -####################################### -####################################### - - -######################################## -###8. Install slime agent framework deps -######################################## -RUN pip install pydra_config==0.0.15 -RUN pip install together -RUN pip install google-generativeai -RUN pip install tensorboard -######################################## -######################################## -######################################## - - -######################################## -###9. Set performance environment vars## -######################################## -ENV HIP_FORCE_DEV_KERNARG=1 -ENV HSA_NO_SCRATCH_RECLAIM=1 -ENV SGLANG_USE_AITER=1 -ENV SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1 -ENV SGLANG_MOE_PADDING=1 -ENV SGLANG_SET_CPU_AFFINITY=1 -ENV SGLANG_ROCM_FUSED_DECODE_MLA=1 -ENV SGLANG_USE_ROCM700A=1 -ENV NCCL_MIN_NCHANNELS=112 -ENV VLLM_FP8_PADDING=1 -ENV VLLM_FP8_ACT_PADDING=1 -ENV VLLM_FP8_WEIGHT_PADDING=1 -ENV VLLM_FP8_REDUCE_CONV=1 -ENV TORCHINDUCTOR_MAX_AUTOTUNE=1 -ENV TORCHINDUCTOR_MAX_AUTOTUNE_POINTWISE=1 -######################################## -######################################## -######################################## - - -########################################### -##############Install SGLang############### -########################################### -WORKDIR /app - -# Install prerequisites -RUN pip install IPython orjson python-multipart torchao==0.9.0 pybind11 - -# Clone SGLang -RUN pip uninstall -y sgl_kernel sglang || true -RUN rm -rf sglang -RUN git clone https://github.com/sgl-project/sglang.git \ - && cd sglang \ - && git checkout v0.5.6 - -# Build sgl-kernel for gfx950 -RUN cd sglang/sgl-kernel \ - && rm -f pyproject.toml \ - && mv pyproject_rocm.toml pyproject.toml \ - && AMDGPU_TARGET=gfx950 python setup_rocm.py install - -# Install SGLang -RUN cd sglang \ - && rm -rf python/pyproject.toml \ - && mv python/pyproject_other.toml python/pyproject.toml \ - && pip install -e "python[all_hip]" - -# Test SGLang installation -RUN python -c "import sglang; import sgl_kernel; print('SGLang + sgl_kernel: OK')" - -RUN python -m pip cache purge -########################################### -########################################### -########################################### - - -########################################### -#### APPLY PATCHES (gfx950/MI355) ######### -########################################### - -# Copy patches from slime repo -COPY amd_patch/latest /app/patch - -# Apply Megatron patches -RUN cd /app/Megatron-LM \ - && git apply /app/patch/amd_megatron_fused_kernels_init.patch \ - && git apply /app/patch/megatron.patch --3way \ - && if grep -R -n '^<<<<<<< ' .; then \ - echo "Patch failed to apply cleanly. Please resolve conflicts." && \ - exit 1; \ - fi \ - && pip install -e . -v - -# Apply SGLang patch -RUN cd /app/sglang \ - && git apply /app/patch/sglang.patch || echo "Check patch compatibility with v0.5.6" \ - && if grep -R -n '^<<<<<<< ' .; then \ - echo "Patch failed to apply cleanly. Please resolve conflicts." && \ - exit 1; \ - fi - -# Copy MOE configs for gfx950/MI355 -RUN find /app/sglang/python/sglang/srt/layers/quantization/configs/ \ - /app/sglang/python/sglang/srt/layers/moe/fused_moe_triton/configs/ \ - -type f -name '*MI300X*' 2>/dev/null | while read f; do \ - cp "$f" "$(echo $f | sed 's/MI300X/MI300X_VF/')" 2>/dev/null || true; \ - cp "$f" "$(echo $f | sed 's/MI300X/MI355/')" 2>/dev/null || true; \ -done - -########################################### -########################################### -########################################### - - -######################################## -#### Install additional packages######## -######################################## -RUN pip install sglang-router --force-reinstall -######################################## -######################################## -######################################## - - -######################################## -# Fix click/ray incompatibility with Python 3.10 -######################################## -RUN pip install click==8.2.1 -######################################## -######################################## -######################################## - - -WORKDIR /app - -CMD ["/usr/bin/bash"] - +#### Use the base image for ROCm 7 / gfx950 (MI355) + +# The Docker image built with this Dockerfile: +# Base image: ROCm 7 with vllm pre-built for gfx950 +# Target GPU: MI355 (gfx950) + + +FROM rocm/sgl-dev:rocm7-vllm-20250904 + +SHELL ["/bin/bash", "-ceuxo", "pipefail"] + +ARG MAX_JOBS=128 +ENV MAX_JOBS=${MAX_JOBS} + +# Set environment variables for gfx950 +ENV GPU_ARCH=gfx950 +ENV PYTORCH_ROCM_ARCH=gfx950 +ENV GPU_ARCH_LIST=gfx950 +ENV AMDGPU_TARGET=gfx950 + + +########################################### +##############1. Install AITER############# +########################################### +WORKDIR /app + +RUN pip uninstall -y aiter || true +RUN rm -rf aiter +RUN git clone https://github.com/ROCm/aiter.git \ + && cd aiter \ + && git checkout v0.1.7.post2 \ + && git submodule update --init --recursive \ + && GPU_ARCHS=gfx950 python setup.py develop +########################################### +########################################### +########################################### + + +########################################### +####2. Install TransformerEngine for gfx950 +########################################### +WORKDIR /app + +RUN rm -rf TransformerEngine +RUN git clone https://github.com/ROCm/TransformerEngine.git \ + && cd TransformerEngine \ + && git checkout 90c04bcdc3c109505b318f40a39680263af55edf \ + && git submodule update --init --recursive + +ENV NVTE_FRAMEWORK=pytorch +ENV NVTE_ROCM_ARCH=gfx950 +ENV NVTE_USE_HIPBLASLT=1 +ENV NVTE_USE_ROCM=1 +ENV CMAKE_PREFIX_PATH="/opt/rocm:/opt/rocm/hip:/usr/local:/usr" + +RUN cd TransformerEngine && pip install . -v +########################################### +########################################### +########################################### + + +######################################### +####3. Install Megatron-LM (NVIDIA version) +######################################### +WORKDIR /app + +RUN pip install "numpy>=1.21.0,<2.0" --force-reinstall + +RUN pip uninstall -y megatron-core || true +RUN rm -rf Megatron-LM +RUN git clone https://github.com/NVIDIA/Megatron-LM \ + && cd Megatron-LM \ + && git checkout 48406695c4efcf1026a7ed70bb390793918dd97b \ + && pip install -e . +######################################### +######################################### +######################################### + + +######################################## +############ 4. Install mbridge######### +######################################## +RUN pip install git+https://github.com/ISEEKYAN/mbridge.git --no-deps +######################################## +######################################## +######################################## + + +######################################## +######5. Install Ray#################### +######################################## +RUN pip uninstall ray -y || true +RUN pip install "ray[data,train,tune,serve]==2.47.1" +######################################## +######################################## +######################################## + + +######################################### +###6. Install torch_memory_saver######### +######################################### +RUN pip install torch_memory_saver +######################################### +######################################### + + +####################################### +####7. Install Apex for ROCm########### +####################################### +WORKDIR /app + +RUN pip uninstall -y apex || true +RUN rm -rf apex +RUN git clone https://github.com/ROCm/apex.git \ + && cd apex \ + && python setup.py install +####################################### +####################################### +####################################### + + +######################################## +###8. Install miles agent framework deps +######################################## +RUN pip install pydra_config==0.0.15 +RUN pip install together +RUN pip install google-generativeai +RUN pip install tensorboard +######################################## +######################################## +######################################## + + +######################################## +###9. Set performance environment vars## +######################################## +ENV HIP_FORCE_DEV_KERNARG=1 +ENV HSA_NO_SCRATCH_RECLAIM=1 +ENV SGLANG_USE_AITER=1 +ENV SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1 +ENV SGLANG_MOE_PADDING=1 +ENV SGLANG_SET_CPU_AFFINITY=1 +ENV SGLANG_ROCM_FUSED_DECODE_MLA=1 +ENV SGLANG_USE_ROCM700A=1 +ENV NCCL_MIN_NCHANNELS=112 +ENV VLLM_FP8_PADDING=1 +ENV VLLM_FP8_ACT_PADDING=1 +ENV VLLM_FP8_WEIGHT_PADDING=1 +ENV VLLM_FP8_REDUCE_CONV=1 +ENV TORCHINDUCTOR_MAX_AUTOTUNE=1 +ENV TORCHINDUCTOR_MAX_AUTOTUNE_POINTWISE=1 +######################################## +######################################## +######################################## + + +########################################### +##############Install SGLang############### +########################################### +WORKDIR /app + +# Install prerequisites +RUN pip install IPython orjson python-multipart torchao==0.9.0 pybind11 + +# Clone SGLang +RUN pip uninstall -y sgl_kernel sglang || true +RUN rm -rf sglang +RUN git clone https://github.com/sgl-project/sglang.git \ + && cd sglang \ + && git checkout v0.5.6 + +# Build sgl-kernel for gfx950 +RUN cd sglang/sgl-kernel \ + && rm -f pyproject.toml \ + && mv pyproject_rocm.toml pyproject.toml \ + && AMDGPU_TARGET=gfx950 python setup_rocm.py install + +# Install SGLang +RUN cd sglang \ + && rm -rf python/pyproject.toml \ + && mv python/pyproject_other.toml python/pyproject.toml \ + && pip install -e "python[all_hip]" + +# Test SGLang installation +RUN python -c "import sglang; import sgl_kernel; print('SGLang + sgl_kernel: OK')" + +RUN python -m pip cache purge +########################################### +########################################### +########################################### + + +########################################### +#### APPLY PATCHES (gfx950/MI355) ######### +########################################### + +# Copy patches from miles repo +COPY amd_patch/latest /app/patch + +# Apply Megatron patches +RUN cd /app/Megatron-LM \ + && git apply /app/patch/amd_megatron_fused_kernels_init.patch \ + && git apply /app/patch/megatron.patch --3way \ + && if grep -R -n '^<<<<<<< ' .; then \ + echo "Patch failed to apply cleanly. Please resolve conflicts." && \ + exit 1; \ + fi \ + && pip install -e . -v + +# Apply SGLang patch +RUN cd /app/sglang \ + && git apply /app/patch/sglang.patch || echo "Check patch compatibility with v0.5.6" \ + && if grep -R -n '^<<<<<<< ' .; then \ + echo "Patch failed to apply cleanly. Please resolve conflicts." && \ + exit 1; \ + fi + +# Copy MOE configs for gfx950/MI355 +RUN find /app/sglang/python/sglang/srt/layers/quantization/configs/ \ + /app/sglang/python/sglang/srt/layers/moe/fused_moe_triton/configs/ \ + -type f -name '*MI300X*' 2>/dev/null | while read f; do \ + cp "$f" "$(echo $f | sed 's/MI300X/MI300X_VF/')" 2>/dev/null || true; \ + cp "$f" "$(echo $f | sed 's/MI300X/MI355/')" 2>/dev/null || true; \ +done + +########################################### +########################################### +########################################### + + +######################################## +#### Install additional packages######## +######################################## +RUN pip install sglang-router --force-reinstall +######################################## +######################################## +######################################## + + +######################################## +# Fix click/ray incompatibility with Python 3.10 +######################################## +RUN pip install click==8.2.1 +######################################## +######################################## +######################################## + + +WORKDIR /app + +CMD ["/usr/bin/bash"] + diff --git a/docker/Dockerfile_20250810_9a48ba0.rocm b/docker/Dockerfile_20250810_9a48ba0.rocm index 073fb3891..db3d25a21 100644 --- a/docker/Dockerfile_20250810_9a48ba0.rocm +++ b/docker/Dockerfile_20250810_9a48ba0.rocm @@ -82,7 +82,7 @@ RUN pip install setuptools==75.8.0 ########################################### -############build sgalng################### +############build sglang################### ########################################### # Set environment variables ENV BASE_DIR=/workspace diff --git a/docker/Dockerfile_20250810_c22f55b.rocm b/docker/Dockerfile_20250810_c22f55b.rocm index 468a17c37..de6d93422 100644 --- a/docker/Dockerfile_20250810_c22f55b.rocm +++ b/docker/Dockerfile_20250810_c22f55b.rocm @@ -92,7 +92,7 @@ RUN pip install setuptools==75.8.0 ########################################### -############build sgalng################### +############build sglang################### ########################################### # Set environment variables ENV BASE_DIR=/workspace diff --git a/docker/README.md b/docker/README.md index 156169c72..29929d282 100644 --- a/docker/README.md +++ b/docker/README.md @@ -5,10 +5,12 @@ We will publish 2 kinds of docker images: 2. latest version, which aligns to `lmsysorg/sglang:latest`. current stable version is: -- sglang nightly-dev-20251208-5e2cda61 (5e2cda6158e670e64b926a9985d65826c537ac82), megatron v0.14.0 (23e00ed0963c35382dfe8a5a94fb3cda4d21e133) +- sglang v0.5.7 nightly-dev-20260103-24c91001 (24c91001cf99ba642be791e099d358f4dfe955f5), megatron dev 3714d81d418c9f1bca4594fc35f9e8289f652862 history versions: +- sglang v0.5.6 nightly-dev-20251208-5e2cda61 (5e2cda6158e670e64b926a9985d65826c537ac82), megatron v0.14.0 (23e00ed0963c35382dfe8a5a94fb3cda4d21e133) - sglang v0.5.5.post1 (303cc957e62384044dfa8e52d7d8af8abe12f0ac), megatron v0.14.0 (23e00ed0963c35382dfe8a5a94fb3cda4d21e133) +- sglang v0.5.0rc0-cu126 (8ecf6b9d2480c3f600826c7d8fef6a16ed603c3f), megatron 48406695c4efcf1026a7ed70bb390793918dd97b The command to build: diff --git a/docker/deepseekv32/Dockerfile b/docker/deepseekv32/Dockerfile new file mode 100644 index 000000000..38e9eadfa --- /dev/null +++ b/docker/deepseekv32/Dockerfile @@ -0,0 +1,145 @@ +ARG SGLANG_IMAGE_TAG=v0.5.6.post2 +FROM lmsysorg/sglang:${SGLANG_IMAGE_TAG} AS sglang + +# ======================================== Arguments ============================================= + +ARG PATCH_VERSION=latest +ARG MEGATRON_COMMIT=436065a86b749ca3b50eebca68f55c9e690a9f63 + +ARG ENABLE_CUDA_13=0 + +ARG ENABLE_SGLANG_PATCH=0 + +# ======================================== Setup ============================================= + +WORKDIR /root/ + +# ======================================== Apt dependencies ============================================= + +RUN apt update +RUN apt install -y nvtop rsync dnsutils + +# ====================================== Python dependencies ============================================ + +# The compilation is slow, thus should be put at top +# TransformerEngines does not support too high FA2 +RUN MAX_JOBS=64 pip -v install flash-attn==2.7.4.post1 --no-build-isolation + +# The compilation is slow, thus should be put at top +RUN git clone https://github.com/Dao-AILab/flash-attention.git && \ + cd flash-attention/ && git checkout fbf24f67cf7f6442c5cfb2c1057f4bfc57e72d89 && git submodule update --init && cd hopper/ && \ + MAX_JOBS=96 python setup.py install && \ + export python_path=`python -c "import site; print(site.getsitepackages()[0])"` && \ + mkdir -p $python_path/flash_attn_3 && \ + cp flash_attn_interface.py $python_path/flash_attn_3/flash_attn_interface.py && \ + rm -rf flash-attention/ + +RUN pip install git+https://github.com/ISEEKYAN/mbridge.git@89eb10887887bc74853f89a4de258c0702932a1c --no-deps + +RUN pip install flash-linear-attention==0.4.0 + +RUN git clone https://github.com/Dao-AILab/fast-hadamard-transform.git fast-hadamard-transform && \ + cd fast-hadamard-transform && \ + pip install -v . --no-build-isolation && \ + cd /root && \ + rm -rf fast-hadamard-transform + +# TE does not have wheel on cuda 13 yet, thus need to install from source +RUN if [ "${ENABLE_CUDA_13}" = "1" ]; then \ + pip install nvidia-mathdx==25.6.0 && \ + pip install pybind11 && \ + pip -v install --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@release_v2.10; \ + else \ + pip -v install --no-build-isolation "transformer_engine[pytorch]==2.10.0"; \ + fi + +RUN NVCC_APPEND_FLAGS="--threads 4" \ + pip -v install --disable-pip-version-check --no-cache-dir \ + --no-build-isolation \ + --config-settings "--build-option=--cpp_ext --cuda_ext --parallel 8" git+https://github.com/NVIDIA/apex.git@10417aceddd7d5d05d7cbf7b0fc2daad1105f8b4 + +RUN git clone https://github.com/NVIDIA/Megatron-LM.git --recursive && \ + cd Megatron-LM && git checkout ${MEGATRON_COMMIT} && \ + pip install -e . + +RUN git clone https://github.com/huggingface/transformers.git && \ + cd transformers && git checkout 8cb5963cc22174954e7dca2c0a3320b7dc2f4edc && \ + pip install -e . + +RUN pip install git+https://github.com/fzyzcjy/torch_memory_saver.git@dc6876905830430b5054325fa4211ff302169c6b --no-cache-dir --force-reinstall +RUN pip install git+https://github.com/fzyzcjy/Megatron-Bridge.git@dev_rl --no-build-isolation +RUN pip install nvidia-modelopt[torch]>=0.37.0 --no-build-isolation + +# This patch from masahi will be included in later Triton releases +RUN if [ "$ENABLE_CUDA_13" = "1" ]; then \ + (cd /root && git clone -b feat/v350_plus_8045 https://github.com/fzyzcjy/triton.git && cd triton && pip install -r python/requirements.txt && pip install --verbose -e .); \ + fi + +COPY requirements.txt /tmp/requirements.txt +RUN pip install -r /tmp/requirements.txt + +# Temporarily install another sgl-kernel version for GB300 without rebuilding the whole image +RUN if [ "$ENABLE_CUDA_13" = "1" ]; then \ + SGL_KERNEL_VERSION=0.3.17.post2 && \ + python3 -m pip install https://github.com/sgl-project/whl/releases/download/v${SGL_KERNEL_VERSION}/sgl_kernel-${SGL_KERNEL_VERSION}+cu130-cp310-abi3-manylinux2014_$(uname -m).whl --force-reinstall --no-deps; \ + fi + +# AMEM +# we need to create a fake libcuda.so.1 to make the linker happy when building AMEM +# ENV CUDA_DIR=/usr/local/cuda +# ENV CUDA_STUBS=${CUDA_DIR}/lib64/stubs +# RUN ln -s ${CUDA_STUBS}/libcuda.so ${CUDA_STUBS}/libcuda.so.1 && \ +# echo "${CUDA_STUBS}" > /etc/ld.so.conf.d/z-cuda-stubs.conf && \ +# ldconfig +# RUN git clone https://github.com/inclusionAI/asystem-amem.git && \ +# cd asystem-amem && git checkout 6483bb17c9a98b51c3a94b7048467d5b50fbad4b && \ +# git submodule init && git submodule update && \ +# MPI_HOME=/usr/lib/x86_64-linux-gnu/openmpi/ ./build.sh && \ +# mv /usr/local/lib/python3.12/dist-packages/nvidia/nccl/lib/libnccl.so.2 /usr/local/lib/python3.12/dist-packages/nvidia/nccl/lib/libnccl.so.2.bak && \ +# cp -r third_party/nccl/build/lib/* /usr/local/lib/python3.12/dist-packages/nvidia/nccl/lib/ + +RUN [ ! -f /root/.tmux.conf ] || rm /root/.tmux.conf + +# ====================================== Patches ============================================ + +COPY docker/deepseekv32/megatron.patch /root/Megatron-LM/ +RUN cd Megatron-LM && \ + git update-index --refresh && \ + git apply megatron.patch --3way && \ + if grep -R -n '^<<<<<<< ' .; then \ + echo "Patch failed to apply cleanly. Please resolve conflicts." && \ + exit 1; \ + fi && \ + rm megatron.patch + +COPY docker/deepseekv32/transformers.patch /root/transformers/ +RUN cd transformers && \ + git update-index --refresh && \ + git apply transformers.patch --3way && \ + if grep -R -n '^<<<<<<< ' .; then \ + echo "Patch failed to apply cleanly. Please resolve conflicts." && \ + exit 1; \ + fi && \ + rm transformers.patch + +# TODO temporarily skip patching for GB200/GB300 (and require users to bring their own sglang version). should add back later. +COPY docker/patch/${PATCH_VERSION}/sglang.patch /sgl-workspace/sglang/ +RUN if [ "$ENABLE_SGLANG_PATCH" = "1" ]; then \ + cd /sgl-workspace/sglang && \ + git update-index --refresh && \ + git apply sglang.patch && \ + if grep -R -n '^<<<<<<< ' .; then \ + echo "Patch failed to apply cleanly. Please resolve conflicts." && \ + exit 1; \ + fi && \ + rm sglang.patch; \ +fi + +# ====================================== Install main package ============================================ + +# TODO may improve +ARG MILES_COMMIT=main +RUN git clone https://github.com/radixark/miles.git /root/miles && \ + cd /root/miles && \ + git checkout ${MILES_COMMIT} && \ + pip install -e . --no-deps diff --git a/docker/deepseekv32/README.md b/docker/deepseekv32/README.md new file mode 100644 index 000000000..1ca43f1c2 --- /dev/null +++ b/docker/deepseekv32/README.md @@ -0,0 +1,41 @@ +## Usage + +### Docker +```bash +docker pull yueming11/miles:dsv32-dev + +docker run --gpus all --ipc=host --shm-size=16g --ulimit memlock=-1 --ulimit stack=67108864 --name miles_dsv32 yueming11/miles:dsv32-dev /bin/zsh + +git clone https://github.com/radixark/miles.git +git checkout dsv32 +cd dsv32 +pip install -e . + +# if shows Megatron does not support numpy 2.x +pip install numpy==1.26.4 +``` + +### Quick test with 5 layer model +#### model download + +``` +hf download Pinaster/DeepSeek-V3.2-5layer /root/models/DeepSeek-V3.2-5layer +``` + +#### Prepare model for training +Note: need to change the paths, for all commands below see `scripts/run_deepseek_v32.py` for details + +Step 1. download dataset & convert fp8 hf checkpoint to bf16 with one node +``` +python scripts/run_deepseek_v32.py prepare-single --model-name DeepSeek-V3.2-5layer --megatron-model-type deepseek-v32-5layer +``` + +Step 2. convert hf checkpoint to megatron checkpoint with multiple nodes +``` +python scripts/run_deepseek_v32.py prepare-spmd --model-name DeepSeek-V3.2-5layer --megatron-model-type deepseek-v32-5layer +``` + +#### Launch training +``` +python scripts/run_deepseek_v32.py train --model-name DeepSeek-V3.2-5layer --megatron-model-type deepseek-v32-5layer +``` \ No newline at end of file diff --git a/docker/deepseekv32/megatron.patch b/docker/deepseekv32/megatron.patch new file mode 100644 index 000000000..f7204197c --- /dev/null +++ b/docker/deepseekv32/megatron.patch @@ -0,0 +1,1559 @@ +diff --git a/megatron/core/transformer/dot_product_attention_context_parallel.py b/megatron/core/transformer/dot_product_attention_context_parallel.py +index 89659a1d7..c69859a04 100644 +--- a/megatron/core/transformer/dot_product_attention_context_parallel.py ++++ b/megatron/core/transformer/dot_product_attention_context_parallel.py +@@ -3,107 +3,12 @@ + # Some of this code was adopted from https://github.com/zhuzilin/ring-flash-attention/ + # This source code is licensed under the MIT license found in the + # LICENSE file in the root directory of this source tree. ++# Kernel is adpoted from tilelang/examples/deepseek_v32 + + import torch ++import torch.distributed as dist + from torch.nn import functional as F +- +-try: +- import einops +- +- HAVE_EINOPS = True +-except ImportError: +- HAVE_EINOPS = False +- +- +-@torch.no_grad +-def eager_attn_fwd(q, k, v, attn_bias, sinks, scale, dropout): +- """Forward pass for eager attention""" +- +- # Rearrange query, key, value to (b, h, s, d) +- b, sq, h, d = q.shape +- sk = k.shape[1] +- _q = einops.rearrange(q, 'b s h d -> b h s d') +- _k = einops.rearrange(k, 'b s h d -> b h d s') +- _v = einops.rearrange(v, 'b s h d -> b h s d') +- +- # Compute attention weights +- attn_w = torch.matmul(_q, _k) * scale +- attn_w = attn_w + attn_bias +- +- # Add sinks to attention weights +- if sinks is None: +- logits = attn_w +- else: +- _sinks = sinks.reshape(1, h, 1, 1).expand(b, -1, sq, 1) +- logits = torch.cat([attn_w, _sinks], dim=-1) +- +- # Compute attention scores +- probs = F.softmax(logits, dim=-1, dtype=logits.dtype) +- if sinks is None: +- attn_w = probs +- else: +- attn_w = probs[..., :-1] # Drop the sink +- +- # Compute attention output +- attn_output = torch.matmul(attn_w, _v) +- attn_output = einops.rearrange(attn_output, 'b h s d -> b s h d') +- attn_output = attn_output.contiguous() +- +- return attn_output, probs +- +- +-@torch.no_grad +-def eager_attn_bwd(q, k, v, attn_bias, sinks, scale, dropout, attn_output, probs, grad_output): +- """Backward pass for eager attention""" +- +- # Rearrange query, key, value to (b, h, s, d) +- b, sq, h, d = q.shape +- sk = k.shape[1] +- _q_T = einops.rearrange(q, 'b s h d -> b h d s') +- _k_T = einops.rearrange(k, 'b s h d -> b h s d') +- _v_T = einops.rearrange(v, ' b s h d -> b h d s') +- +- # Backward pass for score @ value +- if sinks is None: +- attn_w = probs +- else: +- attn_w = probs[..., :-1] # Drop the sink +- grad_output = einops.rearrange(grad_output, 'b s h d -> b h s d') +- attn_w_T = einops.rearrange(attn_w, ' b h sq sk -> b h sk sq') +- grad__v = torch.matmul(attn_w_T, grad_output) +- grad_attn_w = torch.matmul(grad_output, _v_T) +- +- # Backward pass for softmax +- if sinks is None: +- grad_probs = grad_attn_w +- else: +- dummy = torch.zeros((b, h, sq, 1), device=q.device, dtype=q.dtype) +- grad_probs = torch.cat([grad_attn_w, dummy], dim=3) +- del grad_attn_w +- grad_logits = torch._softmax_backward_data( +- grad_probs, probs, -1, probs.dtype +- ) # [b, h, sq, sk+1] +- +- # Backward pass for adding sinks +- if sinks is None: +- grad_sinks = None +- grad_attn_w = grad_logits +- else: +- grad__sinks = grad_logits[:, :, :, -1] # [b, h, sq] +- grad_sinks = einops.rearrange(grad__sinks, 'b h s -> h (b s)').sum(-1) +- grad_attn_w = grad_logits[:, :, :, :-1].contiguous() # [b, h, sq, sk] +- +- # Backward pass for q @ K^T +- grad_attn_w *= scale +- grad__q = torch.matmul(grad_attn_w, _k_T) +- grad__k = torch.matmul(_q_T, grad_attn_w) +- +- # Rearrange grads to (b, s, h, d) +- grad_v = einops.rearrange(grad__v, 'b h s d -> b s h d') +- grad_k = einops.rearrange(grad__k, 'b h d s -> b s h d') +- grad_q = einops.rearrange(grad__q, 'b h s d -> b s h d') +- return grad_q, grad_k, grad_v, grad_sinks +- ++from .tilelang_kernel import sparse_mla_bwd, sparse_mla_fwd_interface + + class AllGatherComm: + """All gather communication with async operations""" +@@ -131,212 +36,145 @@ class AllGatherComm: + handle.wait() + self.handles = [] + +- +-def to_zz_mask_attn_bias(attention_mask, cp_size, nheads, nheads_k, heads_k_stride, device, dtype): +- '''Convert the attention mask to the attention bias''' +- +- if cp_size == 1: +- zz_mask = attention_mask +- else: +- chunked = attention_mask.chunk(dim=3, chunks=cp_size * 2) +- zz_mask = [_x for _p in zip(chunked[:cp_size], reversed(chunked[cp_size:])) for _x in _p] +- zz_mask = torch.cat(zz_mask, dim=3) +- attn_bias = torch.zeros(zz_mask.shape, device=device, dtype=dtype) +- attn_bias.masked_fill_(zz_mask, float('-inf')) +- attn_bias = attn_bias.expand(-1, heads_k_stride * (nheads // nheads_k), -1, -1) +- return attn_bias +- +- + class AttentionFuncionWithContextParallel(torch.autograd.Function): + """Native attention function with context parallelism.""" + ++ # q: [seq_len_shard, batch, nheads, dim] ++ # k: [seq_len_kv_shard, batch, 1, dim] ++ # v: [seq_len_kv_shard, batch, 1, dim_v] ++ # indices: [batch, 1, seq_len, topk] ++ # masks: [batch, 1, seq_len, seq_len_kv] + @staticmethod +- def forward(ctx, q, k, v, attention_mask, attention_dropout, softmax_scale, pg): ++ def forward(ctx, q, k, dim_v, indices, masks, attention_dropout, softmax_scale, pg): + '''Forward pass for the native attention function with context parallelism''' + +- # Assert einops exists +- if not HAVE_EINOPS: +- raise ImportError("einops is required by the attention CP but cannot be imported.") +- +- # Initialize communication group and constants + cp_size = 1 + if pg is not None: + cp_size = torch.distributed.get_world_size(pg) + comm = AllGatherComm(group=pg) +- nheads = q.shape[2] +- nheads_k = k.shape[2] +- heads_k_stride = 1 +- assert nheads % nheads_k == 0 and nheads_k % heads_k_stride == 0 +- outs = [] +- probs = [] +- +- # Initialize KV buffers +- kv_buffer = torch.empty( +- (2, k.shape[0] * cp_size, k.shape[1], heads_k_stride, k.shape[3]), ++ ++ k_buffer = torch.empty( ++ (k.shape[0] * cp_size, k.shape[1], 1, k.shape[3]), + dtype=k.dtype, + device=k.device, + ) +- kv_buffer_copy = torch.empty_like(kv_buffer) +- +- # All-gather first chunk of KV buffers +- k_0 = k[:, :, :heads_k_stride].contiguous() +- v_0 = v[:, :, :heads_k_stride].contiguous() +- comm.all_gather(kv_buffer_copy[0], k_0) +- comm.all_gather(kv_buffer_copy[1], v_0) +- +- # Prepare attention bias +- attn_bias = to_zz_mask_attn_bias( +- attention_mask, cp_size, nheads, nheads_k, heads_k_stride, q.device, q.dtype +- ) +- +- # Iterate over heads +- for i in range(0, nheads_k, heads_k_stride): +- # Wait for previous all-gather to complete +- comm.wait() +- kv_buffer, kv_buffer_copy = kv_buffer_copy, kv_buffer +- # All-gather the next portion of KV buffers if not the last iteration +- if i < nheads_k - heads_k_stride: +- kvsl = i + heads_k_stride +- kvsr = kvsl + heads_k_stride +- send_k = k[:, :, kvsl:kvsr].contiguous() +- send_v = v[:, :, kvsl:kvsr].contiguous() +- comm.all_gather(kv_buffer_copy[0], send_k) +- comm.all_gather(kv_buffer_copy[1], send_v) +- +- # Prepare query, key, value for attention +- q_i = q[:, :, i * nheads // nheads_k : (i + heads_k_stride) * nheads // nheads_k] +- k_i = kv_buffer[0] +- v_i = kv_buffer[1] +- +- # Rearrange query, key, value to (b, s, h, d) +- q_i = einops.rearrange(q_i, 's b h d -> b s h d') +- k_i = einops.rearrange(k_i, 's b h d -> b s h d') +- v_i = einops.rearrange(v_i, 's b h d -> b s h d') +- +- # Forward pass +- out_i, probs_i = eager_attn_fwd( +- q_i, k_i, v_i, attn_bias, None, softmax_scale, attention_dropout +- ) +- outs.append(out_i) +- probs.append(probs_i) +- +- # Concatenate outputs and rearrange to (s, b, h, d) +- out = torch.cat(outs, dim=2) +- out = einops.rearrange(out, 'b s h d -> s b h d') +- +- # Save contexts for backward pass +- ctx.save_for_backward(q, k, v, attention_mask, *outs, *probs) ++ comm.all_gather(k_buffer, k) ++ comm.wait() ++ ++ zz_indices = indices.transpose(1, 2) ++ zz_masks = masks.transpose(1, 2) ++ ++ q_i = q ++ k_i = k_buffer ++ ++ s_, b_, h_, d_ = q_i.shape ++ q_i = q_i.transpose(0, 1).flatten().view(b_, s_, h_, d_) ++ s_, b_, h_, d_ = k_i.shape ++ k_i = k_i.transpose(0, 1).flatten().view(b_, s_, h_, d_) ++ zz_indices_i = zz_indices ++ b_, s_, g_, topk_ = zz_indices_i.shape ++ zz_indices_i = zz_indices_i.flatten().view(b_, s_, g_, topk_) ++ zz_masks_i = zz_masks ++ b_, s_, g_, skv_ = zz_masks_i.shape ++ zz_masks_i = zz_masks_i.flatten().view(b_, s_, g_, skv_) ++ ++ out_i, lse_i = sparse_mla_fwd_interface(q_i.contiguous(), k_i, zz_indices_i, zz_masks_i, dim_v, sm_scale = softmax_scale) ++ ++ # out: [B, seq_len_shard, h, dim] -> [seq_len, B, h, dim] ++ b_, s_, h_, d_ = out_i.shape ++ out_i = out_i.transpose(0, 1).flatten().view(s_, b_, h_, d_).contiguous() ++ ++ # outs: [[B, seq_len_shard, nheads // kv_group, dim], ...., [B, seq_len_shard, nheads // kv_group, dim]], repeat kv_group // heads_kv_stride times ++ # lses: [[B, seq_len_shard, heads_kv_stride], ...., [B, seq_len_shard, heads_kv_stride]], repeat kv_group // heads_kv_stride times ++ ctx.save_for_backward(q, k, indices, masks, out_i, lse_i) + ctx.dropout = attention_dropout +- ctx.scale = softmax_scale +- ctx.heads_k_stride = heads_k_stride # TODO make it configurable ++ ctx.softmax_scale = softmax_scale ++ ctx.dim_v = dim_v + ctx.pg = pg + +- return out ++ return out_i + + @staticmethod + def backward(ctx, dout): + '''Backward pass for the native attention function with context parallelism''' + +- # Initialize or resume constants and communication group +- q, k, v, attention_mask, *rest = ctx.saved_tensors +- nheads = q.shape[2] +- nheads_k = k.shape[2] +- heads_k_stride = ctx.heads_k_stride +- assert nheads_k % heads_k_stride == 0 +- outs = rest[: nheads_k // heads_k_stride] +- probs = rest[nheads_k // heads_k_stride :] ++ q, k, indices, masks, out, lse = ctx.saved_tensors ++ s, b, heads, dim = q.shape ++ dim_v = ctx.dim_v ++ softmax_scale = ctx.softmax_scale ++ + pg = ctx.pg + cp_size = 1 + if pg is not None: + cp_size = torch.distributed.get_world_size(pg) + comm = AllGatherComm(group=pg) + +- # Initialize KV buffers +- kv_buffer = torch.empty( +- (2, k.shape[0] * cp_size, k.shape[1], heads_k_stride, k.shape[3]), ++ k_buffer = torch.empty( ++ (k.shape[0] * cp_size, k.shape[1], 1, k.shape[3]), + dtype=k.dtype, + device=k.device, + ) +- kv_buffer_copy = torch.empty_like(kv_buffer) +- +- # All-gather first chunk of KV buffers +- dq = [] +- dk = [] +- dv = [] +- k_0 = k[:, :, :heads_k_stride].contiguous() +- v_0 = v[:, :, :heads_k_stride].contiguous() +- comm.all_gather(kv_buffer_copy[0], k_0) +- comm.all_gather(kv_buffer_copy[1], v_0) +- +- # Prepare attention bias +- attn_bias = to_zz_mask_attn_bias( +- attention_mask, cp_size, nheads, nheads_k, heads_k_stride, q.device, q.dtype +- ) + +- # Iterate over heads +- for i in range(0, nheads_k, heads_k_stride): +- # Slice query and output for this iteration +- q_slice = slice(i * nheads // nheads_k, (i + heads_k_stride) * nheads // nheads_k) +- q_i = q[:, :, q_slice] +- dout_i = dout[:, :, q_slice] +- +- # Wait for previous all-gather to complete +- comm.wait() +- kv_buffer, kv_buffer_copy = kv_buffer_copy, kv_buffer +- +- # All-gather the next portion of KV buffers if not the last iteration +- if i < nheads_k - heads_k_stride: +- kvsl = i + heads_k_stride +- kvsr = kvsl + heads_k_stride +- send_k = k[:, :, kvsl:kvsr].contiguous() +- send_v = v[:, :, kvsl:kvsr].contiguous() +- comm.all_gather(kv_buffer_copy[0], send_k) +- comm.all_gather(kv_buffer_copy[1], send_v) +- +- # Prepare key, value for attention +- k_i = kv_buffer[0] +- v_i = kv_buffer[1] +- +- # Rearrange query, key, value to (b, s, h, d) +- q_i = einops.rearrange(q_i, 's b h d -> b s h d') +- k_i = einops.rearrange(k_i, 's b h d -> b s h d') +- v_i = einops.rearrange(v_i, 's b h d -> b s h d') +- dout_i = einops.rearrange(dout_i, 's b h d -> b s h d') +- +- # Backward pass +- dq_i, _dk_i, _dv_i, _ = eager_attn_bwd( +- q_i, k_i, v_i, attn_bias, None, ctx.scale, ctx.dropout, outs[i], probs[i], dout_i +- ) ++ comm.all_gather(k_buffer, k) ++ comm.wait() ++ ++ zz_indices = indices.transpose(1, 2) ++ zz_masks = masks.transpose(1, 2) ++ ++ k_i = k_buffer ++ ++ dq_list = [] ++ dk_list = [] ++ ++ s_, b_, h_, d_ = q.shape ++ q = q.transpose(0, 1).flatten().view(b_, s_, h_, d_) ++ s_, b_, h_, d_ = k_i.shape ++ k_i = k_i.transpose(0, 1).flatten().view(b_, s_, h_, d_) ++ s_, b_, h_, d_ = dout.shape ++ dout = dout.transpose(0, 1).flatten().view(b_, s_, h_, d_) ++ s_, b_, h_, d_ = out.shape ++ out = out.transpose(0, 1).flatten().view(b_, s_, h_, d_) ++ b_, s_, h_ = lse.shape ++ lse = lse.flatten().view(b_, s_, h_) ++ zz_indices_i = zz_indices ++ b_, s_, g_, topk_ = zz_indices_i.shape ++ zz_indices_i = zz_indices_i.flatten().view(b_, s_, g_, topk_) ++ zz_masks_i = zz_masks ++ b_, s_, g_, skv_ = zz_masks_i.shape ++ zz_masks_i = zz_masks_i.flatten().view(b_, s_, g_, skv_) ++ ++ heads_kv_stride = 16 ++ for i in range(0, heads, heads_kv_stride): ++ q_slice = slice(i, min(i + heads_kv_stride, heads)) ++ q_i = q[:, :, q_slice, :].contiguous() ++ dout_i = dout[:, :, q_slice, :].contiguous() ++ out_i = out[:, :, q_slice, :].contiguous() ++ lse_i = lse[:, :, q_slice].contiguous() ++ ++ # TODO: needs casual = True, may not be compatible with zz ++ dq_i, _dk_i = sparse_mla_bwd(q_i, k_i, out_i, dout_i, zz_indices_i, zz_masks_i, lse_i, dim_v, sm_scale = softmax_scale) ++ ++ b_, s_, h_, d_ = dq_i.shape ++ dq_i = dq_i.transpose(0, 1).flatten().view(s_, b_, h_, d_).contiguous() ++ b_, s_, h_, d_ = _dk_i.shape ++ _dk_i = _dk_i.transpose(0, 1).flatten().view(s_, b_, h_, d_).contiguous() + +- # Rearrange gradients to (s, b, h, d) +- dq_i = einops.rearrange(dq_i, 'b s h d -> s b h d') +- _dk_i = einops.rearrange(_dk_i, 'b s h d -> s b h d') +- _dv_i = einops.rearrange(_dv_i, 'b s h d -> s b h d') + if pg is None: + dk_i = _dk_i +- dv_i = _dv_i + else: +- # Reduce-scatter gradients if CP > 1 + dk_i = torch.zeros( + (k_i.shape[1] // cp_size, k_i.shape[0], k_i.shape[2], k_i.shape[3]), + device=k_i.device, +- dtype=k_i.dtype, +- ) +- dv_i = torch.zeros( +- (v_i.shape[1] // cp_size, v_i.shape[0], v_i.shape[2], v_i.shape[3]), +- device=v_i.device, +- dtype=v_i.dtype, ++ dtype=torch.float32, + ) + torch.distributed.reduce_scatter_tensor(dk_i, _dk_i, group=pg) +- torch.distributed.reduce_scatter_tensor(dv_i, _dv_i, group=pg) + +- # Collect gradients +- dq.append(dq_i) +- dk.append(dk_i) +- dv.append(dv_i) ++ dq_list.append(dq_i) ++ dk_list.append(dk_i) + + # Concatenate gradients and return +- dq = torch.cat(dq, dim=2) +- dk = torch.cat(dk, dim=2) +- dv = torch.cat(dv, dim=2) +- return dq, dk, dv, None, None, None, None ++ dq = torch.cat(dq_list, dim=2) ++ dk_ = torch.cat(dk_list, dim=2) ++ dk = torch.sum(dk_, dim=2, keepdim=True).to(torch.bfloat16) ++ ++ return dq, dk, None, None, None, None, None, None +\ No newline at end of file +diff --git a/megatron/core/transformer/experimental_attention_variant/dsa.py b/megatron/core/transformer/experimental_attention_variant/dsa.py +index 353b31e9b..221e93500 100644 +--- a/megatron/core/transformer/experimental_attention_variant/dsa.py ++++ b/megatron/core/transformer/experimental_attention_variant/dsa.py +@@ -6,6 +6,7 @@ from dataclasses import dataclass + from typing import Optional, Tuple, Union + + import torch ++import einops + + from megatron.core import parallel_state + from megatron.core.models.common.embeddings import ( +@@ -20,6 +21,7 @@ from megatron.core.transformer.enums import AttnMaskType + from megatron.core.transformer.module import MegatronModule + from megatron.core.transformer.spec_utils import ModuleSpec, build_module + from megatron.core.transformer.transformer_config import TransformerConfig ++from megatron.core.transformer.dot_product_attention_context_parallel import AllGatherComm, AttentionFuncionWithContextParallel + + try: + from fast_hadamard_transform import hadamard_transform +@@ -191,44 +193,72 @@ def compute_dsa_indexer_loss( + Returns: + index_loss: KL divergence loss (scalar). + """ +- sq, b, np, hn = query.size() +- sk = key.size(0) ++ cp_size = parallel_state.get_context_parallel_world_size() + +- # [sq, b, np, hn] -> [b, np, sq, hn] -> [b * np, sq, hn] +- query = query.permute(1, 2, 0, 3).reshape(b * np, sq, hn) +- # [sk, b, np, hn] -> [b, np, hn, sk] -> [b * np, hn, sk] +- key = key.permute(1, 2, 3, 0).reshape(b * np, hn, sk) +- # Compute attention scores [b * np, sq, sk] +- attention_scores = torch.bmm(query.float(), key.float()) * softmax_scale +- # Reshape to [b, np, sq, sk] +- attention_scores = attention_scores.reshape(b, np, sq, sk) ++ if cp_size > 1: ++ sq_local, b, np, hn = query.size() ++ sk_local = key.size(0) ++ sk_global = sk_local * cp_size + +- # causal_mask [sq, sk] +- causal_mask = torch.triu( +- torch.full((sq, sk), float('-inf'), dtype=torch.float32, device=attention_scores.device), +- diagonal=1, +- ) +- # index_mask [b, sq, sk] +- index_mask = torch.full( +- (b, sq, sk), float("-inf"), dtype=torch.float32, device=causal_mask.device +- ).scatter_(-1, topk_indices, 0) +- +- # [b, np, sq, skv] + [1, 1, sq, skv] -> [b, np, sq, skv] +- attention_scores += causal_mask.view(1, 1, sq, sk) +- if sparse_loss: +- # [b, np, sq, sk] + [b, 1, sq, sk] -> [b, np, sq, sk] +- attention_scores += index_mask.view(b, 1, sq, sk) +- # [b, sq, sk] + [b, sq, sk] -> [b, sq, sk] +- index_scores += index_mask +- +- # [b, np, sq, sk] -> [b, np, sq, sk] +- attention_scores = torch.nn.functional.softmax(attention_scores, dim=-1, dtype=torch.float32) +- # [b, sq, sk] -> [b, sq, sk] +- index_scores = torch.nn.functional.softmax(index_scores, dim=-1, dtype=torch.float32) ++ causal_mask = get_causal_mask(sq_local, sk_local, query.device) ++ float_mask = torch.zeros_like(causal_mask, dtype=torch.float32).masked_fill( ++ causal_mask, float('-inf') ++ ) ++ ++ index_mask = torch.full( ++ (b, sq_local, sk_global), float("-inf"), dtype=torch.float32, device=causal_mask.device ++ ).scatter_(-1, topk_indices, 0) + +- # Sum attention scores across heads. +- # [batch, heads, seqlen_q, seqlen_k] -> [batch, seqlen_q, seqlen_k] +- attention_scores = attention_scores.sum(dim=1) ++ float_mask = float_mask.view(1, 1, sq_local, sk_global) ++ float_mask = index_mask.view(b, 1, sq_local, sk_global) + float_mask if sparse_loss else float_mask ++ ++ # because the attention computation is more heavy in memory (has head dim), ++ # we apply cp (all-gather backend) on attention scores computation ++ attention_scores = compute_attention_scores_with_cp(query, key, float_mask, softmax_scale) # [b, sq_local, sk_global] ++ ++ index_scores = torch.nn.functional.softmax(index_scores, dim=-1, dtype=torch.float32) ++ ++ else: ++ sq, b, np, hn = query.size() ++ sk = key.size(0) ++ ++ # [sq, b, np, hn] -> [b, np, sq, hn] -> [b * np, sq, hn] ++ query = query.permute(1, 2, 0, 3).reshape(b * np, sq, hn) ++ # [sk, b, np, hn] -> [b, np, hn, sk] -> [b * np, hn, sk] ++ key = key.permute(1, 2, 3, 0).reshape(b * np, hn, sk) ++ # Compute attention scores [b * np, sq, sk] ++ attention_scores = torch.bmm(query.float(), key.float()) * softmax_scale ++ # Reshape to [b, np, sq, sk] ++ attention_scores = attention_scores.reshape(b, np, sq, sk) ++ ++ # causal_mask [sq, sk] ++ causal_mask = torch.triu( ++ torch.full((sq, sk), float('-inf'), dtype=torch.float32, device=attention_scores.device), ++ diagonal=1, ++ ) ++ # index_mask [b, sq, sk] ++ index_mask = torch.full( ++ (b, sq, sk), float("-inf"), dtype=torch.float32, device=causal_mask.device ++ ).scatter_(-1, topk_indices, 0) ++ ++ # [b, np, sq, skv] + [1, 1, sq, skv] -> [b, np, sq, skv] ++ attention_scores += causal_mask.view(1, 1, sq, sk) ++ if sparse_loss: ++ # [b, np, sq, sk] + [b, 1, sq, sk] -> [b, np, sq, sk] ++ attention_scores += index_mask.view(b, 1, sq, sk) ++ # [b, sq, sk] + [b, sq, sk] -> [b, sq, sk] ++ index_scores += index_mask ++ ++ # [b, np, sq, sk] -> [b, np, sq, sk] ++ attention_scores = torch.nn.functional.softmax(attention_scores, dim=-1, dtype=torch.float32) ++ # [b, sq, sk] -> [b, sq, sk] ++ index_scores = torch.nn.functional.softmax(index_scores, dim=-1, dtype=torch.float32) ++ ++ # Sum attention scores across heads. ++ # [batch, heads, seqlen_q, seqlen_k] -> [batch, seqlen_q, seqlen_k] ++ attention_scores = attention_scores.sum(dim=1) ++ ++ # Common part + if pg_collection.tp.size() > 1: + # attention scores are scattered to TP ranks in head dimension. + torch.distributed.all_reduce(attention_scores.contiguous(), group=pg_collection.tp) +@@ -251,6 +281,56 @@ def compute_dsa_indexer_loss( + + return indexer_loss + ++def compute_attention_scores_with_cp(q, k, attn_bias, scale, heads_k_stride = 1): ++ """ ++ compute attention scores of q_local @ k_global with CP all-gather backend ++ parallel on n_heads dimension ++ """ ++ pg = parallel_state.get_context_parallel_group() ++ cp_size = parallel_state.get_context_parallel_world_size() ++ ++ sq_local, b, nheads, hn_q = q.shape ++ sk_local, _, nheads_k, hn_k = k.shape ++ sk_global = sk_local * cp_size ++ ++ assert nheads % nheads_k == 0 and nheads_k % heads_k_stride == 0 ++ ++ comm = AllGatherComm(group=pg) ++ attns = torch.zeros(b, heads_k_stride, sq_local, sk_global, dtype=q.dtype, device=q.device) ++ ++ k_buffer = torch.empty( ++ (sk_global, b, heads_k_stride, hn_k), ++ dtype=k.dtype, ++ device=k.device ++ ) ++ k_buffer_copy = torch.empty_like(k_buffer) ++ k_0 = k[:, :, :heads_k_stride].contiguous() ++ comm.all_gather(k_buffer_copy, k_0) ++ ++ attn_bias = attn_bias.expand(-1, heads_k_stride * (nheads // nheads_k), -1, -1) ++ ++ for i in range(0, nheads_k, heads_k_stride): ++ comm.wait() ++ k_buffer, k_buffer_copy = k_buffer_copy, k_buffer ++ if i < nheads_k - heads_k_stride: ++ kvsl = i + heads_k_stride ++ kvsr = kvsl + heads_k_stride ++ send_k = k[:, :, kvsl:kvsr].contiguous() ++ comm.all_gather(k_buffer_copy, send_k) ++ q_i = q[:, :, i * nheads // nheads_k : (i + heads_k_stride) * nheads // nheads_k] ++ k_i = k_buffer ++ ++ _q_i = einops.rearrange(q_i, 's b h d -> b h s d') ++ _k_i = einops.rearrange(k_i, 's b h d -> b h d s') ++ attn_i = torch.matmul(_q_i.float(), _k_i.float()) * scale + attn_bias ++ attn_i = torch.nn.functional.softmax(attn_i, dim=-1, dtype=torch.float32) ++ ++ attns = attns + attn_i ++ ++ attns = torch.sum(attns, dim=1) ++ ++ return attns ++ + + class DSAIndexerLossAutoScaler(torch.autograd.Function): + """An AutoScaler that triggers the backward pass and scales the grad for indexer loss. +@@ -496,7 +576,15 @@ class DSAIndexer(MegatronModule): + # Compute attention scores: q @ k^T + # [seqlen_q, batch, index_n_heads, index_head_dim] @ [seqlen_k, batch, index_head_dim]^T + # -> [seqlen_q, batch, index_n_heads, seqlen_k] +- index_scores = torch.einsum('sbhd,tbd->sbht', q.float(), k.float()) ++ cp_size = parallel_state.get_context_parallel_world_size() ++ if cp_size == 1: ++ index_scores = torch.einsum('sbhd,tbd->sbht', q.float(), k.float()) ++ else: ++ # because k is small (only 1 head), do just one all_gather ++ k_buffer = torch.cat(torch.distributed.nn.functional.all_gather(k, group=self.pg_collection.cp), dim=0) # k_buffer: [[chunk_0, chunk_3, chunk_1, chunk_2], batch, index_head_dim] ++ index_scores = torch.einsum('sbhd,tbd->sbht', q.float(), k_buffer.float()) # [s_q_local, batch, index_n_heads, s_k_global] ++ # rank 0: q [chunk_0, chunk_3], k[chunk_0, chunk_3, chunk_1, chunk_2] ++ # rank 1: q [chunk_1, chunk_2], k[chunk_0, chunk_3, chunk_1, chunk_2] + + # Apply ReLU activation. + index_scores = torch.relu(index_scores) +@@ -546,14 +634,10 @@ class DSAIndexer(MegatronModule): + None, None, x, self.config, packed_seq_params + ) + if self.config.rope_type == "rope": +- rotary_pos_emb = self.rotary_pos_emb( +- rotary_seq_len, packed_seq_params=packed_seq_params +- ) ++ rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len, packed_seq_params=packed_seq_params) + mscale = 1.0 + else: +- rotary_pos_emb, mscale = self.rotary_pos_emb( +- rotary_seq_len, packed_seq_params=packed_seq_params +- ) ++ rotary_pos_emb, mscale = self.rotary_pos_emb(rotary_seq_len, packed_seq_params=packed_seq_params) + + # ========================================= + # Gather inputs if sp is enabled +@@ -610,7 +694,9 @@ class DSAIndexer(MegatronModule): + # ========================================= + # Select top-k indices + # ========================================= +- topk_k = min(self.index_topk, seqlen) ++ cp_size = parallel_state.get_context_parallel_world_size() ++ seqlen_k_global = k.shape[0] * cp_size ++ topk_k = min(self.index_topk, seqlen_k_global) + # [batch, seqlen, index_topk] + topk_indices = index_scores.topk(topk_k, dim=-1)[1] + +@@ -691,6 +777,48 @@ def unfused_dsa_fn(query, key, value, topk_indices, softmax_scale): + output = output.reshape(sq, b, np * hnv) + return output + ++def get_causal_mask(sq, skv, device): ++ cp_size = parallel_state.get_context_parallel_world_size() ++ cp_rank = parallel_state.get_context_parallel_rank() ++ skv_global = skv * cp_size ++ ++ if cp_size == 1: ++ causal_mask = torch.triu( ++ torch.ones((sq, skv), dtype=torch.bool, device=device), ++ diagonal=1, ++ ) ++ else: ++ sq_half = sq // 2 ++ global_q_positions = torch.cat([ ++ torch.arange(cp_rank * sq_half, (cp_rank + 1) * sq_half, device=device), ++ torch.arange(skv_global - (cp_rank + 1) * sq_half, skv_global - cp_rank * sq_half, device=device) ++ ]) ++ ++ global_k_positions = torch.arange(skv_global, device=device) ++ # [sq, 1] < [1, skv_global] -> [sq, skv_global] ++ causal_mask = global_q_positions.unsqueeze(1) < global_k_positions.unsqueeze(0) ++ # convert to zz mask ++ chunked = causal_mask.chunk(dim=1, chunks=cp_size * 2) ++ causal_mask = [_x for _p in zip(chunked[:cp_size], reversed(chunked[cp_size:])) for _x in _p] ++ causal_mask = torch.cat(causal_mask, dim=1) ++ ++ return causal_mask ++ ++def unfused_dsa_fn_with_cp(query, key, dim_v, topk_indices, softmax_scale): ++ pg = parallel_state.get_context_parallel_group() ++ sq, b, np, hn = query.size() ++ skv = key.size(0) ++ ++ topk = topk_indices.shape[-1] ++ topk_indices = topk_indices.unsqueeze(1) ++ topk_indices = topk_indices.expand(-1, key.shape[2], -1, -1).contiguous().to(torch.int32) ++ causal_masks = get_causal_mask(sq, skv, query.device) ++ causal_masks = causal_masks[None, None, :, :] ++ causal_masks = causal_masks.expand(b, key.shape[2], -1, -1).contiguous() ++ output = AttentionFuncionWithContextParallel.apply( ++ query, key, dim_v, topk_indices, causal_masks, 0.0, softmax_scale, pg ++ ) ++ return output.reshape(sq, b, np * dim_v) + + class DSAttention(MegatronModule): + """ +@@ -733,7 +861,6 @@ class DSAttention(MegatronModule): + self, + query: torch.Tensor, + key: torch.Tensor, +- value: torch.Tensor, + x: torch.Tensor, + qr: torch.Tensor, + attention_mask: torch.Tensor, +@@ -747,7 +874,6 @@ class DSAttention(MegatronModule): + Args: + query: Query tensor [sq, b, np, hn]. + key: Key tensor [skv, b, np, hn]. +- value: Value tensor [skv, b, np, hnv]. + x: Original hidden states [sq, b, hidden_size]. + qr: Low-rank query representation [sq, b, q_lora_rank]. + attention_mask: Attention mask tensor [b, 1, sq, sk]. +@@ -758,9 +884,11 @@ class DSAttention(MegatronModule): + Returns: + output: Output tensor [sq, b, hidden_size] + """ +- sq, b, np, hn = query.size() +- skv = key.size(0) +- hnv = value.size(3) ++ dim_v = self.config.kv_lora_rank ++ # torch.Size([128, 1, 64, 576]) ++ sq, b, nheads, dim = query.size() ++ # torch.Size([128, 1, 1, 576]) ++ skv, _, kv_groups, _ = key.shape + + # Detach x and qr to prevent gradients of indexer from flowing back to the main model. + x = x.detach() +@@ -772,18 +900,17 @@ class DSAttention(MegatronModule): + # Generate upper triangular mask with -inf above diagonal, 0 elsewhere + # torch.triu with diagonal=1 creates upper triangular matrix (excluding main diagonal) + # float_mask [sq, skv] +- float_mask = torch.triu( +- torch.full((sq, skv), float('-inf'), dtype=torch.float32, device=x.device), +- diagonal=1, +- ) ++ mask = get_causal_mask(sq, skv, x.device) + else: +- assert attention_mask.shape == (b, 1, sq, skv), 'attention_mask shape mismatch' ++ skv_global = skv * parallel_state.get_context_parallel_world_size() ++ assert attention_mask.shape == (b, 1, sq, skv_global), 'attention_mask shape mismatch' + # [b, 1, sq, skv] -> [b, sq, skv] + mask = attention_mask.squeeze() +- # float_mask [b, sq, skv] +- float_mask = torch.zeros_like(mask, dtype=torch.float32).masked_fill( +- mask, float('-inf') +- ) ++ ++ # float_mask [b, sq, skv] ++ float_mask = torch.zeros_like(mask, dtype=torch.float32).masked_fill( ++ mask, float('-inf') ++ ) + + # =================================== + # Get index scores and top-k indices +@@ -795,32 +922,6 @@ class DSAttention(MegatronModule): + # =================================== + # Run sparse attention kernel + # =================================== +- output = unfused_dsa_fn(query, key, value, topk_indices, self.softmax_scale) +- +- # =================================== +- # Attach indexer loss +- # =================================== +- if self.training and torch.is_grad_enabled(): +- # Compute KL divergence loss between indexer scores and true attention scores +- indexer_loss_coeff = getattr(self.config, 'dsa_indexer_loss_coeff', 0.0) +- indexer_loss = compute_dsa_indexer_loss( +- index_scores, +- topk_indices, +- query.detach(), +- key.detach(), +- self.softmax_scale, +- indexer_loss_coeff, +- getattr(self.config, "dsa_indexer_use_sparse_loss", False), +- self.indexer.pg_collection, +- ) +- # Save indexer loss for logging +- if indexer_loss_coeff > 0: +- DSAIndexerLossLoggingHelper.save_loss_to_tracker( +- loss=indexer_loss, +- layer_number=self.layer_number, +- num_layers=self.config.num_layers, +- ) +- # Attach loss to output +- output = DSAIndexerLossAutoScaler.apply(output, indexer_loss) +- ++ output = unfused_dsa_fn_with_cp(query, key, dim_v, topk_indices, self.softmax_scale) ++ + return output +diff --git a/megatron/core/transformer/moe/moe_utils.py b/megatron/core/transformer/moe/moe_utils.py +index 28cff06f5..befb5c124 100644 +--- a/megatron/core/transformer/moe/moe_utils.py ++++ b/megatron/core/transformer/moe/moe_utils.py +@@ -586,6 +586,9 @@ def topk_routing_with_score_function( + ) + else: + return torch.topk(scores, k=topk, dim=1) ++ ++ from miles.utils.routing_replay import get_routing_replay_compute_topk ++ compute_topk = get_routing_replay_compute_topk(compute_topk) + + if score_function == "softmax": + if use_pre_softmax: +diff --git a/megatron/core/transformer/moe/router.py b/megatron/core/transformer/moe/router.py +index 16fc9d9af..3c50a9516 100644 +--- a/megatron/core/transformer/moe/router.py ++++ b/megatron/core/transformer/moe/router.py +@@ -200,6 +200,9 @@ class TopKRouter(Router): + else: + self.global_tokens_per_expert = None + self.ga_steps = None ++ ++ from miles.utils.routing_replay import register_routing_replay ++ register_routing_replay(self) + + def _maintain_float32_expert_bias(self): + """ +diff --git a/megatron/core/transformer/multi_latent_attention.py b/megatron/core/transformer/multi_latent_attention.py +index ed90fdffa..7a7597d66 100644 +--- a/megatron/core/transformer/multi_latent_attention.py ++++ b/megatron/core/transformer/multi_latent_attention.py +@@ -15,7 +15,7 @@ except ImportError: + HAVE_EINOPS = False + + +-from megatron.core import tensor_parallel ++from megatron.core import parallel_state, tensor_parallel + from megatron.core.models.common.embeddings import ( + RotaryEmbedding, + YarnRotaryEmbedding, +@@ -312,7 +312,6 @@ class MultiLatentAttention(Attention): + core_attn_out = self.core_attention( + query, + key, +- value, + x=hidden_states, + qr=q_compressed, + attention_mask=attention_mask, +@@ -371,6 +370,19 @@ class MultiLatentAttention(Attention): + self.qkv_up_checkpoint.discard_output_and_register_recompute(core_attn_out) + self.qkv_up_checkpoint = None + ++ s_, b_ = core_attn_out.size(0), core_attn_out.size(1) ++ core_attn_out = core_attn_out.view( ++ s_, b_, ++ self.num_attention_heads_per_partition, ++ self.config.kv_lora_rank ++ ) ++ ++ # einsum: "sbhk,hdk->sbhd" ++ core_attn_out = torch.einsum("sbhk,hdk->sbhd", core_attn_out, self.up_v_weight_) ++ core_attn_out = core_attn_out.contiguous() ++ core_attn_out = core_attn_out.view(s_, b_, -1) ++ core_attn_out = core_attn_out.contiguous() ++ + # ================= + # Output. [sq, b, h] + # ================= +@@ -555,11 +567,7 @@ class MLASelfAttention(MultiLatentAttention): + assert ( + hidden_states.ndim == 3 + ), f"hidden_states should be 3D, [s, b, n*h], got {hidden_states.ndim}D" +- if packed_seq_params is not None: +- assert ( +- packed_seq_params.local_cp_size is None +- ), "hybrid_context_parallel is not supported with MLA yet and is planned for future. \ +- Please disable hybrid_context_parallel." ++ + + inference_context = deprecate_inference_params(inference_context, inference_params) + +@@ -576,9 +584,7 @@ class MLASelfAttention(MultiLatentAttention): + rotary_pos_sin = None + packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd' + if self.config.rope_type == "rope": +- rotary_pos_emb = self.rotary_pos_emb( +- rotary_seq_len, packed_seq_params=packed_seq_params +- ) ++ rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len, packed_seq_params=packed_seq_params) + else: + if self.config.apply_rope_fusion: + rotary_pos_cos, rotary_pos_sin = self.rotary_pos_emb.get_cached_cos_sin( +@@ -591,11 +597,9 @@ class MLASelfAttention(MultiLatentAttention): + and fused_apply_mla_rope_for_kv is not None + ), "Fused MLA RoPE apply is not imported successfully" + else: +- rotary_pos_emb, mscale = self.rotary_pos_emb( +- rotary_seq_len, packed_seq_params=packed_seq_params +- ) ++ rotary_pos_emb, mscale = self.rotary_pos_emb(rotary_seq_len, packed_seq_params=packed_seq_params) + +- if packed_seq_params is not None and packed_seq_params.qkv_format == 'thd': ++ if packed_seq_params is not None: + if packed_seq_params.cu_seqlens_q_padded is not None: + cu_seqlens_q = packed_seq_params.cu_seqlens_q_padded + else: +@@ -867,6 +871,98 @@ class MLASelfAttention(MultiLatentAttention): + + return query, key, value + ++ def mla_absorb(q_compressed, kv_compressed, k_pos_emb, rotary_pos_emb): ++ if self.config.q_lora_rank is not None: ++ # q_compressed: [num_tokens, q_lora_rank] ++ # q: [num_tokens, n * (qk_head_dim + qk_pos_emb_head_dim)] ++ q, _ = self.linear_q_up_proj(q_compressed) ++ else: ++ # q_compressed: [num_tokens, hidden_size] ++ # q: [num_tokens, n * (qk_head_dim + qk_pos_emb_head_dim)] ++ q, _ = self.linear_q_proj(q_compressed) ++ ++ # q: [num_tokens, n, q_head_dim] ++ q = q.view(*q.size()[:-1], self.num_attention_heads_per_partition, self.q_head_dim) ++ ++ # [num_tokens, qk_pos_emb_head_dim] -> [num_tokens, 1, qk_pos_emb_head_dim] ++ k_pos_emb = torch.unsqueeze(k_pos_emb, -2) ++ ++ if self.config.apply_rope_fusion: ++ raise NotImplementedError( ++ "RoPE fusion is not yet supported with absorption training. " ++ "Please set apply_rope_fusion=False." ++ ) ++ else: ++ q_len = q.size()[0] ++ if inference_context is not None: ++ # add offset to the sequence start for inference ++ sequence_start = inference_context.sequence_len_offset ++ sequence_end = sequence_start + q_len ++ rotary_pos_emb = rotary_pos_emb[sequence_start:sequence_end] ++ elif packed_seq_params is None or self.config.context_parallel_size == 1: ++ rotary_pos_emb = rotary_pos_emb[0:q_len] ++ ++ # q_no_pe: [num_tokens, n, qk_head_dim] ++ # q_pos_emb: [num_tokens, n, qk_pos_emb_head_dim] ++ q_no_pe, q_pos_emb = torch.split( ++ q, [self.config.qk_head_dim, self.config.qk_pos_emb_head_dim], dim=-1 ++ ) ++ ++ # q_no_pe: [num_tokens, n, qk_head_dim] ++ # up_k_weight: [n, qk_head_dim, kv_lora_rank] ++ # q_absorbed: [num_tokens, n, kv_lora_rank] ++ q_absorbed = torch.einsum("...hd,hdk->...hk", q_no_pe, self.up_k_weight_) ++ ++ # TODO: Does it match ZZ? SP does not need but CP needs ++ if self.config.sequence_parallel: ++ kv_compressed = gather_from_sequence_parallel_region(kv_compressed, group=self.tp_group) ++ ++ # kv_compressed: [num_tokens, kv_lora_rank] ++ if kv_compressed.ndim == 3: # [s, b, kv_lora_rank] ++ k_content = kv_compressed.unsqueeze(2).expand( ++ -1, -1, 1, -1 ++ ) ++ else: # [t, kv_lora_rank] for packed sequence ++ k_content = kv_compressed.unsqueeze(1).expand( ++ -1, 1, -1 ++ ) ++ ++ # q_pos_emb: [num_tokens, n, qk_pos_emb_head_dim] ++ q_pos_emb = apply_rotary_pos_emb( ++ q_pos_emb, ++ rotary_pos_emb, ++ config=self.config, ++ cu_seqlens=cu_seqlens_q, ++ mscale=mscale, ++ cp_group=self.pg_collection.cp, ++ ) ++ # k_pos_emb: [num_tokens, 1, qk_pos_emb_head_dim] ++ k_pos_emb = apply_rotary_pos_emb( ++ k_pos_emb, ++ rotary_pos_emb, ++ config=self.config, ++ cu_seqlens=cu_seqlens_kv, ++ mscale=mscale, ++ cp_group=self.pg_collection.cp, ++ ) ++ ++ # query: [num_tokens, n, kv_lora_rank + qk_pos_emb_head_dim] ++ query = torch.cat([q_absorbed, q_pos_emb], dim=-1) ++ ++ # key: [num_tokens, n, kv_lora_rank + qk_pos_emb_head_dim] ++ if k_pos_emb.ndim == 4: ++ k_pos_emb = k_pos_emb.expand(-1, -1, 1, -1) ++ else: ++ assert k_pos_emb.ndim == 3 ++ k_pos_emb = k_pos_emb.expand(-1, 1, -1) ++ ++ key = torch.cat([k_content, k_pos_emb], dim=-1) ++ ++ query = query.contiguous() ++ key = key.contiguous() ++ ++ return query, key ++ + if self.recompute_up_proj: + quantization = self.config.fp8 or self.config.fp4 + self.qkv_up_checkpoint = tensor_parallel.CheckpointWithoutOutput(fp8=quantization) +@@ -882,9 +978,10 @@ class MLASelfAttention(MultiLatentAttention): + q_compressed, kv_compressed, k_pos_emb, rotary_pos_emb + ) + else: +- query, key, value = qkv_up_proj_and_rope_apply( ++ query, key = mla_absorb( + q_compressed, kv_compressed, k_pos_emb, rotary_pos_emb + ) ++ value = None + + if return_compressed_tensors: + return query, key, value, q_compressed, kv_compressed +@@ -1128,3 +1225,26 @@ class MLASelfAttention(MultiLatentAttention): + ) + + return weight_kv_updated ++ ++ @property ++ def up_k_weight_(self): ++ # linear_kv_up_proj.weight: [num_heads_per_partition * (qk_head_dim + v_head_dim), kv_lora_rank] ++ weight = self.linear_kv_up_proj.weight ++ weight_reshaped = weight.view( ++ self.num_attention_heads_per_partition, ++ self.config.qk_head_dim + self.config.v_head_dim, ++ self.config.kv_lora_rank, ++ ) ++ # [num_heads_per_partition, qk_head_dim, kv_lora_rank] ++ return weight_reshaped[:, :self.config.qk_head_dim, :] ++ ++ @property ++ def up_v_weight_(self): ++ weight = self.linear_kv_up_proj.weight ++ weight_reshaped = weight.view( ++ self.num_attention_heads_per_partition, ++ self.config.qk_head_dim + self.config.v_head_dim, ++ self.config.kv_lora_rank, ++ ) ++ # [num_heads_per_partition, v_head_dim, kv_lora_rank] ++ return weight_reshaped[:, self.config.qk_head_dim:, :] +\ No newline at end of file +diff --git a/megatron/core/transformer/tilelang_kernel/__init__.py b/megatron/core/transformer/tilelang_kernel/__init__.py +new file mode 100644 +index 000000000..c63794256 +--- /dev/null ++++ b/megatron/core/transformer/tilelang_kernel/__init__.py +@@ -0,0 +1,10 @@ ++# Code is adopted from tilelang/examples/deepseek_v32 ++# transformer/tilelang_kernel/__init__.py ++ ++from .sparse_mla_fwd import sparse_mla_fwd_interface ++from .sparse_mla_bwd import sparse_mla_bwd ++ ++__all__ = [ ++ "sparse_mla_fwd_interface", ++ "sparse_mla_bwd", ++] +\ No newline at end of file +diff --git a/megatron/core/transformer/tilelang_kernel/sparse_mla_bwd.py b/megatron/core/transformer/tilelang_kernel/sparse_mla_bwd.py +new file mode 100644 +index 000000000..83a259efa +--- /dev/null ++++ b/megatron/core/transformer/tilelang_kernel/sparse_mla_bwd.py +@@ -0,0 +1,272 @@ ++# ruff: noqa ++import tilelang ++from tilelang import language as T ++import torch ++ ++ ++@tilelang.jit(out_idx=[-1]) ++def preprocess( ++ B, ++ S, ++ H, ++ D, ++ block_ND=32, ++ num_stages=5, ++ dtype=T.bfloat16, ++ accum_dtype=T.float32, ++): ++ assert dtype == T.bfloat16 ++ assert accum_dtype == T.float32 ++ shape = [B, S, H, D] ++ ++ @T.prim_func ++ def preprocess_kernel( ++ O: T.Tensor(shape, dtype), ++ dO: T.Tensor(shape, dtype), ++ Delta: T.Tensor([B, S, H], accum_dtype), ++ ): ++ with T.Kernel(H, T.ceildiv(S, block_ND), B) as (bx, by, bz): ++ o = T.alloc_fragment([block_ND, block_ND], accum_dtype) ++ do = T.alloc_fragment([block_ND, block_ND], accum_dtype) ++ delta = T.alloc_fragment([block_ND], accum_dtype) ++ acc = T.alloc_fragment([block_ND, block_ND], accum_dtype) ++ T.clear(acc) ++ for k in T.Pipelined(T.ceildiv(D, block_ND), num_stages=num_stages): ++ T.copy(O[bz, by * block_ND : (by + 1) * block_ND, bx, k * block_ND : (k + 1) * block_ND], o) ++ T.copy(dO[bz, by * block_ND : (by + 1) * block_ND, bx, k * block_ND : (k + 1) * block_ND], do) ++ for i, j in T.Parallel(block_ND, block_ND): ++ acc[i, j] += o[i, j] * do[i, j] ++ T.reduce_sum(acc, delta, 1) ++ T.copy(delta, Delta[bz, by * block_ND : (by + 1) * block_ND, bx]) ++ ++ return preprocess_kernel ++ ++ ++@tilelang.jit(out_idx=[-1]) ++def postprocess( ++ B, ++ S_kv, ++ D, ++ D_tail, ++ kv_group=1, ++ block_N=64, ++ threads=256, ++ dtype=T.bfloat16, ++ accum_dtype=T.float32, ++): ++ assert dtype == T.bfloat16 ++ assert accum_dtype == T.float32 ++ dkv_shape = [B, S_kv, kv_group, D + D_tail] ++ ++ @T.prim_func ++ def postprocess_kernel( ++ dKV: T.Tensor(dkv_shape, accum_dtype), ++ dKV_out: T.Tensor(dkv_shape, dtype), ++ ): ++ with T.Kernel(T.ceildiv(S_kv, block_N), kv_group, B, threads=threads) as (bx, by, bz): ++ T.copy( ++ dKV[bz, bx * block_N : (bx + 1) * block_N, by, :], ++ dKV_out[bz, bx * block_N : (bx + 1) * block_N, by, :], ++ ) ++ ++ return postprocess_kernel ++ ++ ++@tilelang.jit( ++ out_idx=[-2], ++ pass_configs={ ++ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, ++ tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, ++ tilelang.PassConfigKey.TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE: True, ++ }, ++) ++def bwd( ++ B, ++ S, ++ S_kv, ++ H, ++ D, ++ D_tail, ++ topk, ++ kv_group=1, ++ sm_scale=None, ++ is_causal=True, ++ block_size=32, ++ num_stages=0, ++ threads=128, ++ indices_dtype=T.int32, ++ dtype=T.bfloat16, ++ accum_dtype=T.float32, ++ masks_dtype=T.bool, ++): ++ assert is_causal == True, "non-casual is not supported now" ++ assert topk % block_size == 0, "otherwise will load some index=0 thus causing wrong kv to be loaded" ++ assert dtype == T.bfloat16 ++ assert accum_dtype == T.float32 ++ assert indices_dtype == T.int32 ++ ++ if sm_scale is None: ++ sm_scale = (D + D_tail) ** (-0.5) ++ sm_scale_mul_reciprocal_log2 = sm_scale * 1.44269504 # log2(e) ++ ++ H_kv = H // kv_group ++ q_shape = [B, S, H, D + D_tail] ++ k_shape = [B, S_kv, kv_group, D + D_tail] ++ o_shape = [B, S, H, D] ++ indices_shape = [B, S, kv_group, topk] ++ delta_shape = [B, S, H] ++ lse_shape = [B, S, H] ++ masks_shape = [B, S, kv_group, S_kv] ++ assert indices_dtype == T.int32 ++ assert dtype == T.bfloat16 ++ assert accum_dtype == T.float32 ++ ++ H = H_kv ++ padded_H = max(tilelang.math.next_power_of_2(H_kv), 16) ++ block_H = min(64, padded_H) ++ assert padded_H % block_H == 0 ++ NH = padded_H // block_H ++ BS = block_size ++ NS = tilelang.cdiv(topk, block_size) ++ ++ split_store = 2 ++ ++ @T.prim_func ++ def sparse_mla_bwd_kernel( ++ Q: T.Tensor(q_shape, dtype), ++ KV: T.Tensor(k_shape, dtype), ++ dO: T.Tensor(o_shape, dtype), ++ Indices: T.Tensor(indices_shape, indices_dtype), ++ Masks: T.Tensor(masks_shape, masks_dtype), ++ Lse: T.Tensor(lse_shape, accum_dtype), ++ Delta: T.Tensor(delta_shape, accum_dtype), ++ dQ: T.Tensor(q_shape, dtype), ++ dKV: T.Tensor(k_shape, accum_dtype), ++ ): ++ with T.Kernel(S, B, kv_group * NH, threads=threads) as (s_i, by, bz): ++ Q_shared = T.alloc_shared([block_H, D], dtype) ++ Q_tail_shared = T.alloc_shared([block_H, D_tail], dtype) ++ KV_shared = T.alloc_shared([BS, D], dtype) ++ KV_tail_shared = T.alloc_shared([BS, D_tail], dtype) ++ dO_shared = T.alloc_shared([block_H, D], dtype) ++ mask = T.alloc_fragment([BS], "bool") ++ ++ P_shared_cast = T.alloc_shared([block_H, BS], dtype) ++ dP_shared_cast = T.alloc_shared([block_H, BS], dtype) ++ dQ_shared = T.alloc_shared([block_H, D], dtype) ++ dQ_tail_shared = T.alloc_shared([block_H, D_tail], dtype) ++ ++ acc_p = T.alloc_fragment([block_H, BS], accum_dtype) ++ acc_dp = T.alloc_fragment([block_H, BS], accum_dtype) ++ acc_dq = T.alloc_fragment([block_H, D], accum_dtype) ++ acc_dq_tail = T.alloc_fragment([block_H, D_tail], accum_dtype) ++ acc_dkv = T.alloc_fragment([BS, D], accum_dtype) ++ acc_dkv_tail = T.alloc_fragment([BS, D_tail], accum_dtype) ++ acc_dkv_shared = T.alloc_shared([BS // split_store, D], accum_dtype) ++ acc_dkv_tail_shared = T.alloc_shared([BS // split_store, D_tail], accum_dtype) ++ ++ T.copy(Q[by, s_i, bz * block_H : (bz + 1) * block_H, :D], Q_shared) ++ T.copy(Q[by, s_i, bz * block_H : (bz + 1) * block_H, D:], Q_tail_shared) ++ T.copy(dO[by, s_i, bz * block_H : (bz + 1) * block_H, :D], dO_shared) ++ ++ T.clear(acc_dq) ++ T.clear(acc_dq_tail) ++ ++ # Process each block of indices ++ for i_i in T.Pipelined(NS, num_stages=num_stages): ++ # Compute attention scores ++ for bi_i in T.Parallel(BS): ++ mask[bi_i] = Masks[by, s_i, bz // NH, Indices[by, s_i, bz // NH, i_i * BS + bi_i]] ++ ++ for h_i, bi_i in T.Parallel(block_H, BS): ++ acc_p[h_i, bi_i] = T.if_then_else(mask[bi_i], -T.infinity(acc_p.dtype), 0) ++ ++ # Load KV, V for this block of indices ++ for bi_i, d_i in T.Parallel(BS, D): ++ KV_shared[bi_i, d_i] = KV[by, Indices[by, s_i, bz // NH, i_i * BS + bi_i], bz // NH, d_i] ++ ++ T.gemm(Q_shared, KV_shared, acc_p, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) ++ ++ for bi_i, d_i in T.Parallel(BS, D_tail): ++ KV_tail_shared[bi_i, d_i] = KV[by, Indices[by, s_i, bz // NH, i_i * BS + bi_i], bz // NH, D + d_i] ++ T.gemm(Q_tail_shared, KV_tail_shared[:, :D_tail], acc_p, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) ++ ++ for h_i, bi_i in T.Parallel(block_H, BS): ++ acc_p[h_i, bi_i] = T.exp2(acc_p[h_i, bi_i] * sm_scale_mul_reciprocal_log2 - Lse[by, s_i, bz * block_H + h_i]) ++ ++ T.copy(acc_p, P_shared_cast) ++ ++ T.gemm(dO_shared, KV_shared, acc_dp, transpose_B=True, policy=T.GemmWarpPolicy.FullCol, clear_accum=True) ++ ++ for h_i, bi_i in T.Parallel(block_H, BS): ++ acc_dp[h_i, bi_i] = acc_p[h_i, bi_i] * (acc_dp[h_i, bi_i] - Delta[by, s_i, bz * block_H + h_i]) * sm_scale ++ ++ T.copy(acc_dp, dP_shared_cast) ++ T.gemm(dP_shared_cast, KV_shared, acc_dq, policy=T.GemmWarpPolicy.FullCol) ++ T.gemm(dP_shared_cast, KV_tail_shared, acc_dq_tail, policy=T.GemmWarpPolicy.FullCol) ++ ++ T.gemm(dP_shared_cast, Q_shared, acc_dkv, transpose_A=True, policy=T.GemmWarpPolicy.FullCol, clear_accum=True) ++ T.gemm(P_shared_cast, dO_shared, acc_dkv, transpose_A=True, policy=T.GemmWarpPolicy.FullCol) ++ ++ T.clear(acc_dkv_tail) ++ T.gemm(dP_shared_cast, Q_tail_shared, acc_dkv_tail, transpose_A=True, policy=T.GemmWarpPolicy.FullCol) ++ ++ for s in range(split_store): ++ for bi_i, d_i in T.Parallel(BS, D): ++ if bi_i < BS // split_store: ++ acc_dkv_shared[bi_i, d_i] = acc_dkv[bi_i + s * (BS // split_store), d_i] ++ ++ for bi_i, d_i in T.Parallel(BS, D_tail): ++ if bi_i < BS // split_store: ++ acc_dkv_tail_shared[bi_i, d_i] = acc_dkv_tail[bi_i + s * (BS // split_store), d_i] ++ ++ for bi_i, d_i in T.Parallel(BS // split_store, D // 4): ++ T.atomic_addx4( ++ dKV[by, Indices[by, s_i, bz // NH, i_i * BS + bi_i + s * (BS // split_store)], bz // NH, d_i * 4], ++ acc_dkv_shared[bi_i, d_i * 4], ++ ) ++ ++ # Atomically update dKV, dKV_tail tensors ++ for bi_i, d_i in T.Parallel(BS // split_store, D_tail // 4): ++ T.atomic_addx4( ++ dKV[by, Indices[by, s_i, bz // NH, i_i * BS + bi_i + s * (BS // split_store)], bz // NH, D + d_i * 4], ++ acc_dkv_tail_shared[bi_i, d_i * 4], ++ ) ++ ++ # Store the accumulated dQ ++ T.copy(acc_dq, dQ_shared) ++ T.copy(acc_dq_tail[:, :D_tail], dQ_tail_shared) ++ ++ T.copy(dQ_shared, dQ[by, s_i, bz * block_H : (bz + 1) * block_H, :D]) ++ T.copy(dQ_tail_shared, dQ[by, s_i, bz * block_H : (bz + 1) * block_H, D:]) ++ ++ return sparse_mla_bwd_kernel ++ ++ ++def sparse_mla_bwd(q, kv, o, do, indices, masks, lse, dim_v, sm_scale=None, is_casual=True, return_kernel=False, delta=None): ++ assert q.is_contiguous() ++ assert kv.is_contiguous() ++ assert indices.is_contiguous() ++ assert lse.is_contiguous() ++ B, S, H, dim_plus_tail_dim = q.shape ++ _, S_kv, kv_group, _ = kv.shape ++ assert kv.shape[-1] == dim_plus_tail_dim ++ assert kv.shape[0] == B ++ # dim should be assigned ++ D = dim_v ++ ++ D_tail = dim_plus_tail_dim - D ++ topk = indices.shape[-1] ++ assert indices.shape == (B, S, kv_group, topk) ++ assert lse.shape == (B, S, H) ++ ++ # Get kernels ++ preprocess_kernel = preprocess(B, S, H, D) ++ bwd_kernel = bwd(B, S, S_kv, H, D, D_tail, topk, kv_group, sm_scale, is_casual) ++ ++ if delta is None: ++ delta = preprocess_kernel(o, do) ++ dkv = torch.zeros_like(kv, dtype=torch.float32) ++ dq = bwd_kernel(q, kv, do, indices, masks, lse, delta, dkv) ++ ++ return dq, dkv +\ No newline at end of file +diff --git a/megatron/core/transformer/tilelang_kernel/sparse_mla_fwd.py b/megatron/core/transformer/tilelang_kernel/sparse_mla_fwd.py +new file mode 100644 +index 000000000..e247038de +--- /dev/null ++++ b/megatron/core/transformer/tilelang_kernel/sparse_mla_fwd.py +@@ -0,0 +1,191 @@ ++# ruff: noqa ++import torch ++import tilelang ++from tilelang import language as T ++ ++ ++@tilelang.jit( ++ out_idx=[-2, -1], ++ pass_configs={ ++ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, ++ tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, ++ }, ++) ++def sparse_mla_fwd( ++ heads, ++ dim, ++ tail_dim, ++ topk, ++ kv_group=1, ++ sm_scale=None, ++ is_causal=True, ++ CP0=True, ++ block_I=64, ++ num_stages=2, ++ threads=256, ++): ++ assert dim == tilelang.math.next_power_of_2(dim), f"haven't check padding correctness yet, dim={dim}" ++ assert tail_dim == tilelang.math.next_power_of_2(tail_dim), f"haven't check padding correctness yet, dim={tail_dim}" ++ assert is_causal == True, "non-casual is not supported" ++ assert topk % block_I == 0, "otherwise will load some index=0 thus causing wrong kv to be loaded" ++ if sm_scale is None: ++ sm_scale = (1.0 / (dim + tail_dim)) ** 0.5 * 1.44269504 # log2(e) ++ else: ++ sm_scale = sm_scale * 1.44269504 # log2(e) ++ ++ batch = T.dynamic("batch") ++ seq_len = T.dynamic("seq_len") ++ seq_len_kv = T.dynamic("seq_len_kv") ++ ++ head_kv = heads // kv_group ++ q_shape = [batch, seq_len, heads, dim + tail_dim] ++ kv_shape = [batch, seq_len_kv, kv_group, dim + tail_dim] ++ o_shape = [batch, seq_len, heads, dim] ++ indices_shape = [batch, seq_len, kv_group, topk] ++ lse_shape = [batch, seq_len, heads] ++ masks_shape = [batch, seq_len, kv_group, seq_len_kv] ++ ++ masks_dtype = T.bool ++ indices_dtype = T.int32 ++ dtype = T.bfloat16 ++ accum_dtype = T.float32 ++ ++ G = kv_group ++ H = head_kv ++ padded_H = max(tilelang.math.next_power_of_2(head_kv), 16) ++ if padded_H != H: ++ assert kv_group == 1, ( ++ "here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)" ++ ) ++ BI = block_I ++ NI = tilelang.cdiv(topk, block_I) ++ D = dim ++ D_tail = tail_dim ++ ++ if head_kv > 64: ++ assert head_kv % 64 == 0, "head_kv should be a multiple of 64" ++ REPLICATE_H = head_kv // 64 ++ else: ++ REPLICATE_H = 1 ++ ++ H_per_block = padded_H if REPLICATE_H == 1 else 64 ++ ++ @T.prim_func ++ def main( ++ Q: T.Tensor(q_shape, dtype), # type: ignore ++ KV: T.Tensor(kv_shape, dtype), # type: ignore ++ Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore ++ Masks: T.Tensor(masks_shape, masks_dtype), # type: ignore ++ Output: T.Tensor(o_shape, dtype), # type: ignore ++ Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore ++ ): ++ with T.Kernel(seq_len * REPLICATE_H, batch, kv_group, threads=threads) as ( ++ bx, ++ by, ++ bz, ++ ): ++ Q_shared = T.alloc_shared([H_per_block, D], dtype) ++ Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype) ++ KV_shared = T.alloc_shared([BI, D], dtype) ++ K_tail_shared = T.alloc_shared([BI, D_tail], dtype) ++ O_shared = T.alloc_shared([H_per_block, D], dtype) ++ Lse_shared = T.alloc_shared([H_per_block], accum_dtype) ++ mask = T.alloc_fragment([BI], "bool") ++ ++ acc_o = T.alloc_fragment([H_per_block, D], accum_dtype) ++ acc_s = T.alloc_fragment([H_per_block, BI], accum_dtype) ++ S_shared = T.alloc_shared([H_per_block, BI], dtype) ++ sumexp = T.alloc_fragment([H_per_block], accum_dtype) ++ sumexp_i = T.alloc_fragment([H_per_block], accum_dtype) ++ alpha = T.alloc_fragment([H_per_block], accum_dtype) ++ m_i = T.alloc_fragment([H_per_block], accum_dtype) ++ m_i_prev = T.alloc_fragment([H_per_block], accum_dtype) ++ ++ T.fill(acc_o, 0) ++ T.fill(sumexp, 0) ++ T.fill(m_i, -(2**30)) # avoid -inf - inf to cause nan ++ ++ b_i, g_i = by, bz ++ s_i = bx if REPLICATE_H == 1 else (bx // REPLICATE_H) ++ q_i = s_i ++ ++ H0 = g_i * padded_H + (0 if REPLICATE_H == 1 else (bx % REPLICATE_H) * 64) ++ H1 = H0 + H_per_block ++ ++ T.copy(Q[b_i, s_i, H0:H1, :D], Q_shared) ++ T.copy(Q[b_i, s_i, H0:H1, D:], Q_tail_shared) ++ ++ for i_i in T.Pipelined(NI, num_stages=num_stages): ++ for bi_i in T.Parallel(BI): ++ mask[bi_i] = Masks[b_i, s_i, g_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i]] ++ ++ for bi_i, d_i in T.Parallel(BI, D): ++ KV_shared[bi_i, d_i] = KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, d_i] ++ for bi_i, d_i in T.Parallel(BI, D_tail): ++ K_tail_shared[bi_i, d_i] = KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, D + d_i] ++ for h_i, bi_i in T.Parallel(H_per_block, BI): ++ acc_s[h_i, bi_i] = T.if_then_else(mask[bi_i], -T.infinity(acc_s.dtype), 0) ++ T.gemm( ++ Q_shared, ++ KV_shared, ++ acc_s, ++ transpose_B=True, ++ policy=T.GemmWarpPolicy.FullRow, ++ ) ++ T.gemm( ++ Q_tail_shared, ++ K_tail_shared, ++ acc_s, ++ transpose_B=True, ++ policy=T.GemmWarpPolicy.FullRow, ++ ) ++ T.copy(m_i, m_i_prev) ++ T.reduce_max(acc_s, m_i, dim=1, clear=False) ++ for h_i in T.Parallel(H_per_block): ++ m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i]) ++ for h_i in T.Parallel(H_per_block): ++ alpha[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale) ++ for h_i, bi_i in T.Parallel(H_per_block, BI): ++ acc_s[h_i, bi_i] = T.exp2(acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale) ++ T.reduce_sum(acc_s, sumexp_i, dim=1) # is this a accumulate operator? ++ for h_i in T.Parallel(H_per_block): ++ sumexp[h_i] = sumexp[h_i] * alpha[h_i] + sumexp_i[h_i] ++ for h_i, d_i in T.Parallel(H_per_block, D): ++ acc_o[h_i, d_i] = acc_o[h_i, d_i] * alpha[h_i] ++ ++ T.copy(acc_s, S_shared) ++ T.gemm(S_shared, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) ++ ++ # Rescale ++ for h_i, d_i in T.Parallel(H_per_block, D): ++ acc_o[h_i, d_i] /= sumexp[h_i] ++ for h_i in T.Parallel(H_per_block): ++ sumexp[h_i] = T.log2(sumexp[h_i]) + m_i[h_i] * sm_scale ++ ++ T.copy(acc_o, O_shared) ++ T.copy(acc_o, Output[b_i, s_i, H0:H1, :]) ++ T.copy(sumexp, Lse_shared) ++ T.copy(sumexp, Lse[b_i, s_i, H0:H1]) ++ ++ return main ++ ++ ++def sparse_mla_fwd_interface(q, kv, indices, masks, d_v, sm_scale=None, return_p_sum: bool = False, block_I=64, num_stages=2, threads=256): ++ is_casual = True ++ assert return_p_sum == False, "This kernel file is for fwd only" ++ assert q.is_contiguous() and kv.is_contiguous() and indices.is_contiguous() ++ batch, seq_len, heads, dim_plus_tail_dim = q.shape ++ _, seq_len_kv, kv_group, _ = kv.shape ++ ++ assert kv.shape[-1] == dim_plus_tail_dim ++ tail_dim = dim_plus_tail_dim - d_v ++ assert kv.shape[0] == batch ++ _, _, _, topk = indices.shape ++ assert indices.shape == (batch, seq_len, kv_group, topk) ++ assert masks.shape == (batch, seq_len, kv_group, seq_len_kv) ++ ++ kernel = sparse_mla_fwd( ++ heads, d_v, tail_dim, topk, kv_group, sm_scale, is_casual, block_I=block_I, num_stages=num_stages, threads=threads ++ ) ++ out, lse = kernel(q, kv, indices, masks) ++ return out, lse +\ No newline at end of file +diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py +index e2705bd9f..29a0ff9e0 100644 +--- a/megatron/core/transformer/transformer_config.py ++++ b/megatron/core/transformer/transformer_config.py +@@ -935,11 +935,10 @@ class TransformerConfig(ModelParallelConfig): + f" but got {self.context_parallel_size=}." + ) + elif self.experimental_attention_variant == "dsa": +- assert ( +- self.context_parallel_size == 1 +- ), "Currently context parallelism is not supported by DSAttention!" ++ # assert ( ++ # self.context_parallel_size == 1 ++ # ), "Currently context parallelism is not supported by DSAttention!" + assert not self.apply_rope_fusion, "RoPE fusion is not supported for DSAttention" +- + if self.fp8: + # cannot support first last layer bf16 with delayed scaling + if self.first_last_layers_bf16 and self.fp8_recipe == Fp8Recipe.delayed: diff --git a/docker/deepseekv32/transformers.patch b/docker/deepseekv32/transformers.patch new file mode 100644 index 000000000..a7631aa00 --- /dev/null +++ b/docker/deepseekv32/transformers.patch @@ -0,0 +1,20 @@ +diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py +index f6a12e7cef..22129a86ee 100644 +--- a/src/transformers/models/auto/configuration_auto.py ++++ b/src/transformers/models/auto/configuration_auto.py +@@ -1355,6 +1355,15 @@ class AutoConfig: + "Detected mistral model with layer_types, treating as ministral for alternating attention compatibility. " + ) + config_dict["model_type"] = "ministral" ++ if config_dict["model_type"] == "deepseek_v32": ++ logger.info( ++ "Detected deepseek_v32 model, treating as deepseek_v3 for compatibility." ++ ) ++ config_dict["model_type"] = "deepseek_v3" ++ if "architectures" in config_dict: ++ config_dict["architectures"] = [ ++ arch.replace("DeepseekV32", "DeepseekV3") for arch in config_dict["architectures"] ++ ] + + try: + config_class = CONFIG_MAPPING[config_dict["model_type"]] diff --git a/docker/justfile b/docker/justfile index a064fb426..d8a512d3e 100644 --- a/docker/justfile +++ b/docker/justfile @@ -24,3 +24,17 @@ _release-raw: docker tag radixark/miles:$IMAGE_TAG radixark/miles:latest docker push radixark/miles:latest fi + +debug: + #!/bin/bash + set -euxo pipefail + cd .. + + VERSION="$(cat docker/version.txt | tr -d '\n')" + IMAGE_TAG=${VERSION} + + docker build -f docker/Dockerfile . --build-arg HTTP_PROXY="$http_proxy" --build-arg HTTPS_PROXY="$https_proxy" --build-arg NO_PROXY="localhost,127.0.0.1" -t radixark/miles-test:$IMAGE_TAG + docker push radixark/miles-test:$IMAGE_TAG + + docker tag radixark/miles-test:$IMAGE_TAG radixark/miles-test:latest + docker push radixark/miles-test:latest diff --git a/docker/patch/latest/megatron.patch b/docker/patch/latest/megatron.patch index 3a56ff4c2..8504c1885 100644 --- a/docker/patch/latest/megatron.patch +++ b/docker/patch/latest/megatron.patch @@ -12,38 +12,10 @@ index 41c21d93d..ef80f72d6 100644 err_msg = f'Common file {load_path} does not exist' if MultiStorageClientFeature.is_enabled(): diff --git a/megatron/core/dist_checkpointing/strategies/torch.py b/megatron/core/dist_checkpointing/strategies/torch.py -index ccf5242a2..9b6d3e31f 100644 +index 5a1ea308d..aa701237f 100644 --- a/megatron/core/dist_checkpointing/strategies/torch.py +++ b/megatron/core/dist_checkpointing/strategies/torch.py -@@ -427,6 +427,15 @@ def _restore_dict_types(x: Union[dict, list, Any], keys_template: Union[dict, li - _restore_dict_types(x_val, templ_val) - - -+@dataclass -+class MCoreMetadata(Metadata): -+ """Metadata with mcore specific data.""" -+ -+ # holds data related to flattened_range -+ # TODO: remove when flattened_range is properly removed -+ mcore_data: Optional[Dict[str, Dict[str, Any]]] = None # Mcore related data about each tensor -+ -+ - @dataclass(frozen=True) - class MCoreSavePlan(SavePlan): - """SavePlan with MCore specific data.""" -@@ -499,9 +508,10 @@ class MCoreSavePlanner(DefaultSavePlanner): - def create_global_plan(self, all_plans: List[MCoreSavePlan]) -> Tuple[List[SavePlan], Metadata]: - """Merges MCore data for all plans.""" - global_plan, metadata = super().create_global_plan(all_plans) -- metadata.mcore_data = dict( -+ mcore_data = dict( - ChainMap(*(plan.mcore_data for plan in all_plans)) # type: ignore[arg-type] - ) -+ metadata = MCoreMetadata(mcore_data=mcore_data, **vars(metadata)) - return global_plan, metadata - - def create_decentralized_global_plan(self, local_plan: SavePlan) -> SavePlan: -@@ -556,10 +566,12 @@ class MCoreLoadPlanner(DefaultLoadPlanner): +@@ -597,10 +597,12 @@ class MCoreLoadPlanner(DefaultLoadPlanner): def _validate_global_shapes(self, metadata, sharded_tensors): for sh_ten in sharded_tensors: if sh_ten.key not in metadata.state_dict_metadata: @@ -60,7 +32,7 @@ index ccf5242a2..9b6d3e31f 100644 loaded_shape = metadata.state_dict_metadata[sh_ten.key].size expected_shape = self._expected_shape(sh_ten) if loaded_shape != expected_shape: -@@ -589,7 +601,7 @@ class MCoreLoadPlanner(DefaultLoadPlanner): +@@ -630,7 +632,7 @@ class MCoreLoadPlanner(DefaultLoadPlanner): tensor_metadata = self.metadata.state_dict_metadata metadata_with_sizes = [ (tensor_metadata[key], tensor_metadata[key].size, sharded_tensor) @@ -69,7 +41,7 @@ index ccf5242a2..9b6d3e31f 100644 ] try: # Temporarily set sizes to expected shapes -@@ -918,6 +930,7 @@ class TorchDistLoadShardedStrategy(LoadShardedStrategy): +@@ -959,6 +961,7 @@ class TorchDistLoadShardedStrategy(LoadShardedStrategy): planner=MCoreLoadPlanner( shapes_validation_sharded_tensors=flexible_shape_sharded_tensors, allow_shape_mismatch_sharded_tensors=allow_shape_mismatch_sharded_tensors, @@ -77,31 +49,11 @@ index ccf5242a2..9b6d3e31f 100644 ), ) -diff --git a/megatron/core/distributed/__init__.py b/megatron/core/distributed/__init__.py -index fe26e8b43..4451f2776 100644 ---- a/megatron/core/distributed/__init__.py -+++ b/megatron/core/distributed/__init__.py -@@ -11,3 +11,15 @@ from .finalize_model_grads import finalize_model_grads - from .fsdp.mcore_fsdp_adapter import FullyShardedDataParallel - from .torch_fully_sharded_data_parallel import TorchFullyShardedDataParallel - from .torch_fully_sharded_data_parallel_config import TorchFullyShardedDataParallelConfig -+ -+# Backward compatibility patch for FSDP module reorganization -+import sys -+import importlib.util -+ -+spec = importlib.util.find_spec('megatron.core.distributed.fsdp.src.megatron_fsdp') -+if spec: -+ custom_fsdp = importlib.util.module_from_spec(spec) -+ spec.loader.exec_module(custom_fsdp) -+ sys.modules['megatron.core.distributed.custom_fsdp'] = custom_fsdp -+ if hasattr(custom_fsdp, 'MegatronFSDP'): -+ custom_fsdp.FullyShardedDataParallel = custom_fsdp.MegatronFSDP diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py -index 7727efe1e..966fe652a 100644 +index acb93ef78..d239db4ab 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py -@@ -366,6 +366,7 @@ class TELinear(te.pytorch.Linear): +@@ -408,6 +408,7 @@ class TELinear(te.pytorch.Linear): ) for param in self.parameters(): @@ -109,44 +61,336 @@ index 7727efe1e..966fe652a 100644 if is_expert: # Reduce the gradient on the expert_data_parallel group for expert linear layers setattr(param, "allreduce", not self.expert_parallel) +@@ -1161,6 +1162,61 @@ class TEDotProductAttention(te.pytorch.DotProductAttention): + + + if HAVE_TE and is_te_min_version("1.9.0.dev0"): ++ def ceil_div(x: int, y: int) -> int: ++ return (x + y - 1) // y ++ ++ class _FakeInt4QuantizationSTE(torch.autograd.Function): ++ @staticmethod ++ def forward(ctx, x, group_size): ++ m, n = x.shape ++ block_size_m, block_size_n = 1, group_size ++ ++ ++ m_padded = ceil_div(m, block_size_m) * block_size_m ++ n_padded = ceil_div(n, block_size_n) * block_size_n ++ ++ x_padded = torch.zeros( ++ (m_padded, n_padded), ++ dtype=x.dtype, device=x.device ++ ) ++ x_padded[:m, :n] = x ++ ++ x_view = x_padded.view( ++ m_padded // block_size_m, ++ block_size_m, ++ n_padded // block_size_n, ++ block_size_n ++ ) ++ ++ x_max = x_view.abs().float().amax(dim=(1, 3), keepdim=True) ++ q_max = 7 ++ x_scale = x_max / q_max ++ ++ x_scale = x_scale.clamp(min=1e-5) ++ ++ x_div = x_view / x_scale ++ x_round = torch.round(x_div) ++ ++ x_q_clamped = x_round.clamp(-q_max, q_max) ++ ++ x_dequant_view = x_q_clamped * x_scale ++ ++ x_dequant_full = x_dequant_view.view_as(x_padded) ++ x_out = x_dequant_full[:m, :n].contiguous().to(x.dtype) ++ ++ return x_out ++ ++ @staticmethod ++ def backward(ctx, grad_output): ++ return grad_output, None ++ ++ def fake_int4_quantization_ste(x, group_size): ++ x_out = _FakeInt4QuantizationSTE.apply(x, group_size) ++ ++ if hasattr(x, 'main_grad'): ++ x_out.main_grad = x.main_grad ++ ++ return x_out + + class TEGroupedLinear(te.pytorch.GroupedLinear): + """ +@@ -1351,6 +1407,7 @@ if HAVE_TE and is_te_min_version("1.9.0.dev0"): + _is_first_microbatch = ( + None if self.disable_parameter_transpose_cache else self.is_first_microbatch + ) ++ + out = super().forward(x, m_splits, is_first_microbatch=_is_first_microbatch) + self.is_first_microbatch = False + +@@ -1361,6 +1418,20 @@ if HAVE_TE and is_te_min_version("1.9.0.dev0"): + return out + return out, None + ++ def _get_weight_tensors(self): ++ """Get the weight tensors of the module.""" ++ weight_tensors = super()._get_weight_tensors() ++ ++ if os.getenv("OPEN_TRAINING_INT4_FAKE_QAT_FLAG", "0") == "1": ++ group_size = int(os.getenv("OPEN_TRAINING_INT4_GROUP_SIZE", "128")) ++ ++ weight_tensors = [ ++ fake_int4_quantization_ste(w, group_size) ++ for w in weight_tensors ++ ] ++ ++ return weight_tensors ++ + def _encode_extra_state(self, state): + # TE 2.0 changed the format of extra_state to be a byte tensor + if is_te_min_version("2.0.0"): +diff --git a/megatron/core/fusions/fused_mla_yarn_rope_apply.py b/megatron/core/fusions/fused_mla_yarn_rope_apply.py +index 1fd5dcfae..c9aeef1f0 100644 +--- a/megatron/core/fusions/fused_mla_yarn_rope_apply.py ++++ b/megatron/core/fusions/fused_mla_yarn_rope_apply.py +@@ -385,6 +385,7 @@ def rotary_fwd_kv_kernel( + SIN, + emb_dim: tl.constexpr, + k_dim: tl.constexpr, ++ k_dim_ceil: tl.constexpr, + v_dim: tl.constexpr, + head_num: tl.constexpr, + batch_size, +@@ -434,21 +435,27 @@ def rotary_fwd_kv_kernel( + cos_right = tl.load(COS + token_idx * emb_dim + emb_dim // 2 + tl.arange(0, emb_dim // 2)) + sin_right = tl.load(SIN + token_idx * emb_dim + emb_dim // 2 + tl.arange(0, emb_dim // 2)) + +- KV_ptr = KV + pid_m * stride_kv_seq + pid_head * BLOCK_H * stride_kv_nheads +- kv_off = tl.arange(0, BLOCK_H)[:, None] * stride_kv_nheads +- mask = kv_off < head_num * stride_kv_nheads +- k_in_off = kv_off + tl.arange(0, k_dim)[None, :] +- v_in_off = kv_off + k_dim + tl.arange(0, v_dim)[None, :] +- k = tl.load(KV_ptr + k_in_off, mask=mask) +- v = tl.load(KV_ptr + v_in_off, mask=mask) ++ KV_ptr = KV + pid_m * stride_kv_seq # + pid_head * BLOCK_H * stride_kv_nheads ++ ki_range = tl.arange(0, BLOCK_H)[:, None] + pid_head * BLOCK_H ++ kj_range = tl.arange(0, k_dim_ceil)[None, :] ++ mask_k = (ki_range < head_num) & (kj_range < k_dim) ++ mask_v = ki_range < head_num ++ k_off = ki_range * stride_kv_nheads + kj_range ++ if v_dim > 0: ++ v_off = ki_range * stride_kv_nheads + k_dim + tl.arange(0, v_dim)[None, :] ++ v = tl.load(KV_ptr + v_off, mask=mask_v) ++ else: ++ v = tl.zeros((BLOCK_H, 1), dtype=KV.dtype.element_ty) ++ k = tl.load(KV_ptr + k_off, mask=mask_k) + +- K_ptr = O_KEY + pid_m * stride_k_seq + pid_head * BLOCK_H * stride_k_nheads +- V_ptr = O_VALUE + pid_m * stride_v_seq + pid_head * BLOCK_H * stride_v_nheads ++ K_ptr = O_KEY + pid_m * stride_k_seq # + pid_head * BLOCK_H * stride_k_nheads ++ V_ptr = O_VALUE + pid_m * stride_v_seq # + pid_head * BLOCK_H * stride_v_nheads + +- k_out_off = tl.arange(0, BLOCK_H)[:, None] * stride_k_nheads + tl.arange(0, k_dim)[None, :] +- v_out_off = tl.arange(0, BLOCK_H)[:, None] * stride_v_nheads + tl.arange(0, v_dim)[None, :] +- tl.store(K_ptr + k_out_off, k, mask=mask) +- tl.store(V_ptr + v_out_off, v, mask=mask) ++ k_out_off = ki_range * stride_k_nheads + kj_range ++ tl.store(K_ptr + k_out_off, k, mask=mask_k) ++ if v_dim > 0: ++ v_out_off = ki_range * stride_v_nheads + tl.arange(0, v_dim)[None, :] ++ tl.store(V_ptr + v_out_off, v, mask=mask_v) + + EMB = K_POS_EMB + pid_m * stride_emb_seq + # x1 = t[..., 0::2], x2 = t[..., 1::2] +@@ -460,14 +467,16 @@ def rotary_fwd_kv_kernel( + x_left = x_left.expand_dims(0).broadcast_to(BLOCK_H, emb_dim // 2) + x_right = x_right.expand_dims(0).broadcast_to(BLOCK_H, emb_dim // 2) + ++ x_range = tl.arange(0, BLOCK_H)[:, None] + pid_head * BLOCK_H ++ mask_x = x_range < head_num + x_left_off = ( +- tl.arange(0, BLOCK_H)[:, None] * stride_k_nheads ++ x_range * stride_k_nheads + + k_dim + + tl.arange(0, emb_dim // 2)[None, :] + ) + x_right_off = x_left_off + emb_dim // 2 +- tl.store(K_ptr + x_left_off, x_left, mask=mask) +- tl.store(K_ptr + x_right_off, x_right, mask=mask) ++ tl.store(K_ptr + x_left_off, x_left, mask=mask_x) ++ tl.store(K_ptr + x_right_off, x_right, mask=mask_x) + + + @triton.autotune( +@@ -493,6 +502,7 @@ def rotary_bwd_kv_kernel( + SIN, + emb_dim: tl.constexpr, + k_dim: tl.constexpr, ++ k_dim_ceil: tl.constexpr, + v_dim: tl.constexpr, + head_num: tl.constexpr, + batch_size, +@@ -533,27 +543,32 @@ def rotary_bwd_kv_kernel( + else: + token_idx = _get_thd_token_idx(cu_seqlens_kv, pid_m, seq_num, cp_rank, cp_size) + +- dKV_ptr = dKV + pid_m * stride_dkv_seq + pid_head * BLOCK_H * stride_dkv_nheads +- dkv_off = tl.arange(0, BLOCK_H)[:, None] * stride_dkv_nheads +- mask = dkv_off < head_num * stride_dkv_nheads +- dk_out_off = dkv_off + tl.arange(0, k_dim)[None, :] +- dv_out_off = dkv_off + k_dim + tl.arange(0, v_dim)[None, :] +- +- dK_ptr = dK + pid_m * stride_dk_seq + pid_head * BLOCK_H * stride_dk_nheads +- dV_ptr = dV + pid_m * stride_dv_seq + pid_head * BLOCK_H * stride_dv_nheads +- dk_in_off = tl.arange(0, BLOCK_H)[:, None] * stride_dk_nheads + tl.arange(0, k_dim)[None, :] +- dv_in_off = tl.arange(0, BLOCK_H)[:, None] * stride_dv_nheads + tl.arange(0, v_dim)[None, :] +- dk = tl.load(dK_ptr + dk_in_off, mask=mask) +- dv = tl.load(dV_ptr + dv_in_off, mask=mask) +- tl.store(dKV_ptr + dk_out_off, dk, mask=mask) +- tl.store(dKV_ptr + dv_out_off, dv, mask=mask) ++ dKV_ptr = dKV + pid_m * stride_dkv_seq # + pid_head * BLOCK_H * stride_dkv_nheads ++ ki_range = tl.arange(0, BLOCK_H)[:, None] + pid_head * BLOCK_H ++ kj_range = tl.arange(0, k_dim_ceil)[None, :] ++ mask_k = (ki_range < head_num) & (kj_range < k_dim) ++ mask_v = ki_range < head_num ++ dk_out_off = ki_range * stride_dkv_nheads + kj_range ++ ++ dK_ptr = dK + pid_m * stride_dk_seq # + pid_head * BLOCK_H * stride_dk_nheads ++ dV_ptr = dV + pid_m * stride_dv_seq # + pid_head * BLOCK_H * stride_dv_nheads ++ dk_in_off = ki_range * stride_dk_nheads + kj_range ++ ++ dk = tl.load(dK_ptr + dk_in_off, mask=mask_k) ++ tl.store(dKV_ptr + dk_out_off, dk, mask=mask_k) ++ ++ if v_dim > 0: ++ dv_out_off = ki_range * stride_dkv_nheads + k_dim + tl.arange(0, v_dim)[None, :] ++ dv_in_off = ki_range * stride_dv_nheads + tl.arange(0, v_dim)[None, :] ++ dv = tl.load(dV_ptr + dv_in_off, mask=mask_v) ++ tl.store(dKV_ptr + dv_out_off, dv, mask=mask_v) + + if pid_head == 0: + x_left_accum = tl.zeros((BLOCK_H, emb_dim // 2), dtype=tl.float32) + x_right_accum = tl.zeros((BLOCK_H, emb_dim // 2), dtype=tl.float32) + for i in tl.static_range(triton.cdiv(head_num, BLOCK_H)): +- dK_ptr = dK + pid_m * stride_dk_seq + i * BLOCK_H * stride_dk_nheads +- x_off = tl.arange(0, BLOCK_H)[:, None] * stride_dk_nheads + k_dim ++ dK_ptr = dK + pid_m * stride_dk_seq # + i * BLOCK_H * stride_dk_nheads ++ x_off = tl.arange(0, BLOCK_H)[:, None] * stride_dk_nheads + k_dim + i * BLOCK_H * stride_dk_nheads + mask = x_off < head_num * stride_dk_nheads + x_left_off = x_off + tl.arange(0, emb_dim // 2)[None, :] + x_right_off = x_left_off + emb_dim // 2 +@@ -632,6 +647,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function): + + o_key = kv.new_empty(total_seqlen, nheads, emb_dim + k_dim) + o_value = kv.new_empty(total_seqlen, nheads, v_dim) ++ k_dim_ceil = triton.next_power_of_2(k_dim) + + grid = lambda META: (total_seqlen, triton.cdiv(nheads, META["BLOCK_H"])) + rotary_fwd_kv_kernel[grid]( +@@ -643,6 +659,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function): + sin, + emb_dim, + k_dim, ++ k_dim_ceil, + v_dim, + nheads, + batch_size, +@@ -700,6 +717,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function): + + d_kv = dk.new_empty(total_seqlen, nheads, ctx.k_dim + ctx.v_dim) + d_emb = dk.new_empty(total_seqlen, 1, ctx.emb_dim) ++ k_dim_ceil = triton.next_power_of_2(ctx.k_dim) + + grid = lambda META: (total_seqlen, triton.cdiv(nheads, META["BLOCK_H"])) + rotary_bwd_kv_kernel[grid]( +@@ -711,6 +729,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function): + sin, + ctx.emb_dim, + ctx.k_dim, ++ k_dim_ceil, + ctx.v_dim, + nheads, + batch_size, +diff --git a/megatron/core/models/common/language_module/language_module.py b/megatron/core/models/common/language_module/language_module.py +index 13d74aa52..060898a7a 100644 +--- a/megatron/core/models/common/language_module/language_module.py ++++ b/megatron/core/models/common/language_module/language_module.py +@@ -184,7 +184,15 @@ class LanguageModule(MegatronModule): + assert ( + column_parallel_linear is not None + ), "column_parallel_linear cannot be None when not using fused linear cross entropy." +- logits, _ = column_parallel_linear(hidden, **col_linear_kwargs) ++ # output ++ output_layer_params = {k: v.detach() for k, v in column_parallel_linear.named_parameters()} ++ output_layer_buffers = dict(column_parallel_linear.named_buffers()) ++ logits, _ = torch.func.functional_call( ++ column_parallel_linear, ++ {**output_layer_params, **output_layer_buffers}, ++ (hidden,), ++ col_linear_kwargs, ++ ) + + return self.compute_language_model_loss(labels, logits) + diff --git a/megatron/core/models/gpt/gpt_layer_specs.py b/megatron/core/models/gpt/gpt_layer_specs.py -index 860ee64a9..80944b702 100755 +index e21127b87..712793853 100755 --- a/megatron/core/models/gpt/gpt_layer_specs.py +++ b/megatron/core/models/gpt/gpt_layer_specs.py -@@ -79,6 +79,8 @@ def get_gpt_layer_with_transformer_engine_spec( - qk_l2_norm: Optional[bool] = False, - use_te_op_fuser: Optional[bool] = False, +@@ -188,6 +188,8 @@ def get_gpt_layer_with_transformer_engine_spec( use_kitchen: bool = False, + use_te_activation_func: bool = False, + fallback_to_eager_attn: bool = False, + post_self_attn_layernorm: bool = False, + post_mlp_layernorm: bool = False, ) -> ModuleSpec: """Use this spec to use lower-level Transformer Engine modules (required for fp8 training). -@@ -178,9 +180,11 @@ def get_gpt_layer_with_transformer_engine_spec( - ), - ), - self_attn_bda=get_bias_dropout_add, -+ post_self_attn_layernorm=TENorm if post_self_attn_layernorm else IdentityOp, - pre_mlp_layernorm=backend.layer_norm() if num_experts else IdentityOp, - mlp=mlp, - mlp_bda=get_bias_dropout_add, -+ post_mlp_layernorm=TENorm if post_mlp_layernorm else IdentityOp, - sharded_state_dict_keys_map={ - "mlp.0.weight": "mlp.linear_fc1.layer_norm_weight", - "mlp.0.bias": "mlp.linear_fc1.layer_norm_bias", +@@ -260,6 +262,8 @@ def get_gpt_layer_with_transformer_engine_spec( + mlp=mlp, + sharded_state_dict_keys_map=sharded_state_dict_keys_map, + normalization=normalization, ++ post_self_attn_layernorm=post_self_attn_layernorm, ++ post_mlp_layernorm=post_mlp_layernorm, + ) + + +@@ -349,6 +353,8 @@ def get_transformer_layer_spec_for_backend( + mlp: ModuleSpec, + sharded_state_dict_keys_map: Optional[dict] = None, + normalization: Optional[str] = None, ++ post_self_attn_layernorm: bool = False, ++ post_mlp_layernorm: bool = False, + ) -> ModuleSpec: + """Helper function to get module spec for TransformerLayer""" + +@@ -371,9 +377,11 @@ def get_transformer_layer_spec_for_backend( + input_layernorm=input_layernorm, + self_attention=attention, + self_attn_bda=get_bias_dropout_add, ++ post_self_attn_layernorm=TENorm if post_self_attn_layernorm else IdentityOp, + pre_mlp_layernorm=pre_mlp_layernorm, + mlp=mlp, + mlp_bda=get_bias_dropout_add, ++ post_mlp_layernorm=TENorm if post_mlp_layernorm else IdentityOp, + sharded_state_dict_keys_map=sharded_state_dict_keys_map, + ), + ) diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py -index 6aec66e6d..6ca48b55f 100644 +index a1230568c..1fd52f65a 100644 --- a/megatron/core/models/gpt/gpt_model.py +++ b/megatron/core/models/gpt/gpt_model.py -@@ -355,6 +355,7 @@ class GPTModel(LanguageModule): +@@ -446,6 +446,7 @@ class GPTModel(LanguageModule): *, inference_params: Optional[BaseInferenceContext] = None, loss_mask: Optional[Tensor] = None, + mtp_kwargs: Optional[dict] = {}, ) -> Tensor: """Forward function of the GPT Model This function passes the input tensors - through the embedding layer, and then the decoeder and finally into the post -@@ -410,6 +411,7 @@ class GPTModel(LanguageModule): + through the embedding layer, and then the decoder and finally into the post +@@ -508,6 +509,7 @@ class GPTModel(LanguageModule): runtime_gather_output=runtime_gather_output, extra_block_kwargs=extra_block_kwargs, inference_context=inference_context, @@ -154,7 +398,7 @@ index 6aec66e6d..6ca48b55f 100644 ) def _postprocess( -@@ -431,6 +433,7 @@ class GPTModel(LanguageModule): +@@ -529,6 +531,7 @@ class GPTModel(LanguageModule): runtime_gather_output=None, extra_block_kwargs=None, inference_context=None, @@ -162,22 +406,23 @@ index 6aec66e6d..6ca48b55f 100644 ): """Postprocesses decoder hidden states to generate logits or compute loss. -@@ -446,7 +449,7 @@ class GPTModel(LanguageModule): +@@ -543,7 +546,8 @@ class GPTModel(LanguageModule): + output_weight = None if self.share_embeddings_and_output_weights: output_weight = self.shared_embedding_or_output_weight() - - if mtp_in_postprocess: ++ + if mtp_in_postprocess and mtp_kwargs.get('mtp_labels', None) is not None: hidden_states = self.mtp( input_ids=input_ids, position_ids=position_ids, -@@ -465,25 +468,37 @@ class GPTModel(LanguageModule): - if not self.post_process: +@@ -563,13 +567,18 @@ class GPTModel(LanguageModule): return hidden_states -- if self.mtp_process: + # Skip when mtp_num_layers is None or 0 +- if self.config.mtp_num_layers: - mtp_labels = labels.clone() -+ if self.mtp_process and mtp_kwargs.get('mtp_labels', None) is not None: ++ if self.config.mtp_num_layers and mtp_kwargs.get('mtp_labels', None) is not None: + mtp_labels = mtp_kwargs['mtp_labels'].clone() + mtp_labels, _ = roll_tensor(mtp_labels, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params) + @@ -190,39 +435,22 @@ index 6aec66e6d..6ca48b55f 100644 + # Otherwise, roll the loss_mask to keep up with the mtp_labels + loss_mask, _ = roll_tensor(loss_mask, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params) for mtp_layer_number in range(self.config.mtp_num_layers): - # output -- mtp_logits, _ = self.output_layer( -- hidden_states_list[mtp_layer_number + 1], -- weight=output_weight, -- runtime_gather_output=runtime_gather_output, -+ output_layer_params = {k: v.detach() for k, v in self.output_layer.named_parameters()} -+ output_layer_buffers = dict(self.output_layer.named_buffers()) -+ mtp_logits, _ = torch.func.functional_call( -+ self.output_layer, -+ {**output_layer_params, **output_layer_buffers}, -+ (hidden_states_list[mtp_layer_number + 1],), -+ { -+ "weight": output_weight.detach() if output_weight else None, -+ "runtime_gather_output": runtime_gather_output, -+ }, - ) # Calc loss for the current Multi-Token Prediction (MTP) layers. -- mtp_labels, _ = roll_tensor(mtp_labels, shifts=-1, dims=-1, cp_group=self.cp_group) -- loss_mask, num_tokens = roll_tensor( -- loss_mask, shifts=-1, dims=-1, cp_group=self.cp_group -+ mtp_labels, _ = roll_tensor(mtp_labels, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params) -+ new_loss_mask, num_tokens = roll_tensor( -+ loss_mask, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params + mtp_labels, _ = roll_tensor( +@@ -595,7 +604,7 @@ class GPTModel(LanguageModule): + sequence_parallel_enabled=self.output_layer.sequence_parallel, + column_parallel_linear=self.output_layer, + col_linear_kwargs={ +- 'weight': output_weight, ++ 'weight': output_weight.detach() if output_weight else None, + 'runtime_gather_output': runtime_gather_output, + }, ) -+ loss_mask = new_loss_mask * loss_mask - mtp_loss = self.compute_language_model_loss(mtp_labels, mtp_logits) - mtp_loss = loss_mask * mtp_loss - if self.training: diff --git a/megatron/core/optimizer/distrib_optimizer.py b/megatron/core/optimizer/distrib_optimizer.py -index a36b67364..ed8883e32 100644 +index 6e093f96f..eac21a3ea 100644 --- a/megatron/core/optimizer/distrib_optimizer.py +++ b/megatron/core/optimizer/distrib_optimizer.py -@@ -657,6 +657,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer): +@@ -677,6 +677,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer): # TE FusedAdam will not accumulate step for empty param groups, so we need to # align the step across param groups. param_group["step"] = int(step) @@ -231,11 +459,20 @@ index a36b67364..ed8883e32 100644 # Grad scaler state. if self.grad_scaler: +@@ -1646,6 +1648,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer): + if key == 'padding': + tensors[key] = LocalNonpersistentObject(tensors[key]) + continue ++ if key == 'step': ++ continue + assert tensors[key].shape == (gbuf_local_end - gbuf_local_start,), ( + tensors[key].shape, + gbuf_local_start, diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py -index a40c85a88..86688c331 100644 +index a273002b9..4f821cfd5 100644 --- a/megatron/core/parallel_state.py +++ b/megatron/core/parallel_state.py -@@ -9,6 +9,7 @@ from typing import Callable, List, Optional +@@ -11,6 +11,7 @@ from typing import Callable, List, Optional import numpy as np import torch @@ -244,7 +481,7 @@ index a40c85a88..86688c331 100644 from .utils import GlobalMemoryBuffer, is_torch_min_version diff --git a/megatron/core/pipeline_parallel/p2p_communication.py b/megatron/core/pipeline_parallel/p2p_communication.py -index 63ee9d1f5..b90b744c1 100644 +index ac839c21f..f18309217 100644 --- a/megatron/core/pipeline_parallel/p2p_communication.py +++ b/megatron/core/pipeline_parallel/p2p_communication.py @@ -26,22 +26,22 @@ def _batched_p2p_ops( @@ -274,153 +511,11 @@ index 63ee9d1f5..b90b744c1 100644 ) ops.append(recv_next_op) if len(ops) > 0: -diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py -index c749bac43..dde8d50e7 100644 ---- a/megatron/core/transformer/attention.py -+++ b/megatron/core/transformer/attention.py -@@ -670,7 +670,10 @@ class Attention(MegatronModule, ABC): - # Get the query, key and value tensors based on the type of attention - - # self or cross attn. - nvtx_range_push(suffix="qkv") -- query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states) -+ if self.config.use_gated_attention: -+ query, gate, key, value = self.get_query_gate_key_value_tensors(hidden_states, key_value_states) -+ else: -+ query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states) - nvtx_range_pop(suffix="qkv") - - # =================================================== -@@ -842,6 +845,11 @@ class Attention(MegatronModule, ABC): - # Output. [sq, b, h] - # ================= - -+ if self.config.use_gated_attention: -+ nvtx_range_push(suffix="sigmoid_gate") -+ core_attn_out = core_attn_out * torch.sigmoid(gate) -+ nvtx_range_pop(suffix="sigmoid_gate") -+ - nvtx_range_push(suffix="linear_proj") - output, bias = self.linear_proj(core_attn_out) - nvtx_range_pop(suffix="linear_proj") -@@ -879,19 +887,34 @@ class SelfAttention(Attention): - model_comm_pgs=model_comm_pgs, - ) - -- self.linear_qkv = build_module( -- submodules.linear_qkv, -- self.config.hidden_size, -- self.query_projection_size + 2 * self.kv_projection_size, -- config=self.config, -- init_method=self.config.init_method, -- gather_output=False, -- bias=self.config.add_bias_linear or self.config.add_qkv_bias, -- skip_bias_add=False, -- is_expert=False, -- tp_comm_buffer_name='qkv', -- tp_group=self.model_comm_pgs.tp, -- ) -+ if self.config.use_gated_attention: -+ self.linear_qgkv = build_module( -+ submodules.linear_qkv, -+ self.config.hidden_size, -+ 2 * (self.query_projection_size + self.kv_projection_size), -+ config=self.config, -+ init_method=self.config.init_method, -+ gather_output=False, -+ bias=self.config.add_bias_linear or self.config.add_qkv_bias, -+ skip_bias_add=False, -+ is_expert=False, -+ tp_comm_buffer_name='qkv', -+ tp_group=self.model_comm_pgs.tp, -+ ) -+ else: -+ self.linear_qkv = build_module( -+ submodules.linear_qkv, -+ self.config.hidden_size, -+ self.query_projection_size + 2 * self.kv_projection_size, -+ config=self.config, -+ init_method=self.config.init_method, -+ gather_output=False, -+ bias=self.config.add_bias_linear or self.config.add_qkv_bias, -+ skip_bias_add=False, -+ is_expert=False, -+ tp_comm_buffer_name='qkv', -+ tp_group=self.model_comm_pgs.tp, -+ ) - - if submodules.q_layernorm is not None: - self.q_layernorm = build_module( -@@ -1036,6 +1059,65 @@ class SelfAttention(Attention): - - return query, key, value - -+ # adapt from https://github.com/alibaba/Pai-Megatron-Patch/blob/8e6cbb0556ba09933ab4a4edb23c0af1d19d9960/megatron_patch/model/qwen3_next/gated_attention.py#L192 -+ def get_query_gate_key_value_tensors(self, hidden_states, key_value_states=None): -+ """ -+ Derives `query`, `key` and `value` tensors from `hidden_states`. -+ """ -+ # Attention heads [sq, b, h] --> [sq, b, ng * 2 * (np/ng + 1) * hn)] -+ mixed_qgkv, _ = self.linear_qgkv(hidden_states) -+ -+ # [sq, b, hp] --> [sq, b, ng, 2 * (np/ng + 1) * hn] -+ new_tensor_shape = mixed_qgkv.size()[:-1] + ( -+ self.num_query_groups_per_partition, -+ ( -+ 2 * (self.num_attention_heads_per_partition // self.num_query_groups_per_partition + 1) -+ * self.hidden_size_per_attention_head -+ ), -+ ) -+ mixed_qgkv = mixed_qgkv.view(*new_tensor_shape) -+ -+ split_arg_list = [ -+ ( -+ self.num_attention_heads_per_partition -+ // self.num_query_groups_per_partition -+ * self.hidden_size_per_attention_head -+ ), -+ ( -+ self.num_attention_heads_per_partition -+ // self.num_query_groups_per_partition -+ * self.hidden_size_per_attention_head -+ ), -+ self.hidden_size_per_attention_head, -+ self.hidden_size_per_attention_head, -+ ] -+ -+ if SplitAlongDim is not None: -+ -+ # [sq, b, ng, (np/ng + 2) * hn] -+ # --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn] -+ (query, gate, key, value) = SplitAlongDim(mixed_qgkv, 3, split_arg_list) -+ else: -+ -+ # [sq, b, ng, (np/ng + 2) * hn] -+ # --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn] -+ (query, gate, key, value) = torch.split(mixed_qgkv, split_arg_list, dim=3) -+ -+ # [sq, b, ng, np/ng * hn] -> [sq, b, np, hn] -+ query = query.reshape(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head) -+ gate = gate.reshape(query.size(0), query.size(1), -1) -+ -+ if self.q_layernorm is not None: -+ query = self.q_layernorm(query) -+ -+ if self.k_layernorm is not None: -+ key = self.k_layernorm(key) -+ -+ if self.config.test_mode: -+ self.run_realtime_tests() -+ -+ return query, gate, key, value -+ - def backward_dw(self) -> NoReturn: - """Execute weight update operations""" - self._backward_qkv_proj() diff --git a/megatron/core/transformer/moe/moe_utils.py b/megatron/core/transformer/moe/moe_utils.py -index 235b6f6af..fbcffe278 100644 +index 28cff06f5..58dc4bb70 100644 --- a/megatron/core/transformer/moe/moe_utils.py +++ b/megatron/core/transformer/moe/moe_utils.py -@@ -566,6 +566,9 @@ def topk_routing_with_score_function( +@@ -587,6 +587,9 @@ def topk_routing_with_score_function( else: return torch.topk(scores, k=topk, dim=1) @@ -431,12 +526,12 @@ index 235b6f6af..fbcffe278 100644 if use_pre_softmax: scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits) diff --git a/megatron/core/transformer/moe/router.py b/megatron/core/transformer/moe/router.py -index 6b20b8622..459e65921 100644 +index 16fc9d9af..517944f25 100644 --- a/megatron/core/transformer/moe/router.py +++ b/megatron/core/transformer/moe/router.py -@@ -156,6 +156,9 @@ class TopKRouter(Router): - self.local_tokens_per_expert = None - self.expert_bias = None +@@ -201,6 +201,9 @@ class TopKRouter(Router): + self.global_tokens_per_expert = None + self.ga_steps = None + from miles.utils.routing_replay import register_routing_replay + register_routing_replay(self) @@ -445,7 +540,7 @@ index 6b20b8622..459e65921 100644 """ Maintain the expert bias in float32. diff --git a/megatron/core/transformer/multi_token_prediction.py b/megatron/core/transformer/multi_token_prediction.py -index b7884e18e..f0104f861 100755 +index a8f4abfcd..f33f6f05e 100755 --- a/megatron/core/transformer/multi_token_prediction.py +++ b/megatron/core/transformer/multi_token_prediction.py @@ -6,6 +6,7 @@ from typing import Callable, List, Optional, Union @@ -454,186 +549,27 @@ index b7884e18e..f0104f861 100755 from torch import Tensor +import warnings - from megatron.core import InferenceParams, mpu, parallel_state, tensor_parallel + from megatron.core import InferenceParams, parallel_state, tensor_parallel from megatron.core.dist_checkpointing.mapping import ShardedStateDict -@@ -105,17 +106,21 @@ def tie_output_layer_state_dict( - ) - - --def roll_tensor(tensor, shifts=-1, dims=-1, cp_group=None): -- """Roll the tensor input along the sequence dimension with Context Parallelism (CP) support. - -- This function extends the original roll_tensor to support Context Parallelism, which allows -- MTP to work with CP > 1. When CP is enabled, the sequence dimension is split across CP ranks, -- and tensor rolling requires communication between adjacent CP ranks to properly handle the -- boundary conditions. -+def roll_tensor(tensor, shifts=-1, dims=-1, cp_group=None, packed_seq_params=None): -+ """Roll the tensor input along the sequence dimension with Context Parallelism (CP) and Packed Sequence support. -+ -+ This function extends the original roll_tensor to support Context Parallelism and Packed Sequences. -+ When CP is enabled, the sequence dimension is split across CP ranks, and tensor rolling requires -+ communication between adjacent CP ranks to properly handle the boundary conditions. -+ When packed sequences are used, rolling is performed within each individual sequence boundary -+ to prevent mixing tokens between different packed sequences. - - For CP=1 (default behavior): Uses standard torch.roll with zero padding - For CP>1: Splits tensor into chunks, performs rolling within each chunk, then exchanges - boundary elements between adjacent CP ranks to maintain sequence continuity. -+ For packed sequences: Rolls tensors within sequence boundaries defined by cu_seqlens. -+ - - Args: - tensor (Tensor): The input tensor to roll. -@@ -123,9 +128,15 @@ def roll_tensor(tensor, shifts=-1, dims=-1, cp_group=None): - dims (int): The dimension to roll (typically -1 for sequence dimension). - cp_group (ProcessGroup): The context parallelism process group. If None or size=1, - falls back to standard rolling behavior. -+ packed_seq_params (PackedSeqParams): Parameters for packed sequence processing. -+ If provided, rolling respects sequence boundaries. - Returns: - tuple: (rolled_tensor, sum_of_rolled_tensor) - """ -+ -+ if packed_seq_params is not None: -+ return _roll_tensor_packed_seq(tensor, shifts, dims, packed_seq_params, cp_group) -+ - # Standard rolling behavior when CP is not enabled (cp_group is None or size=1) - if cp_group is None or cp_group.size() == 1: - rolled_tensor = torch.roll(tensor, shifts=shifts, dims=dims) -@@ -193,6 +204,103 @@ def roll_tensor(tensor, shifts=-1, dims=-1, cp_group=None): - - return rolled_tensor, rolled_tensor.sum() - -+def _roll_tensor_packed_seq(tensor, shifts, dims, packed_seq_params, cp_group=None): -+ """Roll tensor with packed sequence support. -+ -+ This function handles rolling for packed sequences by respecting sequence boundaries -+ defined in packed_seq_params.cu_seqlens. Rolling is performed within each individual -+ sequence to prevent mixing tokens between different packed sequences. When Context -+ Parallelism (CP) is enabled, each CP rank still receives the full `cu_seqlens` metadata -+ so we slice out the portion of every packed sequence that lives on the current rank and -+ reuse the standard CP boundary exchange to populate the rolling window. -+ -+ Args: -+ tensor (Tensor): The input tensor to roll. -+ shifts (int): The shift of the tensor (typically -1 for MTP). -+ dims (int): The dimension to roll (typically -1 for sequence dimension). -+ packed_seq_params (PackedSeqParams): Parameters for packed sequence processing. -+ cp_group (ProcessGroup): The context parallelism process group. -+ -+ Returns: -+ tuple: (rolled_tensor, sum_of_rolled_tensor) -+ """ -+ -+ # Notice: This is a naive implementation to test the correctness, a better solution will only sync the boundary tokens once. -+ assert dims == -1 or dims == tensor.dim() - 1, "Packed sequence roll only supports the last dimension." -+ assert shifts == -1, "Packed sequence roll only supports a single-token left shift." -+ cu_seqlens = packed_seq_params.cu_seqlens_q -+ assert cu_seqlens is not None, "Packed sequence parameters must provide cu_seqlens_q." -+ -+ rolled_tensor = tensor.clone() -+ -+ cp_size = cp_group.size() if cp_group is not None else 1 -+ if cp_size == 1: -+ # CP disabled: simply roll inside each packed sequence boundary. -+ for i in range(len(cu_seqlens) - 1): -+ start_idx = cu_seqlens[i] -+ end_idx = cu_seqlens[i + 1] -+ seq_slice = tensor[..., start_idx:end_idx] -+ rolled_seq = torch.roll(seq_slice, shifts=shifts, dims=dims) -+ rolled_seq[..., shifts:] = 0 -+ rolled_tensor[..., start_idx:end_idx] = rolled_seq -+ return rolled_tensor, rolled_tensor.sum() -+ -+ # CP enabled: each rank owns two chunks per sequence (front and mirrored tail). -+ local_rank = torch.distributed.get_rank(group=cp_group) -+ global_ranks = torch.distributed.get_process_group_ranks(group=cp_group) -+ next_rank = global_ranks[(local_rank + 1) % cp_size] -+ prev_rank = global_ranks[(local_rank - 1) % cp_size] -+ -+ # iterate over each sequence individually -+ for i in range(len(cu_seqlens) - 1): -+ start_idx = cu_seqlens[i] -+ end_idx = cu_seqlens[i + 1] -+ -+ # the idx has been multiplied by cp_size, so we need to divide it by cp_size to get the local idx -+ local_start_idx = start_idx // cp_size -+ local_end_idx = end_idx // cp_size -+ tensor_slice = rolled_tensor[..., local_start_idx:local_end_idx].clone() -+ -+ # The following code is very similar as the code in roll_tensor function -+ local_chunks = tensor_slice.chunk(2, dim=dims) -+ rolled_chunks = [ -+ torch.roll(chunk, shifts=shifts, dims=dims) for chunk in local_chunks -+ ] -+ -+ tensor_send_list = [] -+ tensor_recv_list = [] -+ for chunk in rolled_chunks: -+ boundary = chunk.select(dims, shifts).contiguous().clone() -+ tensor_send_list.append(boundary) -+ tensor_recv_list.append(torch.empty_like(boundary)) -+ -+ ops = [] -+ if local_rank != 0: -+ ops.append(torch.distributed.isend(tensor=tensor_send_list[0], dst=prev_rank)) -+ ops.append(torch.distributed.irecv(tensor=tensor_recv_list[1], src=prev_rank)) -+ else: -+ tensor_recv_list[1].zero_() -+ -+ if local_rank != cp_size - 1: -+ ops.append(torch.distributed.irecv(tensor=tensor_recv_list[0], src=next_rank)) -+ ops.append(torch.distributed.isend(tensor=tensor_send_list[1], dst=next_rank)) -+ else: -+ tensor_recv_list[0].copy_(tensor_send_list[1]) -+ -+ for op in ops: -+ op.wait() -+ -+ index = [slice(None)] * rolled_chunks[0].dim() -+ index[dims] = shifts -+ for chunk, recv in zip(rolled_chunks, tensor_recv_list): -+ chunk[tuple(index)] = recv -+ -+ seq_result = torch.cat(rolled_chunks, dim=dims) -+ -+ # update the rolled tensor -+ rolled_tensor[..., local_start_idx:local_end_idx] = seq_result -+ -+ return rolled_tensor, rolled_tensor.sum() - - class MTPLossLoggingHelper: - """Helper class for logging MTP losses.""" -@@ -480,9 +588,10 @@ class MultiTokenPredictionLayer(MegatronModule): - def _get_embeddings( - self, - input_ids: torch.Tensor, -- position_ids: torch.Tensor, - embedding: Callable, - hidden_states: torch.Tensor, -+ position_ids: Optional[torch.Tensor] = None, -+ packed_seq_params: Optional[PackedSeqParams] = None, - ): - """ - Preprocesses input data for the Multi-Token Prediction (MTP) layers. -@@ -499,12 +608,23 @@ class MultiTokenPredictionLayer(MegatronModule): - sequence length, b is the batch size, and h is the hidden size. - """ - # Calc logits for the current Multi-Token Prediction (MTP) layers. -- input_ids, _ = roll_tensor(input_ids, shifts=-1, dims=-1, cp_group=self.cp_group) -- position_ids, _ = roll_tensor(position_ids, shifts=-1, dims=-1, cp_group=self.cp_group) -+ input_ids, _ = roll_tensor(input_ids, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params) -+ -+ # Prepare/roll position ids only when applicable. -+ if position_ids is None: -+ # Fallback position ids for learned absolute embedding. -+ seq_len = input_ids.size(-1) -+ position_ids = torch.arange(seq_len, dtype=torch.long, device=input_ids.device) -+ position_ids = position_ids.unsqueeze(0).expand_as(input_ids) -+ -+ position_ids, _ = roll_tensor( -+ position_ids, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params -+ ) +@@ -714,17 +715,19 @@ class MultiTokenPredictionLayer(MegatronModule): + cp_group=self.cp_group, + packed_seq_params=packed_seq_params, + ) +- position_ids, _ = roll_tensor( +- position_ids, +- shifts=-1, +- dims=-1, +- cp_group=self.cp_group, +- packed_seq_params=packed_seq_params, +- ) ++ if position_ids is not None: ++ position_ids, _ = roll_tensor( ++ position_ids, ++ shifts=-1, ++ dims=-1, ++ cp_group=self.cp_group, ++ packed_seq_params=packed_seq_params, ++ ) # embedding decoder_input = embedding(input_ids=input_ids, position_ids=position_ids) + decoder_input = decoder_input.detach() @@ -643,7 +579,7 @@ index b7884e18e..f0104f861 100755 return input_ids, position_ids, decoder_input, hidden_states -@@ -604,22 +724,66 @@ class MultiTokenPredictionLayer(MegatronModule): +@@ -826,6 +829,51 @@ class MultiTokenPredictionLayer(MegatronModule): return hidden_states def _checkpointed_forward(self, forward_func, *args, **kwargs): @@ -693,14 +629,9 @@ index b7884e18e..f0104f861 100755 + tensor_args_tuple = tuple(tensor_args) + def checkpoint_handler(): -- """Determines whether to use the `te_checkpoint` or `tensor_parallel.checkpoint`""" -+ """Determines whether to use the `te_checkpoint` or `tensor_parallel.checkpoint`.""" + """Determines whether to use the `te_checkpoint` or `tensor_parallel.checkpoint`""" if self.config.fp8: - from megatron.core.extensions.transformer_engine import te_checkpoint - - return te_checkpoint( -- forward_func, -+ run, +@@ -836,12 +884,11 @@ class MultiTokenPredictionLayer(MegatronModule): self.config.distribute_saved_activations, tensor_parallel.random.get_cuda_rng_tracker, parallel_state.get_tensor_model_parallel_group(), @@ -715,43 +646,25 @@ index b7884e18e..f0104f861 100755 ) if self.config.recompute_method == 'uniform': -@@ -681,15 +845,13 @@ class MultiTokenPredictionLayer(MegatronModule): - [s, b, h], and optionally the updated context tensor if cross-attention is used. - """ - assert context is None, f"multi token prediction + cross attention is not yet supported." -- assert ( -- packed_seq_params is None -- ), f"multi token prediction + sequence packing is not yet supported." - - input_ids, position_ids, decoder_input, hidden_states = self._get_embeddings( - input_ids=input_ids, - position_ids=position_ids, - embedding=embedding, - hidden_states=hidden_states, -+ packed_seq_params=packed_seq_params, - ) - - if self.config.recompute_granularity == 'full' and self.training: diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py -index d55bebe7e..1eecbbd38 100644 +index e2705bd9f..a0aa109b5 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py -@@ -173,6 +173,10 @@ class TransformerConfig(ModelParallelConfig): - qk_layernorm: bool = False - """Whether to apply `normalization` type of normalization to the query and key embeddings.""" +@@ -210,6 +210,9 @@ class TransformerConfig(ModelParallelConfig): + attention_output_gate: bool = False + """Whether to apply output gate to the attention layers.""" + post_self_attn_layernorm: bool = False + post_mlp_layernorm: bool = False -+ use_gated_attention: bool = False + test_mode: bool = False """Whether to run real-time tests.""" diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py -index 84f22bdea..f0f3f8e86 100644 +index 3ea405770..5a42001b9 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py -@@ -224,6 +224,7 @@ class TransformerLayerSubmodules: +@@ -223,6 +223,7 @@ class TransformerLayerSubmodules: input_layernorm: Union[ModuleSpec, type] = IdentityOp self_attention: Union[ModuleSpec, type] = IdentityOp self_attn_bda: Union[ModuleSpec, type] = IdentityFuncOp @@ -759,7 +672,7 @@ index 84f22bdea..f0f3f8e86 100644 pre_cross_attn_layernorm: Union[ModuleSpec, type] = IdentityOp cross_attention: Union[ModuleSpec, type] = IdentityOp -@@ -232,6 +233,7 @@ class TransformerLayerSubmodules: +@@ -231,6 +232,7 @@ class TransformerLayerSubmodules: pre_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp mlp: Union[ModuleSpec, type] = IdentityOp mlp_bda: Union[ModuleSpec, type] = IdentityFuncOp @@ -767,7 +680,7 @@ index 84f22bdea..f0f3f8e86 100644 # Mapping for sharded tensor keys to be applied in `sharded_state_dict` method sharded_state_dict_keys_map: Dict[str, str] = field(default_factory=dict) -@@ -336,6 +338,13 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer): +@@ -310,6 +312,13 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): # [Module 3: BiasDropoutFusion] self.self_attn_bda = build_module(submodules.self_attn_bda) @@ -781,9 +694,9 @@ index 84f22bdea..f0f3f8e86 100644 # [Module 4: Post SelfAttention] Optional Layernorm after self-attn self.pre_cross_attn_layernorm = build_module( submodules.pre_cross_attn_layernorm, -@@ -399,6 +408,13 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer): - # [Module 9: BiasDropoutFusion] - self.mlp_bda = build_module(submodules.mlp_bda) +@@ -375,6 +384,13 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): + + self.is_moe_layer = isinstance(self.mlp, MoELayer) + self.post_mlp_layernorm = build_module( + submodules.post_mlp_layernorm, @@ -795,7 +708,7 @@ index 84f22bdea..f0f3f8e86 100644 self.recompute_input_layernorm = False self.recompute_pre_mlp_layernorm = False self.recompute_mlp = False -@@ -535,6 +551,10 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer): +@@ -551,6 +567,10 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): attention_output_with_bias[0] ) @@ -806,7 +719,7 @@ index 84f22bdea..f0f3f8e86 100644 # TODO: could we move `bias_dropout_add_exec_handler` itself # inside the module provided in the `bias_dropout_add_spec` module? nvtx_range_push(suffix="self_attn_bda") -@@ -635,6 +655,10 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer): +@@ -677,6 +697,10 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): else: mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output) @@ -818,30 +731,20 @@ index 84f22bdea..f0f3f8e86 100644 # discard the output of the pre-mlp layernorm and register the recompute # as a gradient hook of mlp_output_with_bias[0] diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py -index e3459c5ee..7346bf35b 100644 +index b267c8a81..83736acdc 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py -@@ -937,8 +937,6 @@ def validate_args(args, defaults={}): - # MoE Spec check - if args.num_experts == 0: - args.num_experts = None -- if args.num_experts is not None: -- assert args.spec is None, "Model Spec must be None when using MoEs" - if args.num_experts is not None and args.moe_ffn_hidden_size is None: - args.moe_ffn_hidden_size = args.ffn_hidden_size - print("Warning: moe_ffn_hidden_size is not set, using ffn_hidden_size for MoE instead.") -@@ -1198,6 +1196,10 @@ def core_transformer_config_from_args(args, config_class=None): - if args.is_hybrid_model: - kw_args['is_hybrid_model'] = args.is_hybrid_model +@@ -1398,6 +1398,9 @@ def core_transformer_config_from_args(args, config_class=None): + + kw_args['inference_sampling_seed'] = args.seed + kw_args['post_self_attn_layernorm'] = args.post_self_attn_layernorm + kw_args['post_mlp_layernorm'] = args.post_mlp_layernorm -+ kw_args['use_gated_attention'] = args.use_gated_attention + # handle quantization config # NOTE: Kitchen arguments are only added to the namespace when # Kitchen library is available. -@@ -1488,6 +1490,12 @@ def _add_network_size_args(parser): +@@ -1764,6 +1767,12 @@ def _add_network_size_args(parser): action='store_true', help='If set, use original BERT residula connection ' 'ordering.') @@ -855,15 +758,15 @@ index e3459c5ee..7346bf35b 100644 help='Use OpenAIs GeLU implementation. This option' 'should not be used unless for backward compatibility' diff --git a/megatron/training/tokenizer/tokenizer.py b/megatron/training/tokenizer/tokenizer.py -index 5cf222ccc..d1554ca4c 100644 +index 13b7526ca..6c590f653 100644 --- a/megatron/training/tokenizer/tokenizer.py +++ b/megatron/training/tokenizer/tokenizer.py -@@ -138,6 +138,8 @@ class _HuggingFaceTokenizer(MegatronTokenizer): - f"The transformers library must be installed to use huggingface_tokenizer_provider" - ) - -+ if "trust_remote_code" not in kwargs: -+ kwargs["trust_remote_code"] = True +@@ -136,7 +136,7 @@ class _HuggingFaceTokenizer(MegatronLegacyTokenizer): # TODO(bnorick): download tokenizer once to lustre and use force offline to make sure all tasks read it from there self._tokenizer = transformers.AutoTokenizer.from_pretrained( - pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs + pretrained_model_name_or_path=pretrained_model_name_or_path, +- trust_remote_code=trust_remote_code, ++ trust_remote_code=True, + **kwargs, + ) + self._vocab = self._tokenizer.get_vocab() diff --git a/docker/patch/latest/sglang.patch b/docker/patch/latest/sglang.patch index de12cdd43..b801e1162 100644 --- a/docker/patch/latest/sglang.patch +++ b/docker/patch/latest/sglang.patch @@ -1,8 +1,25 @@ +diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py +index f807deedb..4c0407dec 100644 +--- a/python/sglang/srt/configs/model_config.py ++++ b/python/sglang/srt/configs/model_config.py +@@ -269,6 +269,12 @@ class ModelConfig: + ): + self.hf_config.architectures[0] = "DeepseekV3ForCausalLMNextN" + ++ if ( ++ is_draft_model ++ and self.hf_config.architectures[0] == "DeepseekV32ForCausalLM" ++ ): ++ self.hf_config.architectures[0] = "DeepseekV3ForCausalLMNextN" ++ + if is_draft_model and self.hf_config.architectures[0] == "Glm4MoeForCausalLM": + self.hf_config.architectures[0] = "Glm4MoeForCausalLMNextN" + diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py -index ef52bda7f..537d892dc 100644 +index 199885244..742ad0639 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py -@@ -296,6 +296,13 @@ class DecodePreallocQueue: +@@ -314,6 +314,13 @@ class DecodePreallocQueue: ) return kv_manager @@ -17,10 +34,10 @@ index ef52bda7f..537d892dc 100644 """Add a request to the pending queue.""" if self._check_if_req_exceed_kv_capacity(req): diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py -index d4414d084..c5fb10155 100644 +index 32e8c0b69..df913da7b 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py -@@ -1074,6 +1074,19 @@ class MooncakeKVManager(CommonKVManager): +@@ -1079,6 +1079,19 @@ class MooncakeKVManager(CommonKVManager): f"Losing connection with prefill instance (bootstrap_addr: {failed_bootstrap_addr}), {len(affected_rooms)} requests affected" ) @@ -41,10 +58,10 @@ index d4414d084..c5fb10155 100644 class MooncakeKVSender(CommonKVSender): diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py -index 952374ed5..239ac2571 100644 +index ac11013f8..478e469f6 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py -@@ -305,6 +305,13 @@ class PrefillBootstrapQueue: +@@ -309,6 +309,13 @@ class PrefillBootstrapQueue: else: return bootstrapped_reqs, failed_reqs @@ -58,60 +75,11 @@ index 952374ed5..239ac2571 100644 class SchedulerDisaggregationPrefillMixin: """ -diff --git a/python/sglang/srt/distributed/device_communicators/pynccl.py b/python/sglang/srt/distributed/device_communicators/pynccl.py -index 86c53f26b..52acf95b9 100644 ---- a/python/sglang/srt/distributed/device_communicators/pynccl.py -+++ b/python/sglang/srt/distributed/device_communicators/pynccl.py -@@ -380,3 +380,9 @@ class PyNcclCommunicator: - - self.disabled = old_disable - self.stream = old_stream -+ -+ def nccl_pause(self): -+ self.nccl.ncclPause(self.comm) -+ -+ def nccl_resume(self): -+ self.nccl.ncclResume(self.comm) -diff --git a/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py b/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py -index 6b12f2922..7028a4e46 100644 ---- a/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py -+++ b/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py -@@ -304,6 +304,17 @@ class NCCLLibrary: - Function("ncclGroupEnd", ncclResult_t, []), - ] - -+ if os.environ.get("AMEM_ENABLE", "0") == "1": -+ exported_functions.extend( -+ [ -+ # ncclResult_t ncclPause(ncclComm_t comm); -+ Function("ncclPause", ncclResult_t, [ncclComm_t]), -+ # ncclResult_t ncclResume(ncclComm_t comm); -+ Function("ncclResume", ncclResult_t, [ncclComm_t]), -+ Function("ncclSetGroupID", ncclResult_t, [ctypes.c_int]), -+ ] -+ ) -+ - exported_functions_symm_mem = [ - # ncclResult_t ncclCommWindowRegister(ncclComm_t comm, void* buff, size_t size, ncclWindow_t* win, int winFlags); - Function( -@@ -551,6 +562,12 @@ class NCCLLibrary: - def ncclGroupEnd(self) -> None: - self.NCCL_CHECK(self._funcs["ncclGroupEnd"]()) - -+ def ncclPause(self, comm: ncclComm_t) -> None: -+ self.NCCL_CHECK(self._funcs["ncclPause"](comm)) -+ -+ def ncclResume(self, comm: ncclComm_t) -> None: -+ self.NCCL_CHECK(self._funcs["ncclResume"](comm)) -+ - - __all__ = [ - "NCCLLibrary", diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py -index cf90f6fe0..11d26df81 100644 +index 0478526ef..cfb1aa669 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py -@@ -1780,7 +1780,10 @@ def get_tensor_model_parallel_world_size(): +@@ -1797,7 +1797,10 @@ def get_tensor_model_parallel_world_size(): def get_tensor_model_parallel_rank(): """Return my rank for the tensor model parallel group.""" @@ -124,101 +92,125 @@ index cf90f6fe0..11d26df81 100644 def get_pipeline_model_parallel_world_size(): diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py -index 67a082ea6..390365864 100644 +index 21909706b..8fac5f162 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py -@@ -183,6 +183,7 @@ class Engine(EngineBase): - lora_path: Optional[List[Optional[str]]] = None, - custom_logit_processor: Optional[Union[List[str], str]] = None, - return_hidden_states: bool = False, -+ return_routed_experts: bool = False, - stream: bool = False, - bootstrap_host: Optional[Union[List[str], str]] = None, - bootstrap_port: Optional[Union[List[int], int]] = None, -@@ -218,6 +219,7 @@ class Engine(EngineBase): - lora_path=lora_path, - custom_logit_processor=custom_logit_processor, - return_hidden_states=return_hidden_states, -+ return_routed_experts=return_routed_experts, - stream=stream, - bootstrap_host=bootstrap_host, - bootstrap_port=bootstrap_port, -diff --git a/python/sglang/srt/layers/attention/vision.py b/python/sglang/srt/layers/attention/vision.py -index 9f556a885..992843285 100644 ---- a/python/sglang/srt/layers/attention/vision.py -+++ b/python/sglang/srt/layers/attention/vision.py -@@ -518,11 +518,25 @@ class VisionAttention(nn.Module): - self.dummy_dim = (num_dummy_heads + num_heads) * self.head_size - - if self.qk_normalization: -+ norm_kwargs = ( -+ dict( -+ weight_dtype=torch.float32, -+ cast_x_before_out_mul=True, -+ ) -+ if get_global_server_args().rl_on_policy_target is not None -+ else {} -+ ) - self.q_norm = RMSNorm( -- self.dummy_dim, eps=layer_norm_eps, var_hidden_size=embed_dim -+ self.dummy_dim, -+ eps=layer_norm_eps, -+ var_hidden_size=embed_dim, -+ **norm_kwargs, - ) - self.k_norm = RMSNorm( -- self.dummy_dim, eps=layer_norm_eps, var_hidden_size=embed_dim -+ self.dummy_dim, -+ eps=layer_norm_eps, -+ var_hidden_size=embed_dim, -+ **norm_kwargs, - ) +@@ -49,6 +49,7 @@ from sglang.srt.managers.io_struct import ( + InitWeightsUpdateGroupReqInput, + LoadLoRAAdapterReqInput, + MultimodalDataInputFormat, ++ PostProcessWeightsReqInput, + ReleaseMemoryOccupationReqInput, + ResumeMemoryOccupationReqInput, + RpcReqInput, +@@ -593,6 +594,24 @@ class Engine(EngineBase): + self.tokenizer_manager.update_weights_from_ipc(obj, None) + ) - # Select attention backend via a unified method -@@ -648,6 +662,15 @@ class VisionAttention(nn.Module): - if x.dim() == 2: - x = x.unsqueeze(0) - assert x.dim() == 3, x.shape -+ if ( -+ get_global_server_args().rl_on_policy_target is not None -+ and position_embeddings is not None -+ ): -+ assert isinstance(position_embeddings, tuple), ( -+ "expected position_embeddings to be a tuple of two tensors,\n" -+ f"but got {type(position_embeddings)}, change if needed" -+ ) -+ position_embeddings = tuple(p.to(x.dtype) for p in position_embeddings) - x_shape = x.shape - bsz, s, _ = x_shape - head = self.num_attention_heads_per_partition -diff --git a/python/sglang/srt/layers/communicator.py b/python/sglang/srt/layers/communicator.py -index 932f52aeb..79c6b664f 100644 ---- a/python/sglang/srt/layers/communicator.py -+++ b/python/sglang/srt/layers/communicator.py -@@ -372,6 +372,7 @@ class LayerCommunicator: - residual: torch.Tensor, - forward_batch: ForwardBatch, - quant_format: str = "", -+ post_residual_addition: Optional[torch.Tensor] = None, - ): - if get_attn_tp_context().input_scattered: - hidden_states, residual = self._tp_reduce_scatter( -@@ -453,7 +454,9 @@ class LayerCommunicator: - ) - else: - hidden_states, residual = self.input_layernorm( -- hidden_states, residual -+ hidden_states, -+ residual, -+ post_residual_addition, - ) - - hidden_states = self._communicate_simple_fn( ++ def post_process_weights( ++ self, ++ restore_weights_before_load: bool = False, ++ post_process_quantization: bool = False, ++ ): ++ """ ++ Optional post-processing for updated weights (e.g., Marlin conversion). ++ Should be called after weight update is finished. ++ """ ++ obj = PostProcessWeightsReqInput( ++ restore_weights_before_load=restore_weights_before_load, ++ post_process_quantization=post_process_quantization, ++ ) ++ ++ return self.loop.run_until_complete( ++ self.tokenizer_manager.post_process_weights(obj, None) ++ ) ++ + def get_weights_by_name(self, name: str, truncate_size: int = 100): + """Get weights by parameter name.""" + obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size) +diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py +index 88705cc35..c8dc052f1 100644 +--- a/python/sglang/srt/entrypoints/http_server.py ++++ b/python/sglang/srt/entrypoints/http_server.py +@@ -107,6 +107,7 @@ from sglang.srt.managers.io_struct import ( + OpenSessionReqInput, + ParseFunctionCallReq, + PauseGenerationReqInput, ++ PostProcessWeightsReqInput, + ProfileReqInput, + ReleaseMemoryOccupationReqInput, + ResumeMemoryOccupationReqInput, +@@ -957,6 +958,21 @@ async def update_weights_from_ipc(obj: UpdateWeightsFromIPCReqInput, request: Re + else: + return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST) + ++@app.post("/post_process_weights") ++async def post_process_weights(req: PostProcessWeightsReqInput, request: Request): ++ """ ++ Optional post-processing for updated weights (e.g., Marlin conversion). ++ This should be called selectively after `update_weights_from_distributed/update_weights_from_tensor`. ++ """ ++ success, message = await _global_state.tokenizer_manager.post_process_weights( ++ req, request ++ ) ++ ++ content = {"success": success, "message": message} ++ return ORJSONResponse( ++ content, status_code=200 if success else HTTPStatus.BAD_REQUEST ++ ) ++ + + @app.post("/update_weight_version") + async def update_weight_version(obj: UpdateWeightVersionReqInput, request: Request): +diff --git a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py +index c9e82e4b1..58270e34a 100644 +--- a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py ++++ b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py +@@ -3,6 +3,7 @@ from __future__ import annotations + from abc import ABC, abstractmethod + from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple + ++import os + import torch + from einops import rearrange + +@@ -188,6 +189,9 @@ class Indexer(MultiPlatformOp): + @torch.compile(dynamic=True) + def _get_logits_head_gate(self, x: torch.Tensor, q_scale: torch.Tensor): + weights, _ = self.weights_proj(x.float()) ++ if weights.shape[1] < 32: ++ assert 32 % weights.shape[1] == 0 ++ weights = weights.repeat_interleave(32 // weights.shape[1], dim=1) + weights = weights * self.n_heads**-0.5 + weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale + return weights +@@ -278,7 +282,10 @@ class Indexer(MultiPlatformOp): + key, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1 + ) + +- _, k_rope = self.rotary_emb(positions, k_rope, k_rope) ++ if os.environ.get("USE_FIRST_HALF_ROPE", "0") == "1": ++ k_rope, _ = self.rotary_emb(positions, k_rope, k_rope) ++ else: ++ _, k_rope = self.rotary_emb(positions, k_rope, k_rope) + key[..., : self.rope_head_dim] = k_rope + key = rotate_activation(key) + +@@ -837,6 +844,9 @@ class Indexer(MultiPlatformOp): + query, key = self._get_q_k_bf16( + q_lora, x, positions, enable_dual_stream, forward_batch=forward_batch + ) ++ if query.shape[1] < 32: ++ assert 32 % query.shape[1] == 0 ++ query = query.repeat_interleave(32//query.shape[1], dim=1) + + if enable_dual_stream: + current_stream = torch.cuda.current_stream() diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py -index 3293a8a59..a075b71ce 100644 +index b07164c53..8e6722ce0 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py -@@ -84,15 +84,12 @@ class RMSNorm(CustomOp): +@@ -83,15 +83,12 @@ class RMSNorm(MultiPlatformOp): eps: float = 1e-6, var_hidden_size: Optional[int] = None, cast_x_before_out_mul: bool = False, @@ -236,45 +228,13 @@ index 3293a8a59..a075b71ce 100644 self.variance_epsilon = eps self.hidden_size = hidden_size self.variance_size_override = ( -@@ -105,21 +102,26 @@ class RMSNorm(CustomOp): - self, - x: torch.Tensor, - residual: Optional[torch.Tensor] = None, -+ post_residual_addition: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - if self.variance_size_override is not None: -- return self.forward_native(x, residual) -+ return self.forward_native(x, residual, post_residual_addition) - if is_batch_invariant_mode_enabled(): - if ( - residual is not None - or get_global_server_args().rl_on_policy_target == "fsdp" - ): -- return self.forward_native(x, residual) -+ return self.forward_native(x, residual, post_residual_addition) - return rms_norm_batch_invariant( - x, - self.weight.data, - self.variance_epsilon, - ) - if residual is not None: -+ # TODO: Ideally we want to have (a+b)+c. but right now we can only have a+(b+c). -+ # (a+b)+c != a+(b+c), we probably need to add another parameter to fused_add_rmsnorm -+ if post_residual_addition is not None: -+ residual = residual + post_residual_addition - fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon) - return x, residual - out = rmsnorm(x, self.weight.data, self.variance_epsilon) -@@ -179,17 +181,35 @@ class RMSNorm(CustomOp): - self, - x: torch.Tensor, - residual: Optional[torch.Tensor] = None, -+ post_residual_addition: Optional[torch.Tensor] = None, +@@ -194,10 +191,22 @@ class RMSNorm(MultiPlatformOp): ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: if not x.is_contiguous(): x = x.contiguous() - orig_dtype = self.override_orig_dtype or x.dtype + orig_dtype = x.dtype + post_residual_addition = kwargs.get("post_residual_addition") + + if residual is not None and not self.fp32_residual: + x = ( @@ -289,30 +249,27 @@ index 3293a8a59..a075b71ce 100644 + residual = x.clone() x = x.to(torch.float32) - if residual is not None: -- x = x + residual.to(torch.float32) ++ if residual is not None and self.fp32_residual: + x = ( + x + + residual.to(torch.float32) +@@ -207,10 +216,7 @@ class RMSNorm(MultiPlatformOp): + else 0.0 + ) + ) - if self.fp32_residual: - residual = x.clone() - else: - residual = x.to(orig_dtype) -+ if residual is not None and self.fp32_residual: -+ x = ( -+ x -+ + residual.to(torch.float32) -+ + ( -+ post_residual_addition.to(torch.float32) -+ if post_residual_addition is not None -+ else 0.0 -+ ) -+ ) + residual = x.to(orig_dtype) hidden_size = x.shape[-1] if hidden_size != self.hidden_size: diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py -index 522865765..733bad5f2 100644 +index fa7431048..cd33ea735 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py -@@ -841,11 +841,6 @@ class LogitsProcessor(nn.Module): +@@ -878,11 +878,6 @@ class LogitsProcessor(nn.Module): None, # bias True, # is_vnni ) @@ -325,7 +282,7 @@ index 522865765..733bad5f2 100644 logits = torch.matmul( hidden_states.to(lm_head.weight.dtype), lm_head.weight.T diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py -index e7d5a67cc..639e47163 100644 +index a1885fade..14d692365 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -14,6 +14,7 @@ import torch.nn.functional as F @@ -335,8 +292,8 @@ index e7d5a67cc..639e47163 100644 +from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import ( cpu_has_amx_support, - direct_register_custom_op, -@@ -626,7 +627,10 @@ def fused_experts_impl( + get_bool_env_var, +@@ -573,7 +574,10 @@ def fused_experts_impl( ).squeeze(dim=1) else: # According to micro benchmark results, torch.compile can get better performance for small token. @@ -349,389 +306,177 @@ index e7d5a67cc..639e47163 100644 intermediate_cache3.view(*intermediate_cache3.shape), out_hidden_states[begin_chunk_idx:end_chunk_idx], diff --git a/python/sglang/srt/layers/moe/routed_experts_capturer.py b/python/sglang/srt/layers/moe/routed_experts_capturer.py -new file mode 100644 -index 000000000..e16817f1f ---- /dev/null +index 00bd68755..5a3ca8a67 100644 +--- a/python/sglang/srt/layers/moe/routed_experts_capturer.py +++ b/python/sglang/srt/layers/moe/routed_experts_capturer.py -@@ -0,0 +1,279 @@ -+import logging -+from abc import ABC +@@ -1,5 +1,6 @@ + import logging + from abc import ABC +from contextlib import contextmanager -+from typing import Optional -+ -+import numpy as np -+import torch -+ -+from sglang.srt.configs.model_config import ModelConfig -+from sglang.srt.layers.dp_attention import ( -+ get_attention_dp_rank, -+ get_dp_local_info, -+ is_dp_attention_enabled, + from typing import Optional + + import numpy as np +@@ -8,13 +9,18 @@ import torch + + from sglang.srt.configs.model_config import ModelConfig + from sglang.srt.layers.dp_attention import ( ++ attn_tp_all_gather_into_tensor, + get_attention_dp_rank, ++ get_attention_tp_size, + get_dp_local_info, + is_dp_attention_enabled, + ) + from sglang.srt.mem_cache.memory_pool import ReqToTokenPool + from sglang.srt.model_executor.forward_batch_info import ForwardBatch + from sglang.srt.server_args import get_global_server_args ++from sglang.srt.layers.moe import ( ++ get_moe_a2a_backend, +) -+from sglang.srt.mem_cache.memory_pool import ReqToTokenPool -+from sglang.srt.server_args import get_global_server_args -+ -+logger = logging.getLogger(__name__) -+ -+_GB = 1024 * 1024 * 1024 -+_MB = 1024 * 1024 -+ -+ -+def get_tensor_size_bytes(t: torch.Tensor): -+ return np.prod(t.shape) * t.dtype.itemsize -+ -+ -+class _RoutedExpertsDeviceCache: -+ def __init__( -+ self, -+ max_running_requests: int, -+ num_hidden_layers: int, -+ num_experts_per_tok: int, -+ num_fused_shared_experts: int, -+ device: str, -+ ) -> None: -+ self.buffer = torch.zeros( -+ ( -+ max( -+ get_global_server_args().chunked_prefill_size -+ * get_global_server_args().dp_size, -+ max_running_requests, + + logger = logging.getLogger(__name__) + +@@ -181,13 +187,26 @@ class _RoutedExpertsCapturerReal(RoutedExpertsCapturer): + device=device, + ) + ++ if get_moe_a2a_backend().is_deepep(): ++ attn_tp_size = get_attention_tp_size() if is_dp_attention_enabled() else 1 ++ self.gather_buffer = torch.empty( ++ ( ++ self.device_cache.buffer.shape[0] * attn_tp_size, ++ self.device_cache.buffer.shape[2], + ), -+ num_hidden_layers, -+ num_experts_per_tok + num_fused_shared_experts, -+ ), -+ dtype=torch.int32, -+ device=device, -+ ) -+ self._finalize_allocation_log() -+ -+ def get_buffer_size_bytes(self): -+ assert hasattr(self, "buffer") -+ return get_tensor_size_bytes(self.buffer) -+ -+ def capture_fwd_routed_experts(self, layer_id: int, topk_ids: torch.Tensor): -+ assert layer_id is not None, "capturing routing experts but get layer_id None" -+ batch, _ = topk_ids.shape -+ self.buffer[:batch, layer_id, :] = topk_ids -+ -+ def _finalize_allocation_log(self): -+ """Common logging and memory usage computation for captured experts buffers.""" -+ buffer_size_MB = self.get_buffer_size_bytes() / _MB -+ logger.info( -+ f"Routing experts device buffer allocated. #shape: {tuple(self.buffer.shape)}, size: {buffer_size_MB:.2f} MB" -+ ) -+ -+ -+class _RoutedExpertsHostCache: -+ def __init__( -+ self, -+ num_tokens: int, -+ num_hidden_layers: int, -+ num_experts_per_tok: int, -+ ) -> None: -+ self.num_tokens = num_tokens -+ self.buffer = torch.zeros( -+ ( -+ num_tokens, -+ num_hidden_layers, -+ num_experts_per_tok, -+ ), -+ dtype=torch.int32, -+ device="cpu", -+ pin_memory=True, -+ ) -+ self._finalize_allocation_log() -+ -+ def get_buffer_size_bytes(self): -+ assert hasattr(self, "buffer") -+ return get_tensor_size_bytes(self.buffer) -+ -+ def set_experts_buffer(self, layer_id: int, loc: torch.Tensor, top_k: torch.Tensor): -+ self.buffer[layer_id, loc, :] = top_k.to(device="cpu", non_blocking=True) -+ -+ def _finalize_allocation_log(self): -+ """Common logging and memory usage computation for captured experts buffers.""" -+ buffer_size_GB = self.get_buffer_size_bytes() / _GB -+ logger.info( -+ f"Routing experts host buffer allocated. #tokens: {self.num_tokens}, size: {buffer_size_GB:.2f} GB" -+ ) -+ -+ -+class RoutedExpertsCapturer(ABC): -+ @staticmethod -+ def create( -+ enable: bool, -+ model_config: ModelConfig, -+ num_fused_shared_experts: int, -+ num_tokens: int, -+ max_running_requests: int, -+ device: str, -+ ): -+ if enable: -+ return _RoutedExpertsCapturerReal( -+ model_config, -+ num_tokens=num_tokens, -+ max_running_requests=max_running_requests, -+ num_fused_shared_experts=num_fused_shared_experts, ++ dtype=torch.int32, + device=device, + ) -+ else: -+ return _RoutedExpertsCapturerNoop() -+ -+ def capture(self, layer_id: int, topk_ids: torch.Tensor): -+ raise NotImplementedError -+ -+ def get_routed_experts( -+ self, -+ req_pool_idx: int, -+ seqlen: int, -+ req_to_token_pool: ReqToTokenPool, -+ ): -+ raise NotImplementedError -+ -+ def sync_fwd_experts_buffer_DtoH( -+ self, -+ device_loc: torch.Tensor, -+ cpu_loc: torch.Tensor, -+ can_run_graph: bool, -+ cuda_graph_batch: int, -+ ): -+ raise NotImplementedError -+ -+ @contextmanager -+ def with_forward(self, forward_batch): -+ yield -+ -+ def get_host_cache(self): -+ raise NotImplementedError -+ -+ def get_device_cache(self): -+ raise NotImplementedError -+ -+ -+class _RoutedExpertsCapturerReal(RoutedExpertsCapturer): -+ """Capturer for routed experts with host buffer""" -+ -+ def __init__( -+ self, -+ model_config: ModelConfig, -+ num_tokens: int, -+ max_running_requests: int, -+ num_fused_shared_experts: int, -+ device: str, -+ ): -+ self.forward_batch = None -+ self.num_fused_shared_experts = num_fused_shared_experts -+ self.num_hidden_layers = model_config.hf_text_config.num_hidden_layers -+ self.num_experts_per_tok = model_config.hf_text_config.num_experts_per_tok -+ -+ self.host_cache = _RoutedExpertsHostCache( -+ num_tokens=num_tokens, -+ num_hidden_layers=self.num_hidden_layers, -+ num_experts_per_tok=self.num_experts_per_tok, -+ ) -+ -+ self.device_cache = _RoutedExpertsDeviceCache( -+ max_running_requests=max_running_requests, -+ num_hidden_layers=self.num_hidden_layers, -+ num_experts_per_tok=self.num_experts_per_tok, -+ num_fused_shared_experts=self.num_fused_shared_experts, -+ device=device, -+ ) -+ -+ def capture(self, layer_id: int, topk_ids: torch.Tensor): -+ self.device_cache.capture_fwd_routed_experts(layer_id, topk_ids) -+ -+ def sync_fwd_experts_buffer_DtoH( -+ self, -+ device_loc: torch.Tensor, -+ cpu_loc: torch.Tensor, -+ can_run_graph: bool, -+ cuda_graph_batch: int, -+ ): -+ if is_dp_attention_enabled(): -+ local_start_pos, local_num_tokens = get_dp_local_info(self.forward_batch) -+ # handle with cuda graph padding -+ if can_run_graph: -+ local_start_pos = get_attention_dp_rank() * cuda_graph_batch -+ local_end_pos = local_start_pos + local_num_tokens -+ else: -+ local_end_pos = local_start_pos + local_num_tokens -+ else: -+ local_start_pos = 0 -+ local_end_pos = device_loc.shape[0] + -+ self.host_cache.buffer[cpu_loc] = self.device_cache.buffer[ -+ local_start_pos:local_end_pos, :, : self.num_experts_per_tok -+ ].cpu() -+ -+ def get_routed_experts( -+ self, -+ req_pool_idx: int, -+ seqlen: int, -+ req_to_token_pool: ReqToTokenPool, -+ ): -+ cache_pool_idx = ( -+ req_to_token_pool.req_to_token[req_pool_idx][: seqlen - 1].cpu().clone() -+ ) -+ return self.get_host_cache().buffer[cache_pool_idx] -+ -+ @contextmanager -+ def with_forward(self, forward_batch): -+ self.forward_batch = forward_batch -+ yield -+ -+ def get_host_cache(self): -+ return self.host_cache -+ -+ def get_device_cache(self): -+ return self.device_cache -+ -+ -+class _RoutedExpertsCapturerNoop(RoutedExpertsCapturer): -+ def __init__(self): -+ pass -+ -+ def capture(self, layer_id: int, topk_ids: torch.Tensor): -+ pass -+ -+ def get_routed_experts( -+ self, -+ req_pool_idx: int, -+ seqlen: int, -+ req_to_token_pool: ReqToTokenPool, -+ ): -+ pass -+ -+ def sync_fwd_experts_buffer_DtoH( -+ self, -+ device_loc: torch.Tensor, -+ cpu_loc: torch.Tensor, -+ can_run_graph: bool, -+ cuda_graph_batch: int, -+ ): -+ pass -+ -+ @contextmanager -+ def with_forward(self, forward_batch): -+ yield -+ -+ def get_host_cache(self): -+ pass -+ -+ def get_device_cache(self): -+ pass -+ -+ -+_global_expert_capturer: Optional[RoutedExpertsCapturer] = _RoutedExpertsCapturerNoop() + def _sync_fwd_experts_buffer_DtoH( + self, + forward_batch: ForwardBatch, + can_run_graph: bool, + cuda_graph_batch: int, + ): +- if is_dp_attention_enabled(): ++ # When DeepEP is enabled, capture() already does all_gather, so device_cache.buffer ++ # contains data from all DP ranks. We should not slice by DP rank in this case. ++ if is_dp_attention_enabled() and not get_moe_a2a_backend().is_deepep(): + local_start_pos, local_num_tokens = get_dp_local_info(forward_batch) + # handle with cuda graph padding + if can_run_graph: +@@ -206,6 +225,12 @@ class _RoutedExpertsCapturerReal(RoutedExpertsCapturer): + ].cpu() + + def capture(self, layer_id: int, topk_ids: torch.Tensor): ++ if get_moe_a2a_backend().is_deepep(): ++ local_topk_ids = topk_ids ++ topk_ids = self.gather_buffer[ ++ : local_topk_ids.size(0) * get_attention_tp_size() ++ ] ++ attn_tp_all_gather_into_tensor(topk_ids, local_topk_ids) + self.device_cache.capture_fwd_routed_experts(layer_id, topk_ids) + + def get_routed_experts( +diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +index c5e5a11fc..6b788fb1d 100644 +--- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py ++++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +@@ -1016,13 +1016,38 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): + layer.a2_scale = None + layer.marlin_state = GPTQMarlinState.REPACK + ++ if not hasattr(layer, "_original_shapes"): ++ layer._original_shapes = {} ++ ++ # Force record: these are the target GPTQ shapes for rollback. ++ layer._original_shapes["w13_weight_packed"] = tuple(w13_weight.shape) ++ layer._original_shapes["w2_weight_packed"] = tuple(w2_weight.shape) ++ ++ # Also record the shapes of the scales. ++ layer._original_shapes["w2_weight_scale"] = tuple(w2_scale.shape) ++ layer._original_shapes["w13_weight_scale"] = tuple(w13_scale.shape) ++ + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: ++ ++ # Skip if the layer is already converted to Marlin format to prevent double-packing. ++ if getattr(layer, "is_marlin_converted", False): ++ return ++ ++ if not hasattr(layer, "_original_shapes"): ++ layer._original_shapes = {} + + def replace_tensor(name, new_t): ++ target_attr = getattr(layer, name) ++ ++ # Only save if the key doesn't exist to prevent overwriting with Marlin shapes. ++ if name not in layer._original_shapes: ++ # This is a safety check; `create_weights` usually handles this already. ++ layer._original_shapes[name] = tuple(target_attr.shape) ++ + # It is important to use resize_() here since it ensures + # the same buffer is reused +- getattr(layer, name).resize_(new_t.shape) +- getattr(layer, name).copy_(new_t) ++ target_attr.resize_(new_t.shape) ++ target_attr.copy_(new_t) + del new_t + + num_experts = layer.w13_weight_g_idx.shape[0] +@@ -1078,7 +1103,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): + layer.w13_weight_packed.shape[2], + self.num_bits, + ) +- replace_parameter(layer, "w13_weight_packed", marlin_w13_qweight) ++ replace_tensor("w13_weight_packed", marlin_w13_qweight) + marlin_w2_qweight = gptq_marlin_moe_repack( + layer.w2_weight_packed, + layer.w2_g_idx_sort_indices, +@@ -1086,7 +1111,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): + layer.w2_weight_packed.shape[2], + self.num_bits, + ) +- replace_parameter(layer, "w2_weight_packed", marlin_w2_qweight) ++ replace_tensor("w2_weight_packed", marlin_w2_qweight) + # Repack scales + marlin_w13_scales = marlin_moe_permute_scales( + layer.w13_weight_scale, +@@ -1094,7 +1119,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): + layer.w13_weight_scale.shape[2], + self.group_size, + ) +- replace_parameter(layer, "w13_weight_scale", marlin_w13_scales) ++ replace_tensor("w13_weight_scale", marlin_w13_scales) + + marlin_w2_scales = marlin_moe_permute_scales( + layer.w2_weight_scale, +@@ -1103,7 +1128,22 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): + layer.w2_weight_scale.shape[2], + self.group_size, + ) +- replace_parameter(layer, "w2_weight_scale", marlin_w2_scales) ++ replace_tensor("w2_weight_scale", marlin_w2_scales) + ++ layer.is_marlin_converted = True ++ ++ def restore_weights_before_loading(self, layer: torch.nn.Module): ++ """Forcibly resize parameters back to their original shapes (e.g., GPTQ format) before loading weights.""" ++ if not hasattr(layer, "_original_shapes"): ++ return + -+def get_global_experts_capturer(): -+ return _global_expert_capturer ++ for name, orig_shape in layer._original_shapes.items(): ++ param = getattr(layer, name, None) + ++ if param is not None and param.shape != orig_shape: ++ param.resize_(orig_shape) + -+def set_global_experts_capturer(capturer: RoutedExpertsCapturer): -+ global _global_expert_capturer -+ _global_expert_capturer = capturer -\ No newline at end of file -diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py -index a802647e8..0fd550c0c 100644 ---- a/python/sglang/srt/layers/moe/topk.py -+++ b/python/sglang/srt/layers/moe/topk.py -@@ -48,6 +48,7 @@ from sglang.srt.eplb.expert_location_dispatch import ( - ) - from sglang.srt.layers.dp_attention import is_allocation_symmetric - from sglang.srt.layers.moe import get_moe_runner_backend -+from sglang.srt.layers.moe.routed_experts_capturer import get_global_experts_capturer - from sglang.srt.utils import ( - cpu_has_amx_support, - get_bool_env_var, -@@ -212,6 +213,7 @@ class TopK(CustomOp): - self, - top_k: int, - *, -+ layer_id: Optional[int] = None, - use_grouped_topk: bool = False, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, -@@ -233,6 +235,7 @@ class TopK(CustomOp): - if use_grouped_topk: - assert num_expert_group is not None and topk_group is not None - -+ self.layer_id = layer_id - self.topk_config = TopKConfig( - top_k=top_k, - use_grouped_topk=use_grouped_topk, -@@ -260,6 +263,7 @@ class TopK(CustomOp): - self.topk_config.torch_native = True - return select_experts( - hidden_states=hidden_states, -+ layer_id=self.layer_id, - router_logits=router_logits, - topk_config=self.topk_config, - num_token_non_padded=num_token_non_padded, -@@ -309,6 +313,7 @@ class TopK(CustomOp): - ): - topk_output = select_experts( - hidden_states=hidden_states, -+ layer_id=self.layer_id, - router_logits=router_logits, - topk_config=self.topk_config, - num_token_non_padded=num_token_non_padded, -@@ -326,6 +331,7 @@ class TopK(CustomOp): - ) -> TopKOutput: - return select_experts( - hidden_states=hidden_states, -+ layer_id=self.layer_id, - router_logits=router_logits, - topk_config=self.topk_config, - num_token_non_padded=num_token_non_padded, -@@ -856,6 +862,7 @@ def select_experts( - router_logits: torch.Tensor, - topk_config: TopKConfig, - *, -+ layer_id: Optional[int] = None, - num_token_non_padded: Optional[torch.Tensor] = None, - expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, - ) -> StandardTopKOutput: -@@ -983,7 +990,10 @@ def select_experts( - ) ++ layer.is_marlin_converted = False - get_global_expert_distribution_recorder().on_select_experts(topk_ids=topk_ids) -- -+ get_global_experts_capturer().capture( -+ layer_id=layer_id, -+ topk_ids=topk_ids, -+ ) - return StandardTopKOutput(topk_weights, topk_ids, router_logits) - - -diff --git a/python/sglang/srt/layers/moe/utils.py b/python/sglang/srt/layers/moe/utils.py -index 70466bb20..cd85fc2f2 100644 ---- a/python/sglang/srt/layers/moe/utils.py -+++ b/python/sglang/srt/layers/moe/utils.py -@@ -284,7 +284,7 @@ def speculative_moe_a2a_backend_context(): - global MOE_A2A_BACKEND - original_backend = MOE_A2A_BACKEND - try: -- MOE_A2A_BACKEND = MoeA2ABackend.NONE -+ MOE_A2A_BACKEND = get_speculative_moe_a2a_backend() - yield - finally: - MOE_A2A_BACKEND = original_backend + def create_moe_runner( + self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py -index 0cdb7e1ae..df8860409 100644 +index 56516b41b..cb2ebca60 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py -@@ -15,7 +15,6 @@ from sglang.srt.server_args import get_global_server_args - from sglang.srt.utils import ( - cpu_has_amx_support, - get_bool_env_var, -- get_compiler_backend, - is_cpu, - is_cuda, - is_hip, -@@ -132,9 +131,7 @@ class RotaryEmbedding(CustomOp): +@@ -135,9 +135,7 @@ class RotaryEmbedding(MultiPlatformOp): if get_global_server_args().rl_on_policy_target is not None: self._forward_method = self.forward_native @@ -742,82 +487,21 @@ index 0cdb7e1ae..df8860409 100644 self.position_cos, self.position_sin = None, None def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: -@@ -1423,6 +1420,9 @@ class MRotaryEmbedding(RotaryEmbedding): - f"Corrected mrope_section: {self.mrope_section} (sum={sum(self.mrope_section)})" - ) - -+ if get_global_server_args().rl_on_policy_target is not None: -+ self._forward_method = self.forward_native -+ - def _match_cos_sin_cache_dtype(self, query: torch.Tensor) -> None: - # __setattr__ in nn.Module (called by `self.cos_sin_cache = ...`) - # is expensive, so avoid calling it if possible -@@ -1432,8 +1432,7 @@ class MRotaryEmbedding(RotaryEmbedding): - ): - self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype) - -- @torch.compile(dynamic=True, backend=get_compiler_backend()) -- def _forward_native( -+ def forward_native( - self, - positions: torch.Tensor, - query: torch.Tensor, -@@ -1490,7 +1489,7 @@ class MRotaryEmbedding(RotaryEmbedding): - key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) - return query, key - -- def forward( -+ def forward_cuda( - self, - positions: torch.Tensor, - query: torch.Tensor, -@@ -1507,14 +1506,12 @@ class MRotaryEmbedding(RotaryEmbedding): - """ - assert positions.ndim == 1 or positions.ndim == 2 - -- if positions.ndim == 2 and self.mrope_section and _is_cuda: -- return self._forward_triton(positions, query, key) -- elif _is_npu: -- return self._forward_npu(positions, query, key) -- else: -- return self._forward_native(positions, query, key) -+ # Use Triton kernel for multimodal (2D positions) with mrope -+ if positions.ndim == 2 and self.mrope_section: -+ return self.forward_triton(positions, query, key) -+ return self.forward_native(positions, query, key, fused_set_kv_buffer_arg) - -- def _forward_triton( -+ def forward_triton( - self, - positions: torch.Tensor, - query: torch.Tensor, -@@ -1563,15 +1560,19 @@ class MRotaryEmbedding(RotaryEmbedding): - key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) - return query, key - -- def _forward_npu( -+ def forward_npu( - self, - positions: torch.Tensor, - query: torch.Tensor, +@@ -1577,6 +1575,9 @@ class MRotaryEmbedding(RotaryEmbedding): key: torch.Tensor, -+ fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None, + fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: + assert ( + fused_set_kv_buffer_arg is None + ), "fused_set_kv_buffer_arg is not supported for npu implementation" # TODO: remove this when npu_mrope supports QNumHeads * QHeadSize > 4096 - if query.shape[1] > 4096: -- return self._forward_native(positions, query, key) -+ return self.forward_native(positions, query, key, fused_set_kv_buffer_arg) - rotary_mode = "half" - if self.is_neox_style: - rotary_mode = "half" + assert ( + fused_set_kv_buffer_arg is None diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py -index 7f6f6a010..c4a673145 100644 +index 55bef5652..35ad68b1c 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py -@@ -105,16 +105,11 @@ class Sampler(nn.Module): +@@ -108,16 +108,11 @@ class Sampler(nn.Module): if return_logprob and SGLANG_RETURN_ORIGINAL_LOGPROB: probs_without_temp_scaling = torch.softmax(logits, dim=-1) @@ -837,210 +521,35 @@ index 7f6f6a010..c4a673145 100644 # For ascend backend, softmax is not needed before sampling if not get_global_server_args().sampling_backend == "ascend" or ( return_logprob and not SGLANG_RETURN_ORIGINAL_LOGPROB -diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py -index 87922077e..8cb6bad8d 100644 ---- a/python/sglang/srt/managers/detokenizer_manager.py -+++ b/python/sglang/srt/managers/detokenizer_manager.py -@@ -247,6 +247,16 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin): - s.sent_offset = len(output_str) - output_strs.append(incremental_output) - -+ output_routed_experts = [] -+ if recv_obj.output_routed_experts is not None: -+ output_routed_experts = [ -+ ( -+ output_routed_experts.tolist() -+ if output_routed_experts is not None -+ else [] -+ ) -+ for output_routed_experts in recv_obj.output_routed_experts -+ ] - return BatchStrOutput( - rids=recv_obj.rids, - http_worker_ipcs=recv_obj.http_worker_ipcs, -@@ -272,6 +282,7 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin): - output_token_ids_logprobs_idx=recv_obj.output_token_ids_logprobs_idx, - output_token_entropy_val=recv_obj.output_token_entropy_val, - output_hidden_states=recv_obj.output_hidden_states, -+ output_routed_experts=output_routed_experts, - placeholder_tokens_idx=None, - placeholder_tokens_val=None, - retraction_counts=recv_obj.retraction_counts, diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py -index e34736cc4..5e5997a1a 100644 +index 879e1bfa6..de52085fa 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py -@@ -23,6 +23,8 @@ from dataclasses import dataclass, field - from enum import Enum - from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union +@@ -1286,6 +1286,19 @@ class UpdateWeightsFromIPCReqOutput(BaseReq): + success: bool + message: str -+import torch ++@dataclass ++class PostProcessWeightsReqInput(BaseReq): ++ # Whether to restore weights before loading new weights ++ restore_weights_before_load: bool = False ++ # Whether to enable quantization post-processing ++ post_process_quantization: bool = False + - from sglang.srt.lora.lora_registry import LoRARef - from sglang.srt.managers.schedule_batch import BaseFinishReason - from sglang.srt.multimodal.mm_utils import has_valid_data -@@ -175,6 +177,8 @@ class GenerateReqInput(BaseReq): - log_metrics: bool = True - # Whether to return hidden states - return_hidden_states: Union[List[bool], bool] = False -+ # Whether to return captured routed experts -+ return_routed_experts: bool = False - - # The modalities of the image data [image, multi-images, video] - modalities: Optional[List[str]] = None -@@ -592,6 +596,7 @@ class GenerateReqInput(BaseReq): - if isinstance(self.return_hidden_states, list) - else self.return_hidden_states - ), -+ return_routed_experts=self.return_routed_experts, - modalities=self.modalities[i] if self.modalities else None, - session_params=self.session_params, - lora_path=self.lora_path[i] if self.lora_path is not None else None, -@@ -655,6 +660,9 @@ class TokenizedGenerateReqInput(BaseReq): - # Whether to return hidden states - return_hidden_states: bool = False - -+ # Whether to return captured routed experts -+ return_routed_experts: bool = False + - # The input embeds - input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None - -@@ -910,6 +918,9 @@ class BatchTokenIDOutput( - # Hidden states - output_hidden_states: List[List[float]] - -+ # The routed experts for each output token -+ output_routed_experts: List[torch.Tensor] ++@dataclass ++class PostProcessWeightsReqOutput(BaseReq): ++ success: bool ++ message: str + - # The information of placeholder tokens (e.g., image token) - # idx is the index of the token in the prompt after expansion. - # val is the length of padded tokens after expansion. -@@ -989,6 +1000,9 @@ class BatchStrOutput( - # Hidden states - output_hidden_states: List[List[float]] - -+ # The routed experts for each output token -+ output_routed_experts: List[List[int]] -+ - # The information of placeholder tokens (e.g., image token) - # idx is the index of the token in the prompt after expansion. - # val is the length of padded tokens after expansion. + + @dataclass + class InitWeightsSendGroupForRemoteInstanceReqOutput(BaseReq): diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py -index c4c5a9ebb..1450c5fd8 100644 +index 468d8fb8a..229a9a2dc 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py -@@ -450,6 +450,7 @@ class Req: - session_id: Optional[str] = None, - custom_logit_processor: Optional[str] = None, - return_hidden_states: bool = False, -+ return_routed_experts: bool = False, - eos_token_ids: Optional[Set[int]] = None, - bootstrap_host: Optional[str] = None, - bootstrap_port: Optional[int] = None, -@@ -629,6 +630,12 @@ class Req: - self.output_topk_p = None - self.output_topk_index = None - -+ # capture routed experts -+ self.return_routed_experts = return_routed_experts -+ self.routed_experts: Optional[torch.Tensor] = ( -+ None # cpu tensor: shape (seqlen, topk) -+ ) -+ - # Embedding (return values) - self.embedding = None - -@@ -992,6 +999,7 @@ class Req: - self.retraction_count += 1 - - self.prefix_indices = torch.empty((0,), dtype=torch.int64) -+ self.routed_experts = [] - self.last_node = None - self.swa_uuid_for_lock = None - self.extend_input_len = 0 -@@ -1159,6 +1167,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): - # Whether to return hidden states - return_hidden_states: bool = False - -+ # Whether to return captured experts -+ return_routed_experts: bool = False -+ - # Whether this batch is prefill-only (no token generation needed) - is_prefill_only: bool = False - -@@ -1206,6 +1217,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): - device=req_to_token_pool.device, - spec_algorithm=spec_algorithm, - return_hidden_states=any(req.return_hidden_states for req in reqs), -+ return_routed_experts=any(req.return_routed_experts for req in reqs), - is_prefill_only=all(req.is_prefill_only for req in reqs), - chunked_req=chunked_req, - dllm_config=dllm_config, -@@ -1457,6 +1469,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): - self.req_pool_indices = req_pool_indices_tensor - self.orig_seq_lens = orig_seq_lens_tensor - self.out_cache_loc = out_cache_loc -+ self.out_cache_loc_cpu = out_cache_loc.cpu() - self.input_embeds = ( - torch.tensor(input_embeds).to(self.device, non_blocking=True) - if input_embeds -@@ -1508,10 +1521,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): - - input_ids = torch.cat([self.input_ids, running_batch.input_ids]) - out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc]) -+ out_cache_loc_cpu = torch.cat( -+ [self.out_cache_loc_cpu, running_batch.out_cache_loc_cpu] -+ ) - - self.merge_batch(running_batch) - self.input_ids = input_ids - self.out_cache_loc = out_cache_loc -+ self.out_cache_loc_cpu = out_cache_loc_cpu - - # For overlap scheduler, the output_ids has one step delay - delta = 0 if self.enable_overlap else -1 -@@ -1677,6 +1694,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): - self.seq_lens_cpu = torch.empty(0, dtype=torch.int64) - self.orig_seq_lens = torch.empty(0, dtype=torch.int32, device=self.device) - self.out_cache_loc = torch.empty(0, dtype=torch.int64, device=self.device) -+ self.out_cache_loc_cpu = torch.empty(0, dtype=torch.int64, device="cpu") - self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device) - self.seq_lens_sum = 0 - self.extend_num_tokens = 0 -@@ -1736,6 +1754,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): - - # Allocate memory - self.out_cache_loc = alloc_for_decode(self, token_per_req=1) -+ self.out_cache_loc_cpu = self.out_cache_loc.to("cpu", non_blocking=True) - - # Update req-level memory management fields - for req in self.reqs: -@@ -1807,6 +1826,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): - self.seq_lens_cpu = self.seq_lens_cpu[keep_indices] - self.orig_seq_lens = self.orig_seq_lens[keep_indices_device] - self.out_cache_loc = None -+ self.out_cache_loc_cpu = None - self.seq_lens_sum = self.seq_lens.sum().item() - self.output_ids = self.output_ids[keep_indices_device] - self.return_logprob = any(req.return_logprob for req in self.reqs) -@@ -1852,6 +1872,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): - self.seq_lens_cpu = torch.cat([self.seq_lens_cpu, other.seq_lens_cpu]) - self.orig_seq_lens = torch.cat([self.orig_seq_lens, other.orig_seq_lens]) - self.out_cache_loc = None -+ self.out_cache_loc_cpu = None - self.seq_lens_sum += other.seq_lens_sum - if self.output_ids is not None: - self.output_ids = torch.cat([self.output_ids, other.output_ids]) -@@ -1903,6 +1924,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): - seq_lens=self.seq_lens, - orig_seq_lens=self.orig_seq_lens, - out_cache_loc=self.out_cache_loc, -+ out_cache_loc_cpu=self.out_cache_loc_cpu, - seq_lens_cpu=seq_lens_cpu, - seq_lens_sum=self.seq_lens_sum, - return_logprob=self.return_logprob, -@@ -1983,7 +2005,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): +@@ -2181,7 +2181,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): def __str__(self): return ( f"ScheduleBatch(forward_mode={self.forward_mode.name if self.forward_mode else 'None'}, " @@ -1050,97 +559,40 @@ index c4c5a9ebb..1450c5fd8 100644 ) -@@ -2038,6 +2061,9 @@ class ModelWorkerBatch: - # Sampling info - sampling_info: SamplingBatchInfo - -+ # cpu copy of out_cache_loc -+ out_cache_loc_cpu: Optional[torch.Tensor] = None -+ - # The original sequence lengths, Qwen-1M related - orig_seq_lens: Optional[torch.Tensor] = None - diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py -index b801fd8f8..9e27cc825 100644 +index bca1c31e6..0c82e37a4 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py -@@ -1305,6 +1305,7 @@ class Scheduler( - input_embeds=recv_req.input_embeds, - custom_logit_processor=recv_req.custom_logit_processor, - return_hidden_states=recv_req.return_hidden_states, -+ return_routed_experts=recv_req.return_routed_experts, - eos_token_ids=self.model_config.hf_eos_token_id, - bootstrap_host=recv_req.bootstrap_host, - bootstrap_port=recv_req.bootstrap_port, +@@ -97,6 +97,7 @@ from sglang.srt.managers.io_struct import ( + OpenSessionReqInput, + OpenSessionReqOutput, + PauseGenerationReqInput, ++ PostProcessWeightsReqInput, + ProfileReq, + ReleaseMemoryOccupationReqInput, + ResumeMemoryOccupationReqInput, +@@ -1055,6 +1056,7 @@ class Scheduler( + ), + (UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor), + (UpdateWeightsFromIPCReqInput, self.update_weights_from_ipc), ++ (PostProcessWeightsReqInput, self.post_process_weights), + (GetWeightsByNameReqInput, self.get_weights_by_name), + (ReleaseMemoryOccupationReqInput, self.release_memory_occupation), + (ResumeMemoryOccupationReqInput, self.resume_memory_occupation), diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py -index c48f5f893..a9796c25f 100644 +index e40586c24..32d98aee4 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py -@@ -9,6 +9,7 @@ import torch - from sglang.srt.disaggregation.utils import DisaggregationMode +@@ -10,6 +10,7 @@ from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.environ import envs from sglang.srt.layers.logits_processor import LogitsProcessorOutput -+from sglang.srt.layers.moe.routed_experts_capturer import get_global_experts_capturer + from sglang.srt.layers.moe.routed_experts_capturer import get_global_experts_capturer ++ from sglang.srt.managers.io_struct import ( AbortReq, BatchEmbeddingOutput, -@@ -112,6 +113,14 @@ class SchedulerOutputProcessorMixin: - req.check_finished() - - if req.finished(): -+ req.routed_experts = ( -+ get_global_experts_capturer().get_routed_experts( -+ req_pool_idx=req.req_pool_idx, -+ seqlen=req.seqlen, -+ req_to_token_pool=self.req_to_token_pool, -+ ) -+ ) -+ - release_kv_cache(req, self.tree_cache) - req.time_stats.completion_time = time.perf_counter() - elif not batch.decoding_reqs or req not in batch.decoding_reqs: -@@ -362,6 +371,12 @@ class SchedulerOutputProcessorMixin: - req.check_finished(new_accepted_len) - - if req.finished(): -+ req.routed_experts = get_global_experts_capturer().get_routed_experts( -+ req_pool_idx=req.req_pool_idx, -+ seqlen=req.seqlen, -+ req_to_token_pool=self.req_to_token_pool, -+ ) -+ - if self.server_args.disaggregation_decode_enable_offload_kvcache: - # Asynchronously offload KV cache; release_kv_cache will be called after Device->Host transfer completes - if not self.decode_offload_manager.offload_kv_cache(req): -@@ -756,6 +771,7 @@ class SchedulerOutputProcessorMixin: - spec_accepted_tokens = [] - retraction_counts = [] - output_hidden_states = None -+ output_routed_experts = None - - queue_times = [] - forward_entry_times = [] -@@ -946,6 +962,10 @@ class SchedulerOutputProcessorMixin: - if output_hidden_states is None: - output_hidden_states = [] - output_hidden_states.append(req.hidden_states) -+ if req.return_routed_experts: -+ if output_routed_experts is None: -+ output_routed_experts = [] -+ output_routed_experts.append(req.routed_experts) - - if ( - req.finished() -@@ -994,6 +1014,7 @@ class SchedulerOutputProcessorMixin: - output_token_ids_logprobs_idx=output_token_ids_logprobs_idx, - output_token_entropy_val=None, - output_hidden_states=output_hidden_states, -+ output_routed_experts=output_routed_experts, - placeholder_tokens_idx=None, - placeholder_tokens_val=None, - retraction_counts=retraction_counts, diff --git a/python/sglang/srt/managers/scheduler_update_weights_mixin.py b/python/sglang/srt/managers/scheduler_update_weights_mixin.py -index f8ebfc1f4..a05449fac 100644 +index 293a84350..68911c433 100644 --- a/python/sglang/srt/managers/scheduler_update_weights_mixin.py +++ b/python/sglang/srt/managers/scheduler_update_weights_mixin.py @@ -1,6 +1,7 @@ @@ -1161,7 +613,28 @@ index f8ebfc1f4..a05449fac 100644 from sglang.srt.managers.io_struct import ( CheckWeightsReqInput, CheckWeightsReqOutput, -@@ -127,6 +131,13 @@ class SchedulerUpdateWeightsMixin: +@@ -21,6 +25,8 @@ from sglang.srt.managers.io_struct import ( + GetWeightsByNameReqOutput, + InitWeightsUpdateGroupReqInput, + InitWeightsUpdateGroupReqOutput, ++ PostProcessWeightsReqInput, ++ PostProcessWeightsReqOutput, + ReleaseMemoryOccupationReqInput, + ReleaseMemoryOccupationReqOutput, + ResumeMemoryOccupationReqInput, +@@ -113,6 +119,11 @@ class SchedulerUpdateWeightsMixin: + logger.error(message) + torch.distributed.barrier(group=self.tp_cpu_group) + return UpdateWeightsFromIPCReqOutput(success, message) ++ ++ def post_process_weights(self, recv_req: PostProcessWeightsReqInput): ++ """Optional post-processing for updated weights (e.g., Marlin conversion).""" ++ success, message = self.tp_worker.post_process_weights(recv_req) ++ return PostProcessWeightsReqOutput(success, message) + + def get_weights_by_name(self: Scheduler, recv_req: GetWeightsByNameReqInput): + parameter = self.tp_worker.get_weights_by_name(recv_req) +@@ -137,6 +148,13 @@ class SchedulerUpdateWeightsMixin: self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_KV_CACHE) self.flush_cache() @@ -1175,49 +648,7 @@ index f8ebfc1f4..a05449fac 100644 if GPU_MEMORY_TYPE_WEIGHTS in tags: self.stashed_model_static_state = _export_static_state( self.tp_worker.model_runner.model -@@ -137,6 +148,20 @@ class SchedulerUpdateWeightsMixin: - if GPU_MEMORY_TYPE_CUDA_GRAPH in tags: - self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_CUDA_GRAPH) - -+ if os.environ.get("AMEM_ENABLE", "0") == "1": -+ tp_group = get_tp_group() -+ if tp_group is not None and tp_group.pynccl_comm is not None: -+ tp_group.pynccl_comm.nccl_pause() -+ attn_tp_group = get_attention_tp_group() -+ if attn_tp_group is not None and attn_tp_group.pynccl_comm is not None: -+ attn_tp_group.pynccl_comm.nccl_pause() -+ moe_ep_group = get_moe_ep_group() -+ if moe_ep_group is not None and moe_ep_group.pynccl_comm is not None: -+ moe_ep_group.pynccl_comm.nccl_pause() -+ moe_tp_group = get_moe_tp_group() -+ if moe_tp_group is not None and moe_tp_group.pynccl_comm is not None: -+ moe_tp_group.pynccl_comm.nccl_pause() -+ - torch.get_device_module().synchronize() - - return ReleaseMemoryOccupationReqOutput() -@@ -155,6 +180,20 @@ class SchedulerUpdateWeightsMixin: - if GPU_MEMORY_TYPE_CUDA_GRAPH in tags: - self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_CUDA_GRAPH) - -+ if os.environ.get("AMEM_ENABLE", "0") == "1": -+ tp_group = get_tp_group() -+ if tp_group is not None and tp_group.pynccl_comm is not None: -+ tp_group.pynccl_comm.nccl_resume() -+ attn_tp_group = get_attention_tp_group() -+ if attn_tp_group is not None and attn_tp_group.pynccl_comm is not None: -+ attn_tp_group.pynccl_comm.nccl_resume() -+ moe_ep_group = get_moe_ep_group() -+ if moe_ep_group is not None and moe_ep_group.pynccl_comm is not None: -+ moe_ep_group.pynccl_comm.nccl_resume() -+ moe_tp_group = get_moe_tp_group() -+ if moe_tp_group is not None and moe_tp_group.pynccl_comm is not None: -+ moe_tp_group.pynccl_comm.nccl_resume() -+ - if GPU_MEMORY_TYPE_WEIGHTS in tags: - self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_WEIGHTS) - torch.distributed.barrier(self.tp_cpu_group) -@@ -167,6 +206,13 @@ class SchedulerUpdateWeightsMixin: +@@ -177,6 +195,13 @@ class SchedulerUpdateWeightsMixin: if GPU_MEMORY_TYPE_KV_CACHE in tags: self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_KV_CACHE) @@ -1231,29 +662,63 @@ index f8ebfc1f4..a05449fac 100644 return ResumeMemoryOccupationReqOutput() def check_weights(self: Scheduler, recv_req: CheckWeightsReqInput): +diff --git a/python/sglang/srt/managers/tokenizer_communicator_mixin.py b/python/sglang/srt/managers/tokenizer_communicator_mixin.py +index e5d42bed8..412293b30 100644 +--- a/python/sglang/srt/managers/tokenizer_communicator_mixin.py ++++ b/python/sglang/srt/managers/tokenizer_communicator_mixin.py +@@ -49,6 +49,8 @@ from sglang.srt.managers.io_struct import ( + LoadLoRAAdapterReqOutput, + LoRAUpdateOutput, + OpenSessionReqInput, ++ PostProcessWeightsReqInput, ++ PostProcessWeightsReqOutput, + ProfileReq, + ProfileReqOutput, + ProfileReqType, +@@ -177,6 +179,9 @@ class TokenizerCommunicatorMixin: + self.update_weights_from_ipc_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) ++ self.post_process_weights_communicator = _Communicator( ++ self.send_to_scheduler, server_args.dp_size ++ ) + self.get_weights_by_name_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) +@@ -250,6 +255,10 @@ class TokenizerCommunicatorMixin: + UpdateWeightsFromIPCReqOutput, + self.update_weights_from_ipc_communicator.handle_recv, + ), ++ ( ++ PostProcessWeightsReqOutput, ++ self.post_process_weights_communicator.handle_recv, ++ ), + ( + GetWeightsByNameReqOutput, + self.get_weights_by_name_communicator.handle_recv, +@@ -433,6 +442,17 @@ class TokenizerCommunicatorMixin: + + return success, message + ++ async def post_process_weights( ++ self: TokenizerManager, ++ obj: PostProcessWeightsReqInput, ++ request: Optional[fastapi.Request] = None, ++ ) -> Tuple[bool, str]: ++ """Trigger post-processing hooks for weights after loading (e.g., Marlin conversion).""" ++ self.auto_create_handle_loop() ++ async with self.model_update_lock.writer_lock: ++ results = await self.post_process_weights_communicator(obj) ++ return _Communicator.merge_results(results) ++ + async def init_weights_send_group_for_remote_instance( + self, + obj: InitWeightsSendGroupForRemoteInstanceReqInput, diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py -index b90cf0616..98d71d896 100644 +index f4fc29e29..5ef12cca6 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py -@@ -888,6 +888,7 @@ class TokenizerManager(TokenizerCommunicatorMixin): - session_params=session_params, - custom_logit_processor=obj.custom_logit_processor, - return_hidden_states=obj.return_hidden_states, -+ return_routed_experts=obj.return_routed_experts, - data_parallel_rank=obj.data_parallel_rank, - priority=obj.priority, - extra_key=obj.extra_key, -@@ -1621,6 +1622,9 @@ class TokenizerManager(TokenizerCommunicatorMixin): - if getattr(recv_obj, "output_hidden_states", None): - meta_info["hidden_states"] = recv_obj.output_hidden_states[i] - -+ if getattr(recv_obj, "output_routed_experts", None): -+ meta_info["routed_experts"] = recv_obj.output_routed_experts[i] -+ - if isinstance(recv_obj, BatchStrOutput): - state.text += recv_obj.output_strs[i] - if self.server_args.stream_output and state.obj.stream: -@@ -1747,12 +1751,13 @@ class TokenizerManager(TokenizerCommunicatorMixin): +@@ -1652,12 +1652,13 @@ class TokenizerManager(TokenizerCommunicatorMixin, TokenizerManagerMultiItemMixi return if len(recv_obj.input_token_logprobs_val) > 0: @@ -1273,276 +738,156 @@ index b90cf0616..98d71d896 100644 state.output_token_logprobs_val.extend( recv_obj.output_token_logprobs_val[recv_obj_index] ) -diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py -index 3a85e6a7e..2859dafa1 100644 ---- a/python/sglang/srt/model_executor/forward_batch_info.py -+++ b/python/sglang/srt/model_executor/forward_batch_info.py -@@ -51,6 +51,7 @@ from sglang.srt.layers.dp_attention import ( - set_dp_buffer_len, - set_is_extend_in_batch, - ) -+from sglang.srt.server_args import get_global_server_args - from sglang.srt.utils import get_compiler_backend, is_npu, support_triton - from sglang.srt.utils.common import ceil_align - -@@ -214,6 +215,9 @@ class ForwardBatch: - # The sum of all sequence lengths - seq_lens_sum: int - -+ # cpu copy of out_cache_loc -+ out_cache_loc_cpu: Optional[torch.Tensor] = None -+ - # The original sequence length without being chunked. Qwen-1M related. - orig_seq_lens: Optional[torch.Tensor] = None - -@@ -368,6 +372,7 @@ class ForwardBatch: - req_pool_indices=batch.req_pool_indices, - seq_lens=batch.seq_lens, - out_cache_loc=batch.out_cache_loc, -+ out_cache_loc_cpu=batch.out_cache_loc_cpu, - mm_inputs=batch.multimodal_inputs, - encoder_cached=batch.encoder_cached, - encoder_lens=batch.encoder_lens, -@@ -623,7 +628,10 @@ class ForwardBatch: - mm_input = batch.multimodal_inputs[batch_idx] - if self.forward_mode.is_decode(): - # 3 * N -- if mm_input is None: -+ if ( -+ mm_input is None -+ or get_global_server_args().rl_on_policy_target is not None -+ ): - mrope_positions_list[batch_idx] = torch.full( - (3, 1), - self.seq_lens[batch_idx] - 1, -@@ -640,7 +648,10 @@ class ForwardBatch: - batch.extend_seq_lens[batch_idx], - batch.extend_prefix_lens[batch_idx], - ) -- if mm_input is None: -+ if ( -+ mm_input is None -+ or get_global_server_args().rl_on_policy_target is not None -+ ): - # text only - mrope_positions = torch.tensor( - [ -@@ -823,6 +834,10 @@ class ForwardBatch: - ) - - self.out_cache_loc = self._pad_tensor_to_size(self.out_cache_loc, num_tokens) -+ if self.out_cache_loc_cpu is not None: -+ self.out_cache_loc_cpu = self._pad_tensor_to_size( -+ self.out_cache_loc_cpu, num_tokens -+ ) - if self.encoder_lens is not None: - self.encoder_lens = self._pad_tensor_to_size(self.encoder_lens, bs) - self.positions = self._pad_tensor_to_size(self.positions, num_tokens) +diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py +index 1f1875254..51d8651ce 100644 +--- a/python/sglang/srt/managers/tp_worker.py ++++ b/python/sglang/srt/managers/tp_worker.py +@@ -27,6 +27,7 @@ from sglang.srt.managers.io_struct import ( + InitWeightsSendGroupForRemoteInstanceReqInput, + InitWeightsUpdateGroupReqInput, + LoadLoRAAdapterReqInput, ++ PostProcessWeightsReqInput, + SendWeightsToRemoteInstanceReqInput, + UnloadLoRAAdapterReqInput, + UpdateWeightFromDiskReqInput, +@@ -175,6 +176,11 @@ class BaseTpWorker(ABC): + success, message = self.model_runner.update_weights_from_ipc(recv_req) + return success, message + ++ def post_process_weights(self, recv_req: PostProcessWeightsReqInput): ++ """Perform optional post-processing on the updated model weights (e.g., Marlin conversion).""" ++ success, message = self.model_runner.post_process_weights(recv_req) ++ return success, message ++ + def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput): + parameter = self.model_runner.get_weights_by_name( + recv_req.name, recv_req.truncate_size diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py -index 4d58278b7..8f50dc430 100644 +index 1d69c0582..d984c2e12 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py -@@ -94,6 +94,11 @@ from sglang.srt.layers.dp_attention import ( - set_is_extend_in_batch, - ) - from sglang.srt.layers.logits_processor import LogitsProcessorOutput -+from sglang.srt.layers.moe.routed_experts_capturer import ( -+ RoutedExpertsCapturer, -+ get_global_experts_capturer, -+ set_global_experts_capturer, -+) - from sglang.srt.layers.pooler import EmbeddingPoolerOutput - from sglang.srt.layers.sampler import Sampler - from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model -@@ -502,6 +507,10 @@ class ModelRunner: - server_args.max_running_requests, - server_args.max_total_tokens, +@@ -558,7 +558,8 @@ class ModelRunner(ModelRunnerKVCacheMixin): ) -+ -+ # Init routed experts capturer -+ self.init_routed_experts_capturer() -+ + + # Init routed experts capturer +- self.init_routed_experts_capturer() ++ if not self.is_draft_worker: ++ self.init_routed_experts_capturer() + if self.device == "cuda": self.init_cublas() - self.init_attention_backend() -@@ -545,6 +554,40 @@ class ModelRunner: - # Initialize piecewise CUDA graph - self.init_piecewise_cuda_graphs() - -+ def init_routed_experts_capturer(self): -+ # TODO: the redundant logic with TpModelWorker -+ max_running_requests = min( -+ ( -+ self.max_total_num_tokens // 2 -+ if self.server_args.max_running_requests is None -+ else self.server_args.max_running_requests -+ // ( -+ self.server_args.dp_size -+ if self.server_args.enable_dp_attention -+ else 1 +@@ -2224,11 +2225,19 @@ class ModelRunner(ModelRunnerKVCacheMixin): + output.expert_distribution_metrics = recorder_outputs.get("metrics") + + # Copy cached routing experts' buffers back to CPU cache +- get_global_experts_capturer().on_forward_end( +- forward_batch=forward_batch, +- can_run_graph=output.can_run_graph, +- cuda_graph_batch=getattr(self.graph_runner, "bs", None), +- ) ++ if not self.is_draft_worker: ++ # In speculative decoding, num_tokens_per_bs > 1, so we need to pass ++ # the actual number of tokens per dp rank in cuda graph, not batch size. ++ cuda_graph_num_tokens = None ++ if getattr(self.graph_runner, "bs", None): ++ cuda_graph_num_tokens = ( ++ self.graph_runner.bs * self.graph_runner.num_tokens_per_bs + ) -+ ), -+ self.req_to_token_pool.size, -+ ) -+ -+ if not self.server_args.disable_shared_experts_fusion and hasattr( -+ self.model, "num_fused_shared_experts" -+ ): -+ num_fused_shared_experts = self.model.num_fused_shared_experts -+ else: -+ num_fused_shared_experts = 0 -+ -+ set_global_experts_capturer( -+ RoutedExpertsCapturer.create( -+ enable=get_global_server_args().enable_return_routed_experts, -+ model_config=self.model_config, -+ num_fused_shared_experts=num_fused_shared_experts, -+ num_tokens=self.max_total_num_tokens + self.page_size, -+ max_running_requests=max_running_requests, -+ device=self.device, -+ ) -+ ) -+ - def model_specific_adjustment(self): - server_args = self.server_args - -@@ -792,7 +835,11 @@ class ModelRunner: - ) - with self.memory_saver_adapter.region( - GPU_MEMORY_TYPE_WEIGHTS, -- enable_cpu_backup=enable_cpu_backup, -+ enable_cpu_backup=( -+ self.server_args.enable_weights_cpu_backup -+ if not self.is_draft_worker -+ else True -+ ), - ): - self.model = get_model( - model_config=self.model_config, -@@ -2645,9 +2692,12 @@ class ModelRunner: - ) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]: - self.forward_pass_id += 1 - -- with get_global_expert_distribution_recorder().with_forward_pass( -- self.forward_pass_id, -- forward_batch, -+ with ( -+ get_global_expert_distribution_recorder().with_forward_pass( -+ self.forward_pass_id, -+ forward_batch, -+ ), -+ get_global_experts_capturer().with_forward(forward_batch), - ): - output = self._forward_raw( - forward_batch, -@@ -2656,6 +2706,13 @@ class ModelRunner: - reinit_attn_backend, - split_forward_count, - ) -+ # Copy cached routing experts' buffers back to CPU cache -+ get_global_experts_capturer().sync_fwd_experts_buffer_DtoH( -+ device_loc=forward_batch.out_cache_loc, -+ cpu_loc=forward_batch.out_cache_loc_cpu, -+ can_run_graph=output[1], -+ cuda_graph_batch=getattr(self.graph_runner, "bs", None), ++ get_global_experts_capturer().on_forward_end( ++ forward_batch=forward_batch, ++ can_run_graph=output.can_run_graph, ++ cuda_graph_batch=cuda_graph_num_tokens, + ) if self.eplb_manager is not None: self.eplb_manager.on_forward_pass_end() +@@ -2436,6 +2445,41 @@ class ModelRunner(ModelRunnerKVCacheMixin): + logger.error(f"IPC weight update failed: {e}") + return False, str(e) + ++ def post_process_weights(self, recv_req): ++ """ ++ Execute post-processing logic for model weights, such as Marlin quantization format conversion. ++ """ ++ from sglang.srt.model_loader.loader import device_loading_context ++ ++ target_device = torch.device("cuda", torch.cuda.current_device()) ++ ++ if recv_req.restore_weights_before_load: ++ for _, module in self.model.named_modules(): ++ quant_method = getattr(module, "quant_method", None) ++ ++ # Check if the module supports restoring weights ++ if quant_method is not None and hasattr( ++ quant_method, "restore_weights_before_loading" ++ ): ++ ++ with device_loading_context(module, target_device): ++ quant_method.restore_weights_before_loading(module) ++ ++ if recv_req.post_process_quantization: ++ # Iterate through all modules to apply specific post-loading processing ++ for _, module in self.model.named_modules(): ++ quant_method = getattr(module, "quant_method", None) ++ ++ # Check if the module supports quantization post-processing ++ if quant_method is not None and hasattr( ++ quant_method, "process_weights_after_loading" ++ ): ++ ++ # Apply the post-processing (e.g., repacking weights for Marlin kernel) ++ with device_loading_context(module, target_device): ++ quant_method.process_weights_after_loading(module) ++ ++ return True, "Success" + + def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]): + params_dict = dict(model.named_parameters()) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py -index dc30b4f0a..f29dc4b71 100644 +index 2918461d3..d44c8aaa0 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py -@@ -667,6 +667,7 @@ class DeepseekV2MoE(nn.Module): - - self.topk = TopK( - top_k=config.num_experts_per_tok + self.num_fused_shared_experts, -+ layer_id=self.layer_id, - renormalize=config.norm_topk_prob, - use_grouped_topk=True, - num_expert_group=config.n_group, -diff --git a/python/sglang/srt/models/ernie4.py b/python/sglang/srt/models/ernie4.py -index ab1b6576b..dffd8f09a 100644 ---- a/python/sglang/srt/models/ernie4.py -+++ b/python/sglang/srt/models/ernie4.py -@@ -87,6 +87,7 @@ class Ernie4Moe(nn.Module): - - self.topk = TopK( - top_k=config.moe_k, -+ layer_id=layer_id, - renormalize=True, - use_grouped_topk=False, - correction_bias=self.gate.e_score_correction_bias, -diff --git a/python/sglang/srt/models/glm4_moe.py b/python/sglang/srt/models/glm4_moe.py -index a9689b8f2..bc8538da8 100644 ---- a/python/sglang/srt/models/glm4_moe.py -+++ b/python/sglang/srt/models/glm4_moe.py -@@ -379,6 +379,17 @@ class Glm4MoeSparseMoeBlock(nn.Module): - - self.gate = Glm4MoeGate(config=config, prefix=add_prefix("gate", prefix)) - -+ self.topk = TopK( -+ top_k=self.top_k, -+ layer_id=self.layer_id, -+ renormalize=config.norm_topk_prob, -+ use_grouped_topk=True, -+ num_expert_group=config.n_group, -+ topk_group=config.topk_group, -+ correction_bias=self.gate.e_score_correction_bias, -+ routed_scaling_factor=self.routed_scaling_factor, -+ ) -+ - self.experts = get_moe_impl_class(quant_config)( - num_experts=config.n_routed_experts + self.num_fused_shared_experts, - num_fused_shared_experts=self.num_fused_shared_experts, -diff --git a/python/sglang/srt/models/gpt_oss.py b/python/sglang/srt/models/gpt_oss.py -index 9474700c4..398d622ff 100644 ---- a/python/sglang/srt/models/gpt_oss.py -+++ b/python/sglang/srt/models/gpt_oss.py -@@ -113,6 +113,7 @@ class GptOssSparseMoeBlock(nn.Module): - self.topk = TopK( - top_k=config.num_experts_per_tok, - renormalize=True, -+ layer_id=layer_id, - ) - - self.top_k = config.num_experts_per_tok -diff --git a/python/sglang/srt/models/grok.py b/python/sglang/srt/models/grok.py -index fd513060a..a089475b7 100644 ---- a/python/sglang/srt/models/grok.py -+++ b/python/sglang/srt/models/grok.py -@@ -142,6 +142,7 @@ class Grok1MoE(nn.Module): - self.topk = TopK( - top_k=top_k, - renormalize=False, -+ layer_id=layer_id, - custom_routing_function=custom_routing_function, - ) - -diff --git a/python/sglang/srt/models/hunyuan.py b/python/sglang/srt/models/hunyuan.py -index 7c6fd9e48..b20d28544 100644 ---- a/python/sglang/srt/models/hunyuan.py -+++ b/python/sglang/srt/models/hunyuan.py -@@ -150,6 +150,7 @@ class HunYuanSparseMoeBlock(nn.Module): - - self.topk = TopK( - top_k=top_k, -+ layer_id=layer_id, - renormalize=True if top_k > 1 else False, - ) - -diff --git a/python/sglang/srt/models/longcat_flash.py b/python/sglang/srt/models/longcat_flash.py -index 3530609ba..01c89e893 100644 ---- a/python/sglang/srt/models/longcat_flash.py -+++ b/python/sglang/srt/models/longcat_flash.py -@@ -245,6 +245,7 @@ class LongcatFlashMoE(nn.Module): - renormalize=False, - use_grouped_topk=False, - correction_bias=self.router.e_score_correction_bias.data, -+ layer_id=layer_id, - ) - self.topk.forward = self.topk.forward_native +@@ -2704,7 +2704,11 @@ class DeepseekV2AttentionMLA(nn.Module): + ): + k = k_nope.new_empty(*k_shape) + concat_mla_k(k=k, k_nope=k_nope, k_rope=k_pe) +- elif _is_cuda: ++ elif _is_cuda and all( ++ # (i.bit_count() == 1) == (is_power_of_two(i)) ++ i.bit_count() == 1 ++ for i in (k_shape[1], k_nope.shape[-1], k_pe.shape[-1]) ++ ): + # fa3 mha support fp8 inputs + if ( + self.current_attention_backend == "fa3" +@@ -3997,16 +4001,17 @@ class DeepseekV2ForCausalLM(nn.Module): + f"model.layers.{nextn_layer_id}.mlp.{expert_sub_name}.{stem}" + ) + +- for partial_name in tqdm.tqdm( +- partial_names, +- desc="quant weights to fp8 ue8m0", +- ): +- original_weight = weights_dict[f"{partial_name}.weight"] +- out_w, out_s = quant_weight_ue8m0( +- original_weight, weight_block_size=weight_block_size +- ) +- weights_dict[f"{partial_name}.weight"] = out_w +- weights_dict[f"{partial_name}.weight_scale_inv"] = out_s ++ if len(partial_names) > 0: ++ for partial_name in tqdm.tqdm( ++ partial_names, ++ desc="quant weights to fp8 ue8m0", ++ ): ++ original_weight = weights_dict[f"{partial_name}.weight"] ++ out_w, out_s = quant_weight_ue8m0( ++ original_weight, weight_block_size=weight_block_size ++ ) ++ weights_dict[f"{partial_name}.weight"] = out_w ++ weights_dict[f"{partial_name}.weight_scale_inv"] = out_s + if is_nextn and enable_nextn_moe_bf16_cast_to_fp8(self.quant_config): + self._mark_nextn_moe_weights_as_ue8m0() diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index a7dbadec6..c83a41338 100644 --- a/python/sglang/srt/models/qwen2.py @@ -1582,18 +927,10 @@ index a7dbadec6..c83a41338 100644 if get_global_server_args().rl_on_policy_target is not None else {} diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py -index ea33e81ef..561934dce 100644 +index 3ad9f6736..0b9c7f499 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py -@@ -161,6 +161,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module): - self.topk = TopK( - top_k=config.num_experts_per_tok, - renormalize=config.norm_topk_prob, -+ layer_id=layer_id, - ) - - self.experts = get_moe_impl_class(quant_config)( -@@ -581,7 +582,17 @@ class Qwen2MoeModel(nn.Module): +@@ -586,7 +586,17 @@ class Qwen2MoeModel(nn.Module): prefix=add_prefix("layers", prefix), ) if self.pp_group.is_last_rank: @@ -1613,7 +950,7 @@ index ea33e81ef..561934dce 100644 self.norm = PPMissingLayer(return_tuple=True) diff --git a/python/sglang/srt/models/qwen3.py b/python/sglang/srt/models/qwen3.py -index 30b92acbd..a0d14895f 100644 +index 9220831f6..47a1a4e4c 100644 --- a/python/sglang/srt/models/qwen3.py +++ b/python/sglang/srt/models/qwen3.py @@ -90,8 +90,8 @@ class Qwen3Attention(nn.Module): @@ -1626,7 +963,7 @@ index 30b92acbd..a0d14895f 100644 ) if get_global_server_args().rl_on_policy_target is not None else {} -@@ -256,10 +256,8 @@ class Qwen3DecoderLayer(nn.Module): +@@ -242,10 +242,8 @@ class Qwen3DecoderLayer(nn.Module): norm_kwargs = ( dict( @@ -1638,24 +975,8 @@ index 30b92acbd..a0d14895f 100644 ) if get_global_server_args().rl_on_policy_target is not None else {} -@@ -289,10 +287,14 @@ class Qwen3DecoderLayer(nn.Module): - hidden_states: torch.Tensor, - forward_batch: ForwardBatch, - residual: Optional[torch.Tensor], -+ post_residual_addition: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - # Self Attention - hidden_states, residual = self.layer_communicator.prepare_attn( -- hidden_states, residual, forward_batch -+ hidden_states, -+ residual, -+ forward_batch, -+ post_residual_addition=post_residual_addition, - ) - if hidden_states.shape[0] != 0: - hidden_states = self.self_attn( diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py -index 9737ac719..09c756918 100644 +index e11678a9e..e277d46f2 100644 --- a/python/sglang/srt/models/qwen3_moe.py +++ b/python/sglang/srt/models/qwen3_moe.py @@ -22,6 +22,7 @@ import math @@ -1675,17 +996,15 @@ index 9737ac719..09c756918 100644 from sglang.srt.layers.moe.utils import RoutingMethodType from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention -@@ -227,7 +228,9 @@ class Qwen3MoeSparseMoeBlock(nn.Module): - top_k=config.num_experts_per_tok, - renormalize=config.norm_topk_prob, +@@ -229,6 +230,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module): use_grouped_topk=False, -+ layer_id=layer_id, + layer_id=layer_id, ) + self.top_k = config.num_experts_per_tok self.experts = get_moe_impl_class(quant_config)( num_experts=config.num_experts -@@ -293,7 +296,22 @@ class Qwen3MoeSparseMoeBlock(nn.Module): +@@ -294,7 +296,22 @@ class Qwen3MoeSparseMoeBlock(nn.Module): # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) @@ -1709,7 +1028,7 @@ index 9737ac719..09c756918 100644 final_hidden_states = self.experts(hidden_states, topk_output) if ( self.tp_size > 1 -@@ -474,13 +492,14 @@ class Qwen3MoeAttention(nn.Module): +@@ -475,13 +492,14 @@ class Qwen3MoeAttention(nn.Module): ) self.compatible_with_fused_kv_buffer = ( False if isinstance(self.rotary_emb, MRotaryEmbedding) else True @@ -1725,7 +1044,7 @@ index 9737ac719..09c756918 100644 ) self._used_fused_qk_norm_rope_last_call = False -@@ -493,8 +512,16 @@ class Qwen3MoeAttention(nn.Module): +@@ -494,8 +512,16 @@ class Qwen3MoeAttention(nn.Module): prefix=add_prefix("attn", prefix), ) @@ -1743,8 +1062,8 @@ index 9737ac719..09c756918 100644 + self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps, **norm_kwargs) self.alt_stream = alt_stream - def _apply_qk_norm( -@@ -751,9 +778,19 @@ class Qwen3MoeDecoderLayer(nn.Module): + def op_prepare(self, state): +@@ -736,9 +762,19 @@ class Qwen3MoeDecoderLayer(nn.Module): quant_config=quant_config, prefix=add_prefix("mlp", prefix), ) @@ -1767,90 +1086,44 @@ index 9737ac719..09c756918 100644 self.layer_communicator = LayerCommunicator( diff --git a/python/sglang/srt/models/qwen3_vl.py b/python/sglang/srt/models/qwen3_vl.py -index ed52f7ff4..8ce9fab9d 100644 +index 891913078..c9dbecd23 100644 --- a/python/sglang/srt/models/qwen3_vl.py +++ b/python/sglang/srt/models/qwen3_vl.py -@@ -18,7 +18,6 @@ import re - from functools import lru_cache, partial - from typing import Callable, Iterable, List, Optional, Tuple, Union - --import numpy as np - import torch - import torch.nn as nn - from einops import rearrange -@@ -349,83 +348,65 @@ class Qwen3VLMoeVisionModel(nn.Module, RotaryPosMixin): - return rotary_pos_emb +@@ -397,28 +397,68 @@ class Qwen3VLMoeVisionModel(nn.Module, RotaryPosMixin): + return cos_combined, sin_combined def fast_pos_embed_interpolate(self, grid_thw): +- patch_pos_embeds_permute = [] +- m_size = self.spatial_merge_size + grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2] - num_grid_per_side = int(self.num_position_embeddings**0.5) ++ num_grid_per_side = int(self.num_position_embeddings**0.5) + device = self.pos_embed.weight.device - - idx_list = [[] for _ in range(4)] - weight_list = [[] for _ in range(4)] - -- # TODO: use torch instand of np -- for t, h, w in grid_thw: -- h_idxs = np.linspace(0, num_grid_per_side - 1, h) -- w_idxs = np.linspace(0, num_grid_per_side - 1, w) ++ ++ idx_list = [[] for _ in range(4)] ++ weight_list = [[] for _ in range(4)] ++ + for t, h, w in zip(grid_ts, grid_hs, grid_ws): + h_idxs = torch.linspace(0, num_grid_per_side - 1, h) + w_idxs = torch.linspace(0, num_grid_per_side - 1, w) - -- h_idxs_floor = h_idxs.astype(int) -- w_idxs_floor = w_idxs.astype(int) -- h_idxs_ceil = (h_idxs.astype(int) + 1).clip(max=num_grid_per_side - 1) -- w_idxs_ceil = (w_idxs.astype(int) + 1).clip(max=num_grid_per_side - 1) ++ + h_idxs_floor = h_idxs.int() + w_idxs_floor = w_idxs.int() + h_idxs_ceil = (h_idxs.int() + 1).clip(max=num_grid_per_side - 1) + w_idxs_ceil = (w_idxs.int() + 1).clip(max=num_grid_per_side - 1) - - dh = h_idxs - h_idxs_floor - dw = w_idxs - w_idxs_floor - -- idx_list[0].extend( -- ((h_idxs_floor * num_grid_per_side)[None].T + w_idxs_floor[None]) -- .flatten() -- .tolist() -- * t -- ) -- idx_list[1].extend( -- ((h_idxs_floor * num_grid_per_side)[None].T + w_idxs_ceil[None]) -- .flatten() -- .tolist() -- * t -- ) -- idx_list[2].extend( -- ((h_idxs_ceil * num_grid_per_side)[None].T + w_idxs_floor[None]) -- .flatten() -- .tolist() -- * t -- ) -- idx_list[3].extend( -- ((h_idxs_ceil * num_grid_per_side)[None].T + w_idxs_ceil[None]) -- .flatten() -- .tolist() -- * t -- ) ++ ++ dh = h_idxs - h_idxs_floor ++ dw = w_idxs - w_idxs_floor ++ + base_h = h_idxs_floor * num_grid_per_side + base_h_ceil = h_idxs_ceil * num_grid_per_side - -- weight_list[0].extend( -- ((1 - dh)[None].T * (1 - dw)[None]).flatten().tolist() * t -- ) -- weight_list[1].extend(((1 - dh)[None].T * dw[None]).flatten().tolist() * t) -- weight_list[2].extend((dh[None].T * (1 - dw)[None]).flatten().tolist() * t) -- weight_list[3].extend((dh[None].T * dw[None]).flatten().tolist() * t) ++ + indices = [ + (base_h[None].T + w_idxs_floor[None]).flatten(), + (base_h[None].T + w_idxs_ceil[None]).flatten(), + (base_h_ceil[None].T + w_idxs_floor[None]).flatten(), + (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(), + ] - -- device = self.pos_embed.weight.device -- dtype = self.pos_embed.weight.dtype ++ + weights = [ + ((1 - dh)[None].T * (1 - dw)[None]).flatten(), + ((1 - dh)[None].T * dw[None]).flatten(), @@ -1858,17 +1131,11 @@ index ed52f7ff4..8ce9fab9d 100644 + (dh[None].T * dw[None]).flatten(), + ] -- p0 = ( -- self.pos_embed(torch.tensor(idx_list[0], dtype=torch.long, device=device)) -- * torch.tensor(weight_list[0], dtype=dtype, device=device)[:, None] -- ) -- p1 = ( -- self.pos_embed(torch.tensor(idx_list[1], dtype=torch.long, device=device)) -- * torch.tensor(weight_list[1], dtype=dtype, device=device)[:, None] -- ) -- p2 = ( -- self.pos_embed(torch.tensor(idx_list[2], dtype=torch.long, device=device)) -- * torch.tensor(weight_list[2], dtype=dtype, device=device)[:, None] +- embeds = torch.arange(self.num_grid, device=self.pos_embed.weight.device) +- embeds = ( +- self.pos_embed(embeds) +- .permute(1, 0) +- .reshape(1, -1, self.num_grid_per_side, self.num_grid_per_side) + for i in range(4): + idx_list[i].extend(indices[i].tolist()) + weight_list[i].extend(weights[i].tolist()) @@ -1877,33 +1144,40 @@ index ed52f7ff4..8ce9fab9d 100644 + weight_tensor = torch.tensor( + weight_list, dtype=self.pos_embed.weight.dtype, device=device ) -- p3 = ( -- self.pos_embed(torch.tensor(idx_list[3], dtype=torch.long, device=device)) -- * torch.tensor(weight_list[3], dtype=dtype, device=device)[:, None] +- for t, h, w in grid_thw: +- pos_embed = torch.nn.functional.interpolate( +- embeds, size=(h, w), mode="bilinear", align_corners=self.align_corners +- ) +- pos_embed = pos_embed.reshape( +- -1, +- h // self.spatial_merge_size, +- self.spatial_merge_size, +- w // self.spatial_merge_size, +- self.spatial_merge_size, + pos_embeds = self.pos_embed(idx_tensor).to(device) * weight_tensor[:, :, None] + patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3] + + patch_pos_embeds = patch_pos_embeds.split( + [h * w for h, w in zip(grid_hs, grid_ws)] - ) - -- patch_pos_embeds = p0 + p1 + p2 + p3 -- patch_pos_embeds = patch_pos_embeds.split([t * h * w for t, h, w in grid_thw]) - patch_pos_embeds_permute = [] -- m_size = self.spatial_merge_size -- for pos_embed, (t, h, w) in zip(patch_pos_embeds, grid_thw): ++ ) ++ ++ patch_pos_embeds_permute = [] + merge_size = self.spatial_merge_size + for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws): + pos_embed = pos_embed.repeat(t, 1) - pos_embed = ( -- pos_embed.view(t, h // m_size, m_size, w // m_size, m_size, -1) ++ pos_embed = ( + pos_embed.view( + t, h // merge_size, merge_size, w // merge_size, merge_size, -1 + ) - .permute(0, 1, 3, 2, 4, 5) - .flatten(0, 4) ++ .permute(0, 1, 3, 2, 4, 5) ++ .flatten(0, 4) ) -@@ -555,21 +536,27 @@ class Qwen3LLMModel(Qwen3Model): +- pos_embed = pos_embed.permute(1, 3, 2, 4, 0) +- pos_embed = pos_embed.flatten(0, 3).repeat(t, 1) + patch_pos_embeds_permute.append(pos_embed) + return torch.cat(patch_pos_embeds_permute) + +@@ -607,14 +647,19 @@ class Qwen3LLMModel(Qwen3Model): hidden_states + residual if residual is not None else hidden_states ) @@ -1916,103 +1190,73 @@ index ed52f7ff4..8ce9fab9d 100644 + :, sep : sep + self.hidden_size + ] + -+ # SGLang applies residual at the START of the next layer, not at the END like HuggingFace. -+ # See: https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py#L549 -+ # To match HF behavior, deepstack must be added AFTER residual: (hidden_states + residual) + deepstack -+ # The order matters because addition with different tensors is not associative in practice. + # SGLang applies residual at the START of the next layer, not at the END like HuggingFace. + # See: https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py#L549 + # To match HF behavior, deepstack must be added AFTER residual: (hidden_states + residual) + deepstack + # The order matters because addition with different tensors is not associative in practice. +- # Deepstack for prev_layer is applied at the start of current layer via post_residual_addition. +- deepstack_embeds = self.get_deepstack_embeds( +- layer_idx - 1, input_deepstack_embeds +- ) hidden_states, residual = layer( positions, hidden_states, - forward_batch, - residual, -+ post_residual_addition=deepstack_embeds, - ) - -- # process deepstack -- if ( -- input_deepstack_embeds is not None -- and layer_idx in self.deepstack_embed_to_decoder_layer -- ): -- sep = self.hidden_size * layer_idx -- hidden_states += input_deepstack_embeds[:, sep : sep + self.hidden_size] -- - if not self.pp_group.is_last_rank: - return PPProxyTensors( - { -diff --git a/python/sglang/srt/models/step3_vl.py b/python/sglang/srt/models/step3_vl.py -index 4474f62d5..0e537c398 100644 ---- a/python/sglang/srt/models/step3_vl.py -+++ b/python/sglang/srt/models/step3_vl.py -@@ -129,6 +129,7 @@ class Step3TextMoEMLP(nn.Module): - top_k=config.moe_top_k, - renormalize=config.norm_expert_weight, - use_grouped_topk=False, -+ layer_id=layer_id, - ) - - self.experts = get_moe_impl_class(quant_config)( -diff --git a/python/sglang/srt/multimodal/processors/base_processor.py b/python/sglang/srt/multimodal/processors/base_processor.py -index 370aec2b6..47666d8f3 100644 ---- a/python/sglang/srt/multimodal/processors/base_processor.py -+++ b/python/sglang/srt/multimodal/processors/base_processor.py -@@ -13,6 +13,7 @@ from PIL import Image - from transformers import BaseImageProcessorFast - - from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem -+from sglang.srt.server_args import get_global_server_args - from sglang.srt.utils import ( - get_bool_env_var, - is_npu, -@@ -260,7 +261,9 @@ class BaseMultimodalProcessor(ABC): - and isinstance(processor.image_processor, BaseImageProcessorFast) - and not self.server_args.disable_fast_image_processor - ): -- if not _is_npu: -+ if get_global_server_args().rl_on_policy_target is not None: -+ kwargs["device"] = "cpu" -+ elif not _is_npu: - kwargs["device"] = "cuda" - elif processor.__class__.__name__ not in { - "Qwen2_5_VLProcessor", diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py -index 8e7753dab..323788f39 100644 +index 54d4e415a..de7620c20 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py -@@ -535,6 +535,7 @@ class ServerArgs: - disable_fast_image_processor: bool = False - keep_mm_feature_on_device: bool = False - enable_return_hidden_states: bool = False -+ enable_return_routed_experts: bool = False - scheduler_recv_interval: int = 1 - numa_node: Optional[List[int]] = None - enable_deterministic_inference: bool = False -@@ -1966,6 +1967,9 @@ class ServerArgs: - "Enable deterministic inference because of rl_on_policy_target." - ) - self.enable_deterministic_inference = True -+ -+ # For VLM -+ os.environ["SGLANG_VLM_CACHE_SIZE_MB"] = "0" - # TODO remove this environment variable as a whole - os.environ["SGLANG_ENABLE_DETERMINISTIC_INFERENCE"] = "1" - -@@ -3705,6 +3709,11 @@ class ServerArgs: +@@ -523,6 +523,7 @@ class ServerArgs: + cuda_graph_max_bs: Optional[int] = None + cuda_graph_bs: Optional[List[int]] = None + disable_cuda_graph: bool = False ++ disable_draft_cuda_graph: bool = False + disable_cuda_graph_padding: bool = False + enable_profile_cuda_graph: bool = False + enable_cudagraph_gc: bool = False +@@ -3951,6 +3952,11 @@ class ServerArgs: action="store_true", - help="Enable returning hidden states with responses.", + help="Disable cuda graph.", ) + parser.add_argument( -+ "--enable-return-routed-experts", ++ "--disable-draft-cuda-graph", + action="store_true", -+ help="Enable returning routed experts of each layer with responses.", ++ help="Disable cuda graph for draft model in speculative decoding.", + ) parser.add_argument( - "--scheduler-recv-interval", - type=int, + "--disable-cuda-graph-padding", + action="store_true", +diff --git a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +index 5fe45086c..c95fbd0f6 100644 +--- a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py ++++ b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +@@ -341,7 +341,10 @@ class EAGLEDraftCudaGraphRunner: + self.seq_lens.fill_(self.seq_len_fill_value) + self.out_cache_loc.zero_() + self.positions.zero_() +- ++ self.topk_p.zero_() ++ self.topk_index.zero_() ++ self.hidden_states.zero_() ++ self.req_pool_indices.zero_() + num_tokens = bs * self.num_tokens_per_bs + + # Common inputs +@@ -350,8 +353,8 @@ class EAGLEDraftCudaGraphRunner: + forward_batch.out_cache_loc + ) + self.positions[:raw_num_token].copy_(forward_batch.positions) +- self.topk_p[:raw_bs].copy_(forward_batch.spec_info.topk_p) +- self.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index) ++ self.topk_p[:raw_bs].copy_(forward_batch.spec_info.topk_p.clamp(0, 1)) ++ self.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index.clamp(0, self.model_runner.model_config.vocab_size - 1)) + self.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states) + self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices) + diff --git a/python/sglang/srt/speculative/eagle_info.py b/python/sglang/srt/speculative/eagle_info.py -index b3d72df05..ddfe0b178 100644 +index 1bf3816e9..b5b41dba4 100644 --- a/python/sglang/srt/speculative/eagle_info.py +++ b/python/sglang/srt/speculative/eagle_info.py -@@ -746,6 +746,10 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin): +@@ -778,6 +778,10 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin): self.topk_index = self.topk_index[: len(new_indices)] self.hidden_states = self.hidden_states[: len(new_indices)] self.verified_id = self.verified_id[: len(new_indices)] @@ -2023,7 +1267,7 @@ index b3d72df05..ddfe0b178 100644 else: # in some cases(e.g draft_extend), we have not filtered the batch by `unfinished_index` self.topk_p = self.topk_p[new_indices] -@@ -777,6 +781,27 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin): +@@ -809,6 +813,27 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin): self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], axis=0) self.topk_p = torch.cat([self.topk_p, spec_info.topk_p]) self.topk_index = torch.cat([self.topk_index, spec_info.topk_index]) @@ -2051,3 +1295,16 @@ index b3d72df05..ddfe0b178 100644 @dataclass +diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py +index a702df4f8..61d9ae366 100644 +--- a/python/sglang/srt/speculative/eagle_worker.py ++++ b/python/sglang/srt/speculative/eagle_worker.py +@@ -231,7 +231,7 @@ class EAGLEWorker(TpModelWorker): + self.cuda_graph_runner = None + self.cuda_graph_runner_for_draft_extend = None + +- if self.server_args.disable_cuda_graph: ++ if self.server_args.disable_cuda_graph or self.server_args.disable_draft_cuda_graph: + return + + Device2DraftCudaGraphRunner = { diff --git a/docker/patch/v0.5.7/megatron.patch b/docker/patch/v0.5.7/megatron.patch new file mode 100644 index 000000000..a337b19fb --- /dev/null +++ b/docker/patch/v0.5.7/megatron.patch @@ -0,0 +1,681 @@ +diff --git a/megatron/core/dist_checkpointing/strategies/common.py b/megatron/core/dist_checkpointing/strategies/common.py +index 41c21d93d..ef80f72d6 100644 +--- a/megatron/core/dist_checkpointing/strategies/common.py ++++ b/megatron/core/dist_checkpointing/strategies/common.py +@@ -86,7 +86,7 @@ class TorchCommonLoadStrategy(LoadCommonStrategy): + msc = MultiStorageClientFeature.import_package() + return msc.torch.load(load_path, map_location='cpu') + else: +- return torch.load(load_path, map_location='cpu') ++ return torch.load(load_path, map_location='cpu', weights_only=False) + except FileNotFoundError as e: + err_msg = f'Common file {load_path} does not exist' + if MultiStorageClientFeature.is_enabled(): +diff --git a/megatron/core/dist_checkpointing/strategies/torch.py b/megatron/core/dist_checkpointing/strategies/torch.py +index 5a1ea308d..aa701237f 100644 +--- a/megatron/core/dist_checkpointing/strategies/torch.py ++++ b/megatron/core/dist_checkpointing/strategies/torch.py +@@ -597,10 +597,12 @@ class MCoreLoadPlanner(DefaultLoadPlanner): + def _validate_global_shapes(self, metadata, sharded_tensors): + for sh_ten in sharded_tensors: + if sh_ten.key not in metadata.state_dict_metadata: +- raise KeyError( +- f"{sh_ten.key} from model not in state dict:" +- f" {sorted(metadata.state_dict_metadata.keys())}" +- ) ++ # raise KeyError( ++ # f"{sh_ten.key} from model not in state dict:" ++ # f" {sorted(metadata.state_dict_metadata.keys())}" ++ # ) ++ print(f"{sh_ten.key} from model not in state dict, will skip") ++ continue + loaded_shape = metadata.state_dict_metadata[sh_ten.key].size + expected_shape = self._expected_shape(sh_ten) + if loaded_shape != expected_shape: +@@ -630,7 +632,7 @@ class MCoreLoadPlanner(DefaultLoadPlanner): + tensor_metadata = self.metadata.state_dict_metadata + metadata_with_sizes = [ + (tensor_metadata[key], tensor_metadata[key].size, sharded_tensor) +- for key, sharded_tensor in self.allow_shape_mismatch_sharded_tensors.items() ++ for key, sharded_tensor in self.allow_shape_mismatch_sharded_tensors.items() if key in tensor_metadata + ] + try: + # Temporarily set sizes to expected shapes +@@ -959,6 +961,7 @@ class TorchDistLoadShardedStrategy(LoadShardedStrategy): + planner=MCoreLoadPlanner( + shapes_validation_sharded_tensors=flexible_shape_sharded_tensors, + allow_shape_mismatch_sharded_tensors=allow_shape_mismatch_sharded_tensors, ++ allow_partial_load=True, + ), + ) + +diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py +index acb93ef78..20ee977b0 100644 +--- a/megatron/core/extensions/transformer_engine.py ++++ b/megatron/core/extensions/transformer_engine.py +@@ -408,6 +408,7 @@ class TELinear(te.pytorch.Linear): + ) + + for param in self.parameters(): ++ setattr(param, "parallel_mode", parallel_mode) + if is_expert: + # Reduce the gradient on the expert_data_parallel group for expert linear layers + setattr(param, "allreduce", not self.expert_parallel) +diff --git a/megatron/core/fusions/fused_mla_yarn_rope_apply.py b/megatron/core/fusions/fused_mla_yarn_rope_apply.py +index 1fd5dcfae..c9aeef1f0 100644 +--- a/megatron/core/fusions/fused_mla_yarn_rope_apply.py ++++ b/megatron/core/fusions/fused_mla_yarn_rope_apply.py +@@ -385,6 +385,7 @@ def rotary_fwd_kv_kernel( + SIN, + emb_dim: tl.constexpr, + k_dim: tl.constexpr, ++ k_dim_ceil: tl.constexpr, + v_dim: tl.constexpr, + head_num: tl.constexpr, + batch_size, +@@ -434,21 +435,27 @@ def rotary_fwd_kv_kernel( + cos_right = tl.load(COS + token_idx * emb_dim + emb_dim // 2 + tl.arange(0, emb_dim // 2)) + sin_right = tl.load(SIN + token_idx * emb_dim + emb_dim // 2 + tl.arange(0, emb_dim // 2)) + +- KV_ptr = KV + pid_m * stride_kv_seq + pid_head * BLOCK_H * stride_kv_nheads +- kv_off = tl.arange(0, BLOCK_H)[:, None] * stride_kv_nheads +- mask = kv_off < head_num * stride_kv_nheads +- k_in_off = kv_off + tl.arange(0, k_dim)[None, :] +- v_in_off = kv_off + k_dim + tl.arange(0, v_dim)[None, :] +- k = tl.load(KV_ptr + k_in_off, mask=mask) +- v = tl.load(KV_ptr + v_in_off, mask=mask) ++ KV_ptr = KV + pid_m * stride_kv_seq # + pid_head * BLOCK_H * stride_kv_nheads ++ ki_range = tl.arange(0, BLOCK_H)[:, None] + pid_head * BLOCK_H ++ kj_range = tl.arange(0, k_dim_ceil)[None, :] ++ mask_k = (ki_range < head_num) & (kj_range < k_dim) ++ mask_v = ki_range < head_num ++ k_off = ki_range * stride_kv_nheads + kj_range ++ if v_dim > 0: ++ v_off = ki_range * stride_kv_nheads + k_dim + tl.arange(0, v_dim)[None, :] ++ v = tl.load(KV_ptr + v_off, mask=mask_v) ++ else: ++ v = tl.zeros((BLOCK_H, 1), dtype=KV.dtype.element_ty) ++ k = tl.load(KV_ptr + k_off, mask=mask_k) + +- K_ptr = O_KEY + pid_m * stride_k_seq + pid_head * BLOCK_H * stride_k_nheads +- V_ptr = O_VALUE + pid_m * stride_v_seq + pid_head * BLOCK_H * stride_v_nheads ++ K_ptr = O_KEY + pid_m * stride_k_seq # + pid_head * BLOCK_H * stride_k_nheads ++ V_ptr = O_VALUE + pid_m * stride_v_seq # + pid_head * BLOCK_H * stride_v_nheads + +- k_out_off = tl.arange(0, BLOCK_H)[:, None] * stride_k_nheads + tl.arange(0, k_dim)[None, :] +- v_out_off = tl.arange(0, BLOCK_H)[:, None] * stride_v_nheads + tl.arange(0, v_dim)[None, :] +- tl.store(K_ptr + k_out_off, k, mask=mask) +- tl.store(V_ptr + v_out_off, v, mask=mask) ++ k_out_off = ki_range * stride_k_nheads + kj_range ++ tl.store(K_ptr + k_out_off, k, mask=mask_k) ++ if v_dim > 0: ++ v_out_off = ki_range * stride_v_nheads + tl.arange(0, v_dim)[None, :] ++ tl.store(V_ptr + v_out_off, v, mask=mask_v) + + EMB = K_POS_EMB + pid_m * stride_emb_seq + # x1 = t[..., 0::2], x2 = t[..., 1::2] +@@ -460,14 +467,16 @@ def rotary_fwd_kv_kernel( + x_left = x_left.expand_dims(0).broadcast_to(BLOCK_H, emb_dim // 2) + x_right = x_right.expand_dims(0).broadcast_to(BLOCK_H, emb_dim // 2) + ++ x_range = tl.arange(0, BLOCK_H)[:, None] + pid_head * BLOCK_H ++ mask_x = x_range < head_num + x_left_off = ( +- tl.arange(0, BLOCK_H)[:, None] * stride_k_nheads ++ x_range * stride_k_nheads + + k_dim + + tl.arange(0, emb_dim // 2)[None, :] + ) + x_right_off = x_left_off + emb_dim // 2 +- tl.store(K_ptr + x_left_off, x_left, mask=mask) +- tl.store(K_ptr + x_right_off, x_right, mask=mask) ++ tl.store(K_ptr + x_left_off, x_left, mask=mask_x) ++ tl.store(K_ptr + x_right_off, x_right, mask=mask_x) + + + @triton.autotune( +@@ -493,6 +502,7 @@ def rotary_bwd_kv_kernel( + SIN, + emb_dim: tl.constexpr, + k_dim: tl.constexpr, ++ k_dim_ceil: tl.constexpr, + v_dim: tl.constexpr, + head_num: tl.constexpr, + batch_size, +@@ -533,27 +543,32 @@ def rotary_bwd_kv_kernel( + else: + token_idx = _get_thd_token_idx(cu_seqlens_kv, pid_m, seq_num, cp_rank, cp_size) + +- dKV_ptr = dKV + pid_m * stride_dkv_seq + pid_head * BLOCK_H * stride_dkv_nheads +- dkv_off = tl.arange(0, BLOCK_H)[:, None] * stride_dkv_nheads +- mask = dkv_off < head_num * stride_dkv_nheads +- dk_out_off = dkv_off + tl.arange(0, k_dim)[None, :] +- dv_out_off = dkv_off + k_dim + tl.arange(0, v_dim)[None, :] +- +- dK_ptr = dK + pid_m * stride_dk_seq + pid_head * BLOCK_H * stride_dk_nheads +- dV_ptr = dV + pid_m * stride_dv_seq + pid_head * BLOCK_H * stride_dv_nheads +- dk_in_off = tl.arange(0, BLOCK_H)[:, None] * stride_dk_nheads + tl.arange(0, k_dim)[None, :] +- dv_in_off = tl.arange(0, BLOCK_H)[:, None] * stride_dv_nheads + tl.arange(0, v_dim)[None, :] +- dk = tl.load(dK_ptr + dk_in_off, mask=mask) +- dv = tl.load(dV_ptr + dv_in_off, mask=mask) +- tl.store(dKV_ptr + dk_out_off, dk, mask=mask) +- tl.store(dKV_ptr + dv_out_off, dv, mask=mask) ++ dKV_ptr = dKV + pid_m * stride_dkv_seq # + pid_head * BLOCK_H * stride_dkv_nheads ++ ki_range = tl.arange(0, BLOCK_H)[:, None] + pid_head * BLOCK_H ++ kj_range = tl.arange(0, k_dim_ceil)[None, :] ++ mask_k = (ki_range < head_num) & (kj_range < k_dim) ++ mask_v = ki_range < head_num ++ dk_out_off = ki_range * stride_dkv_nheads + kj_range ++ ++ dK_ptr = dK + pid_m * stride_dk_seq # + pid_head * BLOCK_H * stride_dk_nheads ++ dV_ptr = dV + pid_m * stride_dv_seq # + pid_head * BLOCK_H * stride_dv_nheads ++ dk_in_off = ki_range * stride_dk_nheads + kj_range ++ ++ dk = tl.load(dK_ptr + dk_in_off, mask=mask_k) ++ tl.store(dKV_ptr + dk_out_off, dk, mask=mask_k) ++ ++ if v_dim > 0: ++ dv_out_off = ki_range * stride_dkv_nheads + k_dim + tl.arange(0, v_dim)[None, :] ++ dv_in_off = ki_range * stride_dv_nheads + tl.arange(0, v_dim)[None, :] ++ dv = tl.load(dV_ptr + dv_in_off, mask=mask_v) ++ tl.store(dKV_ptr + dv_out_off, dv, mask=mask_v) + + if pid_head == 0: + x_left_accum = tl.zeros((BLOCK_H, emb_dim // 2), dtype=tl.float32) + x_right_accum = tl.zeros((BLOCK_H, emb_dim // 2), dtype=tl.float32) + for i in tl.static_range(triton.cdiv(head_num, BLOCK_H)): +- dK_ptr = dK + pid_m * stride_dk_seq + i * BLOCK_H * stride_dk_nheads +- x_off = tl.arange(0, BLOCK_H)[:, None] * stride_dk_nheads + k_dim ++ dK_ptr = dK + pid_m * stride_dk_seq # + i * BLOCK_H * stride_dk_nheads ++ x_off = tl.arange(0, BLOCK_H)[:, None] * stride_dk_nheads + k_dim + i * BLOCK_H * stride_dk_nheads + mask = x_off < head_num * stride_dk_nheads + x_left_off = x_off + tl.arange(0, emb_dim // 2)[None, :] + x_right_off = x_left_off + emb_dim // 2 +@@ -632,6 +647,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function): + + o_key = kv.new_empty(total_seqlen, nheads, emb_dim + k_dim) + o_value = kv.new_empty(total_seqlen, nheads, v_dim) ++ k_dim_ceil = triton.next_power_of_2(k_dim) + + grid = lambda META: (total_seqlen, triton.cdiv(nheads, META["BLOCK_H"])) + rotary_fwd_kv_kernel[grid]( +@@ -643,6 +659,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function): + sin, + emb_dim, + k_dim, ++ k_dim_ceil, + v_dim, + nheads, + batch_size, +@@ -700,6 +717,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function): + + d_kv = dk.new_empty(total_seqlen, nheads, ctx.k_dim + ctx.v_dim) + d_emb = dk.new_empty(total_seqlen, 1, ctx.emb_dim) ++ k_dim_ceil = triton.next_power_of_2(ctx.k_dim) + + grid = lambda META: (total_seqlen, triton.cdiv(nheads, META["BLOCK_H"])) + rotary_bwd_kv_kernel[grid]( +@@ -711,6 +729,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function): + sin, + ctx.emb_dim, + ctx.k_dim, ++ k_dim_ceil, + ctx.v_dim, + nheads, + batch_size, +diff --git a/megatron/core/models/common/language_module/language_module.py b/megatron/core/models/common/language_module/language_module.py +index 13d74aa52..060898a7a 100644 +--- a/megatron/core/models/common/language_module/language_module.py ++++ b/megatron/core/models/common/language_module/language_module.py +@@ -184,7 +184,15 @@ class LanguageModule(MegatronModule): + assert ( + column_parallel_linear is not None + ), "column_parallel_linear cannot be None when not using fused linear cross entropy." +- logits, _ = column_parallel_linear(hidden, **col_linear_kwargs) ++ # output ++ output_layer_params = {k: v.detach() for k, v in column_parallel_linear.named_parameters()} ++ output_layer_buffers = dict(column_parallel_linear.named_buffers()) ++ logits, _ = torch.func.functional_call( ++ column_parallel_linear, ++ {**output_layer_params, **output_layer_buffers}, ++ (hidden,), ++ col_linear_kwargs, ++ ) + + return self.compute_language_model_loss(labels, logits) + +diff --git a/megatron/core/models/gpt/gpt_layer_specs.py b/megatron/core/models/gpt/gpt_layer_specs.py +index e21127b87..712793853 100755 +--- a/megatron/core/models/gpt/gpt_layer_specs.py ++++ b/megatron/core/models/gpt/gpt_layer_specs.py +@@ -188,6 +188,8 @@ def get_gpt_layer_with_transformer_engine_spec( + use_kitchen: bool = False, + use_te_activation_func: bool = False, + fallback_to_eager_attn: bool = False, ++ post_self_attn_layernorm: bool = False, ++ post_mlp_layernorm: bool = False, + ) -> ModuleSpec: + """Use this spec to use lower-level Transformer Engine modules (required for fp8 training). + +@@ -260,6 +262,8 @@ def get_gpt_layer_with_transformer_engine_spec( + mlp=mlp, + sharded_state_dict_keys_map=sharded_state_dict_keys_map, + normalization=normalization, ++ post_self_attn_layernorm=post_self_attn_layernorm, ++ post_mlp_layernorm=post_mlp_layernorm, + ) + + +@@ -349,6 +353,8 @@ def get_transformer_layer_spec_for_backend( + mlp: ModuleSpec, + sharded_state_dict_keys_map: Optional[dict] = None, + normalization: Optional[str] = None, ++ post_self_attn_layernorm: bool = False, ++ post_mlp_layernorm: bool = False, + ) -> ModuleSpec: + """Helper function to get module spec for TransformerLayer""" + +@@ -371,9 +377,11 @@ def get_transformer_layer_spec_for_backend( + input_layernorm=input_layernorm, + self_attention=attention, + self_attn_bda=get_bias_dropout_add, ++ post_self_attn_layernorm=TENorm if post_self_attn_layernorm else IdentityOp, + pre_mlp_layernorm=pre_mlp_layernorm, + mlp=mlp, + mlp_bda=get_bias_dropout_add, ++ post_mlp_layernorm=TENorm if post_mlp_layernorm else IdentityOp, + sharded_state_dict_keys_map=sharded_state_dict_keys_map, + ), + ) +diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py +index a1230568c..1fd52f65a 100644 +--- a/megatron/core/models/gpt/gpt_model.py ++++ b/megatron/core/models/gpt/gpt_model.py +@@ -446,6 +446,7 @@ class GPTModel(LanguageModule): + *, + inference_params: Optional[BaseInferenceContext] = None, + loss_mask: Optional[Tensor] = None, ++ mtp_kwargs: Optional[dict] = {}, + ) -> Tensor: + """Forward function of the GPT Model This function passes the input tensors + through the embedding layer, and then the decoder and finally into the post +@@ -508,6 +509,7 @@ class GPTModel(LanguageModule): + runtime_gather_output=runtime_gather_output, + extra_block_kwargs=extra_block_kwargs, + inference_context=inference_context, ++ mtp_kwargs=mtp_kwargs, + ) + + def _postprocess( +@@ -529,6 +531,7 @@ class GPTModel(LanguageModule): + runtime_gather_output=None, + extra_block_kwargs=None, + inference_context=None, ++ mtp_kwargs={}, + ): + """Postprocesses decoder hidden states to generate logits or compute loss. + +@@ -543,7 +546,8 @@ class GPTModel(LanguageModule): + output_weight = None + if self.share_embeddings_and_output_weights: + output_weight = self.shared_embedding_or_output_weight() +- if mtp_in_postprocess: ++ ++ if mtp_in_postprocess and mtp_kwargs.get('mtp_labels', None) is not None: + hidden_states = self.mtp( + input_ids=input_ids, + position_ids=position_ids, +@@ -563,13 +567,18 @@ class GPTModel(LanguageModule): + return hidden_states + + # Skip when mtp_num_layers is None or 0 +- if self.config.mtp_num_layers: +- mtp_labels = labels.clone() ++ if self.config.mtp_num_layers and mtp_kwargs.get('mtp_labels', None) is not None: ++ mtp_labels = mtp_kwargs['mtp_labels'].clone() ++ mtp_labels, _ = roll_tensor(mtp_labels, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params) ++ + hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0) + hidden_states = hidden_states_list[0] + if loss_mask is None: + # if loss_mask is not provided, use all ones as loss_mask + loss_mask = torch.ones_like(mtp_labels) ++ else: ++ # Otherwise, roll the loss_mask to keep up with the mtp_labels ++ loss_mask, _ = roll_tensor(loss_mask, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params) + for mtp_layer_number in range(self.config.mtp_num_layers): + # Calc loss for the current Multi-Token Prediction (MTP) layers. + mtp_labels, _ = roll_tensor( +@@ -595,7 +604,7 @@ class GPTModel(LanguageModule): + sequence_parallel_enabled=self.output_layer.sequence_parallel, + column_parallel_linear=self.output_layer, + col_linear_kwargs={ +- 'weight': output_weight, ++ 'weight': output_weight.detach() if output_weight else None, + 'runtime_gather_output': runtime_gather_output, + }, + ) +diff --git a/megatron/core/optimizer/distrib_optimizer.py b/megatron/core/optimizer/distrib_optimizer.py +index 6e093f96f..eac21a3ea 100644 +--- a/megatron/core/optimizer/distrib_optimizer.py ++++ b/megatron/core/optimizer/distrib_optimizer.py +@@ -677,6 +677,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer): + # TE FusedAdam will not accumulate step for empty param groups, so we need to + # align the step across param groups. + param_group["step"] = int(step) ++ if "step" in param_group and param_group["step"] is None: ++ del param_group["step"] + + # Grad scaler state. + if self.grad_scaler: +@@ -1646,6 +1648,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer): + if key == 'padding': + tensors[key] = LocalNonpersistentObject(tensors[key]) + continue ++ if key == 'step': ++ continue + assert tensors[key].shape == (gbuf_local_end - gbuf_local_start,), ( + tensors[key].shape, + gbuf_local_start, +diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py +index a273002b9..4f821cfd5 100644 +--- a/megatron/core/parallel_state.py ++++ b/megatron/core/parallel_state.py +@@ -11,6 +11,7 @@ from typing import Callable, List, Optional + + import numpy as np + import torch ++import torch.distributed as dist + + from .utils import GlobalMemoryBuffer, is_torch_min_version + +diff --git a/megatron/core/pipeline_parallel/p2p_communication.py b/megatron/core/pipeline_parallel/p2p_communication.py +index ac839c21f..f18309217 100644 +--- a/megatron/core/pipeline_parallel/p2p_communication.py ++++ b/megatron/core/pipeline_parallel/p2p_communication.py +@@ -26,22 +26,22 @@ def _batched_p2p_ops( + ops = [] + if tensor_send_prev is not None: + send_prev_op = torch.distributed.P2POp( +- torch.distributed.isend, tensor_send_prev, prev_pipeline_rank, group ++ torch.distributed.isend, tensor_send_prev, prev_pipeline_rank, + ) + ops.append(send_prev_op) + if tensor_recv_prev is not None: + recv_prev_op = torch.distributed.P2POp( +- torch.distributed.irecv, tensor_recv_prev, prev_pipeline_rank, group ++ torch.distributed.irecv, tensor_recv_prev, prev_pipeline_rank, + ) + ops.append(recv_prev_op) + if tensor_send_next is not None: + send_next_op = torch.distributed.P2POp( +- torch.distributed.isend, tensor_send_next, next_pipeline_rank, group ++ torch.distributed.isend, tensor_send_next, next_pipeline_rank, + ) + ops.append(send_next_op) + if tensor_recv_next is not None: + recv_next_op = torch.distributed.P2POp( +- torch.distributed.irecv, tensor_recv_next, next_pipeline_rank, group ++ torch.distributed.irecv, tensor_recv_next, next_pipeline_rank, + ) + ops.append(recv_next_op) + if len(ops) > 0: +diff --git a/megatron/core/transformer/moe/moe_utils.py b/megatron/core/transformer/moe/moe_utils.py +index 28cff06f5..58dc4bb70 100644 +--- a/megatron/core/transformer/moe/moe_utils.py ++++ b/megatron/core/transformer/moe/moe_utils.py +@@ -587,6 +587,9 @@ def topk_routing_with_score_function( + else: + return torch.topk(scores, k=topk, dim=1) + ++ from miles.utils.routing_replay import get_routing_replay_compute_topk ++ compute_topk = get_routing_replay_compute_topk(compute_topk) ++ + if score_function == "softmax": + if use_pre_softmax: + scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits) +diff --git a/megatron/core/transformer/moe/router.py b/megatron/core/transformer/moe/router.py +index 16fc9d9af..517944f25 100644 +--- a/megatron/core/transformer/moe/router.py ++++ b/megatron/core/transformer/moe/router.py +@@ -201,6 +201,9 @@ class TopKRouter(Router): + self.global_tokens_per_expert = None + self.ga_steps = None + ++ from miles.utils.routing_replay import register_routing_replay ++ register_routing_replay(self) ++ + def _maintain_float32_expert_bias(self): + """ + Maintain the expert bias in float32. +diff --git a/megatron/core/transformer/multi_token_prediction.py b/megatron/core/transformer/multi_token_prediction.py +index a8f4abfcd..f33f6f05e 100755 +--- a/megatron/core/transformer/multi_token_prediction.py ++++ b/megatron/core/transformer/multi_token_prediction.py +@@ -6,6 +6,7 @@ from typing import Callable, List, Optional, Union + + import torch + from torch import Tensor ++import warnings + + from megatron.core import InferenceParams, parallel_state, tensor_parallel + from megatron.core.dist_checkpointing.mapping import ShardedStateDict +@@ -714,17 +715,19 @@ class MultiTokenPredictionLayer(MegatronModule): + cp_group=self.cp_group, + packed_seq_params=packed_seq_params, + ) +- position_ids, _ = roll_tensor( +- position_ids, +- shifts=-1, +- dims=-1, +- cp_group=self.cp_group, +- packed_seq_params=packed_seq_params, +- ) ++ if position_ids is not None: ++ position_ids, _ = roll_tensor( ++ position_ids, ++ shifts=-1, ++ dims=-1, ++ cp_group=self.cp_group, ++ packed_seq_params=packed_seq_params, ++ ) + # embedding + decoder_input = embedding(input_ids=input_ids, position_ids=position_ids) ++ decoder_input = decoder_input.detach() + +- hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) ++ hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=False) + + return input_ids, position_ids, decoder_input, hidden_states + +@@ -826,6 +829,51 @@ class MultiTokenPredictionLayer(MegatronModule): + return hidden_states + + def _checkpointed_forward(self, forward_func, *args, **kwargs): ++ """Wrap `forward_func` with activation checkpointing while only passing tensors. ++ ++ Non-tensor arguments (e.g., configuration objects, None) are captured via closure so ++ that checkpoint implementations never receive them directly, avoiding save_for_backward ++ issues with non-tensor inputs. ++ """ ++ ++ # TODO(jiajun): Is there any better implementation here? ++ positional_specs = [] ++ kw_specs = [] ++ tensor_args: List[torch.Tensor] = [] ++ ++ for arg in args: ++ if torch.is_tensor(arg): ++ positional_specs.append(('tensor', len(tensor_args))) ++ tensor_args.append(arg) ++ else: ++ positional_specs.append(('const', arg)) ++ ++ for key, value in kwargs.items(): ++ if torch.is_tensor(value): ++ kw_specs.append((key, ('tensor', len(tensor_args)))) ++ tensor_args.append(value) ++ else: ++ kw_specs.append((key, ('const', value))) ++ ++ def run(*flat_tensor_args): ++ rebuilt_args = [] ++ for spec_type, payload in positional_specs: ++ if spec_type == 'tensor': ++ rebuilt_args.append(flat_tensor_args[payload]) ++ else: ++ rebuilt_args.append(payload) ++ ++ rebuilt_kwargs = {} ++ for key, (spec_type, payload) in kw_specs: ++ if spec_type == 'tensor': ++ rebuilt_kwargs[key] = flat_tensor_args[payload] ++ else: ++ rebuilt_kwargs[key] = payload ++ ++ return forward_func(*rebuilt_args, **rebuilt_kwargs) ++ ++ tensor_args_tuple = tuple(tensor_args) ++ + def checkpoint_handler(): + """Determines whether to use the `te_checkpoint` or `tensor_parallel.checkpoint`""" + if self.config.fp8: +@@ -836,12 +884,11 @@ class MultiTokenPredictionLayer(MegatronModule): + self.config.distribute_saved_activations, + tensor_parallel.random.get_cuda_rng_tracker, + parallel_state.get_tensor_model_parallel_group(), +- *args, +- **kwargs, ++ *tensor_args_tuple, + ) + else: + return tensor_parallel.checkpoint( +- forward_func, self.config.distribute_saved_activations, *args, *kwargs.values() ++ run, self.config.distribute_saved_activations, *tensor_args_tuple + ) + + if self.config.recompute_method == 'uniform': +diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py +index e2705bd9f..a0aa109b5 100644 +--- a/megatron/core/transformer/transformer_config.py ++++ b/megatron/core/transformer/transformer_config.py +@@ -210,6 +210,9 @@ class TransformerConfig(ModelParallelConfig): + attention_output_gate: bool = False + """Whether to apply output gate to the attention layers.""" + ++ post_self_attn_layernorm: bool = False ++ post_mlp_layernorm: bool = False ++ + test_mode: bool = False + """Whether to run real-time tests.""" + +diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py +index 3ea405770..5a42001b9 100644 +--- a/megatron/core/transformer/transformer_layer.py ++++ b/megatron/core/transformer/transformer_layer.py +@@ -223,6 +223,7 @@ class TransformerLayerSubmodules: + input_layernorm: Union[ModuleSpec, type] = IdentityOp + self_attention: Union[ModuleSpec, type] = IdentityOp + self_attn_bda: Union[ModuleSpec, type] = IdentityFuncOp ++ post_self_attn_layernorm: Union[ModuleSpec, type] = IdentityOp + + pre_cross_attn_layernorm: Union[ModuleSpec, type] = IdentityOp + cross_attention: Union[ModuleSpec, type] = IdentityOp +@@ -231,6 +232,7 @@ class TransformerLayerSubmodules: + pre_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp + mlp: Union[ModuleSpec, type] = IdentityOp + mlp_bda: Union[ModuleSpec, type] = IdentityFuncOp ++ post_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp + + # Mapping for sharded tensor keys to be applied in `sharded_state_dict` method + sharded_state_dict_keys_map: Dict[str, str] = field(default_factory=dict) +@@ -310,6 +312,13 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): + # [Module 3: BiasDropoutFusion] + self.self_attn_bda = build_module(submodules.self_attn_bda) + ++ self.post_self_attn_layernorm = build_module( ++ submodules.post_self_attn_layernorm, ++ config=self.config, ++ hidden_size=self.config.hidden_size, ++ eps=self.config.layernorm_epsilon, ++ ) ++ + # [Module 4: Post SelfAttention] Optional Layernorm after self-attn + self.pre_cross_attn_layernorm = build_module( + submodules.pre_cross_attn_layernorm, +@@ -375,6 +384,13 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): + + self.is_moe_layer = isinstance(self.mlp, MoELayer) + ++ self.post_mlp_layernorm = build_module( ++ submodules.post_mlp_layernorm, ++ config=self.config, ++ hidden_size=self.config.hidden_size, ++ eps=self.config.layernorm_epsilon ++ ) ++ + self.recompute_input_layernorm = False + self.recompute_pre_mlp_layernorm = False + self.recompute_mlp = False +@@ -551,6 +567,10 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): + attention_output_with_bias[0] + ) + ++ attention_output, attention_output_bias = attention_output_with_bias ++ attention_output = self.post_self_attn_layernorm(attention_output) ++ attention_output_with_bias = (attention_output, attention_output_bias) ++ + # TODO: could we move `bias_dropout_add_exec_handler` itself + # inside the module provided in the `bias_dropout_add_spec` module? + nvtx_range_push(suffix="self_attn_bda") +@@ -677,6 +697,10 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): + else: + mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output) + ++ mlp_output, mlp_output_bias = mlp_output_with_bias ++ mlp_output = self.post_mlp_layernorm(mlp_output) ++ mlp_output_with_bias = (mlp_output, mlp_output_bias) ++ + if self.recompute_pre_mlp_layernorm: + # discard the output of the pre-mlp layernorm and register the recompute + # as a gradient hook of mlp_output_with_bias[0] +diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py +index b267c8a81..83736acdc 100644 +--- a/megatron/training/arguments.py ++++ b/megatron/training/arguments.py +@@ -1398,6 +1398,9 @@ def core_transformer_config_from_args(args, config_class=None): + + kw_args['inference_sampling_seed'] = args.seed + ++ kw_args['post_self_attn_layernorm'] = args.post_self_attn_layernorm ++ kw_args['post_mlp_layernorm'] = args.post_mlp_layernorm ++ + # handle quantization config + # NOTE: Kitchen arguments are only added to the namespace when + # Kitchen library is available. +@@ -1764,6 +1767,12 @@ def _add_network_size_args(parser): + action='store_true', + help='If set, use original BERT residula connection ' + 'ordering.') ++ group.add_argument('--post-self-attn-layernorm', action='store_true', ++ help='If set, use post self attention layernorm.') ++ group.add_argument('--post-mlp-layernorm', action='store_true', ++ help='If set, use post MLP layernorm.') ++ group.add_argument('--use-gated-attention', action='store_true', ++ help='If set, use gated attention as in Qwen3Next') + group.add_argument('--openai-gelu', action='store_true', + help='Use OpenAIs GeLU implementation. This option' + 'should not be used unless for backward compatibility' +diff --git a/megatron/training/tokenizer/tokenizer.py b/megatron/training/tokenizer/tokenizer.py +index 13b7526ca..6c590f653 100644 +--- a/megatron/training/tokenizer/tokenizer.py ++++ b/megatron/training/tokenizer/tokenizer.py +@@ -136,7 +136,7 @@ class _HuggingFaceTokenizer(MegatronLegacyTokenizer): + # TODO(bnorick): download tokenizer once to lustre and use force offline to make sure all tasks read it from there + self._tokenizer = transformers.AutoTokenizer.from_pretrained( + pretrained_model_name_or_path=pretrained_model_name_or_path, +- trust_remote_code=trust_remote_code, ++ trust_remote_code=True, + **kwargs, + ) + self._vocab = self._tokenizer.get_vocab() diff --git a/docker/patch/v0.5.7/sglang.patch b/docker/patch/v0.5.7/sglang.patch new file mode 100644 index 000000000..42d23ed65 --- /dev/null +++ b/docker/patch/v0.5.7/sglang.patch @@ -0,0 +1,864 @@ +diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py +index 199885244..742ad0639 100644 +--- a/python/sglang/srt/disaggregation/decode.py ++++ b/python/sglang/srt/disaggregation/decode.py +@@ -314,6 +314,13 @@ class DecodePreallocQueue: + ) + return kv_manager + ++ def release_memory_occupation(self): ++ if hasattr(self.kv_manager, "close"): ++ self.kv_manager.close() ++ ++ def resume_memory_occupation(self): ++ self.kv_manager = self._init_kv_manager() ++ + def add(self, req: Req, is_retracted: bool = False) -> None: + """Add a request to the pending queue.""" + if self._check_if_req_exceed_kv_capacity(req): +diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py +index 32e8c0b69..df913da7b 100644 +--- a/python/sglang/srt/disaggregation/mooncake/conn.py ++++ b/python/sglang/srt/disaggregation/mooncake/conn.py +@@ -1079,6 +1079,19 @@ class MooncakeKVManager(CommonKVManager): + f"Losing connection with prefill instance (bootstrap_addr: {failed_bootstrap_addr}), {len(affected_rooms)} requests affected" + ) + ++ def close(self): ++ # Batch deregister KV data buffers ++ if self.kv_args.kv_data_ptrs: ++ self.engine.batch_deregister(self.kv_args.kv_data_ptrs) ++ ++ # Batch deregister auxiliary data buffers ++ if self.kv_args.aux_data_ptrs: ++ self.engine.batch_deregister(self.kv_args.aux_data_ptrs) ++ ++ # Batch deregister state/extra pool data buffers ++ if self.kv_args.state_data_ptrs: ++ self.engine.batch_deregister(self.kv_args.state_data_ptrs) ++ + + class MooncakeKVSender(CommonKVSender): + +diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py +index ac11013f8..478e469f6 100644 +--- a/python/sglang/srt/disaggregation/prefill.py ++++ b/python/sglang/srt/disaggregation/prefill.py +@@ -309,6 +309,13 @@ class PrefillBootstrapQueue: + else: + return bootstrapped_reqs, failed_reqs + ++ def release_memory_occupation(self): ++ if hasattr(self.kv_manager, "close"): ++ self.kv_manager.close() ++ ++ def resume_memory_occupation(self): ++ self.kv_manager = self._init_kv_manager() ++ + + class SchedulerDisaggregationPrefillMixin: + """ +diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py +index 0478526ef..cfb1aa669 100644 +--- a/python/sglang/srt/distributed/parallel_state.py ++++ b/python/sglang/srt/distributed/parallel_state.py +@@ -1797,7 +1797,10 @@ def get_tensor_model_parallel_world_size(): + + def get_tensor_model_parallel_rank(): + """Return my rank for the tensor model parallel group.""" +- return get_tp_group().rank_in_group ++ try: ++ return get_tp_group().rank_in_group ++ except Exception: ++ return 0 + + + def get_pipeline_model_parallel_world_size(): +diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py +index b07164c53..8e6722ce0 100644 +--- a/python/sglang/srt/layers/layernorm.py ++++ b/python/sglang/srt/layers/layernorm.py +@@ -83,15 +83,12 @@ class RMSNorm(MultiPlatformOp): + eps: float = 1e-6, + var_hidden_size: Optional[int] = None, + cast_x_before_out_mul: bool = False, +- fp32_residual: bool = False, +- weight_dtype: Optional = None, +- override_orig_dtype: Optional = None, ++ fp32_residual: bool = True, + ) -> None: + super().__init__() + self.cast_x_before_out_mul = cast_x_before_out_mul + self.fp32_residual = fp32_residual +- self.override_orig_dtype = override_orig_dtype +- self.weight = nn.Parameter(torch.ones(hidden_size, dtype=weight_dtype)) ++ self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + self.hidden_size = hidden_size + self.variance_size_override = ( +@@ -194,10 +191,22 @@ class RMSNorm(MultiPlatformOp): + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if not x.is_contiguous(): + x = x.contiguous() +- orig_dtype = self.override_orig_dtype or x.dtype ++ orig_dtype = x.dtype + post_residual_addition = kwargs.get("post_residual_addition") ++ ++ if residual is not None and not self.fp32_residual: ++ x = ( ++ x ++ + residual ++ + ( ++ post_residual_addition ++ if post_residual_addition is not None ++ else 0.0 ++ ) ++ ) ++ residual = x.clone() + x = x.to(torch.float32) +- if residual is not None: ++ if residual is not None and self.fp32_residual: + x = ( + x + + residual.to(torch.float32) +@@ -207,10 +216,7 @@ class RMSNorm(MultiPlatformOp): + else 0.0 + ) + ) +- if self.fp32_residual: +- residual = x.clone() +- else: +- residual = x.to(orig_dtype) ++ residual = x.to(orig_dtype) + + hidden_size = x.shape[-1] + if hidden_size != self.hidden_size: +diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py +index fa7431048..cd33ea735 100644 +--- a/python/sglang/srt/layers/logits_processor.py ++++ b/python/sglang/srt/layers/logits_processor.py +@@ -878,11 +878,6 @@ class LogitsProcessor(nn.Module): + None, # bias + True, # is_vnni + ) +- elif get_global_server_args().rl_on_policy_target is not None: +- # Due to tie-weight, we may not be able to change lm_head's weight dtype +- logits = torch.matmul( +- hidden_states.bfloat16(), lm_head.weight.T.bfloat16() +- ) + else: + logits = torch.matmul( + hidden_states.to(lm_head.weight.dtype), lm_head.weight.T +diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +index a1885fade..14d692365 100644 +--- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py ++++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +@@ -14,6 +14,7 @@ import torch.nn.functional as F + import triton.language as tl + + from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig ++from sglang.srt.server_args import get_global_server_args + from sglang.srt.utils import ( + cpu_has_amx_support, + get_bool_env_var, +@@ -573,7 +574,10 @@ def fused_experts_impl( + ).squeeze(dim=1) + else: + # According to micro benchmark results, torch.compile can get better performance for small token. +- if tokens_in_chunk <= 32: ++ if ( ++ not get_global_server_args().enable_deterministic_inference ++ and tokens_in_chunk <= 32 ++ ): + moe_sum_reduce_torch_compile( + intermediate_cache3.view(*intermediate_cache3.shape), + out_hidden_states[begin_chunk_idx:end_chunk_idx], +diff --git a/python/sglang/srt/layers/moe/routed_experts_capturer.py b/python/sglang/srt/layers/moe/routed_experts_capturer.py +index 00bd68755..5a3ca8a67 100644 +--- a/python/sglang/srt/layers/moe/routed_experts_capturer.py ++++ b/python/sglang/srt/layers/moe/routed_experts_capturer.py +@@ -1,5 +1,6 @@ + import logging + from abc import ABC ++from contextlib import contextmanager + from typing import Optional + + import numpy as np +@@ -8,13 +9,18 @@ import torch + + from sglang.srt.configs.model_config import ModelConfig + from sglang.srt.layers.dp_attention import ( ++ attn_tp_all_gather_into_tensor, + get_attention_dp_rank, ++ get_attention_tp_size, + get_dp_local_info, + is_dp_attention_enabled, + ) + from sglang.srt.mem_cache.memory_pool import ReqToTokenPool + from sglang.srt.model_executor.forward_batch_info import ForwardBatch + from sglang.srt.server_args import get_global_server_args ++from sglang.srt.layers.moe import ( ++ get_moe_a2a_backend, ++) + + logger = logging.getLogger(__name__) + +@@ -181,13 +187,26 @@ class _RoutedExpertsCapturerReal(RoutedExpertsCapturer): + device=device, + ) + ++ if get_moe_a2a_backend().is_deepep(): ++ attn_tp_size = get_attention_tp_size() if is_dp_attention_enabled() else 1 ++ self.gather_buffer = torch.empty( ++ ( ++ self.device_cache.buffer.shape[0] * attn_tp_size, ++ self.device_cache.buffer.shape[2], ++ ), ++ dtype=torch.int32, ++ device=device, ++ ) ++ + def _sync_fwd_experts_buffer_DtoH( + self, + forward_batch: ForwardBatch, + can_run_graph: bool, + cuda_graph_batch: int, + ): +- if is_dp_attention_enabled(): ++ # When DeepEP is enabled, capture() already does all_gather, so device_cache.buffer ++ # contains data from all DP ranks. We should not slice by DP rank in this case. ++ if is_dp_attention_enabled() and not get_moe_a2a_backend().is_deepep(): + local_start_pos, local_num_tokens = get_dp_local_info(forward_batch) + # handle with cuda graph padding + if can_run_graph: +@@ -206,6 +225,12 @@ class _RoutedExpertsCapturerReal(RoutedExpertsCapturer): + ].cpu() + + def capture(self, layer_id: int, topk_ids: torch.Tensor): ++ if get_moe_a2a_backend().is_deepep(): ++ local_topk_ids = topk_ids ++ topk_ids = self.gather_buffer[ ++ : local_topk_ids.size(0) * get_attention_tp_size() ++ ] ++ attn_tp_all_gather_into_tensor(topk_ids, local_topk_ids) + self.device_cache.capture_fwd_routed_experts(layer_id, topk_ids) + + def get_routed_experts( +diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py +index 56516b41b..cb2ebca60 100644 +--- a/python/sglang/srt/layers/rotary_embedding.py ++++ b/python/sglang/srt/layers/rotary_embedding.py +@@ -135,9 +135,7 @@ class RotaryEmbedding(MultiPlatformOp): + + if get_global_server_args().rl_on_policy_target is not None: + self._forward_method = self.forward_native +- self._apply_rotary_emb_wrapped = torch.compile(dynamic=True)( +- self._apply_rotary_emb_wrapped +- ) ++ + self.position_cos, self.position_sin = None, None + + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: +@@ -1577,6 +1575,9 @@ class MRotaryEmbedding(RotaryEmbedding): + key: torch.Tensor, + fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: ++ assert ( ++ fused_set_kv_buffer_arg is None ++ ), "fused_set_kv_buffer_arg is not supported for npu implementation" + # TODO: remove this when npu_mrope supports QNumHeads * QHeadSize > 4096 + assert ( + fused_set_kv_buffer_arg is None +diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py +index 55bef5652..35ad68b1c 100644 +--- a/python/sglang/srt/layers/sampler.py ++++ b/python/sglang/srt/layers/sampler.py +@@ -108,16 +108,11 @@ class Sampler(nn.Module): + if return_logprob and SGLANG_RETURN_ORIGINAL_LOGPROB: + probs_without_temp_scaling = torch.softmax(logits, dim=-1) + +- if get_global_server_args().rl_on_policy_target is not None: +- logits_div_temperature = ( +- logits.bfloat16().div(sampling_info.temperatures).bfloat16() +- ) +- logprobs_via_logsoftmax_kernel = torch.log_softmax( +- logits_div_temperature, dim=-1 +- ) +- + # Post process logits + logits.div_(sampling_info.temperatures) ++ if get_global_server_args().rl_on_policy_target is not None: ++ logprobs_via_logsoftmax_kernel = torch.log_softmax(logits, dim=-1) ++ + # For ascend backend, softmax is not needed before sampling + if not get_global_server_args().sampling_backend == "ascend" or ( + return_logprob and not SGLANG_RETURN_ORIGINAL_LOGPROB +diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py +index 468d8fb8a..229a9a2dc 100644 +--- a/python/sglang/srt/managers/schedule_batch.py ++++ b/python/sglang/srt/managers/schedule_batch.py +@@ -2181,7 +2181,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): + def __str__(self): + return ( + f"ScheduleBatch(forward_mode={self.forward_mode.name if self.forward_mode else 'None'}, " +- f"#req={(len(self.reqs))})" ++ f"#req={(len(self.reqs))}), " ++ f"#out_cache_loc={self.out_cache_loc})" + ) + + +diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py +index e40586c24..32d98aee4 100644 +--- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py ++++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py +@@ -10,6 +10,7 @@ from sglang.srt.disaggregation.utils import DisaggregationMode + from sglang.srt.environ import envs + from sglang.srt.layers.logits_processor import LogitsProcessorOutput + from sglang.srt.layers.moe.routed_experts_capturer import get_global_experts_capturer ++ + from sglang.srt.managers.io_struct import ( + AbortReq, + BatchEmbeddingOutput, +diff --git a/python/sglang/srt/managers/scheduler_update_weights_mixin.py b/python/sglang/srt/managers/scheduler_update_weights_mixin.py +index 293a84350..0947f77e0 100644 +--- a/python/sglang/srt/managers/scheduler_update_weights_mixin.py ++++ b/python/sglang/srt/managers/scheduler_update_weights_mixin.py +@@ -1,6 +1,7 @@ + from __future__ import annotations + + import logging ++import os + import traceback + from typing import TYPE_CHECKING, Tuple + +@@ -12,6 +13,9 @@ from sglang.srt.constants import ( + GPU_MEMORY_TYPE_KV_CACHE, + GPU_MEMORY_TYPE_WEIGHTS, + ) ++from sglang.srt.disaggregation.utils import DisaggregationMode ++from sglang.srt.distributed import get_moe_ep_group, get_moe_tp_group, get_tp_group ++from sglang.srt.layers.dp_attention import get_attention_tp_group + from sglang.srt.managers.io_struct import ( + CheckWeightsReqInput, + CheckWeightsReqOutput, +@@ -137,6 +141,13 @@ class SchedulerUpdateWeightsMixin: + self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_KV_CACHE) + self.flush_cache() + ++ if self.disaggregation_mode == DisaggregationMode.DECODE: ++ if hasattr(self, "disagg_decode_prealloc_queue"): ++ self.disagg_decode_prealloc_queue.release_memory_occupation() ++ elif self.disaggregation_mode == DisaggregationMode.PREFILL: ++ if hasattr(self, "disagg_prefill_bootstrap_queue"): ++ self.disagg_prefill_bootstrap_queue.release_memory_occupation() ++ + if GPU_MEMORY_TYPE_WEIGHTS in tags: + self.stashed_model_static_state = _export_static_state( + self.tp_worker.model_runner.model +@@ -177,6 +188,13 @@ class SchedulerUpdateWeightsMixin: + if GPU_MEMORY_TYPE_KV_CACHE in tags: + self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_KV_CACHE) + ++ if self.disaggregation_mode == DisaggregationMode.DECODE: ++ if hasattr(self, "disagg_decode_prealloc_queue"): ++ self.disagg_decode_prealloc_queue.resume_memory_occupation() ++ elif self.disaggregation_mode == DisaggregationMode.PREFILL: ++ if hasattr(self, "disagg_prefill_bootstrap_queue"): ++ self.disagg_prefill_bootstrap_queue.resume_memory_occupation() ++ + return ResumeMemoryOccupationReqOutput() + + def check_weights(self: Scheduler, recv_req: CheckWeightsReqInput): +diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py +index f4fc29e29..5ef12cca6 100644 +--- a/python/sglang/srt/managers/tokenizer_manager.py ++++ b/python/sglang/srt/managers/tokenizer_manager.py +@@ -1652,12 +1652,13 @@ class TokenizerManager(TokenizerCommunicatorMixin, TokenizerManagerMultiItemMixi + return + + if len(recv_obj.input_token_logprobs_val) > 0: +- state.input_token_logprobs_val.extend( +- recv_obj.input_token_logprobs_val[recv_obj_index] +- ) +- state.input_token_logprobs_idx.extend( +- recv_obj.input_token_logprobs_idx[recv_obj_index] +- ) ++ if recv_obj.input_token_logprobs_val[recv_obj_index]: ++ state.input_token_logprobs_val.extend( ++ recv_obj.input_token_logprobs_val[recv_obj_index] ++ ) ++ state.input_token_logprobs_idx.extend( ++ recv_obj.input_token_logprobs_idx[recv_obj_index] ++ ) + state.output_token_logprobs_val.extend( + recv_obj.output_token_logprobs_val[recv_obj_index] + ) +diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py +index 1d69c0582..9027374be 100644 +--- a/python/sglang/srt/model_executor/model_runner.py ++++ b/python/sglang/srt/model_executor/model_runner.py +@@ -558,7 +558,8 @@ class ModelRunner(ModelRunnerKVCacheMixin): + ) + + # Init routed experts capturer +- self.init_routed_experts_capturer() ++ if not self.is_draft_worker: ++ self.init_routed_experts_capturer() + + if self.device == "cuda": + self.init_cublas() +@@ -2224,11 +2225,12 @@ class ModelRunner(ModelRunnerKVCacheMixin): + output.expert_distribution_metrics = recorder_outputs.get("metrics") + + # Copy cached routing experts' buffers back to CPU cache +- get_global_experts_capturer().on_forward_end( +- forward_batch=forward_batch, +- can_run_graph=output.can_run_graph, +- cuda_graph_batch=getattr(self.graph_runner, "bs", None), +- ) ++ if not self.is_draft_worker: ++ get_global_experts_capturer().on_forward_end( ++ forward_batch=forward_batch, ++ can_run_graph=output.can_run_graph, ++ cuda_graph_batch=getattr(self.graph_runner, "bs", None), ++ ) + + if self.eplb_manager is not None: + self.eplb_manager.on_forward_pass_end() +diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py +index 2918461d3..2bcc67087 100644 +--- a/python/sglang/srt/models/deepseek_v2.py ++++ b/python/sglang/srt/models/deepseek_v2.py +@@ -2704,7 +2704,11 @@ class DeepseekV2AttentionMLA(nn.Module): + ): + k = k_nope.new_empty(*k_shape) + concat_mla_k(k=k, k_nope=k_nope, k_rope=k_pe) +- elif _is_cuda: ++ elif _is_cuda and all( ++ # (i.bit_count() == 1) == (is_power_of_two(i)) ++ i.bit_count() == 1 ++ for i in (k_shape[1], k_nope.shape[-1], k_pe.shape[-1]) ++ ): + # fa3 mha support fp8 inputs + if ( + self.current_attention_backend == "fa3" +diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py +index a7dbadec6..c83a41338 100644 +--- a/python/sglang/srt/models/qwen2.py ++++ b/python/sglang/srt/models/qwen2.py +@@ -90,9 +90,6 @@ class Qwen2MLP(nn.Module): + self.act_fn = SiluAndMul() + + def forward(self, x): +- if get_global_server_args().rl_on_policy_target is not None: +- x = x.bfloat16() +- + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) +@@ -279,11 +276,6 @@ class Qwen2Model(nn.Module): + quant_config=quant_config, + enable_tp=not is_dp_attention_enabled(), + prefix=add_prefix("embed_tokens", prefix), +- params_dtype=( +- torch.float32 +- if get_global_server_args().rl_on_policy_target is not None +- else None +- ), + ) + else: + self.embed_tokens = PPMissingLayer() +@@ -306,10 +298,8 @@ class Qwen2Model(nn.Module): + if self.pp_group.is_last_rank: + norm_kwargs = ( + dict( +- weight_dtype=torch.float32, + cast_x_before_out_mul=True, +- override_orig_dtype=torch.float32, +- fp32_residual=True, ++ fp32_residual=False, + ) + if get_global_server_args().rl_on_policy_target is not None + else {} +diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py +index 3ad9f6736..0b9c7f499 100644 +--- a/python/sglang/srt/models/qwen2_moe.py ++++ b/python/sglang/srt/models/qwen2_moe.py +@@ -586,7 +586,17 @@ class Qwen2MoeModel(nn.Module): + prefix=add_prefix("layers", prefix), + ) + if self.pp_group.is_last_rank: +- self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) ++ norm_kwargs = ( ++ dict( ++ cast_x_before_out_mul=True, ++ fp32_residual=False, ++ ) ++ if get_global_server_args().rl_on_policy_target is not None ++ else {} ++ ) ++ self.norm = RMSNorm( ++ config.hidden_size, eps=config.rms_norm_eps, **norm_kwargs ++ ) + else: + self.norm = PPMissingLayer(return_tuple=True) + +diff --git a/python/sglang/srt/models/qwen3.py b/python/sglang/srt/models/qwen3.py +index 9220831f6..47a1a4e4c 100644 +--- a/python/sglang/srt/models/qwen3.py ++++ b/python/sglang/srt/models/qwen3.py +@@ -90,8 +90,8 @@ class Qwen3Attention(nn.Module): + + norm_kwargs = ( + dict( +- weight_dtype=torch.float32, + cast_x_before_out_mul=True, ++ fp32_residual=False, + ) + if get_global_server_args().rl_on_policy_target is not None + else {} +@@ -242,10 +242,8 @@ class Qwen3DecoderLayer(nn.Module): + + norm_kwargs = ( + dict( +- weight_dtype=torch.float32, + cast_x_before_out_mul=True, +- override_orig_dtype=torch.float32, +- fp32_residual=True, ++ fp32_residual=False, + ) + if get_global_server_args().rl_on_policy_target is not None + else {} +diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py +index e11678a9e..e277d46f2 100644 +--- a/python/sglang/srt/models/qwen3_moe.py ++++ b/python/sglang/srt/models/qwen3_moe.py +@@ -22,6 +22,7 @@ import math + from typing import Any, Dict, Iterable, List, Optional, Tuple, TypeVar + + import torch ++import torch.nn.functional as F + from torch import nn + from transformers import PretrainedConfig + +@@ -50,7 +51,7 @@ from sglang.srt.layers.moe import ( + ) + from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class + from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE +-from sglang.srt.layers.moe.topk import TopK ++from sglang.srt.layers.moe.topk import StandardTopKOutput, TopK + from sglang.srt.layers.moe.utils import RoutingMethodType + from sglang.srt.layers.quantization.base_config import QuantizationConfig + from sglang.srt.layers.radix_attention import RadixAttention +@@ -229,6 +230,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module): + use_grouped_topk=False, + layer_id=layer_id, + ) ++ self.top_k = config.num_experts_per_tok + + self.experts = get_moe_impl_class(quant_config)( + num_experts=config.num_experts +@@ -294,7 +296,22 @@ class Qwen3MoeSparseMoeBlock(nn.Module): + + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) +- topk_output = self.topk(hidden_states, router_logits) ++ ++ if get_global_server_args().rl_on_policy_target is not None: ++ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) ++ routing_weights, selected_experts = torch.topk( ++ routing_weights, self.top_k, dim=-1 ++ ) ++ routing_weights /= routing_weights.sum(dim=-1, keepdim=True) ++ routing_weights = routing_weights.to(hidden_states.dtype) ++ topk_output = StandardTopKOutput( ++ topk_weights=routing_weights, ++ topk_ids=selected_experts, ++ router_logits=router_logits, ++ ) ++ else: ++ topk_output = self.topk(hidden_states, router_logits) ++ + final_hidden_states = self.experts(hidden_states, topk_output) + if ( + self.tp_size > 1 +@@ -475,13 +492,14 @@ class Qwen3MoeAttention(nn.Module): + ) + self.compatible_with_fused_kv_buffer = ( + False if isinstance(self.rotary_emb, MRotaryEmbedding) else True +- ) ++ ) and (get_global_server_args().rl_on_policy_target is None) + self.compatible_with_fused_qk_norm_rope = ( + not isinstance(self.rotary_emb, MRotaryEmbedding) + ) and self.head_dim in (64, 128, 256) + self.use_fused_qk_norm_rope = ( + get_global_server_args().enable_fused_qk_norm_rope + and self.compatible_with_fused_qk_norm_rope ++ and (get_global_server_args().rl_on_policy_target is None) + ) + self._used_fused_qk_norm_rope_last_call = False + +@@ -494,8 +512,16 @@ class Qwen3MoeAttention(nn.Module): + prefix=add_prefix("attn", prefix), + ) + +- self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) +- self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) ++ norm_kwargs = ( ++ dict( ++ cast_x_before_out_mul=True, ++ fp32_residual=False, ++ ) ++ if get_global_server_args().rl_on_policy_target is not None ++ else {} ++ ) ++ self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps, **norm_kwargs) ++ self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps, **norm_kwargs) + self.alt_stream = alt_stream + + def op_prepare(self, state): +@@ -736,9 +762,19 @@ class Qwen3MoeDecoderLayer(nn.Module): + quant_config=quant_config, + prefix=add_prefix("mlp", prefix), + ) +- self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) ++ norm_kwargs = ( ++ dict( ++ cast_x_before_out_mul=True, ++ fp32_residual=False, ++ ) ++ if get_global_server_args().rl_on_policy_target is not None ++ else {} ++ ) ++ self.input_layernorm = RMSNorm( ++ config.hidden_size, eps=config.rms_norm_eps, **norm_kwargs ++ ) + self.post_attention_layernorm = RMSNorm( +- config.hidden_size, eps=config.rms_norm_eps ++ config.hidden_size, eps=config.rms_norm_eps, **norm_kwargs + ) + + self.layer_communicator = LayerCommunicator( +diff --git a/python/sglang/srt/models/qwen3_vl.py b/python/sglang/srt/models/qwen3_vl.py +index 891913078..c9dbecd23 100644 +--- a/python/sglang/srt/models/qwen3_vl.py ++++ b/python/sglang/srt/models/qwen3_vl.py +@@ -397,28 +397,68 @@ class Qwen3VLMoeVisionModel(nn.Module, RotaryPosMixin): + return cos_combined, sin_combined + + def fast_pos_embed_interpolate(self, grid_thw): +- patch_pos_embeds_permute = [] +- m_size = self.spatial_merge_size ++ grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2] ++ num_grid_per_side = int(self.num_position_embeddings**0.5) ++ device = self.pos_embed.weight.device ++ ++ idx_list = [[] for _ in range(4)] ++ weight_list = [[] for _ in range(4)] ++ ++ for t, h, w in zip(grid_ts, grid_hs, grid_ws): ++ h_idxs = torch.linspace(0, num_grid_per_side - 1, h) ++ w_idxs = torch.linspace(0, num_grid_per_side - 1, w) ++ ++ h_idxs_floor = h_idxs.int() ++ w_idxs_floor = w_idxs.int() ++ h_idxs_ceil = (h_idxs.int() + 1).clip(max=num_grid_per_side - 1) ++ w_idxs_ceil = (w_idxs.int() + 1).clip(max=num_grid_per_side - 1) ++ ++ dh = h_idxs - h_idxs_floor ++ dw = w_idxs - w_idxs_floor ++ ++ base_h = h_idxs_floor * num_grid_per_side ++ base_h_ceil = h_idxs_ceil * num_grid_per_side ++ ++ indices = [ ++ (base_h[None].T + w_idxs_floor[None]).flatten(), ++ (base_h[None].T + w_idxs_ceil[None]).flatten(), ++ (base_h_ceil[None].T + w_idxs_floor[None]).flatten(), ++ (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(), ++ ] ++ ++ weights = [ ++ ((1 - dh)[None].T * (1 - dw)[None]).flatten(), ++ ((1 - dh)[None].T * dw[None]).flatten(), ++ (dh[None].T * (1 - dw)[None]).flatten(), ++ (dh[None].T * dw[None]).flatten(), ++ ] + +- embeds = torch.arange(self.num_grid, device=self.pos_embed.weight.device) +- embeds = ( +- self.pos_embed(embeds) +- .permute(1, 0) +- .reshape(1, -1, self.num_grid_per_side, self.num_grid_per_side) ++ for i in range(4): ++ idx_list[i].extend(indices[i].tolist()) ++ weight_list[i].extend(weights[i].tolist()) ++ ++ idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=device) ++ weight_tensor = torch.tensor( ++ weight_list, dtype=self.pos_embed.weight.dtype, device=device + ) +- for t, h, w in grid_thw: +- pos_embed = torch.nn.functional.interpolate( +- embeds, size=(h, w), mode="bilinear", align_corners=self.align_corners +- ) +- pos_embed = pos_embed.reshape( +- -1, +- h // self.spatial_merge_size, +- self.spatial_merge_size, +- w // self.spatial_merge_size, +- self.spatial_merge_size, ++ pos_embeds = self.pos_embed(idx_tensor).to(device) * weight_tensor[:, :, None] ++ patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3] ++ ++ patch_pos_embeds = patch_pos_embeds.split( ++ [h * w for h, w in zip(grid_hs, grid_ws)] ++ ) ++ ++ patch_pos_embeds_permute = [] ++ merge_size = self.spatial_merge_size ++ for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws): ++ pos_embed = pos_embed.repeat(t, 1) ++ pos_embed = ( ++ pos_embed.view( ++ t, h // merge_size, merge_size, w // merge_size, merge_size, -1 ++ ) ++ .permute(0, 1, 3, 2, 4, 5) ++ .flatten(0, 4) + ) +- pos_embed = pos_embed.permute(1, 3, 2, 4, 0) +- pos_embed = pos_embed.flatten(0, 3).repeat(t, 1) + patch_pos_embeds_permute.append(pos_embed) + return torch.cat(patch_pos_embeds_permute) + +@@ -607,14 +647,19 @@ class Qwen3LLMModel(Qwen3Model): + hidden_states + residual if residual is not None else hidden_states + ) + ++ deepstack_embeds = None ++ if input_deepstack_embeds is not None: ++ prev_layer_idx = layer_idx - 1 ++ if prev_layer_idx in self.deepstack_embed_to_decoder_layer: ++ sep = self.hidden_size * prev_layer_idx ++ deepstack_embeds = input_deepstack_embeds[ ++ :, sep : sep + self.hidden_size ++ ] ++ + # SGLang applies residual at the START of the next layer, not at the END like HuggingFace. + # See: https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py#L549 + # To match HF behavior, deepstack must be added AFTER residual: (hidden_states + residual) + deepstack + # The order matters because addition with different tensors is not associative in practice. +- # Deepstack for prev_layer is applied at the start of current layer via post_residual_addition. +- deepstack_embeds = self.get_deepstack_embeds( +- layer_idx - 1, input_deepstack_embeds +- ) + hidden_states, residual = layer( + positions, + hidden_states, +diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py +index 54d4e415a..de7620c20 100644 +--- a/python/sglang/srt/server_args.py ++++ b/python/sglang/srt/server_args.py +@@ -523,6 +523,7 @@ class ServerArgs: + cuda_graph_max_bs: Optional[int] = None + cuda_graph_bs: Optional[List[int]] = None + disable_cuda_graph: bool = False ++ disable_draft_cuda_graph: bool = False + disable_cuda_graph_padding: bool = False + enable_profile_cuda_graph: bool = False + enable_cudagraph_gc: bool = False +@@ -3951,6 +3952,11 @@ class ServerArgs: + action="store_true", + help="Disable cuda graph.", + ) ++ parser.add_argument( ++ "--disable-draft-cuda-graph", ++ action="store_true", ++ help="Disable cuda graph for draft model in speculative decoding.", ++ ) + parser.add_argument( + "--disable-cuda-graph-padding", + action="store_true", +diff --git a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +index 5fe45086c..c95fbd0f6 100644 +--- a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py ++++ b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +@@ -341,7 +341,10 @@ class EAGLEDraftCudaGraphRunner: + self.seq_lens.fill_(self.seq_len_fill_value) + self.out_cache_loc.zero_() + self.positions.zero_() +- ++ self.topk_p.zero_() ++ self.topk_index.zero_() ++ self.hidden_states.zero_() ++ self.req_pool_indices.zero_() + num_tokens = bs * self.num_tokens_per_bs + + # Common inputs +@@ -350,8 +353,8 @@ class EAGLEDraftCudaGraphRunner: + forward_batch.out_cache_loc + ) + self.positions[:raw_num_token].copy_(forward_batch.positions) +- self.topk_p[:raw_bs].copy_(forward_batch.spec_info.topk_p) +- self.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index) ++ self.topk_p[:raw_bs].copy_(forward_batch.spec_info.topk_p.clamp(0, 1)) ++ self.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index.clamp(0, self.model_runner.model_config.vocab_size - 1)) + self.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states) + self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices) + +diff --git a/python/sglang/srt/speculative/eagle_info.py b/python/sglang/srt/speculative/eagle_info.py +index 1bf3816e9..b5b41dba4 100644 +--- a/python/sglang/srt/speculative/eagle_info.py ++++ b/python/sglang/srt/speculative/eagle_info.py +@@ -778,6 +778,10 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin): + self.topk_index = self.topk_index[: len(new_indices)] + self.hidden_states = self.hidden_states[: len(new_indices)] + self.verified_id = self.verified_id[: len(new_indices)] ++ if self.accept_length is not None: ++ self.accept_length = self.accept_length[: len(new_indices)] ++ if self.accept_length_cpu is not None: ++ self.accept_length_cpu = self.accept_length_cpu[: len(new_indices)] + else: + # in some cases(e.g draft_extend), we have not filtered the batch by `unfinished_index` + self.topk_p = self.topk_p[new_indices] +@@ -809,6 +813,27 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin): + self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], axis=0) + self.topk_p = torch.cat([self.topk_p, spec_info.topk_p]) + self.topk_index = torch.cat([self.topk_index, spec_info.topk_index]) ++ if self.accept_length is not None and spec_info.accept_length is not None: ++ self.accept_length = torch.cat( ++ [self.accept_length, spec_info.accept_length] ++ ) ++ self.accept_length_cpu = self.accept_length.tolist() ++ elif self.accept_length is not None: ++ zeros = torch.zeros( ++ [spec_info.verified_id.shape[0]], ++ dtype=self.accept_length.dtype, ++ device=self.accept_length.device, ++ ) ++ self.accept_length = torch.cat([self.accept_length, zeros]) ++ self.accept_length_cpu = self.accept_length.tolist() ++ elif spec_info.accept_length is not None: ++ zeros = torch.zeros( ++ [self.verified_id.shape[0]], ++ dtype=self.accept_length.dtype, ++ device=self.accept_length.device, ++ ) ++ self.accept_length = torch.cat([zeros, spec_info.accept_length]) ++ self.accept_length_cpu = self.accept_length.tolist() + + + @dataclass +diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py +index a702df4f8..61d9ae366 100644 +--- a/python/sglang/srt/speculative/eagle_worker.py ++++ b/python/sglang/srt/speculative/eagle_worker.py +@@ -231,7 +231,7 @@ class EAGLEWorker(TpModelWorker): + self.cuda_graph_runner = None + self.cuda_graph_runner_for_draft_extend = None + +- if self.server_args.disable_cuda_graph: ++ if self.server_args.disable_cuda_graph or self.server_args.disable_draft_cuda_graph: + return + + Device2DraftCudaGraphRunner = { diff --git a/docker/version.txt b/docker/version.txt index b480e0254..f072c7789 100644 --- a/docker/version.txt +++ b/docker/version.txt @@ -1 +1 @@ -nightly-dev-20251212a \ No newline at end of file +nightly-dev-20260113a diff --git a/docs/README.md b/docs/README.md index 9e4971a1f..5251aa5ab 100644 --- a/docs/README.md +++ b/docs/README.md @@ -1,6 +1,6 @@ # Miles Documentation -We recommend new contributors start from writing documentation, which helps you quickly understand SGLang codebase. +We recommend new contributors start from writing documentation, which helps you quickly understand miles codebase. Most documentation files are located under the `docs/` folder. ## Docs Workflow diff --git a/docs/en/advanced/pd-disaggregation.md b/docs/en/advanced/pd-disaggregation.md new file mode 100644 index 000000000..f509d72d8 --- /dev/null +++ b/docs/en/advanced/pd-disaggregation.md @@ -0,0 +1,5 @@ +# PD Disaggregation + +miles supports Prefill and Decode disaggregation (PD Disaggregation). + +You can set the number of servers used for Prefill by setting the `--prefill-num-servers` argument. diff --git a/docs/en/advanced/speculative-decoding.md b/docs/en/advanced/speculative-decoding.md index f85d0ca1f..2c37b6aa4 100644 --- a/docs/en/advanced/speculative-decoding.md +++ b/docs/en/advanced/speculative-decoding.md @@ -4,7 +4,7 @@ Speculative decoding is a key optimization for speeding up rollouts. Instead of ## Accelerating Inference with Speculative Decoding -For models with MTP layers (e.g., GLM-4.6, DeepSeek-V3/R1), simply add: +For models with MTP layers (e.g., GLM-4.7, DeepSeek-V3/R1), simply add: ```bash --sglang-speculative-algorithm EAGLE diff --git a/docs/en/examples/deepseek-r1.md b/docs/en/examples/deepseek-r1.md index e1c24e3ad..19b418e97 100644 --- a/docs/en/examples/deepseek-r1.md +++ b/docs/en/examples/deepseek-r1.md @@ -11,7 +11,7 @@ Regarding parallelism, for sglang we will enable EP64, activate dp attention, an ## Environment Setup -For instructions on setting up the environment and downloading data, please refer to [Example: Qwen3-4B](./qwen3-4B.md). +For instructions on setting up the environment and downloading data, please refer to [Example: Qwen3-4B](qwen3-4B.md). To prepare the DeepSeek R1 checkpoint, first you will need to download DeepSeek-R1 to a directory accessible by all machines (hereinafter referred to as `$BASE_DIR`): @@ -85,7 +85,7 @@ SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" source "${SCRIPT_DIR}/models/deepseek-v3.sh" ``` -This reads the model's config from [scripts/models/deepseek-v3.sh](../../../scripts/models/deepseek-v3.sh). These configs are all Megatron parameters. When training with Megatron, it cannot read the model config from the checkpoint, so we need to configure it ourselves. We provide some examples in [scripts/models](../../../scripts/models/). +This reads the model's config from [scripts/models/deepseek-v3.sh](https://github.com/radixark/miles/blob/main/scripts/models/deepseek-v3.sh). These configs are all Megatron parameters. When training with Megatron, it cannot read the model config from the checkpoint, so we need to configure it ourselves. We provide some examples in [scripts/models](https://github.com/radixark/miles/tree/main/scripts/models/). #### CKPT\_ARGS diff --git a/docs/en/examples/glm4-9B.md b/docs/en/examples/glm4-9B.md index f46e9f373..36629568f 100644 --- a/docs/en/examples/glm4-9B.md +++ b/docs/en/examples/glm4-9B.md @@ -8,7 +8,7 @@ After pulling the `radixark/miles:latest` image, initialize the image environmen cd /root/ git clone https://github.com/radixark/miles.git cd miles/ -pip install -e . +pip install -e . --no-deps ``` Download the model and data: @@ -49,7 +49,7 @@ bash scripts/run-glm4-9B.sh ### Parameter Introduction -Here, we will briefly introduce the various components of the [run-glm4-9B.sh](../../../scripts/run-glm4-9B.sh) script: +Here, we will briefly introduce the various components of the [run-glm4-9B.sh](https://github.com/radixark/miles/blob/main/scripts/run-glm4-9B.sh) script: #### MODEL\_ARGS @@ -58,7 +58,7 @@ SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" source "${SCRIPT_DIR}/models/glm4-9B.sh" ``` -Reads the model's config from [scripts/models/glm4-9B.sh](../../../scripts/models/glm4-9B.sh). These configs are all Megatron parameters. When training with Megatron, it cannot read the model config from the checkpoint, so we need to configure it ourselves. We provide some examples in [scripts/models](../../../scripts/models/). +Reads the model's config from [scripts/models/glm4-9B.sh](https://github.com/radixark/miles/blob/main/scripts/models/glm4-9B.sh). These configs are all Megatron parameters. When training with Megatron, it cannot read the model config from the checkpoint, so we need to configure it ourselves. We provide some examples in [scripts/models](https://github.com/radixark/miles/tree/main/scripts/models/). โš ๏ธ Ensure that settings such as `--rotary-base` in the model configuration file match the settings of the model you are currently training. This is because different models, even with the same architecture, might use different values. If needed, you can override these parameters in your script after loading the model weights. For instance: diff --git a/docs/en/examples/glm4.5-355B-A32B.md b/docs/en/examples/glm4.5-355B-A32B.md index 50b7c4921..b5336ee17 100644 --- a/docs/en/examples/glm4.5-355B-A32B.md +++ b/docs/en/examples/glm4.5-355B-A32B.md @@ -5,12 +5,12 @@ This is an example of doing GLM-4.5 RL training using 64xH100 GPUs. ## Environment Setup -For instructions on setting up the environment and downloading data, please refer to [Example: Qwen3-4B](./qwen3-4B.md). +For instructions on setting up the environment and downloading data, please refer to [Example: Qwen3-4B](qwen3-4B.md). First, you will need to download GLM-4.5 to a directory accessible by all machines (hereinafter referred to as `$BASE_DIR`): ```bash -huggingface-cli download zai-org/GLM-4.5 --local-dir $BASE_DIR/GLM-4.5-355B-A32B +hf download zai-org/GLM-4.5 --local-dir $BASE_DIR/GLM-4.5-355B-A32B ``` Next, we need to convert the huggingface checkpoint into the torch_dist format with 2 nodes, each with 8 GPUs: @@ -66,7 +66,7 @@ SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" source "${SCRIPT_DIR}/models/glm4.5-355B-A32B.sh" ``` -This reads the model's config from [scripts/models/glm4.5-355B-A32B.sh](../../../scripts/models/glm4.5-355B-A32B.sh). These configs are all Megatron parameters. When training with Megatron, it cannot read the model config from the checkpoint, so we need to configure it ourselves. We provide some examples in [scripts/models](../../../scripts/models/). +This reads the model's config from [scripts/models/glm4.5-355B-A32B.sh](https://github.com/radixark/miles/blob/main/scripts/models/glm4.5-355B-A32B.sh). These configs are all Megatron parameters. When training with Megatron, it cannot read the model config from the checkpoint, so we need to configure it ourselves. We provide some examples in [scripts/models](https://github.com/radixark/miles/tree/main/scripts/models/). #### PERF\_ARGS diff --git a/docs/en/examples/qwen3-30B-A3B.md b/docs/en/examples/qwen3-30B-A3B.md index 965ef7eb5..2f731d461 100644 --- a/docs/en/examples/qwen3-30B-A3B.md +++ b/docs/en/examples/qwen3-30B-A3B.md @@ -3,13 +3,13 @@ ## Environment Preparation -The environment setup, model download, data, and checkpoint conversion are the same as for the Qwen3-4B model. You can refer to [Example: Qwen3-4B Model](./qwen3-4B.md), replacing mentions of Qwen3-4B with Qwen3-30B-A3B. +The environment setup, model download, data, and checkpoint conversion are the same as for the Qwen3-4B model. You can refer to [Example: Qwen3-4B Model](qwen3-4B.md), replacing mentions of Qwen3-4B with Qwen3-30B-A3B. To convert huggingface checkpoint to torch_dist, please try: ```bash cd miles/ -pip install -e . +pip install -e . --no-deps source scripts/models/qwen3-30B-A3B.sh PYTHONPATH=/root/Megatron-LM/ torchrun --nproc-per-node 8 \ tools/convert_hf_to_torch_dist.py \ @@ -29,7 +29,7 @@ bash scripts/run-qwen3-30B-A3B.sh ### Parameter Introduction -Here, we will briefly introduce the MoE-related parts in the [run-qwen3-30B-A3B.sh](../../../scripts/run-qwen3-30B-A3B.sh) script. +Here, we will briefly introduce the MoE-related parts in the [run-qwen3-30B-A3B.sh](https://github.com/radixark/miles/blob/main/scripts/run-qwen3-30B-A3B.sh) script. 1. To support running Qwen3-30B-A3B in an 8xH800 environment, we need to enable Megatron's CPU Adam to save GPU memory. The corresponding configuration is: diff --git a/docs/en/examples/qwen3-4B.md b/docs/en/examples/qwen3-4B.md index 1966fd823..8374de4be 100644 --- a/docs/en/examples/qwen3-4B.md +++ b/docs/en/examples/qwen3-4B.md @@ -8,7 +8,7 @@ After pulling the `radixark/miles:latest` image, initialize the image environmen cd /root/ git clone https://github.com/radixark/miles.git cd miles/ -pip install -e . +pip install -e . --no-deps ``` Download the model and data: @@ -49,7 +49,7 @@ bash scripts/run-qwen3-4B.sh ### Parameter Introduction -Here, we will briefly introduce the various components of the [run-qwen3-4B.sh](../../../scripts/run-qwen3-4B.sh) script: +Here, we will briefly introduce the various components of the [run-qwen3-4B.sh](https://github.com/radixark/miles/blob/main/scripts/run-qwen3-4B.sh) script: #### MODEL\_ARGS @@ -58,7 +58,7 @@ SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" source "${SCRIPT_DIR}/models/qwen3-4B.sh" ``` -This reads the model's configuration from [scripts/models/qwen3-4B.sh](../../../scripts/models/qwen3-4B.sh). These are all Megatron parameters. When training with Megatron, it cannot read the model config from the checkpoint, so we need to configure it ourselves. We provide some examples in [scripts/models](../../../scripts/models/). +This reads the model's configuration from [scripts/models/qwen3-4B.sh](https://github.com/radixark/miles/blob/main/scripts/models/qwen3-4B.sh). These are all Megatron parameters. When training with Megatron, it cannot read the model config from the checkpoint, so we need to configure it ourselves. We provide some examples in [scripts/models](https://github.com/radixark/miles/tree/main/scripts/models/). โš ๏ธ Ensure that settings such as `--rotary-base` in the model configuration file match the settings of the model you are currently training. This is because different models, even with the same architecture, might use different values. If needed, you can override these parameters in your script after loading the model weights. For instance: diff --git a/docs/en/examples/qwen3-4b-base-openhermes.md b/docs/en/examples/qwen3-4b-base-openhermes.md index cab853008..a4b6237b6 100644 --- a/docs/en/examples/qwen3-4b-base-openhermes.md +++ b/docs/en/examples/qwen3-4b-base-openhermes.md @@ -3,7 +3,7 @@ ## Environment Preparation -First, we need to create a mirror environment and convert the `Qwen3-4B-Base` model by following the [Example: Qwen3-4B Model](./models/qwen3-4B.md). +First, we need to create a mirror environment and convert the `Qwen3-4B-Base` model by following the [Example: Qwen3-4B Model](qwen3-4B.md). After that, we will process the SFT data. Here, we use the classic [OpenHermes-2.5](https://huggingface.co/datasets/teknium/OpenHermes-2.5) as an example. First, we process the data into a format suitable for `miles` to load. You can use the following script to add a column that conforms to the OpenAI message format and save it to `/root/openhermes2_5.parquet`. @@ -50,7 +50,7 @@ bash script/run-qwen3-4B-base-sft.sh ### Parameter Introduction -You can compare [run-qwen3-4B-base-sft.sh](../../scripts/run-qwen3-4B.sh) with [run-qwen3-4B.sh](../../scripts/run-qwen3-4B.sh). You will find that besides changing the model from the instruct version to the base model, the main adjustments are as follows: +You can compare [run-qwen3-4B-base-sft.sh](https://github.com/radixark/miles/blob/main/scripts/run-qwen3-4B-base-sft.sh) with [run-qwen3-4B.sh](https://github.com/radixark/miles/blob/main/scripts/run-qwen3-4B.sh). You will find that besides changing the model from the instruct version to the base model, the main adjustments are as follows: 1. Removed `SGLANG_ARGS` and `GRPO_ARGS`. This is because it is not necessary to start SGLang or configure GRPO-related settings during the SFT process. diff --git a/docs/en/get_started/customization.md b/docs/en/get_started/customization.md new file mode 100644 index 000000000..b1088ce64 --- /dev/null +++ b/docs/en/get_started/customization.md @@ -0,0 +1,419 @@ +# Customization Guide + +miles provides extensive customization capabilities through function path arguments. These allow you to inject custom logic at various stages of the training and rollout pipeline without modifying the core codebase. + +## Overview of Customization Interfaces + +Below is a summary of all available customization interfaces and their purposes. + +| Interface Argument | Purpose | +| :--- | :--- | +| [`--rollout-function-path`](#1-rollout-function---rollout-function-path) | Override the entire rollout generation logic. | +| [`--custom-generate-function-path`](#2-custom-generate-function---custom-generate-function-path) | Override only the generation step (e.g., for RAG or tool use). | +| [`--custom-rm-path`](#3-reward-model---custom-rm-path) | Implement custom reward computation logic. | +| [`--dynamic-sampling-filter-path`](#4-dynamic-sampling-filter---dynamic-sampling-filter-path) | Filter samples during dynamic sampling (e.g., DAPO). | +| [`--buffer-filter-path`](#5-buffer-filter---buffer-filter-path) | Filter samples in the rollout buffer before training. | +| [`--rollout-sample-filter-path`](#6-rollout-sample-filter---rollout-sample-filter-path) | Determine if individual samples participate in loss calculation. | +| [`--rollout-all-samples-process-path`](#7-rollout-all-samples-process---rollout-all-samples-process-path) | Process all samples (including filtered ones) after rollout. | +| [`--rollout-data-postprocess-path`](#8-rollout-data-postprocess---rollout-data-postprocess-path) | Post-process rollout data after log probs are computed. | +| [`--custom-loss-function-path`](#9-custom-loss-function---custom-loss-function-path) | Implement custom training loss computation. | +| [`--custom-tis-function-path`](#10-custom-tisrs-function---custom-tis-function-path) | Implement custom importance sampling for off-policy correction. | +| [`--custom-pg-loss-reducer-function-path`](#11-custom-pg-loss-reducer---custom-pg-loss-reducer-function-path) | Customize pg_loss reduction (e.g., for Dr.GRPO). | +| [`--custom-reward-post-process-path`](#12-reward-post-processing---custom-reward-post-process-path) | Custom post-processing of rewards before advantage computation. | +| [`--custom-convert-samples-to-train-data-path`](#13-samples-to-train-data-conversion---custom-convert-samples-to-train-data-path) | Override the conversion of samples to training data format. | +| [`--custom-rollout-log-function-path`](#14-logging-functions) | Custom logging for training rollouts. | +| [`--custom-eval-rollout-log-function-path`](#14-logging-functions) | Custom logging for evaluation rollouts. | +| [`--data-source-path`](#15-data-source---data-source-path) | Override the data source for rollout prompts. | +| [`--eval-function-path`](#16-evaluation-function---eval-function-path) | Override the rollout function specifically for evaluation. | +| [`--custom-megatron-init-path`](#17-megatron-hooks) | Custom initialization after Megatron setup. | +| [`--custom-megatron-before-log-prob-hook-path`](#17-megatron-hooks) | Custom logic before log probability computation. | +| [`--custom-megatron-before-train-step-hook-path`](#17-megatron-hooks) | Custom logic before each training step. | +| [`--miles-router-middleware-paths`](#18-miles-router-middleware---miles-router-middleware-paths) | Add custom middleware to miles router. | + +## Detailed Interface Reference + +### 1. Rollout Function (`--rollout-function-path`) + +**Default**: `miles.rollout.sglang_rollout.generate_rollout` + +**Purpose**: Override the entire rollout generation logic. + +**Signature**: +```python +async def generate_rollout(args, rollout_id, *, evaluation=False) -> RolloutFnTrainOutput | RolloutFnEvalOutput +``` + +**Use Cases**: +- Implementing complex multi-turn conversations +- Adding custom sampling strategies +- Integrating external tools or APIs during generation + +**Example**: See [examples/multi_agent/rollout_with_multi_agents.py](../../../examples/multi_agent/rollout_with_multi_agents.py) + +--- + +### 2. Custom Generate Function (`--custom-generate-function-path`) + +**Default**: `None` (uses built-in generate function) + +**Purpose**: Override only the generation step within the default rollout function. + +**Signature**: +```python +async def custom_generate(args, sample: Sample, sampling_params: dict) -> Sample +``` + +**Use Cases**: +- Implementing tool-calling or function-calling capabilities +- Adding retrieval-augmented generation (RAG) +- Multi-turn conversation handling + +**Example**: See [examples/search-r1/generate_with_search.py](../../../examples/search-r1/generate_with_search.py) + +--- + +### 3. Reward Model (`--custom-rm-path`) + +**Default**: `None` (uses built-in reward models based on `--rm-type`) + +**Purpose**: Implement custom reward computation logic. + +**Signature** (single sample mode): +```python +async def custom_rm(args, sample: Sample) -> float +``` + +**Signature** (batch mode, when `--group-rm` is enabled): +```python +async def batched_custom_rm(args, samples: list[Sample]) -> list[float] +``` + +**Use Cases**: +- Custom rule-based rewards +- Integration with external reward model services +- Multi-dimensional reward signals + +**Built-in Options** (`--rm-type`): +- `math`: Mathematical answer verification +- `dapo`: DAPO-style scoring +- `deepscaler`: DeepScaler rule-based reward +- `f1`: F1 score computation +- `gpqa`: GPQA reward computation +- `ifbench`: IFBench reward computation +- `remote_rm`: Remote reward model service (requires `--rm-url`) + +--- + +### 4. Dynamic Sampling Filter (`--dynamic-sampling-filter-path`) + +**Default**: `None` + +**Purpose**: Filter samples during dynamic sampling (e.g., DAPO-style filtering). + +**Signature**: +```python +def filter_function(args, samples: list[Sample], **kwargs) -> DynamicFilterOutput +``` + +**Return Type**: +```python +@dataclass +class DynamicFilterOutput: + keep: bool # Whether to keep this sample group + reason: str | None # Reason for filtering (for logging) +``` + +**Use Cases**: +- Filtering out samples where all responses have the same reward +- Implementing curriculum learning strategies +- Quality-based sample selection + +**Example**: `miles.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std` + +--- + +### 5. Buffer Filter (`--buffer-filter-path`) + +**Default**: `None` + +**Purpose**: Filter samples in the rollout buffer before training. + +**Signature**: +```python +def buffer_filter(samples: list[list[Sample]]) -> list[list[Sample]] +``` + +**Use Cases**: +- Removing low-quality samples before training +- Implementing priority-based sample selection +- Balancing sample distributions + +--- + +### 6. Rollout Sample Filter (`--rollout-sample-filter-path`) + +**Default**: `None` + +**Purpose**: Determine whether individual samples participate in loss calculation. + +**Signature**: +```python +def filter_function(args, samples: list[Sample]) -> None +``` + +**Note**: This function should directly modify the `remove_sample` attribute of each `Sample` object. + +**Use Cases**: +- Filtering samples based on response quality +- Implementing selective training strategies + +--- + +### 7. Rollout All Samples Process (`--rollout-all-samples-process-path`) + +**Default**: `None` + +**Purpose**: Process all samples (including filtered ones) after rollout. + +**Signature**: +```python +def process_function(args, samples: list[list[Sample]]) -> None +``` + +**Use Cases**: +- Logging and analysis of all generated samples +- Computing statistics across filtered and kept samples + +--- + +### 8. Rollout Data Postprocess (`--rollout-data-postprocess-path`) + +**Default**: `None` + +**Purpose**: Post-process rollout data after log probabilities are computed. + +**Signature**: +```python +def postprocess_function(args, samples: list[list[Sample]]) -> None +``` + +**Use Cases**: +- Updating loss masks based on computed values +- Adding additional metadata to samples + +--- + +### 9. Custom Loss Function (`--custom-loss-function-path`) + +**Default**: `None` (requires `--loss-type custom_loss`) + +**Purpose**: Implement custom training loss computation. + +**Use Cases**: +- Novel RL objectives +- Multi-objective optimization +- Custom regularization terms + +--- + +### 10. Custom TIS/RS Function (`--custom-tis-function-path`) + +**Default**: `None` + +**Purpose**: Implement custom importance sampling for off-policy correction. + +**Use Cases**: +- Custom importance sampling ratio computation +- Advanced off-policy correction methods + +**Example**: `examples/train_infer_mismatch_helper/mis.py:compute_mis_weights_with_cp` + +--- + +### 11. Custom pg_loss Reducer (`--custom-pg-loss-reducer-function-path`) + +**Default**: `None` + +**Purpose**: Customize the reduction of pg_loss while other metrics (pg_clipfrac, ppo_kl, entropy_loss, etc.) still use the default sum_of_sample_mean. + +**Signature**: +```python +def get_pg_loss_reducer( + total_lengths: list[int], + response_lengths: list[int], + loss_masks: list[torch.Tensor], + calculate_per_token_loss: bool = False, +) -> Callable[[torch.Tensor], torch.Tensor] +``` + +**Use Cases**: +- Dr.GRPO: Divide by a constant instead of effective token count +- Custom loss normalization strategies + +**Example**: `examples/DrGRPO/custom_reducer.py:get_pg_loss_reducer` + +--- + +### 12. Reward Post-Processing (`--custom-reward-post-process-path`) + +**Default**: `None` (uses default GRPO normalization) + +**Purpose**: Custom post-processing of rewards before advantage computation. + +**Use Cases**: +- Custom reward normalization strategies +- Reward shaping + +--- + +### 13. Samples to Train Data Conversion (`--custom-convert-samples-to-train-data-path`) + +**Default**: `None` (uses built-in conversion logic) + +**Purpose**: Override the conversion of samples to training data format. + +**Signature**: +```python +def convert_samples_to_train_data( + args, + samples: list[Sample] | list[list[Sample]], +) -> dict +``` + +**Return Type**: +```python +dict: { + "tokens": list[list[int]], # Token IDs for each sample + "response_lengths": list[int], # Response lengths + "rewards": list[float], # Normalized rewards + "raw_reward": list[float], # Raw rewards + "truncated": list[int], # Truncation flags (0 or 1) + "sample_indices": list[int], # Sample indices + "loss_masks": list[list[int]], # Loss masks for each sample + # Optional fields: + "round_number": list[int], # Round numbers (for rollout buffer) + "rollout_log_probs": list, # Log probs (for off-policy correction) + "rollout_routed_experts": list, # Routed experts (for MoE) + "metadata": list, # Train metadata + "multimodal_train_inputs": list, # Multimodal tensors (for VLM) + "teacher_log_probs": list, # Teacher log probs (for distillation) +} +``` + +**Use Cases**: +- Handling `list[list[Sample]]` inputs +- Custom data format requirements for training + +--- + +### 14. Logging Functions + +#### Training Rollout Logging (`--custom-rollout-log-function-path`) + +**Signature**: +```python +def log_rollout_data(rollout_id, args, samples, rollout_extra_metrics, rollout_time) -> bool +``` + +**Return**: `True` to skip default logging, `False` to continue with default logging. + +#### Evaluation Rollout Logging (`--custom-eval-rollout-log-function-path`) + +**Signature**: +```python +def log_eval_rollout_data(rollout_id, args, data, extra_metrics) -> bool +``` + +**Return**: `True` to skip default logging, `False` to continue with default logging. + +--- + +### 15. Data Source (`--data-source-path`) + +**Default**: `miles.rollout.data_source.RolloutDataSourceWithBuffer` + +**Purpose**: Override the data source for rollout prompts. + +**Base Class**: `miles.rollout.data_source.DataSource` + +**Required Methods**: +```python +class CustomDataSource(DataSource): + def get_samples(self, num_samples: int) -> list[list[Sample]]: + """Return num_samples samples""" + + def add_samples(self, samples: list[list[Sample]]): + """Add samples back to the data source""" + + def save(self, rollout_id): + """Save state for checkpointing""" + + def load(self, rollout_id=None): + """Load state from checkpoint""" +``` + +--- + +### 16. Evaluation Function (`--eval-function-path`) + +**Default**: Same as `--rollout-function-path` + +**Purpose**: Override the rollout function specifically for evaluation. + +**Use Cases**: +- Different sampling parameters for evaluation +- Evaluation-specific logic + +--- + +### 17. Megatron Hooks + +#### Megatron Initialization (`--custom-megatron-init-path`) + +**Signature**: +```python +def custom_init(args) -> None +``` + +**Purpose**: Custom initialization after Megatron setup. + +#### Before Log Prob Hook (`--custom-megatron-before-log-prob-hook-path`) + +**Signature**: +```python +def custom_hook(args, model, store_prefix) -> None +``` + +**Purpose**: Custom logic before log probability computation. + +#### Before Train Step Hook (`--custom-megatron-before-train-step-hook-path`) + +**Signature**: +```python +def custom_hook(args, rollout_id, step_id, model, optimizer, opt_param_scheduler) -> None +``` + +**Purpose**: Custom logic before each training step. + +--- + +### 18. miles Router Middleware (`--miles-router-middleware-paths`) + +**Purpose**: Add custom middleware to the miles router for request processing. + +**Use Cases**: +- Request/response transformation +- Custom routing logic +- Caching and optimization + +--- + +### 19. MoE Routing Replay + +Stabilize MoE RL training by recording and replaying expert routing decisions to ensure consistency. + +| Argument | Description | +| --- | --- | +| `--use-routing-replay` | Forward-backward routing consistency in training. ([arXiv:2507.18071](https://arxiv.org/abs/2507.18071)) | +| `--use-rollout-routing-replay` | R3: Replay routing from rollout during training. **Requires `--use-miles-router`**. ([arXiv:2510.11370](https://arxiv.org/abs/2510.11370)) | + diff --git a/docs/en/get_started/qa.md b/docs/en/get_started/qa.md index b3163f215..c9e8dad21 100644 --- a/docs/en/get_started/qa.md +++ b/docs/en/get_started/qa.md @@ -49,7 +49,7 @@ 9. **My gradient norm is very high and the training crashes. What should I do?** - First, ensure that your data and model are compatible. For example, if your data already uses a chat template, check if this template matches the one used by the original model. If the data is correct, please refer to our [Debug Guide](./debug.md) for a more in-depth analysis. + First, ensure that your data and model are compatible. For example, if your data already uses a chat template, check if this template matches the one used by the original model. If the data is correct, please refer to our [Debug Guide](../developer_guide/debug.md) for a more in-depth analysis. 10. **My sglang generation takes an extremely long time, GPU power is maxed out, and there's no output for a long while. Why?** @@ -57,7 +57,7 @@ 11. **Sglang shows an `an illegal memory access was encountered` error.** - According to the sglang documentation ([https://docs.sglang.ai/references/troubleshooting.html](https://docs.sglang.ai/references/troubleshooting.html)), this could be an OOM error. Consider reducing the value of `--sglang-mem-fraction-static`. + According to [SGLang documentation](https://docs.sglang.io/references/faq.html), this could be an OOM error. Consider reducing the value of `--sglang-mem-fraction-static`. 12. **A `JSONDecodeError` occurs related to torch compile/inductor.** diff --git a/docs/en/get_started/quick_start.md b/docs/en/get_started/quick_start.md index db07ab705..180075ed3 100644 --- a/docs/en/get_started/quick_start.md +++ b/docs/en/get_started/quick_start.md @@ -39,13 +39,13 @@ docker run --rm --gpus all --ipc=host --shm-size=16g \ ### Install miles -miles is already installed in the docker image. To update to the latest version, please execute the following command: +miles is already installed in the docker image. To update to the latest verison, please execute the following command: ```bash # Path can be adjusted according to actual situation cd /root/miles git pull -pip install -e . +pip install -e . --no-deps ``` ## Model and Dataset Download @@ -105,6 +105,14 @@ PYTHONPATH=/root/Megatron-LM python tools/convert_torch_dist_to_hf.py \ Note that as Megatron will do padding to embedding for better performance, it may happen that the converted embedding is not correct. In that case, please manually set `--vocab-size` during convertion. +For FSDP checkpoints (without `common.pt`), use the dedicated conversion script. Point `--input-dir` to the checkpoint directory (e.g. `iter_xxx` or `iter_xxx/model`) and provide the original Hugging Face directory: + +```bash +python tools/convert_fsdp_to_hf.py \ + --input-dir /path/to/fsdp_ckpt/iter_xxx \ + --output-dir /root/fsdp-converted \ + --origin-hf-dir /root/GLM-Z1-9B-0414 +``` ## Training Script and Parameter Overview diff --git a/docs/en/get_started/usage.md b/docs/en/get_started/usage.md index 0c6d8b098..97f6449fe 100644 --- a/docs/en/get_started/usage.md +++ b/docs/en/get_started/usage.md @@ -6,7 +6,7 @@ When using miles, parameters are primarily passed for the following purposes: 1. To allocate a portion of the GPUs in the cluster for training and another portion for inference. -2. To load Megatron for the training portion. +2. To load Megatron or FSDP for the training portion. 3. To load SGLang for the inference portion. 4. To configure the hyperparameters required for RL training. @@ -28,6 +28,15 @@ For co-located training and inference, you also need to configure: - `--colocate`: Enables co-located training and inference. When enabled, it ignores `--rollout-num-gpus` and makes the number of GPUs for training and inference equal. +Additionally, miles supports Prefill and Decode disaggregation (PD Disaggregation). You can set the number of servers used for Prefill by setting the `--prefill-num-servers` argument. + +### Choosing Training Backend + +miles supports multiple training backends, which can be selected via the `--train-backend` parameter: + +- `megatron` (default): Uses Megatron-LM as the training backend, supporting efficient training of large-scale models. +- `fsdp`: Uses PyTorch FSDP as the training backend, allowing direct loading of HuggingFace format weights without conversion. + ### Loading Megatron Unlike tools such as SGLang, vLLM, or Hugging Face Trainer, Megatron cannot directly read Hugging Face checkpoints. Instead, the user must configure the parameters for the model to be trained and load Megatron's own checkpoint format. @@ -67,7 +76,7 @@ MODEL_ARGS=( ) ``` -We provide configurations for common models in [scripts/models](../../scripts/models), which you can reuse directly. If you are also using Megatron for pre-training/SFT, you can directly reuse the model configurations from your pre-training/SFT setup. +We provide configurations for common models in [scripts/models](../../../scripts/models), which you can reuse directly. If you are also using Megatron for pre-training/SFT, you can directly reuse the model configurations from your pre-training/SFT setup. Note: @@ -99,7 +108,7 @@ Megatron supports several of its custom checkpoint formats. Here are two of the The `torch` format is Megatron's older storage format. Its structure consists of directories like `mp_rank_xxx`, where each directory corresponds to the checkpoint stored by each rank under a specific parallel partitioning. Because of this, when loading a `torch` format checkpoint, you must ensure that the checkpoint's parallelism strategy matches that of the training task. -We recommend using the `torch_dist` format because it supports automatic parallel sharding, meaning that training tasks with different parallelism settings can share the same checkpoint, which is much more convenient. `torch_dist` is also the default format in the open-source Megatron. A `torch_dist` format checkpoint typically contains a set of `.distcp` files. When using `torch_dist`, you can convert from Hugging Face to `torch_dist` and vice versa using the checkpoint conversion method described in the [README](../../README.md). +We recommend using the `torch_dist` format because it supports automatic parallel sharding, meaning that training tasks with different parallelism settings can share the same checkpoint, which is much more convenient. `torch_dist` is also the default format in the open-source Megatron. A `torch_dist` format checkpoint typically contains a set of `.distcp` files. When using `torch_dist`, you can convert from Hugging Face to `torch_dist` and vice versa using the checkpoint conversion method described in the [README](../../../README.md). In terms of storage structure, a Megatron checkpoint typically looks like this, assuming the storage path is `/ckpt/`: @@ -138,6 +147,7 @@ Note: - Before the first training step, miles will synchronize the parameters from Megatron to SGLang. Therefore, the `--hf-checkpoint` does not need to contain the latest training parameters, and you do not need to change the HF checkpoint when resuming training. - By default, SGLang reads the maximum context length from the `config.json` in the Hugging Face checkpoint. You can use the `--sglang-context-length` parameter to override this value to support longer inference. - During co-located training and inference, although Megatron and SGLang will offload sequentially, they still need to leave some memory for each other. You need to adjust SGLang's total VRAM usage by reducing `--sglang-mem-fraction-static`. + - miles supports passing through sgl-router parameters by adding a `router` prefix to the original parameter name. For example, sgl-router's `--balance-abs-threshold` parameter should be set as `--router-balance-abs-threshold`. Since sgl-router uses cache-aware routing by default, it may cause uneven request distribution. You can set `--router-balance-abs-threshold 0` to force balanced distribution, but this may affect prefix cache hit rate in multi-turn conversation scenarios. For details on some of SGLang's customizations and the principles behind how miles incorporates SGLang, please see the "How to Use SGLang" section. @@ -176,14 +186,16 @@ Additionally, we provide a `metadata_key`, which defaults to `"metadata"`. When - `gspo` ([https://arxiv.org/abs/2507.18071](https://arxiv.org/abs/2507.18071)) - `reinforce_plus_plus` and `reinforce_plus_plus_baseline` ([https://arxiv.org/abs/2501.03262](https://arxiv.org/abs/2501.03262)) - `ppo` ([https://arxiv.org/abs/1707.06347](https://arxiv.org/abs/1707.06347)) -- `--calculate-per-token-loss`: By default, Miles calculates loss on a per-sample basis, i.e., `mean(sum(sample_i) / len(sample_i))`. Enable this flag to calculate loss on a per-token basis, i.e., `sum(sum(sample_i)) / sum(len(sample_i))`. + - `on_policy_distillation` +- `--calculate-per-token-loss`: By default, miles calculates loss on a per-sample basis, i.e., `mean(sum(sample_i) / len(sample_i))`. Enable this flag to calculate loss on a per-token basis, i.e., `sum(sum(sample_i)) / sum(len(sample_i))`. - `--use-tis`: Enable this setting to use TIS (Truncated Importance Sampling) (https://fengyao.notion.site/off-policy-rl). +- `--true-on-policy-mode`: Enable True On-Policy mode, which strictly ensures that data is generated by the current policy during training. ## Custom Rollout Function miles supports customizing data generation (rollout) to various degrees. - - By default, it uses the `generate_rollout` function from [miles/rollout/sglang\_example.py](../../miles/rollout/sglang_rollout.py) for data generation. This file implements an asynchronous (asyncio) data generation flow based on SGLang and supports features like dynamic sampling and partial rollout. + - By default, it uses the `generate_rollout` function from [miles/rollout/sglang_rollout.py](https://github.com/radixark/miles/blob/main/miles/rollout/sglang_rollout.py) for data generation. This file implements an asynchronous (asyncio) data generation flow based on SGLang and supports features like dynamic sampling and partial rollout. - You can completely replace the `generate_rollout` in sglang\_example.py by using the `--rollout-function-path` parameter. You just need to ensure that the function signature passed via `--rollout-function-path` is as follows: @@ -213,7 +225,7 @@ miles supports customizing data generation (rollout) to various degrees. - `evaluation`: A boolean indicating if the rollout is for evaluation. You can configure a separate evaluation function using `--eval-function-path`. - - The returned `Sample` type is defined in [miles/utils/types.py](../../miles/utils/types.py). When implementing, you need to ensure the following fields are correctly set: + - The returned `Sample` type is defined in [miles/utils/types.py](https://github.com/radixark/miles/blob/main/miles/utils/types.py). When implementing, you need to ensure the following fields are correctly set: - `tokens`: The tokens for the prompt + response. - `response_length`: The total length of the response. For multi-turn tasks, this is the length of the tokens remaining after the first-turn prompt. @@ -254,7 +266,7 @@ miles supports customizing data generation (rollout) to various degrees. return sample ``` - For a more complete version, please refer to [miles/rollout/sglang\_example.py](../../miles/rollout/sglang_rollout.py). + For a more complete version, please refer to [miles/rollout/sglang_rollout.py](https://github.com/radixark/miles/blob/main/miles/rollout/sglang_rollout.py). - Sometimes, you may also need to support a custom reward model. This can be configured by setting `--custom-rm-path`. @@ -275,11 +287,11 @@ Some parameters related to miles's resource scheduling are configured by miles i - `--tp-size` in miles is set using `--rollout-num-gpus-per-engine`. - `--model-path` in miles is set using `--hf-checkpoint`. -The way SGLang parameters are integrated into miles can be found in [miles/backends/sglang\_utils/arguments.py](../../miles/backends/sglang_utils/arguments.py). +The way SGLang parameters are integrated into miles can be found in [miles/backends/sglang_utils/arguments.py](https://github.com/radixark/miles/blob/main/miles/backends/sglang_utils/arguments.py). ### How to Use the Router -miles uses [sglang-router](https://github.com/sgl-project/sglang/tree/main/sgl-router) to manage the SGLang servers during the training process. You can configure the address of the [sglang-router](https://github.com/sgl-project/sglang/tree/main/sgl-router) using `--sglang-router-ip` and `--sglang-router-port`. If not configured, a router will be started by default within the cluster. +miles uses [sglang-router](https://github.com/sgl-project/sglang/tree/main/sgl-model-gateway) to manage the SGLang servers during the training process. You can configure the address of the [sglang-router](https://github.com/sgl-project/sglang/tree/main/sgl-model-gateway) using `--sglang-router-ip` and `--sglang-router-port`. If not configured, a router will be started by default within the cluster. After starting, all SGLang servers will register with the router via the `/add_worker` endpoint. When actually generating data, you only need to send HTTP requests to the router, which will perform load balancing and forward the requests to the servers. @@ -291,7 +303,7 @@ miles supports different and lightly modified versions of Megatron by reusing co ### Parameter Configuration -miles directly imports all parameters of the Megatron in the current environment by using `from megatron.training.arguments import parse_args`. If the version of Megatron you are using has parameters defined outside of `parse_args`, you can configure them by passing them in, similar to how it's done in [train.py](../../train.py), for example: +miles directly imports all parameters of the Megatron in the current environment by using `from megatron.training.arguments import parse_args`. If the version of Megatron you are using has parameters defined outside of `parse_args`, you can configure them by passing them in, similar to how it's done in [train.py](https://github.com/radixark/miles/blob/main/train.py), for example: ```python if __name__ == "__main__": @@ -309,4 +321,64 @@ In some customized Megatron implementations, special operations need to be perfo - `--custom-megatron-init-path`: Adds some initialization calls. - `--custom-megatron-before-log-prob-hook-path`: Is called before calculating the log probability. - - `--custom-megatron-before-train-step-hook-path`: Is called before each training step. You could use this to mix in special training losses, for example. \ No newline at end of file + - `--custom-megatron-before-train-step-hook-path`: Is called before each training step. You could use this to mix in special training losses, for example. + +## How to Use FSDP + +miles also support FSDP2 as the training backend, docs [here](https://lmsys.org/blog/2025-12-03-miles-fsdp/). + +> FSDP automatically reads all architecture information via `AutoModelForCausalLM.from_pretrained()`, without manual specification. Megatron requires manual configuration of parameters to read model architecture information. FSDP can read entirely from `config.json`, directly avoiding the weight format conversion step. + +To run FSDP as the training backend, pass `--train-backend fsdp` to enable. + +### Parameters + +Parameters that FSDP used are shown as below in comparison to Megatron, more supports are coming on the way. + +| Configuration Category | Megatron Parameter | FSDP Parameter | Description | +| --- | --- | --- | --- | +| **Model Loading** | `--load` (Megatron checkpoint) + architecture args (`--num-layers`, `--hidden-size` etc.) | `--hf-checkpoint` (Required) | **FSDP**: Directly uses HuggingFace format, no weight conversion needed, architecture inferred via `AutoConfig` | +| **Tensor Parallel** | `--tensor-model-parallel-size` | Coming Soon | | +| **Pipeline Parallel** | `--pipeline-model-parallel-size` | Coming Soon | | +| **Expert Parallel** | `--expert-model-parallel-size` | Coming Soon | | +| **Context Parallel** | `--context-parallel-size` | `--context-parallel-size` | Both support CP | +| **Initial Learning Rate** | `--lr` | `--lr` | Same parameter | +| **Learning Rate Decay** | `--lr-decay-style` (linear/cosine etc.) | `--lr-decay-style` | Same parameter | +| **Warmup** | `--lr-warmup-iters` (steps) | `--lr-warmup-iters` | Same parameter | +| **Min Learning Rate** | `--min-lr` | `--min-lr` | Same parameter | +| **Optimizer Type** | `--optimizer` (adam/sgd etc.) | `--optimizer` (default adam) | Basically same | +| **Distributed Optimizer** | `--use-distributed-optimizer` | Built-in to FSDP | FSDP uses distributed optimizer by default | +| **Gradient Checkpoint** | `--recompute-granularity`, `--recompute-method` | `--gradient-checkpointing` | **FSDP**: Simplified to boolean switch | +| **CPU Offload** | Implemented via distributed optimizer | `--fsdp-cpu-offload` | **FSDP**: Offload parameters/gradients/optimizer states to CPU | +| **CPU Backend** | Implemented via distributed optimizer | `--fsdp-cpu-backend` | **FSDP**: Specify CPU backend and use hybrid backend when CPU offload is enabled | +| **Attention Backend** | Decided by Megatron Core | `--attn-implementation` (flash_attention_2/sdpa/eager) | **FSDP**: Directly passed to HuggingFace | +| **Mixed Precision** | `--fp16` or `--bf16` | `--fp16` (bf16 inferred automatically) | Basically same | +| **Training Backend** | Default or `--train-backend megatron` | `--train-backend fsdp` (Required) | Used to switch backend | +| **Config** | | `--config` | **FSDP**: Set additional parameters for FSDP backend | + +### Quick Start + +```bash +# If you need to use WANDB, you need to set the environment variable WANDB_API_KEY in advance +# Download model weights (Qwen3-4B) +hf download Qwen/Qwen3-4B --local-dir /root/Qwen3-4B + +# Download training dataset (dapo-math-17k) +hf download --repo-type dataset zhuzilin/dapo-math-17k \ + --local-dir /root/dapo-math-17k + +# Download evaluation dataset (aime-2024) +hf download --repo-type dataset zhuzilin/aime-2024 \ + --local-dir /root/aime-2024 + +# Clone code and install dependencies +git clone https://github.com/radixark/miles.git +cd miles +pip install -e . --no-deps + + +# FSDP does not require weight conversion, natively supports huggingface format +# Enable reference model, train Qwen3-4B in colocate mode +source /root/miles/scripts/run-qwen3-4B-fsdp.sh +``` + diff --git a/docs/en/index.rst b/docs/en/index.rst index be1427733..afafc6796 100644 --- a/docs/en/index.rst +++ b/docs/en/index.rst @@ -6,7 +6,7 @@ miles is an LLM post-training framework for RL scaling, providing two core capab - High-Performance Training: Supports efficient training in various modes by connecting Megatron with SGLang; - Flexible Data Generation: Enables arbitrary training data generation workflows through custom data generation interfaces and server-based engines. -miles is the RL-framework behind GLM-4.5 and GLM-4.6. Apart from models from Z.ai, we also supports the following models: +miles is the RL-framework behind GLM-4.7, GLM-4.6 and GLM-4.5. Apart from models from Z.ai, we also supports the following models: - Qwen3 series (Qwen3Next, Qwen3MoE, Qwen3), Qwen2.5 series; - DeepSeek V3 series (DeepSeek V3, V3.1, DeepSeek R1); @@ -18,6 +18,7 @@ miles is the RL-framework behind GLM-4.5 and GLM-4.6. Apart from models from Z.a get_started/quick_start.md get_started/usage.md + get_started/customization.md get_started/qa.md .. toctree:: @@ -43,6 +44,7 @@ miles is the RL-framework behind GLM-4.5 and GLM-4.6. Apart from models from Z.a advanced/speculative-decoding.md advanced/fault-tolerance.md advanced/arch-support-beyond-megatron.md + advanced/pd-disaggregation.md .. toctree:: :maxdepth: 1 @@ -52,7 +54,8 @@ miles is the RL-framework behind GLM-4.5 and GLM-4.6. Apart from models from Z.a _examples_synced/search-r1/README.md _examples_synced/fully_async/README.md _examples_synced/retool/README.md - _examples_synced/multi_agent/README.md + _examples_synced/multi_agent/README.md + _examples_synced/on_policy_distillation/README.md .. toctree:: :maxdepth: 1 diff --git a/docs/en/platform_support/amd_tutorial.md b/docs/en/platform_support/amd_tutorial.md index 790aede5d..c22fdb2ae 100644 --- a/docs/en/platform_support/amd_tutorial.md +++ b/docs/en/platform_support/amd_tutorial.md @@ -50,13 +50,27 @@ docker run --rm -it \ /bin/bash ``` -Then, download and install miles. +Then, download and install miles: ```bash git clone https://github.com/radixark/miles.git cd miles -pip install -e . +pip install -e . --no-deps ``` +Download the model and data: + +```bash +# hf checkpoint +hf download Qwen/Qwen3-4B --local-dir /root/Qwen3-4B + +# train data +hf download --repo-type dataset zhuzilin/dapo-math-17k \ + --local-dir /root/dapo-math-17k + +# eval data +hf download --repo-type dataset zhuzilin/aime-2024 \ + --local-dir /root/aime-2024 +``` ### Checkpoint Format Conversion @@ -73,19 +87,26 @@ MEGATRON_LM_PATH=$(pip list | grep megatron-core | awk '{print $NF}') PYTHONPATH=${MEGATRON_LM_PATH} python tools/convert_hf_to_torch_dist.py \ ${MODEL_ARGS[@]} \ --no-gradient-accumulation-fusion \ - --hf-checkpoint model/Qwen3-4B \ - --save model/Qwen3-4B_torch_dist + --hf-checkpoint /root/Qwen3-4B \ + --save /root/Qwen3-4B_torch_dist ``` Note: We implemented a dedicated AMD conversion script that forces a CPU-only conversion workflow using the Gloo backend to bypass hardware-specific issues. A GPU-based script for ROCm is currently in development. -โš ๏ธ If you encounter an issue where miles cannot be found, please run `pip install -e .` in the miles directory. +โš ๏ธ If you encounter an issue where miles cannot be found, please run `pip install -e . --no-deps` in the miles directory. ### Example: Qwen3-4B We provide examples to use [Qwen3-4B](https://huggingface.co/Qwen/Qwen3-4B), please refer to: -- [Example: Qwen3-4B Model](../../../scripts/run-qwen3-4B-amd.sh): Just run `scripts/run-qwen3-4B-amd.sh` +- [Example: Qwen3-4B Model](https://github.com/radixark/miles/blob/main/scripts/run-qwen3-4B-amd.sh): Just run + +```bash +MILES_DIR=/root \ +MODEL_DIR=/root \ +DATA_DIR=/root \ +bash scripts/run-qwen3-4B-amd.sh +``` โš ๏ธ TODO: ROCM seems to not support `apex` yet. Thus, we need to disable gradient accumulation fusionby adding the `--no-gradient-accumulation-fusion` flag in the training script currently. We will continue investigating how to enable this. diff --git a/examples/DrGRPO/README.md b/examples/DrGRPO/README.md new file mode 100644 index 000000000..56bd39848 --- /dev/null +++ b/examples/DrGRPO/README.md @@ -0,0 +1,50 @@ +# Dr.GRPO Custom Reducer + +This example demonstrates how to use a custom reducer function for Dr.GRPO algorithm. + +## Overview + +By default, miles divides the policy gradient loss by the number of effective tokens in each sample. This custom implementation allows you to divide by a constant value (default: 1000) instead. + +## Usage + +Use `--custom-pg-loss-reducer-function-path` to apply the custom reducer **only to pg_loss**, while other metrics (pg_clipfrac, ppo_kl, entropy_loss, etc.) still use the default sum_of_sample_mean: + +```bash +--custom-pg-loss-reducer-function-path examples.Dr.GRPO.custom_reducer.get_pg_loss_reducer +``` + +## Customization + +You can modify the `DIVISOR` constant in `custom_reducer.py` to use a different value: + +```python +# In custom_reducer.py +DIVISOR = 1000.0 # Change this to your desired constant +``` + +## How It Works + +The custom function has the same signature as the default `get_sum_of_sample_mean`: + +```python +def get_pg_loss_reducer( + total_lengths: list[int], + response_lengths: list[int], + loss_masks: list[torch.Tensor], + calculate_per_token_loss: bool = False, +) -> Callable[[torch.Tensor], torch.Tensor]: +``` + +Instead of dividing by `loss_mask_i.sum()` (the number of effective tokens), it divides by the constant `DIVISOR`. + +## Example + +```bash +GRPO_ARGS=( + --advantage-estimator grpo + --custom-pg-loss-reducer-function-path examples.Dr.GRPO.custom_reducer:get_pg_loss_reducer + # ... other arguments +) +``` + diff --git a/examples/DrGRPO/custom_reducer.py b/examples/DrGRPO/custom_reducer.py new file mode 100644 index 000000000..565125a38 --- /dev/null +++ b/examples/DrGRPO/custom_reducer.py @@ -0,0 +1,67 @@ +"""Custom pg_loss reducer for Dr.GRPO. + +This module provides a custom reducer that divides by a constant instead of +the number of effective tokens. This is useful for Dr.GRPO algorithm. + +Usage: + --custom-pg-loss-reducer-function-path examples.Dr.GRPO.custom_reducer:get_pg_loss_reducer +""" + +from collections.abc import Callable + +import torch +from megatron.core import mpu + +# Constant divisor instead of effective token count +DIVISOR = 1000.0 + + +def get_pg_loss_reducer( + total_lengths: list[int], + response_lengths: list[int], + loss_masks: list[torch.Tensor], + calculate_per_token_loss: bool = False, +) -> Callable[[torch.Tensor], torch.Tensor]: + """ + Custom reducer for pg_loss only. Divides by a constant (DIVISOR) + instead of the number of effective tokens. + + This function is designed to be used with --custom-pg-loss-reducer-function-path + so that only pg_loss uses this custom reducer, while other metrics + (pg_clipfrac, ppo_kl, entropy_loss, etc.) still use the default sum_of_sample_mean. + + Note: This implementation only supports cp_size == 1 (no context parallelism). + + Args: + total_lengths: List of total sequence lengths (prompt + response). Unused but kept for API compatibility. + response_lengths: List of response lengths. + loss_masks: List of loss masks for each sample. + calculate_per_token_loss: If True, return sum_of_token (no division). + If False, return sum_of_sample_mean with constant divisor. + + Returns: + A callable function that takes a tensor and returns a scalar tensor. + """ + assert mpu.get_context_parallel_world_size() == 1, "This custom reducer only supports cp_size == 1" + + if calculate_per_token_loss: + + def sum_of_token(x: torch.Tensor) -> torch.Tensor: + return sum( + [ + (x_i * loss_mask_i).sum() + for x_i, loss_mask_i in zip(x.split(response_lengths, dim=0), loss_masks, strict=False) + ] + ) + + return sum_of_token + + def sum_of_sample_mean(x: torch.Tensor) -> torch.Tensor: + return sum( + [ + (x_i * loss_mask_i).sum() / DIVISOR + for x_i, loss_mask_i in zip(x.split(response_lengths, dim=0), loss_masks, strict=False) + ] + ) + + return sum_of_sample_mean diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 000000000..88f135241 --- /dev/null +++ b/examples/README.md @@ -0,0 +1,24 @@ +# Examples + +These examples provide concrete examples to leverage miles in your own RL workflow. Some examples are just demonstrative, but most of them are verifiable with a concrete performance score. + +## Directory Structure + +- **[DrGRPO](./DrGRPO)**: Custom reducer for Dr.GRPO algorithm. +- **[eval](./eval)**: Documentation and setup for evaluation environments using NeMo-Skills. +- **[eval_multi_task](./eval_multi_task)**: Example for supporting OOD evaluation tasks, e.g., GPQA, IFBench. +- **[formal_math](./formal_math)**: Examples related to formal math reasoning tasks, including a single round demo. +- **[fully_async](./fully_async)**: Demonstrates fully asynchronous rollout generation for higher efficiency. +- **[geo3k_vlm](./geo3k_vlm)**: Training VLMs with FSDP on a single-turn reasoning task using GRPO on the GEO3K dataset. +- **[geo3k_vlm_multi_turn](./geo3k_vlm_multi_turn)**: VLM multi-turn training (FSDP backend) on Geo3k dataset. +- **[low_precision](./low_precision)**: Examples of FP8 training and inference for improved throughput and stability. +- **[multi_agent](./multi_agent)**: Example of running multi-agent RL with `miles`. +- **[on_policy_distillation](./on_policy_distillation)**: Example implementation for on-policy distillation, extending the reinforcement learning pipeline to support teacherโ€“student distillation directly within on-policy training. +- **[reproducibility](./reproducibility)**: Guides on achieving bitwise experiment reproduction using deterministic modes. +- **[retool](./retool)**: Demonstrates the retool functionality for tool-enabled language model generation. +- **[search-r1](./search-r1)**: A minimal reproduction of Search-R1, featuring multi-turn conversation and tool-calling. +- **[strands-agents](./strands-agents)**: Integration example with the Strands-Agents scaffolding framework. +- **[tau-bench](./tau-bench)**: Training in an agentic multi-turn tool use environment (Tau-bench). +- **[train_infer_mismatch_helper](./train_infer_mismatch_helper)**: Algorithmic methods for rollout correction (e.g., TIS, MIS). +- **[true_on_policy](./true_on_policy)**: Ensures strictly equal log probabilities between inference (SGLang) and training engines. +- **[true_on_policy_vlm](./true_on_policy_vlm)**: "True On-Policy" training demonstration for VLM (Qwen3-VL). diff --git a/examples/eval/README.md b/examples/eval/README.md deleted file mode 100644 index 4e7e0b4c0..000000000 --- a/examples/eval/README.md +++ /dev/null @@ -1,76 +0,0 @@ -# Docs - -## Prerequisites -- A writable host directory for cached data (`/data/.cache`) -- Choose descriptive container names to replace the placeholders (``, ``). - -## 1) Prepare host network -```bash -docker network create skills-net -``` - -## 2) Launch the miles container -```bash -docker run \ - -itd \ - --shm-size 32g \ - --gpus all \ - -v /data/.cache:/root/.cache \ - -v /dev/shm:/shm \ - --ipc=host \ - --privileged \ - --network skills-net \ - --name \ - radixark/miles:latest \ - /bin/bash -``` - -## 3) Launch the Skills container -```bash -docker run \ - -itd \ - --shm-size 32g \ - --gpus all \ - -v /data/.cache:/root/.cache \ - -v /dev/shm:/shm \ - --ipc=host \ - --privileged \ - --network skills-net \ - --name \ - --network-alias skills_server \ - guapisolo/nemoskills:0.7.1 \ - /bin/bash -``` - -## 4) Inside the Skills container -Clone repos and install the Skills package: -```bash -git clone -b miles_skills https://github.com/guapisolo/miles.git /opt/miles -git clone -b miles https://github.com/guapisolo/Skills.git /opt/Skills - -cd /opt/Skills -pip install -e . -``` - -Download/prepare datasets: -```bash -cd /opt/Skills/nemo_skills/dataset -python3 aime25/prepare.py -python3 hle/prepare.py -python3 arena-hard/prepare.py -``` - -Start the skills server: -```bash -cd /opt/miles -python examples/eval/nemo_skills/skills_server.py \ - --host 0.0.0.0 \ - --port 9050 \ - --output-root /opt/skills-eval \ - --config-dir examples/eval/nemo_skills/config \ - --cluster local_cluster \ - --max-concurrent-requests 512 \ - --openai-model-name miles-openai-model -``` - -You can now connect to the server at `skills_server:9050` from within the `skills-net` Docker network. The server always proxies evaluation traffic to an OpenAI-compatible sglang router (Miles starts and manage the router), so adjust `--openai-model-name` and `--max-concurrent-requests` as needed for your deployment. diff --git a/examples/eval/eval_delegate.py b/examples/eval/eval_delegate.py index fd6b9878d..1ecabe659 100644 --- a/examples/eval/eval_delegate.py +++ b/examples/eval/eval_delegate.py @@ -91,6 +91,12 @@ def _rebuild_delegate_config( env_cfg = build_skills_eval_env_config(args, env, defaults) if env_cfg is not None: envs.append(env_cfg) + elif env_name == "terminal_bench": + from examples.eval.terminal_bench.tb_config import build_terminal_bench_config + + env_cfg = build_terminal_bench_config(args, env, defaults) + if env_cfg is not None: + envs.append(env_cfg) else: raise ValueError(f"Unknown delegate environment: {env_name}") return envs @@ -151,6 +157,10 @@ def _create_delegate(env_cfg: EvalEnvConfig, router_addr: str): from examples.eval.nemo_skills.skills_client import SkillsEvalClient return SkillsEvalClient.from_config(env_cfg, router_addr) + elif env_name == "terminal_bench": + from examples.eval.terminal_bench.tb_client import TerminalBenchClient + + return TerminalBenchClient.from_config(env_cfg, router_addr) logger.warning("No delegate client registered for environment: %s", env_name) return None diff --git a/examples/eval/nemo_skills/README.md b/examples/eval/nemo_skills/README.md new file mode 100644 index 000000000..4271a8c7d --- /dev/null +++ b/examples/eval/nemo_skills/README.md @@ -0,0 +1,162 @@ +# Evaluation with Nemo Skills + +This directory contains configuration and utilities for offloading complex evaluation benchmarks to a separate environment using the `eval_delegate` mechanism. It is designed to integrate with [Nemo Skills](https://github.com/NVIDIA/NeMo-Skills) for running benchmarks like AIME25, Arena-Hard, and HLE, which may require specific environments distinct from the main training setup. + +## Overview + +The setup allows miles to delegate evaluation tasks to a dedicated "Skills" server. This creates a clear separation of concerns: + +1. **miles Container**: Runs the main training loop and hosts the model using SGLang. +2. **Skills Container**: Hosts the `nemo_skills` environment, runs the evaluation logic, and queries the model running in the miles container. + +## Prerequisites + +- A writable host directory for cached data (e.g., `/data/.cache`). +- Docker installed with NVIDIA GPU support. + +## Setup Instructions + +### Prepare Host Network + +Create a Docker network to allow communication between the miles and Skills containers. + +```bash +docker network create skills-net +``` + +### Launch the miles Container + +Start the main container where miles and the model will run. Replace `` with your desired name (e.g., `miles_main`). + +```bash +docker run \ + -itd \ + --shm-size 32g \ + --gpus all \ + -v /data/.cache:/root/.cache \ + -v /dev/shm:/shm \ + --ipc=host \ + --privileged \ + --network skills-net \ + --name \ + radixark/miles:latest \ + /bin/bash +``` + +### Launch the Skills Container + +Start the container that will run the evaluation benchmarks. Replace `` with your desired name (e.g., `skills_env`). + +```bash +docker run \ + -itd \ + --shm-size 32g \ + --gpus all \ + -v /data/.cache:/root/.cache \ + -v /dev/shm:/shm \ + --ipc=host \ + --privileged \ + --network skills-net \ + --name \ + --network-alias skills_server \ + guapisolo/nemoskills:0.7.1 \ + /bin/bash +``` + +### Configure the Skills Container + +Enter the **Skills container** and set up the environment. + +**a) Install Dependencies** + +```bash +# Clone repositories +git clone -b miles_skills https://github.com/guapisolo/miles.git /opt/miles +git clone -b miles https://github.com/guapisolo/Skills.git /opt/Skills + +# Install Skills package +cd /opt/Skills +pip install -e . --no-deps +``` + +**b) Prepare Datasets** + +Download and prepare the datasets you intend to use. + +```bash +cd /opt/Skills/nemo_skills/dataset +python3 aime25/prepare.py +python3 hle/prepare.py +python3 arena-hard/prepare.py +``` + +**c) Start the Evaluation Server** + +Start the server that listens for evaluation requests from miles. + +```bash +cd /opt/miles +python examples/eval/nemo_skills/skills_server.py \ + --host 0.0.0.0 \ + --port 9050 \ + --output-root /opt/skills-eval \ + --config-dir examples/eval/nemo_skills/config \ + --cluster local_cluster \ + --max-concurrent-requests 512 \ + --openai-model-name miles-openai-model +``` +*Note: You can now connect to the server at `skills_server:9050` from within the `skills-net` Docker network. The server always proxies evaluation traffic to an OpenAI-compatible sglang router (miles starts and manage the router), so adjust `--openai-model-name` and `--max-concurrent-requests` as needed for your deployment. + +## Running Evaluation + +The example scripts are located in `examples/eval/scripts`. Here is an example workflow for training Qwen3-4B with delegated evaluation. + +### Prepare miles Container + +Enter the **miles container** and install the package. + +```bash +cd /root/miles +git pull +pip install -e . --no-deps +``` + +### Download Model and Data + +```bash +# Download model weights (Qwen3-4B) +hf download Qwen/Qwen3-4B --local-dir /root/Qwen3-4B + +# Download training dataset (dapo-math-17k) +hf download --repo-type dataset zhuzilin/dapo-math-17k \ + --local-dir /root/dapo-math-17k +``` + +### Convert Model to Megatron-LM Format + +You need to convert the HF model to the format required by Megatron-LM. Ensure you load the correct model arguments first. + +```bash +# Source model arguments +source scripts/models/qwen3-4B.sh + +# Convert model +PYTHONPATH=/root/Megatron-LM python tools/convert_hf_to_torch_dist.py \ + ${MODEL_ARGS[@]} \ + --hf-checkpoint /root/Qwen3-4B \ + --save /root/Qwen3-4B_torch_dist +``` + +### Run the Training Script + +Run the training script. + +```bash +bash examples/eval/scripts/run-qwen3-4B.sh +``` + +## Configuration + +The evaluation configuration is defined in `examples/eval/scripts/multi_tasks.yaml`. It specifies: +- `delegate`: Configurations for the external skills server (URL, timeouts). +- `datasets`: List of datasets to evaluate on (e.g., `aime25`, `arena-hard`). diff --git a/examples/eval/scripts/eval_tb_example.yaml b/examples/eval/scripts/eval_tb_example.yaml new file mode 100644 index 000000000..2e2308981 --- /dev/null +++ b/examples/eval/scripts/eval_tb_example.yaml @@ -0,0 +1,29 @@ +eval: + defaults: + n_samples_per_eval_prompt: 1 + temperature: 0.6 + top_p: 0.95 + top_k: -1 + max_response_len: 24576 + datasets: # these eval tasks go through miles dataset config and default rollout function (miles.rollout.sglang_rollout.generate_rollout) + - name: gpqa # huggingface-cli download --repo-type dataset zyzshishui0627/gpqa_diamond --local-dir /root/gpqa + path: /root/gpqa/gpqa_eval.jsonl + rm_type: gpqa + n_samples_per_eval_prompt: 2 + - name: ifbench # huggingface-cli download --repo-type dataset zyzshishui0627/IFBench --local-dir /root/ifbench + path: /root/ifbench/IFBench_eval.jsonl + rm_type: ifbench + n_samples_per_eval_prompt: 1 + delegate: + - name: terminal_bench + url: http://172.17.0.1:9051 # Port must match the TB server running on the host machine + timeout_secs: 86400 # 24 hours + max_retries: 1 # HTTP request retries from Miles to the TB server + model_name: qwen3-8b + api_base: http://127.0.0.1:30005/v1 # Port must match the sglang router port set in run-eval-tb-qwen.sh + dataset_path: /mnt/data/xinyu/program/miles-tb/terminal-bench/tasks # Dataset path on the host machine + # task_ids: + # - hello-world + # n_tasks: 10 + n_attempts: 1 # TB task-level retries (per task within tb run) + n_concurrent: 8 \ No newline at end of file diff --git a/examples/eval/scripts/run-eval-tb-qwen.sh b/examples/eval/scripts/run-eval-tb-qwen.sh new file mode 100644 index 000000000..471a59d56 --- /dev/null +++ b/examples/eval/scripts/run-eval-tb-qwen.sh @@ -0,0 +1,159 @@ +#!/bin/bash + +# Example launcher that reuses the Qwen3-8B recipe but delegates evaluation to an +# external Terminal Bench server via the eval_delegate_rollout wrapper. + +# Clean up any stale processes from a previous run. +pkill -9 sglang +sleep 3 +ray stop --force +pkill -9 ray +pkill -9 python +sleep 3 +pkill -9 ray +pkill -9 python + +set -ex + +export PYTHONBUFFERED=16 +export MILES_HOST_IP=${MILES_HOST_IP:-"127.0.0.1"} + +MODEL_DIR="${MODEL_DIR:-/root/.cache}" +export MODEL_DIR + +NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l) +if [ "$NVLINK_COUNT" -gt 0 ]; then + HAS_NVLINK=1 +else + HAS_NVLINK=0 +fi +echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" +REPO_ROOT="$(cd -- "${SCRIPT_DIR}/../../.." &>/dev/null && pwd)" +source "${REPO_ROOT}/scripts/models/qwen3-8B.sh" + +# Store eval/delegate settings in a YAML config similar to examples/eval_multi_task. +EVAL_CONFIG_PATH=${TB_EVAL_CONFIG_PATH:-"${REPO_ROOT}/examples/eval/scripts/eval_tb_example.yaml"} + +CKPT_ARGS=( + --hf-checkpoint ${MODEL_DIR}/OpenThinker-Agent-v1 # huggingface-cli download open-thoughts/OpenThinker-Agent-v1 + --ref-load ${MODEL_DIR}/OpenThinker-Agent-v1_torch_dist + # --load ${MODEL_DIR}/OpenThinker-Agent-v1_miles/ + --save ${MODEL_DIR}/OpenThinker-Agent-v1_miles/ + --save-interval 20 +) + +ROLLOUT_ARGS=( + --prompt-data /root/dapo-math-17k/dapo-math-17k.jsonl + --input-key prompt + --label-key label + --apply-chat-template + --rollout-shuffle + --rm-type deepscaler + --num-rollout 3000 + --rollout-batch-size 32 + --n-samples-per-prompt 8 + --rollout-max-response-len 8192 + --rollout-temperature 0.8 + --global-batch-size 256 + --balance-data +) + +EVAL_ARGS=( + --eval-interval 5 + --eval-config "${EVAL_CONFIG_PATH}" + --eval-function-path examples.eval.eval_delegate_rollout.generate_rollout +) + +PERF_ARGS=( + --tensor-model-parallel-size 1 + --pipeline-model-parallel-size 1 + --context-parallel-size 1 + --expert-model-parallel-size 1 + --expert-tensor-parallel-size 1 + + --recompute-granularity full + --recompute-method uniform + --recompute-num-layers 1 + + --use-dynamic-batch-size + --max-tokens-per-gpu 9216 +) + +GRPO_ARGS=( + --advantage-estimator grpo + --use-kl-loss + --kl-loss-coef 0.00 + --kl-loss-type low_var_kl + --entropy-coef 0.00 + --eps-clip 0.2 + --eps-clip-high 0.28 +) + +OPTIMIZER_ARGS=( + --optimizer adam + --lr 1e-6 + --lr-decay-style constant + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.98 +) + +WANDB_ARGS=( + --use-wandb + --wandb-project miles-eval + --wandb-group qwen3-8b-eval + --wandb-key ${WANDB_KEY} # export WANDB_KEY="your_key" +) + +SGLANG_ARGS=( + --rollout-num-gpus-per-engine 1 + --sglang-mem-fraction-static 0.7 + --sglang-router-port 30005 +) + +MISC_ARGS=( + --attention-dropout 0.0 + --hidden-dropout 0.0 + --accumulate-allreduce-grads-in-fp32 + --attention-softmax-in-fp32 + --attention-backend flash +) + +export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} +export CUDA_VISIBLE_DEVICES=0,1 + +ray start --head --node-ip-address ${MASTER_ADDR} --port 6380 --num-gpus 2 \ + --disable-usage-stats \ + --dashboard-host=0.0.0.0 \ + --dashboard-port=8266 \ + --dashboard-agent-listen-port 52366 \ + --dashboard-agent-grpc-port 52367 \ + --runtime-env-agent-port 52368 + + +RUNTIME_ENV_JSON="{ + \"env_vars\": { + \"PYTHONPATH\": \"/root/Megatron-LM/\", + \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\" + } +}" + +ray job submit --address="http://${MASTER_ADDR}:8266" \ + --working-dir "${REPO_ROOT}" \ + --runtime-env-json="${RUNTIME_ENV_JSON}" \ + -- python3 train.py \ + --actor-num-nodes 1 \ + --actor-num-gpus-per-node 2 \ + --colocate \ + ${MODEL_ARGS[@]} \ + ${CKPT_ARGS[@]} \ + ${ROLLOUT_ARGS[@]} \ + ${OPTIMIZER_ARGS[@]} \ + ${GRPO_ARGS[@]} \ + ${WANDB_ARGS[@]} \ + ${PERF_ARGS[@]} \ + ${EVAL_ARGS[@]} \ + ${SGLANG_ARGS[@]} \ + ${MISC_ARGS[@]} diff --git a/examples/eval/scripts/run-qwen3-4B.sh b/examples/eval/scripts/run-qwen3-4B.sh index 4343377f1..34891126d 100644 --- a/examples/eval/scripts/run-qwen3-4B.sh +++ b/examples/eval/scripts/run-qwen3-4B.sh @@ -36,10 +36,10 @@ source "${REPO_ROOT}/scripts/models/qwen3-4B.sh" EVAL_CONFIG_PATH=${SKILLS_EVAL_CONFIG_PATH:-"${REPO_ROOT}/examples/eval/scripts/multi_tasks.yaml"} CKPT_ARGS=( - --hf-checkpoint /root/shared/Qwen3-4B - --ref-load /root/shared/Qwen3-4B_torch_dist - --load /root/shared/Qwen3-4B_miles/ - --save /root/shared/Qwen3-4B_miles/ + --hf-checkpoint /root/Qwen3-4B + --ref-load /root/Qwen3-4B_torch_dist + --load /root/Qwen3-4B_miles/ + --save /root/Qwen3-4B_miles/ --save-interval 20 ) @@ -122,7 +122,9 @@ MISC_ARGS=( ) export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} -export CUDA_VISIBLE_DEVICES=6,7 +# export CUDA_VISIBLE_DEVICES=0,1 +# Set Up Your GPUs for Training + ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 2 --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 RUNTIME_ENV_JSON="{ diff --git a/examples/eval/terminal_bench/README.md b/examples/eval/terminal_bench/README.md new file mode 100644 index 000000000..341e543fc --- /dev/null +++ b/examples/eval/terminal_bench/README.md @@ -0,0 +1,129 @@ +# Terminal Bench Eval + +This folder wires Terminal Bench (TB) into Miles as an eval delegate. The TB run happens on the host via the `tb` CLI, and Miles reads back aggregated metrics such as `accuracy`, `n_resolved`, `n_unresolved`, `pass_at_k/*`, and token stats like `total_input_tokens_mean/median` and `total_output_tokens_mean/median`. + +## What runs where + +- Miles runs your training/eval loop inside the Docker container. +- Miles calls the TB delegate client. +- The TB delegate server (`tb_server.py`) runs `tb run ...` on the host. +- The server reads the latest TB JSON results and returns metrics to Miles. + +## 1) Get the code (host) + +```bash +mkdir miles-tb +cd miles-tb +git clone https://github.com/radixark/miles.git +git clone https://github.com/laude-institute/terminal-bench +``` + +## 2) Launch the Miles container + +```bash +docker run \ + -itd \ + --gpus all \ + --shm-size 32g \ + --network host \ + --ipc=host \ + --privileged \ + --ulimit memlock=-1 \ + --ulimit stack=67108864 \ + --ulimit nofile=65536:65536 \ + -v /mnt/data/.cache:/root/.cache \ + -v $(pwd):/shared/miles-tb \ + --name \ + radixark/miles:latest \ + /bin/bash +``` + +## 3) Inside the Miles container + +```bash +docker exec -it /bin/bash +``` + +## 4) Terminal Bench environment (host) + +Run on the machine that will host `tb_server.py` (where you cloned both repos): + +```bash +# Host machine terminal (outside Docker) +uv venv --python 3.13 .venv +source .venv/bin/activate + +uv pip install terminal-bench/. +uv pip install -r miles/examples/eval/terminal_bench/requirements.txt +``` + +Notes: +- Use your local repo paths if they are not `./miles` and `./terminal-bench`. + +## 5) Start the Terminal Bench server + +Run on the host (same machine where `tb` works): + +```bash +python miles/examples/eval/terminal_bench/tb_server.py \ + --host 0.0.0.0 --port 9051 \ + --output-root tb_eval_output +``` + +What it does: +- Uses `OPENAI_API_KEY=EMPTY` +- Runs `tb run -a terminus-2 -m openai/ ... --n-concurrent 8` +- Waits for completion, then returns `accuracy`, `n_resolved`, + `n_unresolved`, `pass_at_k/*`, and token stats such as + `total_input_tokens_mean/median` and `total_output_tokens_mean/median` + +## 6) Run the eval script (example) + +If you use the provided Qwen eval launcher (`run-eval-tb-qwen.sh`), follow the steps below to run Terminal-Bench evaluation. + +First, update the `dataset_path` in `eval_tb_example.yaml` to the local path of `terminal-bench/tasks` on your host (not an internal Docker-only path). + +Then download the HuggingFace model checkpoint inside the Miles container: + +```bash +huggingface-cli download open-thoughts/OpenThinker-Agent-v1 \ +--local-dir /root/.cache/OpenThinker-Agent-v1 +``` + +After downloading, convert the HuggingFace checkpoint to Miles's torch distributed format. From the Miles root directory, run: + +```bash +cd /shared/miles-tb/miles +source scripts/models/qwen3-8B.sh + +export PYTHONPATH=/root/Megatron-LM:/shared/miles-tb/miles + +python tools/convert_hf_to_torch_dist.py \ + ${MODEL_ARGS[@]} \ + --hf-checkpoint /root/.cache/OpenThinker-Agent-v1 \ + --save /root/.cache/OpenThinker-Agent-v1_torch_dist +``` + +Finally, run the following command inside the Miles container: + +```bash +bash miles/examples/eval/scripts/run-eval-tb-qwen.sh 2>&1 | tee run.log +``` + +For convenience, you can restrict the evaluation scope in `eval_tb_example.yaml`, either by specifying a single task or multiple tasks (`task_ids`), or by limiting the number of tasks via `n_tasks`. + +## 7) Common Issues + +When running Miles inside a Docker container with `--network host`, Ray may encounter port conflicts due to shared networking with the host. + +In some cases, this manifests as Ray failing to start or reporting Redis- or session-related errors. This can usually be resolved by explicitly assigning unused ports when starting the Ray head node, for example by setting a non-default `--port` and `--dashboard-port`. + +In more severe cases, Ray job submission may fail with errors indicating that no available agent can accept jobs. This typically happens when the dashboard agent or runtime environment agent ports are also in conflict. In such situations, explicitly specifying the agent-related ports (e.g. `--dashboard-agent-listen-port`, `--dashboard-agent-grpc-port`, and `--runtime-env-agent-port`) when starting Ray can resolve the issue. + +If the TB server cannot connect to the Miles server through the sglang router (`InternalServerError`), check which address is actually listening on the router port (e.g. 30005 in this example) and update the `api_base` in `eval_tb_example.yaml` accordingly: + +```bash +ss -lntp | grep 30005 +``` + +You may see `Parser warnings`, `Context length exceeded`, `Command 1 should end with newline`, `Harness execution failed` in `tb_server.py` logs. They are warnings from Terminal Bench and can be ignored if runs proceed normally. \ No newline at end of file diff --git a/examples/eval/terminal_bench/__init__.py b/examples/eval/terminal_bench/__init__.py new file mode 100644 index 000000000..6d2704250 --- /dev/null +++ b/examples/eval/terminal_bench/__init__.py @@ -0,0 +1 @@ +"""Terminal Bench evaluation helpers.""" diff --git a/examples/eval/terminal_bench/requirements.txt b/examples/eval/terminal_bench/requirements.txt new file mode 100644 index 000000000..1a0006c93 --- /dev/null +++ b/examples/eval/terminal_bench/requirements.txt @@ -0,0 +1,3 @@ +flask +omegaconf +requests diff --git a/examples/eval/terminal_bench/tb_client.py b/examples/eval/terminal_bench/tb_client.py new file mode 100644 index 000000000..2a93b7161 --- /dev/null +++ b/examples/eval/terminal_bench/tb_client.py @@ -0,0 +1,104 @@ +import logging +import time +from typing import Any + +import requests +from examples.eval.eval_delegate import EvalClient, EvalDelegateError +from examples.eval.terminal_bench.tb_config import TerminalBenchConfig + +logger = logging.getLogger(__name__) + + +class TerminalBenchClient(EvalClient): + """HTTP client that proxies evaluation requests to the Terminal Bench server.""" + + def __init__(self, config: TerminalBenchConfig, router_url: str): + super().__init__(config.name or "terminal_bench") + self._config = config + endpoint = (config.url or "").rstrip("/") + if endpoint.endswith("/evaluate"): + base_endpoint = endpoint[: -len("/evaluate")] + else: + base_endpoint = endpoint + self._endpoint = f"{base_endpoint}/evaluate" if base_endpoint else "" + self._status_endpoint = f"{base_endpoint}/status" if base_endpoint else "" + self._timeout_secs = float(config.timeout_secs) + self._max_retries = max(1, int(config.max_retries)) + self._headers = dict(config.headers or {}) + self._session = requests.Session() + + @classmethod + def from_config(cls, config: TerminalBenchConfig, router_url: str): + if not config.url: + return None + return cls(config, router_url) + + def evaluate(self, args, rollout_id: int) -> tuple[dict[str, Any], dict[str, Any]]: + payload = self._build_payload(args, rollout_id) + response = self._request(payload) + metrics = response.get("raw_metrics", {}) + return metrics, response + + def _build_payload(self, args, rollout_id: int) -> dict[str, Any]: + payload = { + "model_name": self._config.model_name, + "api_base": self._config.api_base, + "n_tasks": self._config.n_tasks, + "n_concurrent": self._config.n_concurrent, + "metric_prefix": self._config.name, + } + if self._config.dataset_path: + payload["dataset_path"] = self._config.dataset_path + if self._config.task_ids: + payload["task_ids"] = list(self._config.task_ids) + if self._config.n_attempts is not None: + payload["n_attempts"] = self._config.n_attempts + return payload + + def _request(self, payload: dict[str, Any]) -> dict[str, Any]: + last_error: Exception | None = None + for attempt in range(1, self._max_retries + 1): + try: + response = self._session.post( + self._endpoint, + json=payload, + timeout=self._timeout_secs, + headers=self._headers, + ) + response.raise_for_status() + if not response.content: + return {} + body = response.json() + if body.get("status") == "completed": + return body + job_id = body.get("job_id") + if not job_id: + return body + return self._poll_status(job_id) + except requests.RequestException as exc: + last_error = exc + logger.warning( + "Terminal Bench delegate request failed (attempt %s/%s): %s", attempt, self._max_retries, exc + ) + if attempt < self._max_retries: + time.sleep(min(2**attempt, 30)) + raise EvalDelegateError("Terminal Bench evaluation request failed") from last_error + + def _poll_status(self, job_id: str) -> dict[str, Any]: + status_url = f"{self._status_endpoint}/{job_id}" + deadline = time.time() + self._timeout_secs + while time.time() < deadline: + response = self._session.get(status_url, timeout=min(self._timeout_secs, 30), headers=self._headers) + response.raise_for_status() + if not response.content: + time.sleep(2) + continue + body = response.json() + status = body.get("status") + if status == "completed": + return body + if status == "failed": + error = body.get("error") or "Terminal Bench job failed" + raise EvalDelegateError(error) + time.sleep(2) + raise EvalDelegateError("Terminal Bench evaluation timed out") diff --git a/examples/eval/terminal_bench/tb_config.py b/examples/eval/terminal_bench/tb_config.py new file mode 100644 index 000000000..f57b445dd --- /dev/null +++ b/examples/eval/terminal_bench/tb_config.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +from collections.abc import Mapping +from dataclasses import dataclass, field +from typing import Any + +from examples.eval.eval_delegate import EvalEnvConfig + + +@dataclass +class TerminalBenchConfig(EvalEnvConfig): + """Environment configuration shared by the Terminal Bench client/server.""" + + model_name: str = "qwen3-8b" + api_base: str = "http://127.0.1.1:30001/v1" + dataset_path: str | None = None + n_tasks: int | None = None + task_ids: list[str] = field(default_factory=list) + n_attempts: int | None = None + n_concurrent: int = 8 + + @classmethod + def parse(cls, args, raw_env_config: Mapping[str, Any], defaults: Mapping[str, Any]) -> TerminalBenchConfig: + clean_raw = dict(raw_env_config or {}) + clean_raw.pop("type", None) + base_cfg: TerminalBenchConfig = super().parse(clean_raw, defaults) + + field_casts = { + "model_name": str, + "api_base": str, + "n_attempts": int, + "n_tasks": int, + "n_concurrent": int, + "dataset_path": str, + } + + for key, caster in field_casts.items(): + value = clean_raw.get(key) + if value is not None: + setattr(base_cfg, key, caster(value)) + + task_ids = clean_raw.get("task_ids") + if isinstance(task_ids, (list, tuple)): + base_cfg.task_ids = [str(item) for item in task_ids if item] + elif task_ids is not None: + raise ValueError("task_ids must be a list") + + return base_cfg + + +def build_terminal_bench_config(args, raw_env_config: Mapping[str, Any], defaults: Mapping[str, Any]): + return TerminalBenchConfig.parse(args, raw_env_config, defaults) diff --git a/examples/eval/terminal_bench/tb_server.py b/examples/eval/terminal_bench/tb_server.py new file mode 100644 index 000000000..58c9d54ad --- /dev/null +++ b/examples/eval/terminal_bench/tb_server.py @@ -0,0 +1,433 @@ +#!/usr/bin/env python3 +""" +Simple HTTP server that proxies Miles evaluation requests to the `tb run` +command shipped with Terminal Bench. + +Usage: + python examples/eval/terminal_bench/tb_server.py \ + --host 0.0.0.0 --port 9050 \ + --output-root /opt/tb-eval + +Miles (or Miles-compatible runners) should POST the payload described in +`EvalRequestPayload` to http://:/evaluate. The server blocks until +`tb run` finishes, then returns aggregated metrics along with paths to the +generated artifacts (logs + raw metrics). +""" + +from __future__ import annotations + +import argparse +import json +import logging +import os +import shlex +import statistics +import subprocess +import sys +import threading +import time +import uuid +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +REPO_ROOT = Path(__file__).resolve().parents[3] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +from flask import Flask, jsonify, request +from omegaconf import OmegaConf +from omegaconf.errors import OmegaConfBaseException + +logger = logging.getLogger("terminal_bench_server") +logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") + + +# --------------------------------------------------------------------------- +# Request payload helpers +# --------------------------------------------------------------------------- + + +@dataclass +class EvalRequestPayload: + model_name: str = "" + api_base: str = "" + n_tasks: int | None = None + n_concurrent: int | None = None + dataset_path: str | None = None + task_ids: list[str] | None = None + n_attempts: int | None = None + metric_prefix: str | None = None + + +@dataclass +class JobRecord: + job_id: str + status: str + run_id: str + command: str + output_dir: str + log_path: str + raw_metrics: dict[str, Any] | None = None + error: str | None = None + created_at: float = field(default_factory=time.time) + started_at: float | None = None + finished_at: float | None = None + + def to_dict(self) -> dict[str, Any]: + payload: dict[str, Any] = { + "job_id": self.job_id, + "status": self.status, + "run_id": self.run_id, + "command": self.command, + "output_dir": self.output_dir, + "log_path": self.log_path, + "created_at": self.created_at, + "started_at": self.started_at, + "finished_at": self.finished_at, + } + if self.raw_metrics is not None: + payload["raw_metrics"] = self.raw_metrics + if self.error: + payload["error"] = self.error + return payload + + +# --------------------------------------------------------------------------- +# Configuration + command helpers +# --------------------------------------------------------------------------- + + +def _normalize_model_name(model_name: str) -> str: + name = (model_name or "").strip() + if not name: + return "" + if "/" in name: + return name + return f"openai/{name}" + + +@dataclass +class ServerConfig: + output_root: Path + + @classmethod + def from_args(cls, args: argparse.Namespace) -> ServerConfig: + return cls(output_root=Path(args.output_root).expanduser().resolve()) + + +class TerminalBenchEvaluator: + def __init__(self, config: ServerConfig): + self._config = config + self._lock = threading.Lock() + self._jobs_lock = threading.Lock() + self._jobs: dict[str, JobRecord] = {} + self._config.output_root.mkdir(parents=True, exist_ok=True) + self._log_root = REPO_ROOT.parent / "tb_eval_logs" + self._log_root.mkdir(parents=True, exist_ok=True) + + def evaluate(self, payload: EvalRequestPayload) -> dict[str, Any]: + if not payload.model_name: + raise ValueError("Missing `model_name` in request payload.") + if not payload.api_base: + raise ValueError("Missing `api_base` in request payload.") + + job_id = uuid.uuid4().hex + run_id = f"{int(time.time())}-{job_id[:8]}" + run_dir = self._config.output_root / run_id + + command = self._build_command(payload, run_id) + command_str = " ".join(shlex.quote(part) for part in command) + log_path = self._log_root / f"{run_id}.log" + + record = JobRecord( + job_id=job_id, + status="queued", + run_id=run_id, + command=command_str, + output_dir=str(run_dir), + log_path=str(log_path), + ) + with self._jobs_lock: + self._jobs[job_id] = record + + thread = threading.Thread( + target=self._run_job, + args=(job_id, payload, run_dir, command, log_path), + daemon=True, + ) + thread.start() + + return { + "job_id": job_id, + "status": "queued", + "status_url": f"/status/{job_id}", + "run_id": run_id, + "command": command_str, + "output_dir": str(run_dir), + "log_path": str(log_path), + } + + def _run_job( + self, + job_id: str, + payload: EvalRequestPayload, + run_dir: Path, + command: list[str], + log_path: Path, + ) -> None: + with self._jobs_lock: + record = self._jobs.get(job_id) + if record is None: + return + record.status = "running" + record.started_at = time.time() + + env = self._build_env() + logger.info("Starting Terminal Bench run: %s", " ".join(shlex.quote(part) for part in command)) + try: + with self._lock: + self._run_command(command, env=env, log_path=log_path) + metrics = self._collect_metrics(run_dir) + if payload.metric_prefix: + metrics = {payload.metric_prefix: metrics} + with self._jobs_lock: + record = self._jobs.get(job_id) + if record is None: + return + record.status = "completed" + record.raw_metrics = metrics + record.finished_at = time.time() + except Exception as exc: # noqa: BLE001 + with self._jobs_lock: + record = self._jobs.get(job_id) + if record is None: + return + record.status = "failed" + record.error = str(exc) + record.finished_at = time.time() + + def get_job_status(self, job_id: str) -> dict[str, Any] | None: + with self._jobs_lock: + record = self._jobs.get(job_id) + if record is None: + return None + return record.to_dict() + + def _build_command(self, payload: EvalRequestPayload, run_id: str) -> list[str]: + # 1. Normalize model name (add openai/ prefix) + model_name = _normalize_model_name(payload.model_name) + + cmd = [ + "tb", + "run", + "-a", + "terminus-2", # Added Agent flag + "--output-path", + str(self._config.output_root), + "--run-id", + run_id, + ] + + # 2. Add model + if model_name: + cmd.extend(["--model", model_name]) + + # 3. Add Agent kwargs (Use api_base exactly like the CLI command) + if payload.api_base: + cmd.extend(["--agent-kwarg", f"api_base={payload.api_base}"]) + + if payload.dataset_path: + cmd.extend(["--dataset-path", payload.dataset_path]) + + if payload.n_attempts is not None: + cmd.extend(["--n-attempts", str(payload.n_attempts)]) + + # 4. Add n_tasks if present + task_ids = [] + if payload.task_ids: + task_ids.extend([str(item) for item in payload.task_ids if item]) + if task_ids: + for task_id in task_ids: + cmd.extend(["--task-id", task_id]) + elif payload.n_tasks is not None: + cmd.extend(["--n-tasks", str(payload.n_tasks)]) + + # 5. Add concurrency + n_concurrent = payload.n_concurrent + if n_concurrent is None: + n_concurrent = 1 + cmd.extend(["--n-concurrent", str(n_concurrent)]) + + return cmd + + def _build_env(self) -> dict[str, str]: + env = os.environ.copy() + # Inject env var to simulate "OPENAI_API_KEY=EMPTY" + env["OPENAI_API_KEY"] = "EMPTY" + return env + + @staticmethod + def _run_command(cmd: list[str], *, env: dict[str, str], log_path: Path): + with open(log_path, "w", encoding="utf-8") as log_file: + process = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + env=env, + text=True, + bufsize=1, + ) + assert process.stdout is not None + for line in process.stdout: + log_file.write(line) + log_file.flush() + sys.stdout.write(line) + sys.stdout.flush() + retcode = process.wait() + if retcode != 0: + with open(log_path, encoding="utf-8", errors="ignore") as log_file: + tail = "".join(log_file.readlines()[-200:]) + raise RuntimeError(f"`tb run` failed with exit code {retcode}. See {log_path}\n{tail}") + + @staticmethod + def _collect_metrics(run_dir: Path) -> dict[str, Any]: + metrics_path = run_dir / "results.json" + if not metrics_path.exists(): + logger.warning("Results file missing at %s", metrics_path) + return {} + + metrics = TerminalBenchEvaluator._extract_metrics(metrics_path) + if not metrics: + logger.warning("No accuracy/n_resolved metrics found in %s", metrics_path) + return metrics + + @staticmethod + def _extract_metrics(metrics_path: Path) -> dict[str, Any]: + try: + with open(metrics_path, encoding="utf-8") as fp: + metrics_data = json.load(fp) + except json.JSONDecodeError as exc: + logger.warning("Failed to parse %s: %s", metrics_path, exc) + return {} + + metrics: dict[str, Any] = {} + + # core metrics + accuracy = metrics_data.get("accuracy") + if isinstance(accuracy, (int, float)): + metrics["accuracy"] = float(accuracy) + + n_resolved = metrics_data.get("n_resolved") + if isinstance(n_resolved, (int, float)): + metrics["n_resolved"] = int(n_resolved) + + n_unresolved = metrics_data.get("n_unresolved") + if isinstance(n_unresolved, (int, float)): + metrics["n_unresolved"] = int(n_unresolved) + + # pass@k flatten + pass_at_k = metrics_data.get("pass_at_k") + if isinstance(pass_at_k, dict): + for k, v in pass_at_k.items(): + if isinstance(v, (int, float)): + metrics[f"pass_at_k/{k}"] = float(v) + + # token stats from per-task results + results = metrics_data.get("results") + if isinstance(results, list): + input_tokens = [ + r.get("total_input_tokens") + for r in results + if isinstance(r, dict) and isinstance(r.get("total_input_tokens"), (int, float)) + ] + output_tokens = [ + r.get("total_output_tokens") + for r in results + if isinstance(r, dict) and isinstance(r.get("total_output_tokens"), (int, float)) + ] + + if input_tokens: + metrics["total_input_tokens_mean"] = float(statistics.mean(input_tokens)) + metrics["total_input_tokens_median"] = float(statistics.median(input_tokens)) + if output_tokens: + metrics["total_output_tokens_mean"] = float(statistics.mean(output_tokens)) + metrics["total_output_tokens_median"] = float(statistics.median(output_tokens)) + + return metrics + + +# --------------------------------------------------------------------------- +# HTTP server +# --------------------------------------------------------------------------- + + +def build_app(evaluator: TerminalBenchEvaluator) -> Flask: + app = Flask(__name__) + + @app.get("/health") + def health_check(): + return jsonify({"status": "ok"}) + + @app.post("/evaluate") + def evaluate_endpoint(): + try: + raw_payload = request.get_json(force=True, silent=False) + cfg = OmegaConf.merge( + OmegaConf.structured(EvalRequestPayload), + OmegaConf.create(raw_payload or {}), + ) + payload = OmegaConf.to_object(cfg) + result = evaluator.evaluate(payload) + return jsonify(result) + except OmegaConfBaseException as exc: + logger.exception("Invalid request payload") + return jsonify({"error": str(exc)}), 400 + except Exception as exc: # noqa: BLE001 + logger.exception("Evaluation failed") + return jsonify({"error": str(exc)}), 500 + + @app.get("/status/") + def status_endpoint(job_id: str): + status = evaluator.get_job_status(job_id) + if status is None: + return jsonify({"error": "job not found"}), 404 + return jsonify(status) + + return app + + +# --------------------------------------------------------------------------- +# Entry point +# --------------------------------------------------------------------------- + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Run the Terminal Bench evaluation HTTP server.") + parser.add_argument("--host", type=str, default="0.0.0.0") + parser.add_argument("--port", type=int, default=9050) + parser.add_argument( + "--output-root", + type=str, + default="./terminal-bench-output", + help="Directory to store `tb run` outputs.", + ) + return parser.parse_args() + + +def main(): + args = parse_args() + config = ServerConfig.from_args(args) + evaluator = TerminalBenchEvaluator(config) + app = build_app(evaluator) + logger.info( + "Starting Terminal Bench evaluation server on %s:%s (output root=%s)", + args.host, + args.port, + config.output_root, + ) + app.run(host=args.host, port=args.port) + + +if __name__ == "__main__": + main() diff --git a/examples/experimental/README.md b/examples/experimental/README.md new file mode 100644 index 000000000..efc6363d9 --- /dev/null +++ b/examples/experimental/README.md @@ -0,0 +1 @@ +The examples under this directory are not fully verified, only for experimental use/develop purpose. diff --git a/examples/experimental/swe-agent/README.md b/examples/experimental/swe-agent/README.md new file mode 100644 index 000000000..d71ad3ecd --- /dev/null +++ b/examples/experimental/swe-agent/README.md @@ -0,0 +1,132 @@ +# SWE-agent training Example + +## Introduction + +This is an example for SWE-agent training. This example uses NVIDIA's Nemo-Gym as the Gym environment implement, SWE-Gym as the training data, and SWE-bench as the evaluation. + +This implementation of this example is partially in submodules below: +- Nemo-Gym: https://github.com/yueming-yuan/Gym/tree/miles-swe-agent +- mini-swe-agent: https://github.com/yueming-yuan/nv-mini-swe-agent/tree/miles-swe-agent + + +## Prepare environment +### Update submodules +```bash +git submodule update --init --recursive . +``` +### Docker settings +```bash +# 1. create a docker network +docker network create swe-net + +# 2. create environment docker +docker run -itd \ + --name swe_env \ + --shm-size 16g \ + -v /var/run/docker.sock:/var/run/docker.sock \ + -v /mnt/data:/data \ + -v /home/sglang-rl/:/workspace \ + --ipc=host \ + --ulimit nofile=65536:65536 \ + --ulimit memlock=-1 \ + --ulimit stack=67108864 \ + --network swe-net \ + ubuntu:latest \ + /bin/bash + +# 3. create miles docker +docker run -itd \ + --shm-size 32g \ + --gpus all \ + -v /mnt/data/cache/huggingface:/root/.cache/huggingface \ + -v /mnt/data:/data \ + -v /home/sglang-rl/:/workspace \ + --ipc=host \ + --ulimit nofile=65536:65536 \ + --ulimit memlock=-1 \ + --ulimit stack=67108864 \ + --privileged \ + --network swe-net \ + --name miles_ \ + radixark/miles:latest \ + /bin/zsh + +# 4. install utils in environment docker +docker exec -it swe_env /bin/bash +apt update && apt install -y zsh curl git python3 python3-pip docker.io +``` +note: `-v /var/run/docker.sock:/var/run/docker.sock` is required for Docker-in-Docker SWE environment execution; use `--network swe-net` to enable communication between training & environment. + +### Installation + +In **environment docker**, install Gym +```bash +git clone https://github.com/yueming-yuan/Gym +cd Gym + +curl -LsSf https://astral.sh/uv/install.sh | sh +source $HOME/.local/bin/env +uv venv --python 3.12 && source .venv/bin/activate +uv sync --extra dev --group docs + +# configure env.yaml +echo "policy_base_url: https://api.openai.com/v1 +policy_api_key: your-openai-api-key +policy_model_name: gpt-4.1-2025-04-14 +default_host: 0.0.0.0" > env.yaml +``` +note: set host IP to `0.0.0.0` to enable communications between dockers. + +then set up for SWE-agent server: +```bash +cd responses_api_agents/mini_swe_agent +uv pip install -r requirements.txt +``` +Now you should be able to run the SWE-agent server. + +For **miles docker** setup, please follow the standard setup process. + +## Preparing data +In **miles docker**, download **SWE-Gym** data from huggingface and convert it to Miles' prompt data format with this script. +``` +cd miles/examples/swe-agent +python download_and_process_data.py --input SWE-Gym/SWE-Gym --output /root/swe_train.jsonl +``` + +## Running train +1. In environment docker, launch the agent server +```bash +cd Gym +source .venv/bin/activate +cd responses_api_agents/mini_swe_agent +./start_server.sh +``` + + +2. In miles docker, +(1) export `SWE_AGENT_GYM_URL` to be the port of the second server you started in Gym in environment docker, whose `server_type` is `responses_api_agents`. `swe_env` is the environment docker's name; replace it if you changed the name. +(minor TODO: modify the port selections to avoid setting this every time.) (2) launch the training. +```bash +export SWE_AGENT_GYM_URL="http://swe_env:" +bash examples/swe-agent/run-qwen3-4b-instruct.sh +``` + + +## Troubleshooting +1. The first time of every SWE environment can be slow, and may need to wait before generation, because each SWE-Gym task has a specific docker, and `docker pull` takes time. +2. Sometimes the environment may also be slow at evaluation. The timeout of evaluation is 10 minutes by default. If the server is stuck at `[EVAL] Running eval`, you may need to wait for it. + +## Metrics +``` +agent/turns_mean, agent/turns_sum - Turn counts +agent/tool_calls_mean, agent/tool_calls_sum - Tool call counts +agent/total_time_mean/max/min - Total time statistics +agent/model_query_time_sum_mean - Avg total model time per rollout +agent/env_execution_time_sum_mean - Avg total env time per rollout +agent/eval_time_mean - Avg evaluation time +agent/overhead_time_mean - Avg overhead time +agent/time_per_turn - Avg time per turn +agent/model_query_time_avg - Avg model query time per turn +agent/env_execution_time_avg - Avg env execution time per turn +agent/model_time_ratio, agent/env_time_ratio - Time ratios +``` diff --git a/examples/experimental/swe-agent/download_and_process_data.py b/examples/experimental/swe-agent/download_and_process_data.py new file mode 100755 index 000000000..3512bf3d4 --- /dev/null +++ b/examples/experimental/swe-agent/download_and_process_data.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python3 +"""Download and process data to Miles format.""" + +import argparse +import json +import tempfile +from pathlib import Path +from datasets import load_dataset + + +def convert_to_miles_format(input_path: str, output_path: str, limit: int = None, split: str = "train"): + """Convert JSONL to Miles format. + + Args: + input_path: Path to input JSONL file + output_path: Path to output JSONL file in Miles format + limit: Optional limit on number of samples + split: Dataset split name (used in metadata) + """ + count = 0 + with open(input_path) as fin, open(output_path, "w") as fout: + for line in fin: + if limit and count >= limit: + break + + instance = json.loads(line) + + # Add subset and split to metadata for Gym API + metadata = dict(instance) + metadata["subset"] = "gym" + metadata["split"] = split + + miles_sample = { + "prompt": instance.get("problem_statement", ""), + "metadata": metadata, + } + + fout.write(json.dumps(miles_sample) + "\n") + count += 1 + + print(f"Converted {count} samples: {input_path} -> {output_path}") + + +def main(): + parser = argparse.ArgumentParser(description="Download HuggingFace dataset and convert to Miles format") + parser.add_argument("--input", type=str, required=True, help="HuggingFace dataset path or local JSONL file") + parser.add_argument("--output", type=str, required=True, help="Output JSONL file path") + parser.add_argument( + "--split", type=str, default="train", help="Dataset split (default: train, only for HF datasets)" + ) + parser.add_argument("--limit", type=int, help="Limit number of samples") + + args = parser.parse_args() + + input_path = Path(args.input) + + if input_path.exists() and input_path.suffix == ".jsonl": + print(f"Processing local file: {args.input}") + convert_to_miles_format(args.input, args.output, args.limit, args.split) + else: + print(f"Loading HuggingFace dataset: {args.input} (split={args.split})") + ds = load_dataset(args.input, split=args.split) + + if args.limit: + ds = ds.select(range(min(args.limit, len(ds)))) + + tmp_path = None + try: + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as tmp: + tmp_path = tmp.name + + print(f"Downloading to temporary file: {tmp_path}") + ds.to_json(tmp_path) + + print(f"Converting to Miles format: {args.output}") + convert_to_miles_format(tmp_path, args.output, split=args.split) + finally: + if tmp_path and Path(tmp_path).exists(): + Path(tmp_path).unlink() + + print("Done.") + + +if __name__ == "__main__": + main() diff --git a/examples/experimental/swe-agent/generate_with_swe_agent.py b/examples/experimental/swe-agent/generate_with_swe_agent.py new file mode 100644 index 000000000..b0dbbd612 --- /dev/null +++ b/examples/experimental/swe-agent/generate_with_swe_agent.py @@ -0,0 +1,242 @@ +import logging +import os +from argparse import Namespace +from collections.abc import Callable +from typing import Any + +from miles.rollout.base_types import RolloutFnEvalOutput, RolloutFnTrainOutput +from miles.rollout.filter_hub.base_types import DynamicFilterOutput +from miles.rollout.sglang_rollout import GenerateState, eval_rollout +from miles.utils.async_utils import run +from miles.utils.http_utils import post +from miles.utils.types import Sample + +logger = logging.getLogger(__name__) + + +def build_tokens_and_mask_from_messages( + messages: list[dict], + tokenizer, +) -> tuple[list[int], list[int], str, int]: + + if not messages or len(messages) < 2: + return [], [], "", 0 + + prompt_msgs = messages[:2] + response_msgs = messages[2:] + + prompt_tokens = [] + for msg in prompt_msgs: + content = msg.get("content", "") + if content: + prompt_tokens.extend(tokenizer(content, add_special_tokens=False)["input_ids"]) + + response_tokens = [] + loss_mask = [] + response_text_parts = [] + + for msg in response_msgs: + content = msg.get("content", "") + if not content: + continue + + tokens = tokenizer(content, add_special_tokens=False)["input_ids"] + token_len = len(tokens) + + response_tokens.extend(tokens) + response_text_parts.append(content) + + mask_val = 1 if msg.get("role") == "assistant" else 0 + loss_mask.extend([mask_val] * token_len) + + all_tokens = prompt_tokens + response_tokens + response_text = "".join(response_text_parts) + response_length = len(response_tokens) + + return all_tokens, loss_mask, response_text, response_length + + +async def generate(args: Namespace, sample: Sample, sampling_params: dict[str, Any]) -> Sample: + """ + Custom generation function for SWE-Agent integration. + + Orchestrates the interaction with the external Gym environment: + 1. Sends prompt/metadata to Gym. + 2. Receives execution trace (messages) and rewards. + 3. Formats data for Miles training format. + + Note: Performs in-place modification of `sample` for memory efficiency. + """ + # Prepare request for Gym /run endpoint + request = { + "responses_create_params": { + "input": [], + }, + "sampling_params": sampling_params, + **sample.metadata, + "sglang_url": f"http://{args.sglang_router_ip}:{args.sglang_router_port}/v1", + } + + gym_url = os.getenv("SWE_AGENT_GYM_URL", "http://localhost:11000") + response = await post(f"{gym_url}/run", request) + + exit_status = response.get("info", {}).get("exit_status", "") + logger.debug(f"exit_status: {exit_status}, reward: {response.get('reward', 0.0)}") + + messages = response.get("messages", []) + + if len(messages) >= 2: + sample.prompt = messages[:2] + + state = GenerateState(args) + tokens, loss_mask, response_text, response_length = build_tokens_and_mask_from_messages( + messages=messages, + tokenizer=state.tokenizer, + ) + + sample.rollout_log_probs = None # TODO + sample.tokens = tokens + sample.loss_mask = loss_mask + sample.response = response_text + sample.response_length = response_length + sample.metadata["reward"] = response.get("reward", 0.0) + sample.metadata["eval_report"] = response.get("metadata", {}) + sample.metadata["messages"] = messages + + agent_metrics = response.get("info", {}).get("agent_metrics", {}) + sample.metadata["agent_metrics"] = agent_metrics + + if exit_status == "Submitted": + sample.status = Sample.Status.COMPLETED + elif exit_status in ("RolloutTruncated", "LimitsExceeded", "CollapseContinued"): + sample.status = Sample.Status.TRUNCATED + else: + sample.status = Sample.Status.ABORTED + sample.reward = 0.0 + + return sample + + +async def reward_func(args, sample: Sample, **kwargs) -> float: + """Reward function - already computed in generate()""" + reward = sample.metadata.get("reward", 0.0) + return reward + + +def dynamic_filter(args, samples: list[Sample], **kwargs) -> DynamicFilterOutput: + """Filter out groups with any aborted samples from training""" + has_aborted = any(sample.status == Sample.Status.ABORTED for sample in samples) + if has_aborted: + return DynamicFilterOutput(keep=False, reason="group_has_aborted") + return DynamicFilterOutput(keep=True) + + +def aggregate_agent_metrics(samples: list[Sample]) -> dict: + """Aggregate agent metrics across samples for logging""" + metrics = {} + + all_metrics = [] + for sample in samples: + if hasattr(sample, "metadata") and sample.metadata: + agent_metrics = sample.metadata.get("agent_metrics", {}) + if agent_metrics: + all_metrics.append(agent_metrics) + + if not all_metrics: + return {} + + # Count metrics - mean and sum + for key in ["turns", "tool_calls"]: + values = [m.get(key, 0) for m in all_metrics] + if values: + metrics[f"agent/{key}_mean"] = sum(values) / len(values) + metrics[f"agent/{key}_sum"] = sum(values) + + # Time sum metrics - mean across rollouts + for key in ["model_query_time_sum", "env_execution_time_sum", "eval_time", "agent_run_time"]: + values = [m.get(key, 0) for m in all_metrics] + if values: + metrics[f"agent/{key}_mean"] = sum(values) / len(values) + + # Time avg metrics - mean of means + for key in ["time_per_turn", "model_query_time_avg", "env_execution_time_avg"]: + values = [m.get(key, 0) for m in all_metrics] + if values: + metrics[f"agent/{key}"] = sum(values) / len(values) + + # Ratio metrics (all based on total_time which includes eval) + for key in ["model_time_ratio", "env_time_ratio", "eval_time_ratio"]: + values = [m.get(key, 0) for m in all_metrics] + if values: + metrics[f"agent/{key}"] = sum(values) / len(values) + + # Total time stats + values = [m.get("total_time", 0) for m in all_metrics] + if values: + metrics["agent/total_time_mean"] = sum(values) / len(values) + metrics["agent/total_time_max"] = max(values) + metrics["agent/total_time_min"] = min(values) + + return metrics + + +async def generate_rollout_async( + args: Namespace, rollout_id: int, data_source: Callable[[int], list[list[Sample]]] +) -> tuple[RolloutFnTrainOutput, list[list[Sample]]]: + """ + Custom rollout function that wraps sglang_rollout.generate_rollout_async + and adds agent metrics aggregation. + """ + from miles.rollout.sglang_rollout import generate_rollout_async as base_generate_rollout_async + + rollout_output, aborted_samples = await base_generate_rollout_async(args, rollout_id, data_source) + + all_samples = [] + for group in rollout_output.samples: + if isinstance(group[0], list): + for sample_list in group: + all_samples.extend(sample_list) + else: + all_samples.extend(group) + + agent_metrics = aggregate_agent_metrics(all_samples) + + metrics = rollout_output.metrics or {} + metrics.update(agent_metrics) + + logger.info(f"Aggregated agent metrics for rollout {rollout_id}: {agent_metrics}") + + return RolloutFnTrainOutput(samples=rollout_output.samples, metrics=metrics), aborted_samples + + +def generate_rollout( + args: Namespace, rollout_id: int, data_buffer: Any, evaluation: bool = False +) -> RolloutFnTrainOutput | RolloutFnEvalOutput: + """An example to implement the generate_rollout function for an rule based rm rollout generation. + + Args: + args: the whole args + rollout_id: int, the id of the rollout, used for deterministic data generation + data_buffer: the data buffer to store the generated samples + evaluation: bool, whether the rollout is for evaluation or not + + Returns: + list[list[Sample]]: a list of list of samples generated by the rollout + """ + output, aborted_samples = generate_abortable_samples( + args, rollout_id, data_buffer.get_samples, evaluation=evaluation + ) + data_buffer.add_samples(aborted_samples) + return output + + +def generate_abortable_samples( + args: Namespace, + rollout_id: int, + data_source: Callable[[int], list[list[Sample]]], + evaluation: bool = False, +) -> tuple[Any, list[list[Sample]]]: + assert args.rollout_global_dataset + if evaluation: + return run(eval_rollout(args, rollout_id)) + return run(generate_rollout_async(args, rollout_id, data_source)) diff --git a/examples/experimental/swe-agent/mini-swe-agent b/examples/experimental/swe-agent/mini-swe-agent new file mode 160000 index 000000000..8d74eee82 --- /dev/null +++ b/examples/experimental/swe-agent/mini-swe-agent @@ -0,0 +1 @@ +Subproject commit 8d74eee82036bc1c30f17c18b67c1e6984ad4f0b diff --git a/examples/experimental/swe-agent/nemo-gym b/examples/experimental/swe-agent/nemo-gym new file mode 160000 index 000000000..4fce289f9 --- /dev/null +++ b/examples/experimental/swe-agent/nemo-gym @@ -0,0 +1 @@ +Subproject commit 4fce289f9bbee420ebc9a7ac2f8884437d3a93ea diff --git a/examples/experimental/swe-agent/run-qwen3-4b-instruct.sh b/examples/experimental/swe-agent/run-qwen3-4b-instruct.sh new file mode 100755 index 000000000..d9c9dd953 --- /dev/null +++ b/examples/experimental/swe-agent/run-qwen3-4b-instruct.sh @@ -0,0 +1,166 @@ +#!/bin/bash + +# for rerun the task +pkill -9 sglang +sleep 3 +ray stop --force +pkill -9 ray +pkill -9 python +sleep 3 +pkill -9 ray +pkill -9 python + +set -ex + +# will prevent ray from buffering stdout/stderr +export PYTHONBUFFERED=1 + +NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l) +if [ "$NVLINK_COUNT" -gt 0 ]; then + HAS_NVLINK=1 +else + HAS_NVLINK=0 +fi +echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" + +export SWE_AGENT_GYM_URL="${SWE_AGENT_GYM_URL:-http://swe_env:11000}" + +source "${SCRIPT_DIR}/../../scripts/models/qwen3-4B-Instruct-2507.sh" + +CKPT_ARGS=( + --hf-checkpoint /root/qwen3-4B-Instruct-2507 + --ref-load /root/qwen3-4B-Instruct-2507_torch_dist + # --load /path/to/checkpoint/ + --save /root/qwen3-4B-Instruct-2507_miles/ + --save-interval 100 +) + +PERF_ARGS=( + --tensor-model-parallel-size 2 + --pipeline-model-parallel-size 1 + --context-parallel-size 1 + --expert-model-parallel-size 1 + --expert-tensor-parallel-size 1 + + --recompute-granularity full + --recompute-method uniform + --recompute-num-layers 1 + + # --micro-batch-size 1 + --use-dynamic-batch-size + --max-tokens-per-gpu 2048 +) + +ROLLOUT_ARGS=( + --prompt-data /root/swe_train.jsonl + --input-key prompt + --metadata-key metadata + --rollout-shuffle + --num-rollout 3000 + --rollout-batch-size 8 + --n-samples-per-prompt 8 + --rollout-temperature 0.8 + --rollout-max-response-len 8192 + + --global-batch-size 64 + --balance-data +) + +EVAL_ARGS=( + # --eval-interval 50 + # --eval-prompt-data /workspace/data/swe_gym_val.jsonl + # --eval-input-key prompt + # --eval-metadata-key metadata + # --n-samples-per-eval-prompt 1 + # --eval-max-response-len 4096 +) + +GRPO_ARGS=( + --advantage-estimator grpo + --use-kl-loss + --kl-loss-coef 0.01 + --kl-loss-type low_var_kl + --entropy-coef 0.0 + --eps-clip 0.2 + --eps-clip-high 0.28 +) + +OPTIMIZER_ARGS=( + --optimizer adam + --lr 1e-6 + --lr-decay-style constant + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.98 +) + +WANDB_ARGS=() +if [ -n "$WANDB_KEY" ]; then + WANDB_ARGS=( + --use-wandb + --wandb-project miles-swe-agent + --wandb-group swe-agent-qwen2.5-3b + --wandb-key ${WANDB_KEY} + ) +fi + +SGLANG_ARGS=( + --rollout-num-gpus-per-engine 1 + --sglang-mem-fraction-static 0.7 +) + +MISC_ARGS=( + # default dropout in megatron is 0.1 + --attention-dropout 0.0 + --hidden-dropout 0.0 + # should be good for model performance + --accumulate-allreduce-grads-in-fp32 + --attention-softmax-in-fp32 + # need to comment this when using model with MLA + --attention-backend flash +) + +CUSTOM_ARGS=( + --custom-generate-function-path generate_with_swe_agent.generate + --custom-rm-path generate_with_swe_agent.reward_func + --rollout-function-path generate_with_swe_agent.generate_rollout + --dynamic-sampling-filter-path generate_with_swe_agent.dynamic_filter +) + +export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} +echo "Starting Ray cluster at ${MASTER_ADDR}..." +ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 4 --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 --port=8899 + +RUNTIME_ENV_JSON="{ + \"env_vars\": { + \"PYTHONPATH\": \"/root/Megatron-LM/:${SCRIPT_DIR}:/root/miles\", + \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", + \"SWE_AGENT_GYM_URL\": \"${SWE_AGENT_GYM_URL}\" + } +}" +# \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\", + +echo "Launching training..." +echo " SWE Agent URL: ${SWE_AGENT_GYM_URL}" + +ray job submit --address="http://127.0.0.1:8265" \ + --runtime-env-json="${RUNTIME_ENV_JSON}" \ + -- python3 train.py \ + --actor-num-nodes 1 \ + --actor-num-gpus-per-node 4 \ + --colocate \ + ${MODEL_ARGS[@]} \ + ${CKPT_ARGS[@]} \ + ${ROLLOUT_ARGS[@]} \ + ${OPTIMIZER_ARGS[@]} \ + ${GRPO_ARGS[@]} \ + ${WANDB_ARGS[@]} \ + ${PERF_ARGS[@]} \ + ${EVAL_ARGS[@]} \ + ${SGLANG_ARGS[@]} \ + ${MISC_ARGS[@]} \ + ${CUSTOM_ARGS[@]} + +echo "Training completed!" diff --git a/examples/formal_math/single_round/README.md b/examples/formal_math/single_round/README.md index e66931566..706d7714a 100644 --- a/examples/formal_math/single_round/README.md +++ b/examples/formal_math/single_round/README.md @@ -1,23 +1,65 @@ -# Usage +# Single-Round Formal Math with Lean 4 and RL -For the minimal demo: +This directory contains an example of training a model to solve formal math problems using Lean 4. It leverages Reinforcement Learning (GRPO) with a "verifier-in-the-loop" approach, where generated proofs are verified for correctness using the [Kimina](https://github.com/project-numina/kimina-lean-server) verifier. -```shell -# install dependencies +## Overview + +- **Task**: Given a formal math statement in Lean 4, generate a valid proof. +- **Method**: Single-turn reinforcement learning (GRPO). The model generates a full proof (including thoughts/plans), and the reward is determined by whether the proof compiles and is valid. +- **Verifier**: Uses `kimina-lean-server` running in a Docker container to verify the generated Lean code. + +## Prerequisites + +### Docker Setup +You need Docker installed and a specific network for communication between the training process and the Kimina verifier: + +```bash +# Create a docker network for kimina and miles to communicate +docker network create formal_math +``` + +**Note**: The training script will launch a `kimina-lean-server` container. It requires mounting the host Docker socket (`/var/run/docker.sock`) so the script can manage sibling containers. Connect miles container to the same docker network. + +### Install Dependencies + +```bash apt update && apt install -y docker-cli pip install kimina-client polars +``` + +## Quick Start: Minimal Demo -# prepare data +This minimal demo (`run_minimal.py`) runs a self-contained training loop on a small dataset. + +### Prepare Data +Download and process the data (e.g., FineLeanCorpus, MiniF2F). + +```bash python examples/formal_math/single_round/prepare_data.py --output-name minimal_demo +``` -# prepare ray, model, test dataset, etc -# normally just use this script, but here we want to demonstrate run_minimal.py, thus skip ray-submit part +### Prepare Models & Environment +Use `run.py` to download the base model (e.g., Qwen3-8B) and set up the environment. We skip the actual training submission here (`MILES_SCRIPT_ENABLE_RAY_SUBMIT=0`) as we will use the minimal runner next. + +```bash MILES_SCRIPT_ENABLE_RAY_SUBMIT=0 python examples/formal_math/single_round/run.py +``` -# run +### Run Training +Launch the minimal training script. + +```bash python examples/formal_math/single_round/run_minimal.py ``` +## Advanced Usage + +For full-scale training or standard runs, use `run.py`. This script leverages `miles.utils.external_utils.command_utils` to handle cluster setup and execution. + +```bash +python examples/formal_math/single_round/run.py +``` + The code also support more complicated cases, e.g.: * SFT + RL diff --git a/examples/formal_math/single_round/kimina_wrapper.py b/examples/formal_math/single_round/kimina_wrapper.py index 35a40f922..1f400fa15 100644 --- a/examples/formal_math/single_round/kimina_wrapper.py +++ b/examples/formal_math/single_round/kimina_wrapper.py @@ -56,30 +56,33 @@ def _create_actor_per_node(actor_cls) -> list: @ray.remote class _KiminaServerActor: def __init__(self): - self.addr = _get_current_node_host_ip() self.port = get_free_port() if _KILL_PREVIOUS_KIMINA_DOCKER: _docker_stop_all() - _docker_start(port=self.port) + self.docker_name = _docker_start(port=self.port) _wait_server_ready(base_url=self.get_api_url()) def get_api_url(self): - return f"http://{self.addr}:{self.port}" + return f"http://{self.docker_name}:8000" def _docker_start(port: int): - name = f"kimina_lean_server_auto_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}-{random.randint(0, 1000000)}" + docker_name = ( + f"kimina_lean_server_auto_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}-{random.randint(0, 1000000)}" + ) exec_command( "docker run " "-d " - f"--name {name} " + f"--name {docker_name} " "--restart unless-stopped " + "--network formal_math " # "--env-file .env " # do not use env yet f"-p {port}:8000 " f"projectnumina/kimina-lean-server:2.0.0" ) + return docker_name def _wait_server_ready(base_url: str): @@ -101,13 +104,3 @@ def _docker_stop_all(): '[ -n "$ids" ] && docker stop $ids && docker rm $ids; ' "true" ) - - -def _get_current_node_host_ip(): - # when RL container uses network=host - return "127.0.0.1" - - # when RL container does not use network=host - # https://stackoverflow.com/questions/22944631 - # out = exec_command("ip route show default | awk '/default/ {print $3}'", capture_output=True) - # return out.strip() diff --git a/examples/formal_math/single_round/run.py b/examples/formal_math/single_round/run.py index 8cbb9d738..25f9ad07f 100644 --- a/examples/formal_math/single_round/run.py +++ b/examples/formal_math/single_round/run.py @@ -23,7 +23,7 @@ def prepare(): U.exec_command("mkdir -p /root/models /root/datasets") - U.exec_command(f"huggingface-cli download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") + U.exec_command(f"hf download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") if arg_ref_load is None: U.convert_checkpoint( model_name=MODEL_NAME, diff --git a/examples/formal_math/single_round/run_minimal.py b/examples/formal_math/single_round/run_minimal.py index 469d6b833..69d092884 100644 --- a/examples/formal_math/single_round/run_minimal.py +++ b/examples/formal_math/single_round/run_minimal.py @@ -96,10 +96,14 @@ ) wandb_args = ( - "--use-wandb " - "--wandb-project miles-formal-math-run-minimal " - "--wandb-group demo " - "--wandb-key ${WANDB_API_KEY} " + ( + "--use-wandb " + "--wandb-project miles-formal-math-run-minimal " + "--wandb-group demo " + f"--wandb-key '{wandb_api_key}' " + ) + if (wandb_api_key := os.environ.get("WANDB_API_KEY")) + else "" ) train_args = ( diff --git a/examples/formal_math/single_round/run_sft.py b/examples/formal_math/single_round/run_sft.py index d294b133a..f24f79e3a 100644 --- a/examples/formal_math/single_round/run_sft.py +++ b/examples/formal_math/single_round/run_sft.py @@ -13,7 +13,7 @@ def prepare(): U.exec_command("mkdir -p /root/models /root/datasets") - U.exec_command(f"huggingface-cli download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") + U.exec_command(f"hf download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") U.convert_checkpoint(model_name=MODEL_NAME, megatron_model_type=MODEL_TYPE, num_gpus_per_node=NUM_GPUS) diff --git a/examples/fully_async/README.md b/examples/fully_async/README.md index 7c7bdf343..53f08b3ca 100644 --- a/examples/fully_async/README.md +++ b/examples/fully_async/README.md @@ -1,15 +1,15 @@ -## Fully Asynchronous Rollout Example +# Fully Asynchronous Rollout Example This example shows a simple way to make rollout generation **fully asynchronous**: a single global worker is created once and then keeps running in the background, continuously pulling prompts and launching generation tasks. Training only needs to fetch already finished results. This removes the perโ€‘step wait that happens in the normal synchronous style. -### Files +## Files * `fully_async_rollout.py`: global async worker + `generate_rollout_fully_async` entry. * `run-qwen3-4b-fully_async.sh`: example launch script with Qwen3โ€‘4B. -### Prerequisite +## Prerequisite First set up model & environment following the Qwen3-4B example. -### Quick Start +## Quick Start ```bash cd miles bash examples/fully_async/run-qwen3-4b-fully_async.sh @@ -20,18 +20,18 @@ Creating new global async worker... Continuous async rollout worker started ``` -### How It Works (Very Short) +## How It Works (Very Short) * First call: create `AsyncRolloutWorker` (thread + asyncio loop). * Loop keeps up to `--rollout-batch-size` tasks in flight using `generate_and_rm_group`. * Completed groups are pushed into a queue; caller drains until it has enough samples. * Worker is stopped automatically at process exit. -### Limitations +## Limitations * No evaluation mode. * Ordering is best effort (sorted at the end by index). * Minimal error handling. -### Config Differences (2 Key Points) +## Config Differences (2 Key Points) To enable the fully async pattern there are only two changes compared to a normal run: 1. Use the async training driver: `train_async.py` (not `train.py`). diff --git a/examples/geo3k_vlm/README.md b/examples/geo3k_vlm/README.md index 1946999dd..751faec02 100644 --- a/examples/geo3k_vlm/README.md +++ b/examples/geo3k_vlm/README.md @@ -1,19 +1,89 @@ -# FSDP + VLM Single-Turn RL +# VLM Single-Turn RL (FSDP & Megatron) -Training VLMs with FSDP on single-turn reasoning task using GRPO on the [GEO3K dataset](https://huggingface.co/datasets/hiyouga/geometry3k). We used processed version [here](https://huggingface.co/datasets/chenhegu/geo3k_imgurl). +Training VLMs with FSDP or Megatron on single-turn reasoning task using GRPO on the [GEO3K dataset](https://huggingface.co/datasets/hiyouga/geometry3k). We used processed version [here](https://huggingface.co/datasets/chenhegu/geo3k_imgurl). + +Note: Please make sure the cudnn version in the environment is 9.16.0.29 to prevent severe performance regression in conv3d in torch 2.9 mentioned in https://github.com/pytorch/pytorch/issues/168167. Otherwise, you can reinstall cudnn with: +```bash +pip install nvidia-cudnn-cu12==9.16.0.29 +```

- Reward Plot + FSDP vs Megatron Reward Plot

+## Data Preparation (For SFT Training) + +The [geo3k_imgurl](https://huggingface.co/datasets/chenhegu/geo3k_imgurl) dataset contains: +- `problem`: The math problem text (string) +- `answer`: The answer (string, e.g., "270") +- `images`: Image data (list) + +For SFT training, we need to format the `answer` field for `\boxed{}` format and the messages. You can use the following script to format the answer field: + +```python +from datasets import load_dataset +import pandas as pd + +ds = load_dataset("chenhegu/geo3k_imgurl", split="train") + +def format_answer(answer: str) -> str: + """Format answer to include \\boxed{} format.""" + return f"Answer: \\boxed{{{answer}}}" + +def process_sample(sample): + formatted_answer = f"Answer: \\boxed{{{sample['answer']}}}" + + sample["messages"] = [ + {"role": "user", "content": sample["problem"]}, + {"role": "assistant", "content": formatted_answer} + ] + return sample + +ds = ds.map(process_sample) +ds.to_parquet("/root/datasets/geo3k_imgurl/train_formatted.parquet") +``` + ## Reproduce ```bash export WANDB_API_KEY=your_wandb_api_key -MILES_SCRIPT_MODEL_NAME=Qwen3-VL-2B-Instruct MILES_SCRIPT_NUM_GPUS=8 python examples/geo3k_vlm/run_geo3k_vlm.py 2>&1 | tee run_simple.log +# Megatron backend (default -> Qwen3-VL-8B-Instruct + Megatron) +./examples/geo3k_vlm/run_geo3k_vlm.sh + +# FSDP backend +MILES_SCRIPT_TRAIN_BACKEND=fsdp ./examples/geo3k_vlm/run_geo3k_vlm.sh + +# With different model +MILES_SCRIPT_MODEL_NAME=Qwen3-VL-4B-Instruct ./examples/geo3k_vlm/run_geo3k_vlm.sh + +# SFT +./examples/geo_3k_vlm/run_geo3k_vlm_sft.sh ``` +### Configuration + +| Environment Variable | Default | Description | +|---------------------|---------|-------------| +| `MILES_SCRIPT_TRAIN_BACKEND` | `megatron` | Training backend (`megatron` or `fsdp`) | +| `MILES_SCRIPT_MODEL_NAME` | `Qwen3-VL-8B-Instruct` | Model name | +| `MILES_SCRIPT_DATASET_NAME` | `chenhegu/geo3k_imgurl` | HuggingFace dataset name | +| `MILES_SCRIPT_NUM_GPUS` | `8` | Number of GPUs | +| `MILES_SCRIPT_EXTERNAL_RAY` | `0` | Use external Ray cluster (`1` to enable) | + +### Supported Models + +- `Qwen3-VL-2B-Instruct` +- `Qwen3-VL-4B-Instruct` +- `Qwen3-VL-8B-Instruct` +- `Qwen3-VL-30B-A3B-Instruct` +- `Qwen3-VL-235B-A22B-Instruct` +- `Qwen3-VL-2B-Thinking` +- `Qwen3-VL-4B-Thinking` +- `Qwen3-VL-8B-Thinking` +- `Qwen3-VL-30B-A3B-Thinking` +- `Qwen3-VL-235B-A22B-Thinking` + ## Notes ### Reward Model Configuration @@ -32,4 +102,4 @@ Our initial geo3k-specific verifier produced "format scores" (**0 and 0.9**) ins We fixed this by switching to the default math RM with clean **binary 0/1 rewards**. If you encounter similar precision issues with non-binary rewards, you can change the reward tensor dtype from `torch.float` to `torch.float16` in `miles/ray/rollout.py` (`_post_process_rewards` method) to truncate precision artifacts. ## B200 -Blackwell currently does not support fa3, we need to use `--sglang-mm-attention-backend sdpa` and `--attn-implementation flash_attention_2` \ No newline at end of file +Blackwell currently does not support fa3, we need to use `--sglang-mm-attention-backend sdpa` and `--attn-implementation flash_attention_2` diff --git a/examples/geo3k_vlm/fsdp_vs_megatron.png b/examples/geo3k_vlm/fsdp_vs_megatron.png new file mode 100644 index 000000000..5c32e414b Binary files /dev/null and b/examples/geo3k_vlm/fsdp_vs_megatron.png differ diff --git a/examples/geo3k_vlm/rewards.png b/examples/geo3k_vlm/rewards.png deleted file mode 100644 index 4b1c6c0ce..000000000 Binary files a/examples/geo3k_vlm/rewards.png and /dev/null differ diff --git a/examples/geo3k_vlm/run_geo3k_vlm.py b/examples/geo3k_vlm/run_geo3k_vlm.py deleted file mode 100644 index 0106d2beb..000000000 --- a/examples/geo3k_vlm/run_geo3k_vlm.py +++ /dev/null @@ -1,132 +0,0 @@ -import os - -import miles.utils.misc as U -from miles.utils.external_utils.command_utils import execute_train, get_default_wandb_args - -MODEL_NAME = os.environ.get("MILES_SCRIPT_MODEL_NAME", "Qwen3-VL-2B-Instruct") -assert MODEL_NAME in {"Qwen2.5-VL-3B-Instruct", "Qwen3-VL-2B-Instruct", "Qwen3-VL-4B-Instruct", "Qwen3-VL-8B-Instruct"} - -NUM_GPUS = int(os.environ.get("MILES_SCRIPT_NUM_GPUS", "1")) -EXTERNAL_RAY = int(os.environ.get("MILES_SCRIPT_EXTERNAL_RAY", "0")) - - -def prepare(): - U.exec_command("mkdir -p /root/models /root/datasets") - U.exec_command(f"hf download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") - dataset_name = "chenhegu/geo3k_imgurl" - _, partial_name = dataset_name.split("/") - U.exec_command(f"hf download --repo-type dataset {dataset_name} --local-dir /root/datasets/{partial_name}") - - -def execute(): - - ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME} " - - rollout_args = ( - "--prompt-data /root/datasets/geo3k_imgurl/train.parquet " - "--input-key problem " - "--label-key answer " - '--multimodal-keys \'{"image": "images"}\' ' - "--apply-chat-template " - "--rollout-shuffle " - "--rm-type math " - "--num-rollout 3000 " - "--rollout-batch-size 64 " - "--n-samples-per-prompt 8 " - "--rollout-max-response-len 4096 " - "--rollout-temperature 1 " - "--global-batch-size 512 " - ) - - eval_args = ( - "--eval-interval 20 " - "--eval-prompt-data geo3k /root/datasets/geo3k_imgurl/test.parquet " - "--n-samples-per-eval-prompt 1 " - "--eval-max-response-len 4096 " - "--eval-top-k 1 " - ) - - grpo_args = ( - "--advantage-estimator grpo " - # "--use-kl-loss " - "--kl-loss-coef 0.00 " - "--kl-loss-type low_var_kl " - "--kl-coef 0.00 " - "--entropy-coef 0.00 " - "--eps-clip 0.2 " - "--eps-clip-high 0.28 " - ) - - optimizer_args = ( - "--optimizer adam " - "--lr 1e-6 " - "--lr-decay-style constant " - "--weight-decay 0.1 " - "--adam-beta1 0.9 " - "--adam-beta2 0.98 " - ) - - sglang_args = ( - "--rollout-num-gpus-per-engine 1 " - "--sglang-mem-fraction-static 0.6 " - f"--sglang-cuda-graph-bs {' '.join(map(str, [1, 2, 4, 8] + list(range(16, 257, 8))))} " - ) - - fsdp_args = ( - # Set to true for FULL_STATE_DICT mode, false for SHARDED_STATE_DICT mode (default) - # "--fsdp-full-params " # Uncomment this line to enable full params mode - # Set the bucket size for weight update - "--update-weight-buffer-size 536870912 " # 512MB - "--train-backend fsdp " - "--gradient-checkpointing " - "--sglang-attention-backend fa3 " - "--attn-implementation flash_attention_3 " - ) - - misc_args = "--actor-num-nodes 1 " f"--actor-num-gpus-per-node {NUM_GPUS} " "--colocate " - - # misc_args += ( - # "--use-dynamic-batch-size " - # # TODO pick a good value - # "--max-tokens-per-gpu 2048 " - # ) - - # true_on_policy_args = ( - # "--sglang-enable-deterministic-inference " - # "--sglang-rl-on-policy-target fsdp " - # "--deterministic-mode " - # "--true-on-policy-mode " - # ) - # true_on_policy_envs = { - # # TODO note: "Ring" in original RL PR, "allreduce:tree" in SGLang - # # "NCCL_ALGO": "Ring", - # "NCCL_ALGO": "allreduce:tree", - # "NVTE_ALLOW_NONDETERMINISTIC_ALGO": "0", - # "CUBLAS_WORKSPACE_CONFIG": ":4096:8", - # } - - train_args = ( - f"{ckpt_args} " - f"{rollout_args} " - f"{optimizer_args} " - f"{grpo_args} " - f"{sglang_args} " - f"{fsdp_args} " - f"{eval_args} " - f"{misc_args} " - f"{get_default_wandb_args(__file__)} " - # f"{true_on_policy_args} " - ) - - # Submit Ray job - execute_train( - train_args=train_args, - num_gpus_per_node=NUM_GPUS, - megatron_model_type=None, - extra_env_vars={}, - ) - - -if __name__ == "__main__": - prepare() - execute() diff --git a/examples/geo3k_vlm/run_geo3k_vlm.sh b/examples/geo3k_vlm/run_geo3k_vlm.sh new file mode 100644 index 000000000..051efc285 --- /dev/null +++ b/examples/geo3k_vlm/run_geo3k_vlm.sh @@ -0,0 +1,225 @@ +#!/bin/bash + +# Qwen3 VL RL training on geo3k dataset +# Supports both megatron and fsdp training backends +# Usage: +# MILES_SCRIPT_TRAIN_BACKEND=fsdp ./run_geo3k_vlm.sh +# MILES_SCRIPT_MODEL_NAME=Qwen3-VL-2B-Instruct ./run_geo3k_vlm.sh + +# Configuration +TRAIN_BACKEND=${MILES_SCRIPT_TRAIN_BACKEND:-"megatron"} +MODEL_NAME=${MILES_SCRIPT_MODEL_NAME:-"Qwen3-VL-8B-Instruct"} +DATASET_NAME=${MILES_SCRIPT_DATASET_NAME:-"chenhegu/geo3k_imgurl"} +NUM_GPUS=${MILES_SCRIPT_NUM_GPUS:-8} +DATASET_LOCAL_NAME=$(basename "$DATASET_NAME") + +# Validate MODEL_NAME +VALID_MODELS=" + Qwen3-VL-2B-Instruct + Qwen3-VL-4B-Instruct + Qwen3-VL-8B-Instruct + Qwen3-VL-30B-A3B-Instruct + Qwen3-VL-235B-A22B-Instruct + Qwen3-VL-2B-Thinking + Qwen3-VL-4B-Thinking + Qwen3-VL-8B-Thinking + Qwen3-VL-30B-A3B-Thinking + Qwen3-VL-235B-A22B-Thinking +" +if ! echo "$VALID_MODELS" | grep -qw "$MODEL_NAME"; then + echo "Error: MODEL_NAME must be one of: $VALID_MODELS" + exit 1 +fi + +MODEL_NAME_LOWER=$(echo "$MODEL_NAME" | tr '[:upper:]' '[:lower:]') + +# External Ray flag +if [ -z "$MILES_SCRIPT_EXTERNAL_RAY" ] || [ "$MILES_SCRIPT_EXTERNAL_RAY" = "0" ]; then + USE_EXTERNAL_RAY=0 +else + USE_EXTERNAL_RAY=1 +fi + +# Cleanup +pkill -9 sglang +sleep 3 +if [ "$USE_EXTERNAL_RAY" = "0" ]; then + ray stop --force + pkill -9 ray +fi +pkill -9 miles +sleep 3 +if [ "$USE_EXTERNAL_RAY" = "0" ]; then + pkill -9 ray +fi +pkill -9 miles +pkill -9 redis + +set -ex + +export PYTHONBUFFERED=16 + +# Detect NVLink +NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l) +if [ "$NVLINK_COUNT" -gt 0 ]; then + HAS_NVLINK=1 +else + HAS_NVLINK=0 +fi +echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" + +# Download model and dataset +mkdir -p /root/models /root/datasets +if [ ! -d "/root/models/${MODEL_NAME}" ]; then + hf download Qwen/${MODEL_NAME} --local-dir /root/models/${MODEL_NAME} +fi +if [ ! -d "/root/datasets/${DATASET_LOCAL_NAME}" ]; then + hf download --repo-type dataset ${DATASET_NAME} --local-dir /root/datasets/${DATASET_LOCAL_NAME} +fi + +# Common args +CKPT_ARGS=( + --hf-checkpoint /root/models/${MODEL_NAME} +) + +ROLLOUT_ARGS=( + --prompt-data /root/datasets/${DATASET_LOCAL_NAME}/train.parquet + --input-key problem + --label-key answer + --apply-chat-template + --rollout-shuffle + --rm-type math + --num-rollout 3000 + --rollout-batch-size 64 + --n-samples-per-prompt 8 + --rollout-max-response-len 4096 + --rollout-temperature 0.8 + --global-batch-size 512 +) + +# required for vlm datasets +MULTIMODAL_KEYS='{"image": "images"}' + +EVAL_ARGS=( + --eval-interval 20 + --eval-prompt-data ${DATASET_LOCAL_NAME} /root/datasets/${DATASET_LOCAL_NAME}/test.parquet + --n-samples-per-eval-prompt 1 + --eval-max-response-len 4096 +) + +GRPO_ARGS=( + --advantage-estimator grpo + --kl-loss-coef 0.00 + --kl-loss-type low_var_kl + --kl-coef 0.00 + --entropy-coef 0.00 + --eps-clip 0.2 + --eps-clip-high 0.28 +) + +OPTIMIZER_ARGS=( + --optimizer adam + --lr 1e-6 + --lr-decay-style constant + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.98 +) + +SGLANG_ARGS=( + --rollout-num-gpus-per-engine 1 + --sglang-mem-fraction-static 0.6 + --sglang-cuda-graph-bs 1 2 4 8 16 24 32 40 48 56 64 72 80 88 96 104 112 120 128 136 144 152 160 168 176 184 192 200 208 216 224 232 240 248 256 +) + +# Wandb args (only if WANDB_API_KEY is set) +if [ -n "$WANDB_API_KEY" ]; then + WANDB_ARGS=( + --use-wandb + --wandb-project miles-geo3k-vlm + --wandb-group ${MODEL_NAME_LOWER}-${TRAIN_BACKEND} + --wandb-key ${WANDB_API_KEY} + --disable-wandb-random-suffix + ) +else + WANDB_ARGS=() +fi + +MISC_ARGS=( + --colocate +) + +# Backend-specific args +if [ "$TRAIN_BACKEND" = "fsdp" ]; then + BACKEND_ARGS=( + --train-backend fsdp + --gradient-checkpointing + --sglang-attention-backend fa3 + --attn-implementation flash_attention_3 + --update-weight-buffer-size 536870912 + ) + MODEL_ARGS=() +else + # megatron backend (default) + BACKEND_ARGS=( + --train-backend megatron + --load /root/models/${MODEL_NAME} + --tensor-model-parallel-size 4 + --sequence-parallel + --pipeline-model-parallel-size 1 + --context-parallel-size 1 + --expert-model-parallel-size 1 + --expert-tensor-parallel-size 1 + --recompute-granularity full + --recompute-method uniform + --recompute-num-layers 1 + --use-dynamic-batch-size + --max-tokens-per-gpu 4096 + --attention-dropout 0.0 + --hidden-dropout 0.0 + --accumulate-allreduce-grads-in-fp32 + --attention-softmax-in-fp32 + --attention-backend flash + --megatron-to-hf-mode bridge + ) + + # get MODEL_ARGS from scripts/models for megatron backend + MILES_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")/../.." &>/dev/null && pwd)" + MODEL_ARGS_FILE=$(echo "$MODEL_NAME" | sed 's/-Instruct//g; s/-Thinking//g; s/Qwen3-VL-/qwen3-/g; s/-2B/-1.7B/g') + # VL models require rotary-base 5000000 + MODEL_ARGS_ROTARY_BASE=5000000 source "${MILES_DIR}/scripts/models/${MODEL_ARGS_FILE}.sh" + +fi + +# Start Ray if not using external Ray +if [ "$USE_EXTERNAL_RAY" = "0" ]; then + export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} + export no_proxy="127.0.0.1,${MASTER_ADDR}" + ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus ${NUM_GPUS} --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 +fi + +# Build runtime env +RUNTIME_ENV_JSON="{ + \"env_vars\": { + \"PYTHONPATH\": \"/root/Megatron-LM/\", + \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", + \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\" + } +}" + +ray job submit --address="http://127.0.0.1:8265" \ + --runtime-env-json="${RUNTIME_ENV_JSON}" \ + -- python3 train.py \ + --actor-num-nodes 1 \ + --actor-num-gpus-per-node ${NUM_GPUS} \ + --multimodal-keys "${MULTIMODAL_KEYS}" \ + ${MODEL_ARGS[@]} \ + ${CKPT_ARGS[@]} \ + ${ROLLOUT_ARGS[@]} \ + ${EVAL_ARGS[@]} \ + ${GRPO_ARGS[@]} \ + ${OPTIMIZER_ARGS[@]} \ + ${SGLANG_ARGS[@]} \ + ${WANDB_ARGS[@]} \ + ${BACKEND_ARGS[@]} \ + ${MISC_ARGS[@]} \ No newline at end of file diff --git a/examples/geo3k_vlm/run_geo3k_vlm_sft.sh b/examples/geo3k_vlm/run_geo3k_vlm_sft.sh new file mode 100644 index 000000000..764a7df39 --- /dev/null +++ b/examples/geo3k_vlm/run_geo3k_vlm_sft.sh @@ -0,0 +1,186 @@ +TRAIN_BACKEND=${MILES_SCRIPT_TRAIN_BACKEND:-"megatron"} +MODEL_NAME=${MILES_SCRIPT_MODEL_NAME:-"Qwen3-VL-8B-Instruct"} +DATASET_NAME=${MILES_SCRIPT_DATASET_NAME:-"chenhegu/geo3k_imgurl"} +NUM_GPUS=${MILES_SCRIPT_NUM_GPUS:-8} +DATASET_LOCAL_NAME=$(basename "$DATASET_NAME") + +# Validate MODEL_NAME +VALID_MODELS=" + Qwen3-VL-2B-Instruct + Qwen3-VL-4B-Instruct + Qwen3-VL-8B-Instruct + Qwen3-VL-2B-Thinking + Qwen3-VL-4B-Thinking + Qwen3-VL-8B-Thinking + Qwen3-VL-30B-A3B-Instruct + Qwen3-VL-235B-A22B-Instruct + Qwen3-VL-30B-A3B-Thinking + Qwen3-VL-235B-A22B-Thinking +" +if ! echo "$VALID_MODELS" | grep -qw "$MODEL_NAME"; then + echo "Error: MODEL_NAME must be one of: $VALID_MODELS" + exit 1 +fi + +MODEL_NAME_LOWER=$(echo "$MODEL_NAME" | tr '[:upper:]' '[:lower:]') + +# External Ray flag +if [ -z "$MILES_SCRIPT_EXTERNAL_RAY" ] || [ "$MILES_SCRIPT_EXTERNAL_RAY" = "0" ]; then + USE_EXTERNAL_RAY=0 +else + USE_EXTERNAL_RAY=1 +fi + +# Cleanup +pkill -9 sglang +sleep 3 +if [ "$USE_EXTERNAL_RAY" = "0" ]; then + ray stop --force + pkill -9 ray +fi +pkill -9 miles +sleep 3 +if [ "$USE_EXTERNAL_RAY" = "0" ]; then + pkill -9 ray +fi +pkill -9 miles +pkill -9 redis + +set -ex + +export PYTHONBUFFERED=16 + +# Detect NVLink +NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l) +if [ "$NVLINK_COUNT" -gt 0 ]; then + HAS_NVLINK=1 +else + HAS_NVLINK=0 +fi +echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" + +# Download model and dataset +mkdir -p /root/models /root/datasets +if [ ! -d "/root/models/${MODEL_NAME}" ]; then + hf download Qwen/${MODEL_NAME} --local-dir /root/models/${MODEL_NAME} +fi +if [ ! -d "/root/datasets/${DATASET_LOCAL_NAME}" ]; then + hf download --repo-type dataset ${DATASET_NAME} --local-dir /root/datasets/${DATASET_LOCAL_NAME} +fi + +# Common args +CKPT_ARGS=( + --hf-checkpoint /root/models/${MODEL_NAME} + --load /root/models/${MODEL_NAME} +) + +SFT_ARGS=( + --rollout-function-path miles.rollout.sft_rollout.generate_rollout + --prompt-data /root/datasets/${DATASET_LOCAL_NAME}/train_formatted.parquet + --input-key messages + --apply-chat-template + --rollout-shuffle + --num-epoch 3000 + --rollout-batch-size 128 + --global-batch-size 128 + + --loss-type sft_loss + --calculate-per-token-loss + --disable-compute-advantages-and-returns + --debug-train-only +) + +# required for vlm datasets +MULTIMODAL_KEYS='{"image": "images"}' + + +OPTIMIZER_ARGS=( + --optimizer adam + --lr 1e-5 + --lr-decay-style cosine + --min-lr 1e-6 + --lr-warmup-fraction 0.1 + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.95 +) + +if [ -n "$WANDB_API_KEY" ]; then + WANDB_ARGS=( + --use-wandb + --wandb-project miles-geo3k-vlm-sft + --wandb-group ${MODEL_NAME_LOWER}-${TRAIN_BACKEND} + --wandb-key ${WANDB_API_KEY} + --disable-wandb-random-suffix + ) +else + WANDB_ARGS=() +fi + +# Backend-specific args +if [ "$TRAIN_BACKEND" = "fsdp" ]; then + BACKEND_ARGS=( + --train-backend fsdp + --gradient-checkpointing + --attn-implementation flash_attention_3 + --update-weight-buffer-size 536870912 + ) +else + # megatron backend (default) + BACKEND_ARGS=( + --train-backend megatron + --tensor-model-parallel-size 4 + --sequence-parallel + --pipeline-model-parallel-size 1 + --context-parallel-size 1 + --expert-model-parallel-size 1 + --expert-tensor-parallel-size 1 + --recompute-granularity full + --recompute-method uniform + --recompute-num-layers 1 + --use-dynamic-batch-size + --max-tokens-per-gpu 4096 + --attention-dropout 0.0 + --hidden-dropout 0.0 + --accumulate-allreduce-grads-in-fp32 + --attention-softmax-in-fp32 + --attention-backend flash + --megatron-to-hf-mode bridge + ) + + # get MODEL_ARGS from scripts/models for megatron backend + MILES_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")/../.." &>/dev/null && pwd)" + MODEL_ARGS_FILE=$(echo "$MODEL_NAME" | sed 's/-Instruct//g; s/-Thinking//g; s/Qwen3-VL-/qwen3-/g; s/-2B/-1.7B/g') + # VL models require rotary-base 5000000 + MODEL_ARGS_ROTARY_BASE=5000000 source "${MILES_DIR}/scripts/models/${MODEL_ARGS_FILE}.sh" +fi + +# Start Ray if not using external Ray +if [ "$USE_EXTERNAL_RAY" = "0" ]; then + export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} + export no_proxy="127.0.0.1,${MASTER_ADDR}" + ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus ${NUM_GPUS} --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 +fi + +# Build runtime env +RUNTIME_ENV_JSON="{ + \"env_vars\": { + \"PYTHONPATH\": \"/root/Megatron-LM/\", + \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", + \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\" + } +}" + +ray job submit --address="http://127.0.0.1:8265" \ + --runtime-env-json="${RUNTIME_ENV_JSON}" \ + -- python3 train_async.py \ + --actor-num-nodes 1 \ + --actor-num-gpus-per-node ${NUM_GPUS} \ + --multimodal-keys "${MULTIMODAL_KEYS}" \ + ${MODEL_ARGS[@]} \ + ${CKPT_ARGS[@]} \ + ${SFT_ARGS[@]} \ + ${EVAL_ARGS[@]} \ + ${OPTIMIZER_ARGS[@]} \ + ${WANDB_ARGS[@]} \ + ${BACKEND_ARGS[@]} diff --git a/examples/geo3k_vlm_multi_turn/README.md b/examples/geo3k_vlm_multi_turn/README.md new file mode 100644 index 000000000..dfbed9f67 --- /dev/null +++ b/examples/geo3k_vlm_multi_turn/README.md @@ -0,0 +1,47 @@ +# VLM Multi-Turn (geo3k dataset) +Training VLM on [geo3k dataset](https://huggingface.co/datasets/hiyouga/geometry3k) with multi-turn reasoning with interactive environment feedback, using GRPO. For the dataset, we used the [processed version](https://huggingface.co/datasets/VeraIsHere/geo3k_imgurl_processed). + +Note: Please make sure the cudnn version in the environment is 9.16.0.29 to prevent severe performance regression in conv3d in torch 2.9 mentioned in https://github.com/pytorch/pytorch/issues/168167. Otherwise, you can reinstall cudnn with: +```bash +pip install nvidia-cudnn-cu12==9.16.0.29 +``` + +The multi-turn rollout is implemented through a [custom generate function](rollout.py#L309), overriding the original generate function. + +In terms of the environment interaction, this example initializes a [custom interactive environment](env_geo3k.py) with the APIs below. +
+Environment API (geo3k) + +- `build_env(sample: Sample | None = None, args: Any | None = None, **_) -> Geo3kEnv`: constructs the env. +- `reset() -> tuple[dict, dict]`: clears internal state. +- `step(response_text: str) -> tuple[dict, bool, dict]`: parses the actor's response text and update the state. Return new observation, a flag that marks whether the task is done, and step_info. +- `format_observation(observation: dict) -> dict`: converts an env observation into a chat message. +

+ + +The reward model is the default math RM. + +![VLM multi-turn geo3k reward](geo3k_vlm_multi_turn_reward.png) +![Rollout megatron](rollout_experiment_result_megatron.png) + +## Reproduce +```bash +# 1) Set environment variable +export WANDB_API_KEY=... +export MILES_SCRIPT_MODEL_NAME=Qwen3-VL-2B-Instruct +export MILES_SCRIPT_NUM_GPUS=4 +export MILES_SCRIPT_TRAIN_BACKEND=fsdp + +# 2) Download the dataset +hf download --repo-type dataset VeraIsHere/geo3k_imgurl_processed --local-dir /root/datasets/geo3k_imgurl_processed + +# 3) Run the script: +cd /root/miles +python examples/geo3k_vlm_multi_turn/run_geo3k_vlm_multi_turn.py +``` + +## What each file does +- `examples/geo3k_vlm_multi_turn/run_geo3k_vlm_multi_turn.py`: downloads model, sets training/rollout args, and launches the run. +- `examples/geo3k_vlm_multi_turn/geo3k_vlm_multi_turn_config.yaml`: specifies `max_turns` and `rollout_interaction_env_path` for the multi-turn rollout. +- `examples/geo3k_vlm_multi_turn/rollout.py`: custom multi-turn rollout that calls SGLang for token generation, builds loss masks/log_probs, enforces max_turns, and early-stops on max_new_tokens. +- `examples/geo3k_vlm_multi_turn/env_geo3k.py`: geo3k tool-calling env that parses {...}, scores math answers, and returns tool feedback per turn. diff --git a/examples/geo3k_vlm_multi_turn/__init__.py b/examples/geo3k_vlm_multi_turn/__init__.py new file mode 100644 index 000000000..526e3ae7c --- /dev/null +++ b/examples/geo3k_vlm_multi_turn/__init__.py @@ -0,0 +1 @@ +# Multi-turn VLM Sokoban example package diff --git a/examples/geo3k_vlm_multi_turn/base_env.py b/examples/geo3k_vlm_multi_turn/base_env.py new file mode 100644 index 000000000..05d9632b3 --- /dev/null +++ b/examples/geo3k_vlm_multi_turn/base_env.py @@ -0,0 +1,25 @@ +class BaseInteractionEnv: + """ + Base class that defines the explicit contract for interaction environments. + """ + + def reset(self): + raise NotImplementedError + + def step(self, response_text: str): + raise NotImplementedError + + def close(self): + pass + + def format_observation(self, observation: dict) -> dict: + observation = observation or {} + content = [] + multimodal = observation.get("multi_modal_data") or {} + + for _, images in multimodal.items(): + for image in images: + content.append({"type": "image", "image": image}) + + content.append({"type": "text", "text": observation.get("obs_str", "")}) + return {"role": "user", "content": content} diff --git a/examples/geo3k_vlm_multi_turn/env_geo3k.py b/examples/geo3k_vlm_multi_turn/env_geo3k.py new file mode 100644 index 000000000..be089df9b --- /dev/null +++ b/examples/geo3k_vlm_multi_turn/env_geo3k.py @@ -0,0 +1,273 @@ +from __future__ import annotations + +import json +import logging +import re +from copy import deepcopy +from typing import Any + +try: + import orjson # type: ignore +except Exception: # pragma: no cover - optional dependency + orjson = None +from examples.geo3k_vlm_multi_turn.base_env import BaseInteractionEnv + +from miles.rollout.rm_hub import grade_answer_verl +from miles.rollout.rm_hub.math_utils import extract_answer as extract_boxed_answer +from miles.utils.types import Sample + +logger = logging.getLogger(__name__) + +# Matches the JSON payload emitted between ... tags. +TOOL_CALL_RE = re.compile(r"\s*(\{.*?\})\s*", re.DOTALL) +# Accept either name; verl uses `calc_geo3k_reward` while the instruction refers to `calc_score`. +SUPPORTED_TOOL_NAMES = {"calc_score", "calc_geo3k_reward"} + + +class Geo3kEnv(BaseInteractionEnv): + """ + Minimal interaction environment for multi-turn geo3k with a scoring tool. + + The model is expected to emit a {...} payload that includes + an `answer` argument. We run the math reward checker against the ground truth and + return the score as the next observation. The episode ends immediately after each + step; responses are provided but no further turns are taken. + """ + + def __init__(self, *, ground_truth: str | None = None, max_turns: int | None = None): + self.ground_truth = str(ground_truth) if ground_truth is not None else None + self.tool_calls: list[dict[str, Any]] = [] + self.last_tool_score: float | None = None + self.turn = 0 + self.max_turns = max_turns + + def reset(self): + self.tool_calls.clear() + self.last_tool_score = None + self.turn = 0 + # No initial observation is needed; the question lives in the prompt. + observation: dict[str, Any] = {} + reset_info = {"ground_truth_available": self.ground_truth is not None} + return observation, reset_info + + def close(self): + """No resources to release.""" + return + + def _extract_tool_call(self, text: str) -> dict[str, Any] | None: + """ + Parse the latest tool call payload from the assistant response. + Supports the {...} convention used in the + SGLang multi-turn templates. Tool tags are mandatory. + """ + matches = list(TOOL_CALL_RE.finditer(text)) + raw_json = None + if matches: + raw_json = matches[-1].group(1).strip() + + if raw_json is None: + return None + + payload = self._parse_tool_payload(raw_json) + if payload is None: + return None + + name = payload.get("name") or payload.get("function", {}).get("name") + arguments = payload.get("arguments") or payload.get("function", {}).get("arguments") or {} + if isinstance(arguments, str): + try: + arguments = json.loads(arguments) + except json.JSONDecodeError: + logger.warning("Tool call arguments are not valid JSON; rejecting tool call.") + return None + + if not name: + return None + return {"name": name, "arguments": arguments} + + def _score_answer(self, answer: str) -> float: + """ + Use the same logic as the single-turn math reward model. + We accept either boxed or raw numeric strings by retrying with a boxed wrapper. + """ + if not self.ground_truth: + return 0.0 + + answer = answer.strip() + candidates = [answer] + if "\\boxed" not in answer: + candidates.append(f"\\boxed{{{answer}}}") + + for candidate in candidates: + try: + if grade_answer_verl(candidate, self.ground_truth): + return 1.0 + except Exception as exc: # pragma: no cover - defensive + logger.debug("grade_answer_verl failed on %s: %s", candidate, exc) + continue + return 0.0 + + def _extract_answer_from_text(self, text: str) -> str | None: + """ + Prefer a concise answer by pulling the last \\boxed{} chunk; fall back to the last + non-empty line (capped) to avoid echoing the whole response body. + """ + boxed = extract_boxed_answer(text) + if boxed: + return str(boxed).strip() + for line in reversed(text.splitlines()): + cleaned = line.strip() + if cleaned: + return cleaned[:512] + trimmed = text.strip() + return trimmed[:512] if trimmed else None + + def _extract_balanced_json(self, text: str, start: int) -> str | None: + """ + Best-effort balanced brace extraction starting at `start` (index of an opening '{'). + Keeps string-awareness to avoid terminating inside quoted braces. + """ + depth = 0 + in_string = False + escaped = False + for idx in range(start, len(text)): + ch = text[idx] + if ch == "\\" and not escaped: + escaped = True + continue + if ch == '"' and not escaped: + in_string = not in_string + if not in_string: + if ch == "{": + depth += 1 + elif ch == "}": + depth -= 1 + if depth == 0: + return text[start : idx + 1] + escaped = False + return None + + def _build_tool_feedback(self, score: float, parsed_answer: str) -> str: + """ + Provide concise feedback for the model to continue reasoning. + """ + turn_idx = self.turn - 1 # zero-based + # Send the final reminder one turn before the true last turn so the model sees it in time. + last_warning_turn = None + if self.max_turns is not None: + if self.max_turns >= 2: + last_warning_turn = self.max_turns - 2 + else: + last_warning_turn = self.max_turns - 1 + is_final_turn = last_warning_turn is not None and turn_idx >= last_warning_turn + + if score == 1.0: + return ( + f"calc_score result: {score}. Parsed answer '{parsed_answer}' matches the reference. " + "You can now stop reasoning and provide the final solution in \\boxed{}." + ) + if score == 0.0: + if is_final_turn: + return ( + f"calc_score result: {score}. Parsed answer '{parsed_answer}' does not match the reference. " + "Your answer is wrong. You may need to reason in a different way. Don't repeat your answer unless necessary. " + "Since you only have one chance to answer, don't call tool again. You should provide your final answer in the form below Answer: \\boxed{$Answer} where $Answer is your fiinal answer to this problem." + ) + return ( + f"calc_score result: {score}. Parsed answer '{parsed_answer}' does not match the reference. " + "Your answer is wrong. You may need to reason in a different way. Don't repeat your answer unless necessary." + ) + + # Called during rollout after receiving a model response + def step(self, response_text: str): + self.turn += 1 + is_final_turn = self.max_turns is not None and self.turn >= self.max_turns + tool_call = self._extract_tool_call(response_text) + info: dict[str, Any] = {"tool_call": deepcopy(tool_call)} + + if not tool_call: + info["tool_executed"] = False + obs = { + "obs_str": "No tool call detected; ending the episode.", + "role": "tool", + } + return obs, True, info + + name = (tool_call.get("name") or "").strip() + arguments = tool_call.get("arguments") or {} + if name not in SUPPORTED_TOOL_NAMES: + obs = { + "obs_str": ( + f"Tool `{name}` is not supported. " + 'Call `calc_score` (or `calc_geo3k_reward`) via {"name": "calc_score", "arguments": {"answer": ""}} (format must be (JSON))' + "to check your solution." + ), + "role": "tool", + } + info["tool_executed"] = False + return obs, is_final_turn, info + + raw_answer = arguments.get("answer", None) + parsed_answer = "" if raw_answer is None else str(raw_answer) + if not parsed_answer.strip(): + obs = { + "obs_str": ( + "Tool call detected but no `answer` was provided. " + 'Call `calc_score` (or `calc_geo3k_reward`) via {"name": "calc_score", "arguments": {"answer": ""}} ' + "to check your solution." + ), + "role": "tool", + } + info["tool_executed"] = False + info["answer_missing"] = True + return obs, is_final_turn, info + + score = self._score_answer(parsed_answer) + self.last_tool_score = score + tool_record = {"name": name, "answer": parsed_answer, "score": score} + self.tool_calls.append(tool_record) + info.update(tool_record) + info["tool_executed"] = True + + obs = { + "obs_str": self._build_tool_feedback(score, parsed_answer), + "role": "tool", + "tool_score": score, + } + + return obs, is_final_turn, info + + def _parse_tool_payload(self, raw_json: str) -> dict[str, Any] | None: + """Parse tool payload strictly as JSON. Malformed payloads are rejected.""" + loader = orjson.loads if orjson is not None else json.loads + try: + return loader(raw_json) + except Exception as exc: + logger.warning("Failed to decode tool call payload: %s", exc) + return None + + +def _extract_ground_truth(sample: Sample | None) -> str | None: + """Resolve the ground-truth answer from label or metadata.""" + if sample is None: + return None + if sample.label is not None: + return str(sample.label) + # metadata = sample.metadata + # for key in ("answer", "ground_truth", "label"): + # if key in metadata and metadata[key] is not None: + # return str(metadata[key]) + return None + + +def build_env(sample: Sample | None = None, args: Any | None = None, **_: Any) -> Geo3kEnv: + """ + Construct a Geo3kEnv. Ground truth is pulled from sample.label or metadata. + """ + ground_truth = _extract_ground_truth(sample) + max_turns = args.max_turns + if max_turns is None: + raise ValueError("max_turns must be set via --custom-config-path in the custom config file.") + if ground_truth is None: + logger.warning("Ground truth answer missing; calc_score tool will always return 0.") + return Geo3kEnv(ground_truth=ground_truth, max_turns=max_turns) diff --git a/examples/geo3k_vlm_multi_turn/geo3k_vlm_multi_turn_config.yaml b/examples/geo3k_vlm_multi_turn/geo3k_vlm_multi_turn_config.yaml new file mode 100644 index 000000000..ad2dd6fef --- /dev/null +++ b/examples/geo3k_vlm_multi_turn/geo3k_vlm_multi_turn_config.yaml @@ -0,0 +1,2 @@ +max_turns: 3 +rollout_interaction_env_path: examples.geo3k_vlm_multi_turn.env_geo3k diff --git a/examples/geo3k_vlm_multi_turn/geo3k_vlm_multi_turn_reward.png b/examples/geo3k_vlm_multi_turn/geo3k_vlm_multi_turn_reward.png new file mode 100644 index 000000000..88851bb46 Binary files /dev/null and b/examples/geo3k_vlm_multi_turn/geo3k_vlm_multi_turn_reward.png differ diff --git a/examples/geo3k_vlm_multi_turn/rollout.py b/examples/geo3k_vlm_multi_turn/rollout.py new file mode 100644 index 000000000..0e582c8db --- /dev/null +++ b/examples/geo3k_vlm_multi_turn/rollout.py @@ -0,0 +1,373 @@ +from __future__ import annotations + +import importlib +import importlib.util +import sys +from pathlib import Path +from typing import Any + +import torch +from examples.geo3k_vlm_multi_turn.base_env import BaseInteractionEnv + +# When executed as a module: python -m examples.vlm_multi_turn.rollout +from miles.rollout.sglang_rollout import GenerateState +from miles.utils.http_utils import post +from miles.utils.processing_utils import encode_image_for_rollout_engine +from miles.utils.types import Sample + +DEFAULT_ENV_MODULE = "examples.vlm_multi_turn.env_geo3k" + +# Dummy messages used for calculating trim length in chat template encoding +DUMMY_MESSAGES = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "I am a user."}, +] + + +def _load_env_module(env_path: str | None): + """Load the interaction environment module from a module path or a file path.""" + target = env_path or DEFAULT_ENV_MODULE + module_path = Path(target) + if module_path.suffix == ".py" and module_path.exists(): + spec = importlib.util.spec_from_file_location(f"rollout_env_{module_path.stem}", module_path) + if spec is None or spec.loader is None: + raise ImportError(f"Cannot import environment module from {module_path}") + module = importlib.util.module_from_spec(spec) + sys.modules[spec.name] = module + spec.loader.exec_module(module) + return module + return importlib.import_module(target) + + +def _build_env(env_module, sample: Sample, args: Any): + """Instantiate the interaction environment using the provided module.""" + build_fn = env_module.build_env + if not callable(build_fn): + raise ValueError("Environment module must expose a callable `build_env(sample, args)`.") + try: + return build_fn(sample=sample, args=args) + except TypeError: + # Fallback to positional signature + return build_fn(sample, args) + + +def _encode_observation_for_generation( + tokenizer, + processor, + message: dict, + metadata: dict | None, + apply_chat_template: bool, + apply_chat_template_kwargs: dict | None, +): + """ + Encode a single observation turn that may include images/videos in the content list. + Trim out the system/tool preamble added by the chat template so only the observation tokens remain. + """ + tools = metadata.get("tools") if metadata else None + apply_kwargs = apply_chat_template_kwargs or {} + + trim_length = 0 + + if apply_chat_template: + dummy_prompt = tokenizer.apply_chat_template( + DUMMY_MESSAGES, + tools=tools, + tokenize=False, + add_generation_prompt=False, + **apply_kwargs, + ) + formatted_prompt = tokenizer.apply_chat_template( + DUMMY_MESSAGES + [message], + tools=tools, + tokenize=False, + add_generation_prompt=True, + **apply_kwargs, + ) + trim_length = len(tokenizer.encode(dummy_prompt, add_special_tokens=False)) + else: + formatted_prompt = [message] + + multimodal_inputs = None + multimodal_train_inputs = None + if processor: + # Convert content-embedded images/videos into multimodal inputs for the processor. + from qwen_vl_utils import process_vision_info + + images, videos = process_vision_info([message]) + multimodal_inputs = {"images": images, "videos": videos} + processor_output = processor(text=formatted_prompt, **multimodal_inputs) + prompt_ids = processor_output["input_ids"][0] + multimodal_train_inputs = { + k: v for k, v in processor_output.items() if k not in ["input_ids", "attention_mask"] + } or None + else: + prompt_ids = tokenizer.encode(formatted_prompt, add_special_tokens=False) + + if trim_length: + prompt_ids = prompt_ids[trim_length:] + + image_data = [] + if multimodal_inputs and multimodal_inputs.get("images"): + image_data = [encode_image_for_rollout_engine(img) for img in multimodal_inputs["images"]] + return prompt_ids, image_data, multimodal_inputs, multimodal_train_inputs + + +def _merge_multimodal_train_inputs(chunks: list[dict | None]) -> dict | None: + """ + Merge per-turn multimodal_train_inputs with a single concat per key. + + Note: Only torch.Tensor values are merged; non-tensor fields are ignored by design. + """ + if not chunks: + return None + + values_by_key = {} + for chunk in chunks: + if not chunk: + continue + for key, val in chunk.items(): + if val is None: + continue + values_by_key.setdefault(key, []).append(val) + + merged = {} + for key, values in values_by_key.items(): + if all(isinstance(v, torch.Tensor) for v in values): + merged[key] = torch.cat(values, dim=0) + + return merged + + +def _initialize_resources(args: Any, sample: Sample): + env_module = _load_env_module(args.rollout_interaction_env_path) + max_turns = args.max_turns + if max_turns is None: + raise ValueError("max_turns must be set via --custom-config-path in the custom config file.") + state = GenerateState(args) + url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/generate" + sample.metadata = sample.metadata or {} + env = _build_env(env_module, sample, args) + config = {"max_turns": max_turns} + return env, env_module, config, state, url + + +def _prepare_initial_inputs(sample: Sample, processor, tokenizer): + if processor: + processor_output = processor(text=sample.prompt, **(sample.multimodal_inputs or {})) + prompt_ids = processor_output["input_ids"][0] + sample.multimodal_train_inputs = { + k: v for k, v in processor_output.items() if k not in ["input_ids", "attention_mask"] + } or None + else: + prompt_ids = tokenizer.encode(sample.prompt, add_special_tokens=False) + + image_data = [] + if sample.multimodal_inputs and sample.multimodal_inputs.get("images"): + image_data = [encode_image_for_rollout_engine(img) for img in sample.multimodal_inputs["images"]] + return prompt_ids, image_data, sample.multimodal_train_inputs + + +def _prepare_start_state(sample: Sample, state, args: Any, sampling_params: dict): + prompt_ids, image_data, init_mm_train = _prepare_initial_inputs(sample, state.processor, state.tokenizer) + current_image_data = image_data + multimodal_train_inputs_buffer: list[dict | None] = [] + if init_mm_train: + multimodal_train_inputs_buffer.append(init_mm_train) + + if not sample.tokens: + sample.tokens = list(prompt_ids) + response_tokens: list[int] = sample.tokens[len(prompt_ids) :] if len(sample.tokens) >= len(prompt_ids) else [] + sample.loss_mask = sample.loss_mask or [] + sample.rollout_log_probs = sample.rollout_log_probs or [] + sample.response_length = len(response_tokens) + + budget = None + if args.rollout_max_context_len is not None: + budget = args.rollout_max_context_len - len(sample.tokens) + elif sampling_params.get("max_new_tokens") is not None: + budget = sampling_params["max_new_tokens"] - len(sample.tokens) + return current_image_data, response_tokens, budget, multimodal_train_inputs_buffer + + +async def _run_inference_step(url: str, tokens: list[int], sampling_params: dict, image_data, tokenizer): + payload = { + "input_ids": tokens, + "sampling_params": sampling_params, + "return_logprob": True, + } + if image_data: + payload["image_data"] = image_data + + output = await post(url, payload) + response_text = output["text"] + if "output_token_logprobs" in output["meta_info"]: + new_tokens = [item[1] for item in output["meta_info"]["output_token_logprobs"]] + new_log_probs = [item[0] for item in output["meta_info"]["output_token_logprobs"]] + else: + new_tokens, new_log_probs = [], [] + finish_type = output["meta_info"]["finish_reason"]["type"] + return response_text, new_tokens, new_log_probs, finish_type + + +def _process_env_step(env: BaseInteractionEnv, response_text: str, tokenizer, processor, args, sample_metadata): + observation, done, _ = env.step(response_text) + if done: + return None, None, None, None, True + + next_user_message = env.format_observation(observation) + obs_prompt_ids, obs_image_data, obs_multimodal_inputs, obs_multimodal_train_inputs = ( + _encode_observation_for_generation( + tokenizer, + processor, + next_user_message, + sample_metadata, + args.apply_chat_template, + args.apply_chat_template_kwargs, + ) + ) + + bos_id = tokenizer.bos_token_id + if bos_id is not None and obs_prompt_ids and obs_prompt_ids[0] == bos_id: + obs_prompt_ids = obs_prompt_ids[1:] + + return obs_prompt_ids, obs_image_data, obs_multimodal_inputs, obs_multimodal_train_inputs, False + + +def _append_to_sample( + sample: Sample, + response_tokens: list[int], + tokens_to_add: list[int], + logprobs: list[float], + loss_mask_val: int, +) -> None: + sample.tokens.extend(tokens_to_add) + response_tokens.extend(tokens_to_add) + sample.loss_mask.extend([loss_mask_val] * len(tokens_to_add)) + sample.rollout_log_probs.extend(logprobs) + sample.response_length = len(response_tokens) + + +def _update_multimodal_state( + sample: Sample, + current_image_data, + obs_image_data, + obs_multimodal_inputs, + obs_multimodal_train_inputs, + multimodal_train_inputs_buffer: list[dict | None], +): + if obs_image_data: + current_image_data = (current_image_data or []) + obs_image_data + + if obs_multimodal_inputs: + if not sample.multimodal_inputs: + sample.multimodal_inputs = obs_multimodal_inputs + elif isinstance(sample.multimodal_inputs, dict) and isinstance(obs_multimodal_inputs, dict): + for key, val in obs_multimodal_inputs.items(): + if val is None: + continue + if ( + key in sample.multimodal_inputs + and isinstance(sample.multimodal_inputs[key], list) + and isinstance(val, list) + ): + sample.multimodal_inputs[key].extend(val) + else: + sample.multimodal_inputs = obs_multimodal_inputs + + if obs_multimodal_train_inputs: + multimodal_train_inputs_buffer.append(obs_multimodal_train_inputs) + + return current_image_data + + +def _should_stop_on_finish(sample: Sample, finish_type: str) -> bool: + match finish_type: + case "length": + sample.status = Sample.Status.TRUNCATED + return True + case "abort": + sample.status = Sample.Status.ABORTED + return True + return False + + +def _update_budget(budget, consumed: int): + if budget is None: + return None + return budget - consumed + + +def _finalize_sample(sample: Sample, tokenizer, response_tokens, multimodal_train_inputs_buffer): + sample.multimodal_train_inputs = _merge_multimodal_train_inputs(multimodal_train_inputs_buffer) + sample.response = tokenizer.decode(response_tokens, skip_special_tokens=False) + sample.response_length = len(response_tokens) + if sample.status is None: + sample.status = Sample.Status.COMPLETED + return sample + + +async def generate(args: Any, sample: Sample, sampling_params) -> Sample: + """Custom multi-turn rollout that interacts with a pluggable environment.""" + assert not args.partial_rollout, "Partial rollout is not supported for interaction rollouts." + + env, env_module, config, state, url = _initialize_resources(args, sample) + sampling_params = sampling_params.copy() + current_image_data, response_tokens, budget, multimodal_train_inputs_buffer = _prepare_start_state( + sample, state, args, sampling_params + ) + try: + env.reset() + if budget is not None and budget <= 0: + sample.status = Sample.Status.TRUNCATED + return sample + + cur_sampling_params = sampling_params + for turn_idx in range(config["max_turns"]): + if budget is not None: + cur_sampling_params["max_new_tokens"] = budget + + response_text, new_response_tokens, new_response_log_probs, finish_type = await _run_inference_step( + url, sample.tokens, cur_sampling_params, current_image_data, state.tokenizer + ) + _append_to_sample(sample, response_tokens, new_response_tokens, new_response_log_probs, loss_mask_val=1) + budget = _update_budget(budget, len(new_response_tokens)) + + if _should_stop_on_finish(sample, finish_type): + break + if budget is not None and budget <= 0: + sample.status = Sample.Status.TRUNCATED + break + + obs_prompt_ids, obs_image_data, obs_multimodal_inputs, obs_multimodal_train_inputs, done = ( + _process_env_step(env, response_text, state.tokenizer, state.processor, args, sample.metadata) + ) + if done: + sample.status = Sample.Status.COMPLETED + break + + obs_log_probs = [0.0] * len(obs_prompt_ids) + _append_to_sample(sample, response_tokens, obs_prompt_ids, obs_log_probs, loss_mask_val=0) + budget = _update_budget(budget, len(obs_prompt_ids)) + + current_image_data = _update_multimodal_state( + sample, + current_image_data, + obs_image_data, + obs_multimodal_inputs, + obs_multimodal_train_inputs, + multimodal_train_inputs_buffer, + ) + + if budget is not None and budget <= 0: + sample.status = Sample.Status.TRUNCATED + break + if turn_idx + 1 >= config["max_turns"]: + sample.status = Sample.Status.COMPLETED + break + + return _finalize_sample(sample, state.tokenizer, response_tokens, multimodal_train_inputs_buffer) + finally: + try: + env.close() + except Exception: + pass diff --git a/examples/geo3k_vlm_multi_turn/rollout_experiment_result_megatron.png b/examples/geo3k_vlm_multi_turn/rollout_experiment_result_megatron.png new file mode 100644 index 000000000..dd249de29 Binary files /dev/null and b/examples/geo3k_vlm_multi_turn/rollout_experiment_result_megatron.png differ diff --git a/examples/geo3k_vlm_multi_turn/run_geo3k_vlm_multi_turn.py b/examples/geo3k_vlm_multi_turn/run_geo3k_vlm_multi_turn.py new file mode 100644 index 000000000..5c32d33d3 --- /dev/null +++ b/examples/geo3k_vlm_multi_turn/run_geo3k_vlm_multi_turn.py @@ -0,0 +1,171 @@ +import os + +import miles.utils.misc as U +from miles.utils.external_utils.command_utils import execute_train + +MODEL_NAME = os.environ.get("MILES_SCRIPT_MODEL_NAME", "Qwen3-VL-2B-Instruct") +assert MODEL_NAME in { + "Qwen3-VL-2B-Instruct", + "Qwen3-VL-4B-Instruct", + "Qwen3-VL-8B-Instruct", + "Qwen3-VL-2B-Thinking", + "Qwen3-VL-4B-Thinking", + "Qwen3-VL-8B-Thinking", +} + +NUM_GPUS = int(os.environ.get("MILES_SCRIPT_NUM_GPUS", "4")) +EXTERNAL_RAY = int(os.environ.get("MILES_SCRIPT_EXTERNAL_RAY", "0")) +TRAIN_BACKEND = os.environ.get("MILES_SCRIPT_TRAIN_BACKEND", "fsdp").lower() +assert TRAIN_BACKEND in {"fsdp", "megatron"} + +DATASET_NAME = "VeraIsHere/geo3k_imgurl_processed" +DATA_ROOT = "/root/datasets/geo3k_imgurl_processed" +TRAIN_DATA_PATH = os.path.join(DATA_ROOT, "train.parquet") + + +def get_megatron_model_type(model_name: str) -> str: + model_type = model_name.replace("-Instruct", "").replace("-Thinking", "") + model_type = model_type.replace("Qwen3-VL-", "qwen3-") + return model_type.replace("-2B", "-1.7B") + + +def prepare(): + U.exec_command("mkdir -p /root/models /root/datasets") + U.exec_command(f"hf download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") + data_missing = not os.path.exists(TRAIN_DATA_PATH) + if data_missing: + U.exec_command(f"hf download --repo-type dataset {DATASET_NAME} --local-dir {DATA_ROOT}") + if not os.path.exists(TRAIN_DATA_PATH): + raise FileNotFoundError(f"Dataset not found. Expected local dataset at {TRAIN_DATA_PATH}; ") + + +def execute(): + ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME} " + + wandb_args = ( + ( + "--use-wandb " + "--wandb-project miles-dev " + "--wandb-group geo3k_vlm_multi_turn " + f"--wandb-key '{wandb_api_key}' " + ) + if (wandb_api_key := os.environ.get("WANDB_API_KEY")) + else "" + ) + + rollout_args = ( + f"--prompt-data {TRAIN_DATA_PATH} " + "--input-key problem " + "--label-key answer " + '--multimodal-keys \'{"image": "images"}\' ' + "--rm-type math " + "--apply-chat-template " + "--custom-generate-function-path examples.geo3k_vlm_multi_turn.rollout.generate " + "--custom-config-path examples/geo3k_vlm_multi_turn/geo3k_vlm_multi_turn_config.yaml " + "--rollout-shuffle " + "--num-rollout 3000 " + "--rollout-batch-size 64 " + "--n-samples-per-prompt 8 " + "--rollout-max-response-len 4096 " + "--rollout-temperature 1 " + "--global-batch-size 512 " + ) + + # eval_args = ( + # "--eval-interval 20 " + # f"--eval-prompt-data geo3k_eval {TRAIN_DATA_PATH}@[0:64] " + # "--n-samples-per-eval-prompt 1 " + # "--eval-max-response-len 4096 " + # "--eval-top-k 1 " + # ) + + grpo_args = ( + "--advantage-estimator grpo " + "--kl-loss-coef 0.00 " + "--kl-loss-type low_var_kl " + "--kl-coef 0.00 " + "--entropy-coef 0.00 " + "--eps-clip 0.2 " + "--eps-clip-high 0.28 " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + ) + + sglang_args = ( + "--rollout-num-gpus-per-engine 1 " + "--sglang-mem-fraction-static 0.6 " + f"--sglang-cuda-graph-bs {' '.join(map(str, [1, 2, 4, 8] + list(range(16, 257, 8))))} " + ) + + fsdp_args = ( + "--train-backend fsdp " + "--gradient-checkpointing " + "--sglang-attention-backend fa3 " + "--attn-implementation flash_attention_3 " + "--update-weight-buffer-size 536870912 " + ) + + megatron_args = ( + "--train-backend megatron " + f"--load /root/models/{MODEL_NAME} " + "--tensor-model-parallel-size 4 " + "--sequence-parallel " + "--pipeline-model-parallel-size 1 " + "--context-parallel-size 1 " + "--expert-model-parallel-size 1 " + "--expert-tensor-parallel-size 1 " + "--recompute-granularity full " + "--recompute-method uniform " + "--recompute-num-layers 1 " + "--use-dynamic-batch-size " + "--max-tokens-per-gpu 4096 " + "--attention-dropout 0.0 " + "--hidden-dropout 0.0 " + "--accumulate-allreduce-grads-in-fp32 " + "--attention-softmax-in-fp32 " + "--attention-backend flash " + "--megatron-to-hf-mode bridge " + ) + + misc_args = ( + "--actor-num-nodes 1 " f"--actor-num-gpus-per-node {NUM_GPUS} " f"--rollout-num-gpus {NUM_GPUS} " "--colocate " + ) + + if TRAIN_BACKEND == "megatron": + backend_args = megatron_args + megatron_model_type = get_megatron_model_type(MODEL_NAME) + os.environ["MODEL_ARGS_ROTARY_BASE"] = "5000000" + else: + backend_args = fsdp_args + megatron_model_type = None + + train_args = ( + f"{ckpt_args} " + f"{rollout_args} " + f"{optimizer_args} " + f"{grpo_args} " + f"{sglang_args} " + f"{backend_args} " + f"{misc_args} " + f"{wandb_args} " + # f"{get_default_wandb_args(__file__)} " + ) + + execute_train( + train_args=train_args, + num_gpus_per_node=NUM_GPUS, + megatron_model_type=megatron_model_type, + extra_env_vars=({"WANDB_API_KEY": os.environ["WANDB_API_KEY"]} if os.environ.get("WANDB_API_KEY") else {}), + ) + + +if __name__ == "__main__": + prepare() + execute() diff --git a/examples/low_precision/README.md b/examples/low_precision/README.md index 12ceed735..97389fd91 100644 --- a/examples/low_precision/README.md +++ b/examples/low_precision/README.md @@ -1,14 +1,14 @@ -## FP8 training examples +# FP8 training examples -This is an example of FP8 training and FP8 inference. Under FP8 training and inference, it can achieve more efficient inference throughput and lower training-inference mismatch, resulting in more stable training. +This is an example of FP8 training and FP8 inference. Under FP8 training and inference, it can achieve more efficient inference throughput and lower training-inference mismatch, resulting in more stable training. More details can be found in [this blog](https://lmsys.org/blog/2025-11-25-fp8-rl/). -### Files +## Files * `run-qwen3-4b-fp8.sh`: example launch script with Qwen3โ€‘4B in FP8. * `run-qwen3-30b-a3b-fp8-two-nodes.sh`: example launch script for running Qwen3โ€‘30Bโ€‘A3B in FP8 across two nodes. -### Quick Start +## Quick Start 1. Check if your training script is properly configured. @@ -44,7 +44,7 @@ Following the above command will launch FP8 training. Note that TransformerEngine does not specifically save FP8 quantized weights; the saved torch dist remains in original precision (usually bf16). If you want to evaluate under FP8, you need to convert the checkpoint from `torch_dist` to HuggingFace format, then convert to FP8 HuggingFace format. -### Quick Explanation +## Quick Explanation Here's a quick explanation of how FP8 training is currently implemented in miles: @@ -57,10 +57,79 @@ Here's a quick explanation of how FP8 training is currently implemented in miles 4. Save checkpoint: Similar to weight updates, if checkpoints need to be saved from the training engine, they will also be dequantized back to bf16 and saved to `torch_dist` format checkpoints. -### TODO +## TODO Currently, FP8 is far from being a complete feature and still has the following bugs, for examples: - FP8 weights (`--fp8-param-gather`) can provide memory savings benefits, but currently FP8 weights must be used with TransformerEngine's FusedAdam, which conflicts with the commonly used Adam CPU offload technique in Megatron-LM. -The miles team will continue to collaborate with the NVIDIA team to contribute more complete FP8 training infrastructure to the community. \ No newline at end of file +The miles team will continue to collaborate with the NVIDIA team to contribute more complete FP8 training infrastructure to the community. + +*** + +## INT4 Training Examples + +This guide provides examples for INT4 STE (Straight-Through Estimator) training and INT4 inference. Utilizing INT4 inference significantly improves throughput, thereby accelerating the training pipeline (specifically during the rollout generation phase). + +### Files + +* `run-moonlight-16B-A3B-int4.sh`: Launch script for **Moonlight-16B-A3B** (INT4) on 4x H200 GPUs. +* `run-qwen3โ€‘30Bโ€‘A3B-int4.sh`: Launch script for **Qwen3โ€‘30Bโ€‘A3B** (INT4) on 8x H200 GPUs. +* `run-qwen3-235B-A22B-int4.sh`: Launch script for **Qwen3-235B-A22B** (INT4) on 64x H200 GPUs. +* `run-kimi-k2-Thinking-int4.sh`: Launch script for **Kimi-k2-Thinking** (INT4) on 256x H200 GPUs. + +### Quick Start + +#### 1. Convert HuggingFace Weights to INT4 +First, download the PTQ (Post-Training Quantization) calibration dataset from HuggingFace: +[https://huggingface.co/datasets/Salesforce/wikitext/tree/main/wikitext-2-raw-v1](https://huggingface.co/datasets/Salesforce/wikitext/tree/main/wikitext-2-raw-v1) + +Next, use the `tools/convert_hf_to_hf_int4.py` script to convert BF16 weights to INT4 format. Ensure that the `--hf-checkpoint` parameter points to a directory where `config.json` contains the correct `quantization_config`. miles will automatically utilize INT4 quantization during weight updates. + +```bash +python tools/convert_hf_to_hf_int4.py \ + --input-dir /path/to/your/original/models \ + --output-dir /path/to/your/save/models \ + --data-dir /path/to/your/wikitext +``` + +#### 2. Start INT4 Training + +You need to configure the specific environment variables for quantization settings. + +**Environment Variables:** + +* **`OPEN_TRAINING_INT4_FAKE_QAT_FLAG`**: Enables fake quantization operations for INT4 training. +* **`OPEN_TRAINING_INT4_GROUP_SIZE`**: Specifies the block size (group size) for model quantization. + * Set to **128** for `moonlight-16B-A3B` ใ€ `qwen3-30B-A3B`and `qwen3-235B-A22B-int4`. + * Set to **32** for `kimi-k2-Thinking-int4`. + +**Configuration Example:** + +```json +RUNTIME_ENV_JSON="{ + \"env_vars\": { + ... + \"OPEN_TRAINING_INT4_FAKE_QAT_FLAG\": \"1\", + \"OPEN_TRAINING_INT4_GROUP_SIZE\": \"128\" + } +}" +``` + +**Launch Commands:** + +```bash +# Moonlight-16B-A3B Int4 training +bash examples/low_precision/run-moonlight-16B-A3B-int4.sh + +# Qwen3โ€‘30Bโ€‘A3B Int4 training +bash examples/low_precision/run-qwen3โ€‘30Bโ€‘A3B-int4.sh + +# Qwen3-235B-A22B Int4 training (8 nodes) +bash examples/low_precision/run-qwen3-235B-A22B-int4.sh + +# Kimi-k2-Thinking Int4 training (32 nodes) +bash examples/low_precision/run-kimi-k2-Thinking-int4.sh +``` + +- For multi-node environments, please start the Ray service according to your cluster configuration. \ No newline at end of file diff --git a/examples/low_precision/run-kimi-k2-Thinking-int4.sh b/examples/low_precision/run-kimi-k2-Thinking-int4.sh new file mode 100644 index 000000000..3bedbf88a --- /dev/null +++ b/examples/low_precision/run-kimi-k2-Thinking-int4.sh @@ -0,0 +1,189 @@ +#!/bin/bash + +# for rerun the task +pkill -9 sglang +sleep 3 +ray stop --force +pkill -9 ray +pkill -9 python +sleep 3 +pkill -9 ray +pkill -9 python + +set -ex + +# will prevent ray from buffering stdout/stderr +export PYTHONBUFFERED=16 + +NVLINK_COUNT=$(nvidia-smi | grep -o "NVLink" | wc -l) +if [ "$NVLINK_COUNT" -gt 0 ]; then + HAS_NVLINK=1 +else + HAS_NVLINK=0 +fi +echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" +source "${SCRIPT_DIR}/../../models/kimi-k2-thinking.sh" + +CKPT_ARGS=( + --hf-checkpoint /root/Kimi-K2-Thinking/ + --ref-load /root/Kimi-K2_thinking_torch_dist/ + --load /root/Kimi-K2-thinking_miles/ + --save /root/Kimi-K2-thinking_miles/ + --save-interval 20 +) + +ROLLOUT_ARGS=( + --prompt-data /root/dapo-math-17k/dapo-math-17k.jsonl + --input-key prompt + --label-key label + --apply-chat-template + --rollout-shuffle + + --rm-type math + + --num-rollout 100 + --rollout-batch-size 128 + --n-samples-per-prompt 8 + --rollout-max-response-len 16384 + --rollout-temperature 0.8 + + # --global-batch-size 256 + + --over-sampling-batch-size 256 + --dynamic-sampling-filter-path miles.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std + + --num-steps-per-rollout 4 + --balance-data +) + +EVAL_ARGS=( + --eval-interval 10 + --eval-prompt-data aime /root/aime-2024/aime-2024.jsonl + --n-samples-per-eval-prompt 16 + --eval-max-response-len 16384 + --eval-top-p 0.7 +) + +PERF_ARGS=( + --tensor-model-parallel-size 8 + --sequence-parallel + --pipeline-model-parallel-size 8 + --context-parallel-size 4 + --expert-model-parallel-size 32 + --expert-tensor-parallel-size 1 + --decoder-last-pipeline-num-layers 5 + + --recompute-granularity full + --recompute-method uniform + --recompute-num-layers 1 + + --use-dynamic-batch-size + --max-tokens-per-gpu 16384 +) + +GRPO_ARGS=( + --advantage-estimator grpo + --use-kl-loss + --kl-loss-coef 0.00 + --kl-loss-type low_var_kl + # --kl-coef 0.00 + --entropy-coef 0.00 + --eps-clip 0.2 + --eps-clip-high 0.28 + --use-tis +) + +OPTIMIZER_ARGS=( + --optimizer adam + --lr 1e-6 + + --lr-decay-style constant + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.98 + + --optimizer-cpu-offload + --overlap-cpu-optimizer-d2h-h2d + --use-precision-aware-optimizer +) + +WANDB_ARGS=( + # --use-wandb + # --wandb-project miles-dev + # --wandb-group kimi-k2-thinking-test + # --wandb-key ${WANDB_KEY} +) + +SGLANG_ARGS=( + --rollout-num-gpus-per-engine 8 + --sglang-mem-fraction-static 0.7 + + # dp attention + # --sglang-enable-dp-attention + # --sglang-dp-size 8 + # --sglang-moe-dense-tp-size 1 + # --sglang-enable-dp-lm-head + # --sglang-disable-radix-cache + + --sglang-ep-size 8 + + # enable deepep for sglang + #--sglang-enable-deepep-moe + #--sglang-deepep-mode auto + + # make every dp rank has 128 concurrency + --sglang-server-concurrency 1024 + --use-miles-router +) + + +MISC_ARGS=( + # default dropout in megatron is 0.1 + --attention-dropout 0.0 + --hidden-dropout 0.0 + # should be good for model performance + --accumulate-allreduce-grads-in-fp32 + --attention-softmax-in-fp32 + # need to comment this when using model with MLA + --attention-backend flash + + # use deepep for megatron + # --moe-enable-deepep + # --moe-token-dispatcher-type flex + --no-check-for-nan-in-loss-and-grad +) + +# Build the runtime environment JSON with proper variable substitution +RUNTIME_ENV_JSON="{ + \"env_vars\": { + \"PYTHONPATH\": \"/root/Megatron-LM/\", + \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", + \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\", + \"NCCL_TIMEOUT_MS\":\"360000000\", + \"no_proxy\": \"${no_proxy}\", + \"MASTER_ADDR\": \"${MASTER_ADDR}\", + \"OPEN_TRAINING_INT4_FAKE_QAT_FLAG\": \"1\", + \"OPEN_TRAINING_INT4_GROUP_SIZE\": \"32\" + } +}" + + +ray job submit --address="http://127.0.0.1:8265" \ + --runtime-env-json="${RUNTIME_ENV_JSON}" \ + -- python3 /personal/miles/miles/train.py \ + --actor-num-nodes 32 \ + --actor-num-gpus-per-node 8 \ + --colocate \ + --update-weight-buffer-size $(( 4 * 512 * 1024 * 1024)) \ + ${MODEL_ARGS[@]} \ + ${CKPT_ARGS[@]} \ + ${ROLLOUT_ARGS[@]} \ + ${OPTIMIZER_ARGS[@]} \ + ${GRPO_ARGS[@]} \ + ${WANDB_ARGS[@]} \ + ${PERF_ARGS[@]} \ + ${EVAL_ARGS[@]} \ + ${SGLANG_ARGS[@]} \ + ${MISC_ARGS[@]} \ No newline at end of file diff --git a/examples/low_precision/run-moonlight-16B-A3B-int4.sh b/examples/low_precision/run-moonlight-16B-A3B-int4.sh new file mode 100644 index 000000000..12ea3ee81 --- /dev/null +++ b/examples/low_precision/run-moonlight-16B-A3B-int4.sh @@ -0,0 +1,165 @@ +#!/bin/bash + +# for rerun the task +pkill -9 sglang +sleep 3 +ray stop --force +pkill -9 ray +pkill -9 python +sleep 3 +pkill -9 ray +pkill -9 python +pkill -9 redis + +set -ex + +# will prevent ray from buffering stdout/stderr +export PYTHONBUFFERED=16 + +NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l) +if [ "$NVLINK_COUNT" -gt 0 ]; then + HAS_NVLINK=1 +else + HAS_NVLINK=0 +fi +echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" +source "${SCRIPT_DIR}/../../models/moonlight.sh" + +CKPT_ARGS=( + --hf-checkpoint /root/Moonlight-16B-A3B-Instruct-INT4 + --ref-load /root/Moonlight-16B-A3B-Instruct-INT4_torch_dist + --load /root/Moonlight-16B-A3B_miles/ + --save /root/Moonlight-16B-A3B_miles/ + --save-interval 20 +) + +ROLLOUT_ARGS=( + --prompt-data /root/dapo-math-17k/dapo-math-17k.jsonl + --input-key prompt + --label-key label + --apply-chat-template + --rollout-shuffle + --rm-type math + --num-rollout 3000 + --rollout-batch-size 128 + --n-samples-per-prompt 8 + --rollout-max-response-len 4096 + --rollout-temperature 0.8 + + --over-sampling-batch-size 256 + --dynamic-sampling-filter-path miles.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std + + --num-steps-per-rollout 4 + # --global-batch-size 256 + --balance-data +) + +EVAL_ARGS=( + --eval-interval 20 + --eval-prompt-data aime /root/aime-2024/aime-2024.jsonl + --n-samples-per-eval-prompt 8 + --eval-max-response-len 4096 + --eval-top-p 0.7 +) + +PERF_ARGS=( + --tensor-model-parallel-size 2 + --sequence-parallel + --pipeline-model-parallel-size 1 + --context-parallel-size 1 + --expert-model-parallel-size 4 + --expert-tensor-parallel-size 1 + + --recompute-granularity full + --recompute-method uniform + --recompute-num-layers 1 + + # --micro-batch-size 1 + --use-dynamic-batch-size + --max-tokens-per-gpu 8192 +) + +GRPO_ARGS=( + --advantage-estimator grpo + --use-kl-loss + --kl-loss-coef 0.00 + --kl-loss-type low_var_kl + --entropy-coef 0.00 + --eps-clip 0.2 + --eps-clip-high 0.28 +) + +OPTIMIZER_ARGS=( + --optimizer adam + --lr 1e-6 + --lr-decay-style constant + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.98 + + --optimizer-cpu-offload + --overlap-cpu-optimizer-d2h-h2d + --use-precision-aware-optimizer +) + +WANDB_ARGS=( + # --use-wandb + # --wandb-project miles-dev + # --wandb-group moomlight-16B-A3B-test + # --wandb-key ${WANDB_KEY} +) + +SGLANG_ARGS=( + --rollout-num-gpus-per-engine 4 + --sglang-mem-fraction-static 0.7 + --sglang-cuda-graph-bs 1 2 4 8 $(seq 16 8 256) +) + +MISC_ARGS=( + # default dropout in megatron is 0.1 + --attention-dropout 0.0 + --hidden-dropout 0.0 + # should be good for model performance + --accumulate-allreduce-grads-in-fp32 + --attention-softmax-in-fp32 + # need to comment this when using model with MLA + # --attention-backend flash + + # use deepep for megatron + --moe-enable-deepep + --moe-token-dispatcher-type flex +) + +# launch the master node of ray in container +export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} +ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 4 --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 + +# Build the runtime environment JSON with proper variable substitution +RUNTIME_ENV_JSON="{ + \"env_vars\": { + \"PYTHONPATH\": \"/root/Megatron-LM/\", + \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", + \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\", + \"OPEN_TRAINING_INT4_FAKE_QAT_FLAG\": \"1\", + \"OPEN_TRAINING_INT4_GROUP_SIZE\": \"128\" + } +}" + +ray job submit --address="http://127.0.0.1:8265" \ + --runtime-env-json="${RUNTIME_ENV_JSON}" \ + -- python3 train.py \ + --actor-num-nodes 1 \ + --actor-num-gpus-per-node 4 \ + --colocate \ + ${MODEL_ARGS[@]} \ + ${CKPT_ARGS[@]} \ + ${ROLLOUT_ARGS[@]} \ + ${OPTIMIZER_ARGS[@]} \ + ${GRPO_ARGS[@]} \ + ${WANDB_ARGS[@]} \ + ${PERF_ARGS[@]} \ + ${EVAL_ARGS[@]} \ + ${SGLANG_ARGS[@]} \ + ${MISC_ARGS[@]} diff --git a/examples/low_precision/run-qwen3-235B-A22B-int4.sh b/examples/low_precision/run-qwen3-235B-A22B-int4.sh new file mode 100644 index 000000000..e0e3e6c6b --- /dev/null +++ b/examples/low_precision/run-qwen3-235B-A22B-int4.sh @@ -0,0 +1,171 @@ +#!/bin/bash + +# for rerun the task +pkill -9 sglang +sleep 3 +ray stop --force +pkill -9 ray +pkill -9 python +sleep 3 +pkill -9 ray +pkill -9 python + +set -ex + +# will prevent ray from buffering stdout/stderr +export PYTHONBUFFERED=16 + +NVLINK_COUNT=$(nvidia-smi | grep -o "NVLink" | wc -l) +if [ "$NVLINK_COUNT" -gt 0 ]; then + HAS_NVLINK=1 +else + HAS_NVLINK=0 +fi +echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" +source "${SCRIPT_DIR}/../../models/qwen3-235B-A22B.sh" + +CKPT_ARGS=( + --hf-checkpoint /root/Qwen3-235B-A22B-INT4/ + --ref-load /root/Qwen3-235B-A22B_torch_dist/ + --load /root/Qwen3-235B-A22B-miles/ + --save /root/Qwen3-235B-A22B-miles/ + --save-interval 20 +) + +ROLLOUT_ARGS=( + --prompt-data /root/dapo-math-17k/dapo-math-17k.jsonl + --input-key prompt + --label-key label + --apply-chat-template + --rollout-shuffle + + --rm-type deepscaler + + --num-rollout 300 + --rollout-batch-size 32 + --n-samples-per-prompt 8 + --rollout-max-response-len 8192 + --rollout-temperature 0.8 + + --global-batch-size 256 + --balance-data +) + +EVAL_ARGS=( + --eval-interval 10 + --eval-prompt-data aime /root/aime-2024/aime-2024.jsonl + --n-samples-per-eval-prompt 16 + --eval-max-response-len 16384 + --eval-top-p 0.7 +) + +PERF_ARGS=( + --tensor-model-parallel-size 4 + --sequence-parallel + --pipeline-model-parallel-size 4 + --context-parallel-size 2 + --expert-model-parallel-size 16 + --expert-tensor-parallel-size 1 + --decoder-last-pipeline-num-layers 22 + + --recompute-granularity full + --recompute-method uniform + --recompute-num-layers 1 + + # --micro-batch-size 1 + --use-dynamic-batch-size + --max-tokens-per-gpu 16384 +) + +GRPO_ARGS=( + --advantage-estimator grpo + --use-kl-loss + --kl-loss-coef 0.00 + --kl-loss-type low_var_kl + # --kl-coef 0.00 + --entropy-coef 0.00 + --eps-clip 0.2 + --eps-clip-high 0.28 + --use-tis +) + +OPTIMIZER_ARGS=( + --optimizer adam + --lr 1e-6 + --lr-decay-style constant + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.98 + + --optimizer-cpu-offload + --overlap-cpu-optimizer-d2h-h2d + --use-precision-aware-optimizer +) + +WANDB_ARGS=( + # --use-wandb + # --wandb-project miles-dev + # --wandb-group qwen3-235B-A22B-test + # --wandb-key ${WANDB_KEY} +) + + +SGLANG_ARGS=( + --rollout-num-gpus-per-engine 8 + --sglang-mem-fraction-static 0.7 + # --sglang-enable-dp-attention + # --sglang-dp-size 4 + --sglang-ep-size 8 + --sglang-cuda-graph-bs 1 2 4 8 $(seq 16 8 256) + --use-miles-router +) + + +MISC_ARGS=( + # default dropout in megatron is 0.1 + --attention-dropout 0.0 + --hidden-dropout 0.0 + # should be good for model performance + --accumulate-allreduce-grads-in-fp32 + --attention-softmax-in-fp32 + # need to comment this when using model with MLA + --attention-backend flash + --no-check-for-nan-in-loss-and-grad + + # use deepep for megatron + # --moe-enable-deepep + # --moe-token-dispatcher-type flex +) + +# Build the runtime environment JSON with proper variable substitution +RUNTIME_ENV_JSON="{ + \"env_vars\": { + \"PYTHONPATH\": \"/root/Megatron-LM/\", + \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", + \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\", + \"NCCL_TIMEOUT_MS\":\"360000000\", + \"no_proxy\": \"${no_proxy}\", + \"MASTER_ADDR\": \"${MASTER_ADDR}\", + \"OPEN_TRAINING_INT4_FAKE_QAT_FLAG\": \"1\", + \"OPEN_TRAINING_INT4_GROUP_SIZE\": \"128\" + } +}" + +ray job submit --address="http://127.0.0.1:8265" \ + --runtime-env-json="${RUNTIME_ENV_JSON}" \ + -- python3 train.py \ + --actor-num-nodes 8 \ + --actor-num-gpus-per-node 8 \ + --colocate \ + ${MODEL_ARGS[@]} \ + ${CKPT_ARGS[@]} \ + ${ROLLOUT_ARGS[@]} \ + ${OPTIMIZER_ARGS[@]} \ + ${GRPO_ARGS[@]} \ + ${WANDB_ARGS[@]} \ + ${PERF_ARGS[@]} \ + ${EVAL_ARGS[@]} \ + ${SGLANG_ARGS[@]} \ + ${MISC_ARGS[@]} \ No newline at end of file diff --git a/examples/low_precision/run-qwen3-30B-A3B-int4.sh b/examples/low_precision/run-qwen3-30B-A3B-int4.sh new file mode 100644 index 000000000..eb0c870b3 --- /dev/null +++ b/examples/low_precision/run-qwen3-30B-A3B-int4.sh @@ -0,0 +1,165 @@ +#!/bin/bash + +# for rerun the task +pkill -9 sglang +sleep 3 +ray stop --force +pkill -9 ray +pkill -9 python +sleep 3 +pkill -9 ray +pkill -9 python + +set -ex + +# will prevent ray from buffering stdout/stderrs +export PYTHONBUFFERED=16 + +NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l) +if [ "$NVLINK_COUNT" -gt 0 ]; then + HAS_NVLINK=1 +else + HAS_NVLINK=0 +fi +echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" +source "${SCRIPT_DIR}/../../models/qwen3-30B-A3B.sh" + +CKPT_ARGS=( + --hf-checkpoint /root/Qwen3-30B-A3B-INT4/ + --ref-load /root/Qwen3-30B-A3B_torch_dist/ + --load /root/Qwen3-30B-A3B_miles/ + --save /root/Qwen3-30B-A3B_miles/ + --save-interval 20 +) + +ROLLOUT_ARGS=( + --prompt-data /root/dapo-math-17k/dapo-math-17k.jsonl + --input-key prompt + --label-key label + --apply-chat-template + --rollout-shuffle + + --rm-type deepscaler + + --num-rollout 100 + --rollout-batch-size 32 + --n-samples-per-prompt 8 + --rollout-max-response-len 8192 + --rollout-temperature 0.8 + + --global-batch-size 256 + --balance-data + # --debug-rollout-only +) + +EVAL_ARGS=( + --eval-interval 10 + --eval-prompt-data /root/aime-2024/aime-2024.jsonl + --n-samples-per-eval-prompt 8 + --eval-max-response-len 16384 + --eval-top-p 0.7 +) + +PERF_ARGS=( + --tensor-model-parallel-size 4 + --sequence-parallel + --pipeline-model-parallel-size 1 + --context-parallel-size 1 + --expert-model-parallel-size 8 + --expert-tensor-parallel-size 1 + + --recompute-granularity full + --recompute-method uniform + --recompute-num-layers 1 + + # --micro-batch-size 1 + --use-dynamic-batch-size + --max-tokens-per-gpu 8192 +) + +GRPO_ARGS=( + --advantage-estimator grpo + --use-kl-loss + --kl-loss-coef 0.00 + --kl-loss-type low_var_kl + --entropy-coef 0.00 + --eps-clip 0.2 + --eps-clip-high 0.28 + --use-tis +) + +OPTIMIZER_ARGS=( + --optimizer adam + --lr 1e-6 + --lr-decay-style constant + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.98 + + --optimizer-cpu-offload + --overlap-cpu-optimizer-d2h-h2d + --use-precision-aware-optimizer +) + +WANDB_ARGS=( + # --use-wandb + # --wandb-project miles-dev + # --wandb-group qwen3-30B-A3B-test + # --wandb-key ${WANDB_KEY} +) + +SGLANG_ARGS=( + --rollout-num-gpus-per-engine 1 + --sglang-mem-fraction-static 0.7 + --sglang-cuda-graph-bs 1 2 4 8 $(seq 16 8 256) + --use-miles-router +) + +MISC_ARGS=( + # default dropout in megatron is 0.1 + --attention-dropout 0.0 + --hidden-dropout 0.0 + # should be good for model performance + --accumulate-allreduce-grads-in-fp32 + --attention-softmax-in-fp32 + # need to comment this when using model with MLA + --attention-backend flash + # use deepep for megatron + # --moe-enable-deepep + # --moe-token-dispatcher-type flex +) + +# launch the master node of ray in container +export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} +ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 + +# Build the runtime environment JSON with proper variable substitution +RUNTIME_ENV_JSON="{ + \"env_vars\": { + \"PYTHONPATH\": \"/root/Megatron-LM/\", + \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", + \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\", + \"OPEN_TRAINING_INT4_FAKE_QAT_FLAG\": \"1\", + \"OPEN_TRAINING_INT4_GROUP_SIZE\": \"128\" + } +}" + +ray job submit --address="http://127.0.0.1:8265" \ + --runtime-env-json="${RUNTIME_ENV_JSON}" \ + -- python3 train.py \ + --actor-num-nodes 1 \ + --actor-num-gpus-per-node 8 \ + --colocate \ + ${MODEL_ARGS[@]} \ + ${CKPT_ARGS[@]} \ + ${ROLLOUT_ARGS[@]} \ + ${OPTIMIZER_ARGS[@]} \ + ${GRPO_ARGS[@]} \ + ${WANDB_ARGS[@]} \ + ${PERF_ARGS[@]} \ + ${EVAL_ARGS[@]} \ + ${SGLANG_ARGS[@]} \ + ${MISC_ARGS[@]} + \ No newline at end of file diff --git a/examples/low_precision/run-qwen3-4b-fp8.sh b/examples/low_precision/run-qwen3-4b-fp8.sh index b196ba606..89b7079ad 100644 --- a/examples/low_precision/run-qwen3-4b-fp8.sh +++ b/examples/low_precision/run-qwen3-4b-fp8.sh @@ -120,14 +120,6 @@ MISC_ARGS=( --attention-backend flash ) -PRECISE_ARGS=( - --transformer-impl transformer_engine - --bf16 - --fp8-format e4m3 - --fp8-recipe blockwise - --fp8-param-gather -) - # launch the master node of ray in container export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} @@ -159,5 +151,4 @@ ray job submit --address="http://127.0.0.1:8265" \ ${PERF_ARGS[@]} \ ${EVAL_ARGS[@]} \ ${SGLANG_ARGS[@]} \ - ${MISC_ARGS[@]} \ - ${PRECISE_ARGS[@]} \ No newline at end of file + ${MISC_ARGS[@]} \ No newline at end of file diff --git a/examples/multi_agent/agent_system.py b/examples/multi_agent/agent_system.py index a0d937358..46db66143 100644 --- a/examples/multi_agent/agent_system.py +++ b/examples/multi_agent/agent_system.py @@ -20,11 +20,10 @@ async def generate_response(args, prompt, key): url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/generate" - prompt_token_ids = tokenizer.encode(prompt, add_special_tokens=False) - sample.tokens = prompt_token_ids sample.prompt = prompt - input_token_ids = prompt_token_ids - prompt_length = len(input_token_ids) + prompt_token_ids = tokenizer(sample.prompt, add_special_tokens=False)["input_ids"] + sample.tokens = prompt_token_ids + prompt_length = len(prompt_token_ids) current_sampling_params = deepcopy(sampling_params) current_sampling_params["max_new_tokens"] = min( sampling_params["max_new_tokens"], max_context_length - prompt_length @@ -33,7 +32,7 @@ async def generate_response(args, prompt, key): if current_sampling_params["max_new_tokens"] <= 0: return None - payload = {"input_ids": input_token_ids, "sampling_params": current_sampling_params, "return_logprob": True} + payload = {"input_ids": prompt_token_ids, "sampling_params": current_sampling_params, "return_logprob": True} output = await post(url, payload) @@ -150,7 +149,7 @@ async def select(self, args, problem_statement, candidate_solutions: list[str]) def extract_selected_solution_idx(self, response: str, candidate_solutions: list[str]) -> int: """Extracts the selected solution ID from the response.""" - PATTERN = re.compile("Judgment:\s*(\d+)") + PATTERN = re.compile(r"Judgment:\s*(\d+)") matched = PATTERN.findall(response) try: selected_id = int(matched[0]) - 1 diff --git a/examples/on_policy_distillation/README.md b/examples/on_policy_distillation/README.md new file mode 100644 index 000000000..a6b22c5b1 --- /dev/null +++ b/examples/on_policy_distillation/README.md @@ -0,0 +1,59 @@ +# On-Policy Distillation Example + +This example shows how to run **on-policy distillation** using miles. A small student (Qwen3-8B) is aligned to imitate a larger teacher (Qwen3-32B) by training only on the student's own rollouts and matching the teacher's token-level log-probabilities. + +In this example, the teacher model acts as a reward model (RM) by providing teacher log probabilities as the supervision signal. + +## Components + +- `on_policy_distillation.py` implements:: + - `reward_func` calls the teacher server (via `args.rm_url`) with every sample to obtain token-level logprobs. + - `post_process_rewards` trims the teacher logprobs to the generated response span and writes the tensors back to each `Sample` to compute advantages. +- `run-qwen3-8B-opd.sh` launches an SGLang teacher server, then submits a Ray job that runs `train.py`. + +## Running the example + +1. Download or prepare the required checkpoints and data. +```bash +hf download Qwen/Qwen3-32B --local-dir /root/Qwen3-32B +hf download Qwen/Qwen3-8B --local-dir /root/Qwen3-8B +hf download --repo-type dataset zhuzilin/dapo-math-17k --local-dir /root/dapo-math-17k +``` + +2. Run the hf to mcore for student model conversion: +```bash +cd /root/miles +source scripts/models/qwen3-8B.sh + +PYTHONPATH=/root/Megatron-LM python tools/convert_hf_to_torch_dist.py \ + ${MODEL_ARGS[@]} \ + --hf-checkpoint /root/Qwen3-8B \ + --save /root/Qwen3-8B_torch_dist +``` +3. run on-policy distillation: +```bash +bash examples/on_policy_distillation/run-qwen3-8B-opd.sh +``` + + +# Preliminary Results +Using Qwen3-8B-Base model sfted on part of the [OpenThoughts3-1.2M](https://huggingface.co/datasets/open-thoughts/OpenThoughts3-1.2M) dataset, we performed on-policy distillation with a Qwen3-32B teacher on the remaining data. Evaluation on Math500 shows: + +| | Pass@1 | +|-----------------------------------------------|--------| +| Qwen3-8B-Base + SFT | 76% | +| Qwen3-8B-Base + SFT + On-Policy Distillation | 94% | + + + + + +# FAQ +1. **Why are teacher logits computed via a sglang server instead of inside the training backend?** +The teacher runs on an independent SGLang server that miles treats as a reward model. Hosting it inside Megatron/FSDP would require maintaining a second, fully configured training stack for the teacher. + + +# References +1. https://thinkingmachines.ai/blog/on-policy-distillation/ +2. https://arxiv.org/abs/2306.13649 +3. https://arxiv.org/abs/2306.08543 \ No newline at end of file diff --git a/examples/on_policy_distillation/on_policy_distillation.py b/examples/on_policy_distillation/on_policy_distillation.py index 929e2d64c..94a7c29a7 100644 --- a/examples/on_policy_distillation/on_policy_distillation.py +++ b/examples/on_policy_distillation/on_policy_distillation.py @@ -6,7 +6,8 @@ async def reward_func(args, sample, **kwargs): payload = { - "text": sample.prompt + sample.response, + # "text": sample.prompt + sample.response, + "input_ids": sample.tokens, "sampling_params": { "temperature": 0, "max_new_tokens": 0, diff --git a/examples/on_policy_distillation/run-qwen3-8B-opd.sh b/examples/on_policy_distillation/run-qwen3-8B-opd.sh index c57b9eef4..f45c2634b 100644 --- a/examples/on_policy_distillation/run-qwen3-8B-opd.sh +++ b/examples/on_policy_distillation/run-qwen3-8B-opd.sh @@ -29,6 +29,7 @@ until curl -sf http://$TEACHER_IP:$TEACHER_PORT/health_generate > /dev/null; do sleep 5 done +curl http://$TEACHER_IP:$TEACHER_PORT/get_model_info echo "Teacher model server is up and running at $TEACHER_IP:$TEACHER_PORT." sleep 10 diff --git a/examples/reproducibility/README.md b/examples/reproducibility/README.md index 84fbb028a..7005f3cd9 100644 --- a/examples/reproducibility/README.md +++ b/examples/reproducibility/README.md @@ -29,8 +29,8 @@ For data and checkpoint preparation, please run: ```bash # download -huggingface-cli download --repo-type dataset zhuzilin/gsm8k --local-dir /root/gsm8k -huggingface-cli download Qwen/Qwen2.5-0.5B-Instruct --local-dir /root/Qwen2.5-0.5B-Instruct +hf download --repo-type dataset zhuzilin/gsm8k --local-dir /root/gsm8k +hf download Qwen/Qwen2.5-0.5B-Instruct --local-dir /root/Qwen2.5-0.5B-Instruct # convert ckpt cd miles/ @@ -48,4 +48,4 @@ And to run training, bash examples/reproducibility/run-qwen2.5-0.5B-gsm8k.sh ``` -For screen shots of the wandb, please refer to [pull#370](https://github.com/radixark/miles/pull/370). +For screen shots of the wandb, please refer to [pull#370](https://github.com/THUDM/slime/pull/370). diff --git a/examples/retool/README.md b/examples/retool/README.md index b4e3f71eb..bd9af717b 100644 --- a/examples/retool/README.md +++ b/examples/retool/README.md @@ -21,7 +21,7 @@ The retool example provides: 1. Setup and download datasets: ```bash cd miles -pip install -e . +pip install -e . --no-deps # For SFT part, you can use later model to RL directly and skip SFT. hf download --repo-type dataset JoeYing/ReTool-SFT --local-dir /root/JoeYing/ReTool-SFT hf download Qwen/Qwen3-4B-Instruct-2507 --local-dir /root/Qwen/Qwen3-4B-Instruct-2507 diff --git a/examples/retool/generate_with_retool.py b/examples/retool/generate_with_retool.py index 068ca07f9..f5b8ad268 100644 --- a/examples/retool/generate_with_retool.py +++ b/examples/retool/generate_with_retool.py @@ -230,9 +230,20 @@ async def generate(args, sample: Sample, sampling_params) -> Sample: tool_call_count = 0 # Track actual tool call rounds for turn in range(TOOL_CONFIGS["max_turns"]): - # Simple: just send prompt + response + # Check if total length exceeds max context length + total_length = len(prompt_tokens_ids) + len(response_token_ids) + if args.rollout_max_context_len is not None: + max_context_length = args.rollout_max_context_len + else: + max_context_length = args.context_parallel_size * args.max_tokens_per_gpu + if total_length >= max_context_length: + sample.status = Sample.Status.TRUNCATED + break + + # Use token IDs instead of text + current_token_ids = prompt_tokens_ids + response_token_ids payload = { - "text": prompt + response, + "input_ids": current_token_ids, "sampling_params": sampling_params, "return_logprob": True, # Request log probabilities for training } @@ -265,15 +276,16 @@ async def generate(args, sample: Sample, sampling_params) -> Sample: sample.status = Sample.Status.ABORTED return sample - cur_response = output["text"] - if "output_token_logprobs" in output["meta_info"]: cur_response_token_ids = [item[1] for item in output["meta_info"]["output_token_logprobs"]] + cur_response = state.tokenizer.decode(cur_response_token_ids) cur_log_probs = [item[0] for item in output["meta_info"]["output_token_logprobs"]] if sample.rollout_log_probs is None: sample.rollout_log_probs = [] sample.rollout_log_probs += cur_log_probs + else: + cur_response = output["text"] cur_response = postprocess_responses(cur_response) cur_response_token_ids = state.tokenizer(cur_response, add_special_tokens=False)["input_ids"] @@ -309,7 +321,7 @@ async def generate(args, sample: Sample, sampling_params) -> Sample: sample.rollout_log_probs ), f"Token/logp length mismatch at turn {turn}: {len(response_token_ids)} tokens vs {len(sample.rollout_log_probs)} logps" - if turn >= TOOL_CONFIGS["max_tool_calls"]: + if tool_call_count >= TOOL_CONFIGS["max_tool_calls"]: break # Set sample attributes diff --git a/examples/search-r1/README.md b/examples/search-r1/README.md index 46b1ef85c..867ca504b 100644 --- a/examples/search-r1/README.md +++ b/examples/search-r1/README.md @@ -9,7 +9,7 @@ Use the `radixark/miles:latest` image and initialize the environment required fo ```bash cd /root/ git clone https://github.com/radixark/miles.git -pip install -e . +pip install -e . --no-deps # for Search R1 pip install chardet ``` @@ -20,6 +20,8 @@ Download and prepare the training data: cd /root/ git clone https://github.com/PeterGriffinJin/Search-R1.git cd Search-R1/ +pip install -e . --no-deps +pip install tensordict # Set your working directory WORK_DIR=/root/Search-R1 @@ -45,7 +47,7 @@ Initialize the Qwen2.5-3B model: ```bash # hf checkpoint -huggingface-cli download Qwen/Qwen2.5-3B --local-dir /root/Qwen2.5-3B +hf download Qwen/Qwen2.5-3B --local-dir /root/Qwen2.5-3B # mcore checkpoint cd /root/miles diff --git a/examples/search-r1/generate_with_search.py b/examples/search-r1/generate_with_search.py index 2549e2a68..a1096b745 100644 --- a/examples/search-r1/generate_with_search.py +++ b/examples/search-r1/generate_with_search.py @@ -150,8 +150,8 @@ async def generate(args, sample: Sample, sampling_params) -> Sample: url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/generate" # Handle partial rollout samples: continue generation from existing response - prompt = sample.prompt - prompt_tokens_ids = state.tokenizer(sample.prompt, add_special_tokens=False)["input_ids"] + prompt_text = sample.prompt + prompt_tokens_ids = state.tokenizer(prompt_text, add_special_tokens=False)["input_ids"] response = "" response_token_ids = [] loss_mask = [] @@ -159,7 +159,7 @@ async def generate(args, sample: Sample, sampling_params) -> Sample: for _turn_idx in range(SEARCH_R1_CONFIGS["max_turns"]): payload = { - "text": prompt + response, + "text": prompt_text + response, "sampling_params": sampling_params, } # Add log probability collection if enabled @@ -230,6 +230,7 @@ async def generate(args, sample: Sample, sampling_params) -> Sample: sample.response_length = len(response_token_ids) sample.response = response sample.loss_mask = loss_mask + sample.prompt = prompt_text # Store log probs if enabled if SEARCH_R1_CONFIGS["return_logprob"]: diff --git a/examples/strands-agents/README.md b/examples/strands-agents/README.md deleted file mode 100644 index 3a785f31d..000000000 --- a/examples/strands-agents/README.md +++ /dev/null @@ -1,56 +0,0 @@ -# Miles x Strands-Agents - -This is a running example that connects the [Strands-Agents](https://github.com/strands-agents/sdk-python) agent scaffolding framework with Miles for RL training. - -## Install Dependencies - -1. Pull the `radixark/miles:latest` image and enter it -2. Goes to miles folder: `cd /root/miles` (Clone the repository if not already there: `cd /root && git clone https://github.com/radixark/miles.git`) -3. Install Miles: `pip install -e .` -4. Goes to the example folder: `cd /root/miles/examples/strands-agents` -5. Install other dependencies: `pip install -r requirements.txt` - -> NOTE: we use camel-ai's subprocess code interpreter for python code execution, which is NOT a good practice; it's just for convenience of this example and the dependencies for solving math problems are usually ready in `miles`'s docker - -## Prepare Model - -```bash -# hf checkpoint -huggingface-cli download Qwen/Qwen3-4B-Instruct-2507 --local-dir /root/models/Qwen/Qwen3-4B-Instruct-2507 - -# mcore checkpoint -cd /root/miles -source scripts/models/qwen3-4B.sh -PYTHONPATH=/root/Megatron-LM python tools/convert_hf_to_torch_dist.py \ - ${MODEL_ARGS[@]} \ - --hf-checkpoint /root/models/Qwen/Qwen3-4B-Instruct-2507 \ - --save /root/models/Qwen/Qwen3-4B-Instruct-2507_torch_dist -``` - -## Prepare Dataset - -Following [Retool](https://arxiv.org/abs/2504.11536), we used `dapo-math-17k` as training data: - -``` -from datasets import load_dataset -ds = load_dataset("zhuzilin/dapo-math-17k", split="train") -ds.to_json("/root/data/dapo-math-17k.jsonl", orient="records", lines=True) -``` - -and `aime-2024` as eval data: - -``` -from datasets import load_dataset -ds = load_dataset("zhuzilin/aime-2024", split="train") -ds.to_json("/root/data/aime-2024.jsonl", orient="records", lines=True) -``` - -## Run Training - -Assuming `/root/miles` is up-to-date (if this PR is not merged you may need to switch branch): - -``` -cd /root/miles -export WANDB_KEY=$your_wandb_key -bash examples/strands-agents/strands_qwen3_4b.sh -``` diff --git a/examples/strands-agents/generate_with_strands.py b/examples/strands-agents/generate_with_strands.py deleted file mode 100644 index 4f020614a..000000000 --- a/examples/strands-agents/generate_with_strands.py +++ /dev/null @@ -1,267 +0,0 @@ -import logging - -import openai -import wandb -from camel.interpreters import SubprocessInterpreter -from strands import Agent, tool -from strands.models.openai import OpenAIModel -from strands.types.exceptions import ContextWindowOverflowException, EventLoopException, MaxTokensReachedException - -from miles.rollout.rm_hub.math_dapo_utils import compute_score as math_dapo_compute_score -from miles.rollout.sglang_rollout import GenerateState -from miles.utils.types import Sample - -logging.basicConfig(level=logging.INFO) - -logger = logging.getLogger(__name__) - - -SYSTEM_PROMPT = """ -You are a helpful math-solving assistant with access to the `execute_python_code` tool. - -Guidelines: -- For any numerical or symbolic computation, always use the `execute_python_code` tool rather than performing calculations mentally. -- Break problems into clear steps, calling the Python tool whenever computation is required. -- After completing your reasoning, present the final result enclosed in \\boxed{}. -""".strip() - -MAX_NUM_MESSAGES = 16 # messages beyond this will be truncated - - -def create_strands_agent(args, sampling_params): - """Create a strands agent that connects to the SGLang rollout server""" - - # Create an OpenAI model from the SGLang server - model_params = { - "max_tokens": sampling_params["max_new_tokens"], - "temperature": sampling_params["temperature"], - "top_p": sampling_params["top_p"], - } - sglang_server_url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/v1" - logger.info( - f"[Strands Agents] Creating OpenAIModel from SGLang server at {sglang_server_url}" - f" with parameters: {model_params}" - ) - model = OpenAIModel( - client_args={ - "api_key": "EMPTY", - "base_url": sglang_server_url, - "timeout": 300.0, # needed for tool calls - }, - model_id=args.hf_checkpoint.split("/")[-1], - params=model_params, - ) - - # Define the `execute_python_code` tool using camel-ai's subprocess interpreter - @tool - def execute_python_code(code: str) -> str: - r"""Execute a given Python code snippet. - - Args: - code (str): The input Python code to the Code Execution tool call. - - Returns: - str: The text output from the Code Execution tool call. - """ - interpreter = SubprocessInterpreter( - require_confirm=False, - print_stdout=False, - print_stderr=False, - execution_timeout=60.0, - ) - result = interpreter.run(code=code, code_type="python") - logger.info( - f"[Strands Agents] executing Python code: ```python\n{code}\n``` and get execution result: ```python\n{result}\n```" - ) - return result - - # Create the strands agent - agent = Agent( - model=model, - tools=[execute_python_code], - system_prompt=SYSTEM_PROMPT, - callback_handler=None, - ) - - return agent - - -async def run_strands_agent(agent: Agent, prompt: str) -> Sample.Status: - """Run the strands agent with the given prompt and set the sample status.""" - try: - logger.info(f"[Strands Agents] running agent with prompt: {prompt}") - await agent.invoke_async(prompt=prompt) - sample_status = Sample.Status.COMPLETED - except Exception as e: - truncated_conditions = [ - isinstance(e, MaxTokensReachedException), - isinstance(e, ContextWindowOverflowException), - isinstance(e, EventLoopException) - and isinstance(e.original_exception, openai.APIError) - and "context length" in str(e.original_exception).lower(), - ] - if any(truncated_conditions): - sample_status = Sample.Status.TRUNCATED - logger.warning(f"[Strands Agents] sample is TRUNCATED due to {type(e).__name__}: {e}") - else: - sample_status = Sample.Status.ABORTED - logger.error(f"[Strands Agents] sample is ABORTED due to {type(e).__name__}: {e}") - - return sample_status - - -def get_trajectory(agent: Agent) -> list[dict]: - """Get the chat template-compatible trajectory from strands agent's messages.""" - openai_model: OpenAIModel = agent.model - trajectory = openai_model.format_request_messages(messages=agent.messages, system_prompt=agent.system_prompt) - for message in trajectory: - if "content" in message and isinstance(message["content"], list): - if len(message["content"]) > 0 and "text" in message["content"][0]: - message["content"] = message["content"][0]["text"] - else: - message["content"] = "" - return trajectory - - -async def generate(args, sample: Sample, sampling_params) -> Sample: - """Generate function using strands-agents as agent scaffolding""" - assert not args.partial_rollout, "Partial rollout is not supported for this function at the moment." - - state = GenerateState(args) - - # Create strands agent - agent = create_strands_agent(args, sampling_params) - - # Run the strands agent - prompt_text = sample.prompt if isinstance(sample.prompt, str) else sample.prompt[0]["content"] - sample.status = await run_strands_agent(agent, prompt_text) - - # Early return if sample is aborted - if sample.status == Sample.Status.ABORTED: - agent.cleanup() - return sample - - # Get the trajectory from the agent and further truncate if necessary - trajectory = get_trajectory(agent) - if len(trajectory) > MAX_NUM_MESSAGES: - logger.warning( - f"[Strands Agents] sample is TRUNCATED due to number of messages (={len(trajectory)}) exceeding limit (={MAX_NUM_MESSAGES})" - ) - # This post-processing is not optimal but just for simplicity - # We should implement a hook in strands-agents to handle this truncation - trajectory = trajectory[:MAX_NUM_MESSAGES] - sample.status = Sample.Status.TRUNCATED - - # Get the initial prompt (system + user message) - initial_prompt_messages = [msg for msg in trajectory if msg["role"] in ["system", "user"]] - assert len(initial_prompt_messages) == 2, "Initial prompt messages must be exactly 2 for single-turn conversations" - prompt_text = state.tokenizer.apply_chat_template( - initial_prompt_messages, - tokenize=False, - add_generation_prompt=True, # Add generation prompt for the assistant - ) - prompt_tokens_ids = state.tokenizer(prompt_text, add_special_tokens=False)["input_ids"] - - # Build (re-tokenize) the response incrementally - response_token_ids = [] - loss_masks = [] - response_text = "" - - # Start with the initial prompt messages for progressive chat template application - current_messages = list(initial_prompt_messages) - prev_token_count = len(prompt_tokens_ids) - - # Iterate through remaining messages (assistant and tool messages) - for message in trajectory[len(initial_prompt_messages) :]: - # Add this message to the conversation - current_messages.append(message) - - # Apply chat template and tokenize up to this point - current_text = state.tokenizer.apply_chat_template( - current_messages, tokenize=False, add_generation_prompt=False - ) - current_token_ids = state.tokenizer(current_text, add_special_tokens=False)["input_ids"] - - # Calculate how many new tokens this message added - new_token_count = len(current_token_ids) - message_token_length = new_token_count - prev_token_count - - # Extract the new tokens for this message - message_tokens = current_token_ids[prev_token_count:] - assert len(message_tokens) == message_token_length, "Message tokens length mismatch" - response_token_ids.extend(message_tokens) - - # Align message tokens with loss masks - if message["role"] == "assistant": - # We train on assistant messages - loss_masks.extend([1] * message_token_length) - else: - # We don't train on tool messages - loss_masks.extend([0] * message_token_length) - - prev_token_count = new_token_count - - # Extract the response text (everything after the initial prompt) - full_conversation_text = state.tokenizer.apply_chat_template( - trajectory, tokenize=False, add_generation_prompt=False - ) - response_text = full_conversation_text[len(prompt_text) :] - - # Set sample attributes and some debug information - sample.tokens = prompt_tokens_ids + response_token_ids - sample.response_length = len(response_token_ids) - sample.response = response_text - sample.loss_mask = loss_masks - # Store tool call count for reward calculation - sample.tool_call_count = [message["role"] == "tool" for message in trajectory].count(True) - - # Log to wandb if available - if wandb.run is not None: - wandb.log( - { - "debug/response_length": sample.response_length, - "debug/available_tools": len(agent.tool_names), - "debug/tool_calls": sample.tool_call_count, - "debug/num_messages": len(trajectory), - "debug/truncated": sample.status == Sample.Status.TRUNCATED, - } - ) - - agent.cleanup() - return sample - - -async def reward_func(args, sample, **kwargs): - """Tool call reward function using math_dapo as primary reward model""" - if not isinstance(sample, Sample): - raise TypeError("Sample must be an instance of Sample class.") - - # Extract information from sample - solution_str = sample.response - ground_truth = sample.label if sample.label is not None else "" - tool_call_count = getattr(sample, "tool_call_count", 0) - - # Accept both Answer: ... and \\boxed{...} answer - result = math_dapo_compute_score(solution_str, ground_truth, strict_box_verify=False) - result_boxed = math_dapo_compute_score(solution_str, ground_truth, strict_box_verify=True) - if result["pred"] == "[INVALID]": - result = result_boxed - - # Encourage model to call tools - if result["score"] < 0: - tool_call_reward = (tool_call_count - 2) / 2 * 0.1 - result["score"] = min(-0.6, result["score"] + tool_call_reward) - - if result["pred"] is None: - result["pred"] = "" - - logger.info( - f"[Strands Agents] sample summary: " - f"status={sample.status} | " - f"tool_call_count={sample.tool_call_count} | " - f"response_length={sample.response_length} | " - f"reward={result} | " - f"ground_truth={ground_truth}" - ) - - return result diff --git a/examples/strands_sglang/README.md b/examples/strands_sglang/README.md new file mode 100644 index 000000000..ad310238a --- /dev/null +++ b/examples/strands_sglang/README.md @@ -0,0 +1,70 @@ +# miles x Strands-SGLang + +This example connects `miles` with [`strands-sglang`](https://github.com/horizon-rl/strands-sglang) (SGLang extension for the agentic scaffolding [`strands`](https://github.com/strands-agents/sdk-python)) for agentic RL training. + +## Why `strands-sglang`? + +| Component | Agent Loop | TITO Support | +| ------------------------------------------------------------------ | ----------------------------------- | -------------------------------------- | +| [Strands-Agents](https://github.com/strands-agents/sdk-python) | โœ… Handles agent loop, custom hooks | โŒ text-based, requires retokenization | +| [SGLang](https://github.com/sgl-project/sglang) | โŒ Single generation only | โœ… Native `input_ids` in/out | +| **[strands-sglang](https://github.com/horizon-rl/strands-sglang)** | โœ… Via Strands | โœ… Via SGLang's native API | + +`strands-sglang` bridges the gap by extending `strands` with SGLang's native `/generate` endpoint: + +- Captures exact token IDs during generation (no retokenization drift) +- Automatically tracks `loss_mask` via `token_manager` +- Provides `ToolIterationLimiter` for clean trajectory truncation + +## Install Dependencies + +1. Pull the `radixark/miles:latest` image and enter it +2. Go to miles folder: `cd /root/miles` +3. Install miles: `pip install -e . --no-deps` +4. Go to the example folder: `cd /root/miles/examples/strands_sglang` +5. Install other dependencies: `pip install -r requirements.txt` + +> NOTE: `strands-sglang` is under rapid development, so we recommend using the GitHub repo version: `strands-sglang @ git+https://github.com/horizon-rl/strands-sglang.git` + +> NOTE: We use camel-ai's subprocess code interpreter for python code execution, which is NOT a good practice; it's just for convenience of this example. + +## Prepare Model + +```bash +# hf checkpoint +huggingface-cli download Qwen/Qwen3-8B --local-dir /root/models/Qwen/Qwen3-8B + +# mcore checkpoint +cd /root/miles +source scripts/models/qwen3-8B.sh +PYTHONPATH=/root/Megatron-LM python tools/convert_hf_to_torch_dist.py \ + ${MODEL_ARGS[@]} \ + --hf-checkpoint /root/models/Qwen/Qwen3-8B \ + --save /root/models/Qwen/Qwen3-8B_torch_dist +``` + +## Prepare Dataset + +Following [Retool](https://arxiv.org/abs/2504.11536), we use `dapo-math-17k` as training data: + +```python +from datasets import load_dataset +ds = load_dataset("zhuzilin/dapo-math-17k", split="train") +ds.to_json("/root/data/dapo-math-17k.jsonl", orient="records", lines=True) +``` + +and `aime-2024` as eval data: + +```python +from datasets import load_dataset +ds = load_dataset("zhuzilin/aime-2024", split="train") +ds.to_json("/root/data/aime-2024.jsonl", orient="records", lines=True) +``` + +## Run Training + +```bash +cd /root/miles +export WANDB_KEY=$your_wandb_key +bash examples/strands_sglang/strands_qwen3_8b.sh +``` diff --git a/examples/strands_sglang/generate_with_strands.py b/examples/strands_sglang/generate_with_strands.py new file mode 100644 index 000000000..a7cf91819 --- /dev/null +++ b/examples/strands_sglang/generate_with_strands.py @@ -0,0 +1,117 @@ +import logging + +from camel.interpreters import SubprocessInterpreter +from strands import Agent, tool +from strands_sglang import SGLangClient, SGLangModel +from strands_sglang.tool_limiter import ToolIterationLimiter + +from miles.rollout.rm_hub.math_dapo_utils import compute_score as math_dapo_compute_score +from miles.rollout.sglang_rollout import GenerateState +from miles.utils.types import Sample + +logger = logging.getLogger(__name__) + +SYSTEM_PROMPT = """ +You are a helpful math-solving assistant with access to the `execute_python_code` tool. + +Guidelines: +- For any numerical or symbolic computation, always use the `execute_python_code` tool rather than performing calculations mentally. +- Break problems into clear steps, calling the Python tool whenever computation is required. +- After completing your reasoning, present the final result enclosed in \\boxed{}. +""".strip() + +MAX_TOOL_ITERATIONS = 5 + +_client_cache: dict[str, SGLangClient] = {} + + +def get_client(args) -> SGLangClient: + """Get shared client for connection pooling (like MILES).""" + base_url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}" + if base_url not in _client_cache: + _client_cache[base_url] = SGLangClient.from_miles_args(args) + return _client_cache[base_url] + + +@tool +def execute_python_code(code: str) -> str: + """Execute Python code and return the output.""" + interpreter = SubprocessInterpreter( + require_confirm=False, + print_stdout=False, + print_stderr=False, + execution_timeout=60.0, + ) + result = interpreter.run(code, "python") + logger.info(f"Executing Python code: ```python\n{code}\n``` and get execution result: ```python\n{result}\n```") + return result + + +async def generate(args, sample: Sample, sampling_params) -> Sample: + """Generate with TITO: tokens captured during generation, no retokenization.""" + assert not args.partial_rollout, "Partial rollout not supported." + + state = GenerateState(args) + model = SGLangModel( + tokenizer=state.tokenizer, + client=get_client(args), + model_id=args.hf_checkpoint.split("/")[-1], + params={k: sampling_params[k] for k in ["max_new_tokens", "temperature", "top_p"]}, + ) + + limiter = ToolIterationLimiter(max_iterations=MAX_TOOL_ITERATIONS) + agent = Agent( + model=model, + tools=[execute_python_code], + hooks=[limiter], + callback_handler=None, + system_prompt=SYSTEM_PROMPT, + ) + + prompt = sample.prompt if isinstance(sample.prompt, str) else sample.prompt[0]["content"] + + try: + await agent.invoke_async(prompt) + sample.status = Sample.Status.COMPLETED + except Exception as e: + # Always use TRUNCATED instead of ABORTED because Miles doesn't properly + # handle ABORTED samples in reward processing. See: https://github.com/THUDM/slime/issues/200 + sample.status = Sample.Status.TRUNCATED + logger.warning(f"TRUNCATED: {type(e).__name__}: {e}") + + # TITO: extract trajectory from token_manager + tm = model.token_manager + prompt_len = len(tm.segments[0]) # system + user are first segment + sample.tokens = tm.token_ids + sample.loss_mask = tm.loss_mask[prompt_len:] + sample.rollout_log_probs = tm.logprobs[prompt_len:] + sample.response_length = len(sample.tokens) - prompt_len + sample.response = model.tokenizer.decode(sample.tokens[prompt_len:], skip_special_tokens=False) + # Tool iteration and tool call count are different because multiple parallel tool calls count as 1 iteration + sample.tool_iterations = limiter.iteration_count + trajectory = model.format_request_messages(agent.messages, None) + sample.tool_call_count = [message["role"] == "tool" for message in trajectory].count(True) + + model.reset() + agent.cleanup() + return sample + + +async def reward_func(args, sample: Sample, **kwargs): + """Reward function using math_dapo scoring.""" + ground_truth = sample.label or "" + tool_iterations = getattr(sample, "tool_iterations", 0) + + result = math_dapo_compute_score(sample.response, ground_truth, strict_box_verify=False) + if result["pred"] == "[INVALID]": + result = math_dapo_compute_score(sample.response, ground_truth, strict_box_verify=True) + + # Encourage tool use on failures + if result["score"] < 0: + result["score"] = min(-0.6, result["score"] + (tool_iterations - 2) / 2 * 0.1) + + result["pred"] = result["pred"] or "" + logger.info( + f"reward={result['score']:.2f} | status={sample.status.name} | tool_iters={tool_iterations} | tool_calls={getattr(sample, 'tool_call_count', 0)} | tokens={len(sample.tokens)} | resp_len={sample.response_length} | " + ) + return result["score"] diff --git a/examples/strands-agents/requirements.txt b/examples/strands_sglang/requirements.txt similarity index 75% rename from examples/strands-agents/requirements.txt rename to examples/strands_sglang/requirements.txt index 040fa471c..2c838bab3 100644 --- a/examples/strands-agents/requirements.txt +++ b/examples/strands_sglang/requirements.txt @@ -1,3 +1,4 @@ camel-ai strands-agents strands-agents-tools +strands-sglang diff --git a/examples/strands-agents/strands_qwen3_4b.sh b/examples/strands_sglang/strands_qwen3_8b.sh similarity index 73% rename from examples/strands-agents/strands_qwen3_4b.sh rename to examples/strands_sglang/strands_qwen3_8b.sh index 647c8e2f5..7d29eb279 100644 --- a/examples/strands-agents/strands_qwen3_4b.sh +++ b/examples/strands_sglang/strands_qwen3_8b.sh @@ -1,5 +1,8 @@ #!/bin/bash +# Qwen3-8B Training with Strands-SGLang +# Note: 8B model requires ~2x memory of 4B, adjusted settings accordingly + # for rerun the task pkill -9 sglang sleep 3 @@ -23,39 +26,39 @@ else fi echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" -source "/root/miles/scripts/models/qwen3-4B.sh" +source "/root/miles/scripts/models/qwen3-8B.sh" # Generate timestamp suffix for save path TIMESTAMP_SUFFIX=$(date +%Y%m%d_%H%M%S) CKPT_ARGS=( - --hf-checkpoint /root/models/Qwen/Qwen3-4B-Instruct-2507 - --ref-load /root/models/Qwen/Qwen3-4B-Instruct-2507_torch_dist - # --load Qwen3-4B-Instruct-2507_strands_dapo_1129 - --save /root/models/Qwen/Qwen3-4B-Instruct-2507_strands_dapo_${TIMESTAMP_SUFFIX} + --hf-checkpoint /root/models/Qwen/Qwen3-8B + --ref-load /root/models/Qwen/Qwen3-8B_torch_dist + # --load Qwen3-8B_strands_dapo + --save /root/models/Qwen/Qwen3-8B_strands_dapo_${TIMESTAMP_SUFFIX} --save-interval 20 - --rotary-base 5000000 + --rotary-base 1000000 ) ROLLOUT_ARGS=( - --prompt-data /root/data/dapo-math-17k/dapo-math-17k.jsonl + --prompt-data /root/data/dapo-math-17k.jsonl --input-key prompt --label-key label --rollout-shuffle - --reward-key score + # --reward-key score --num-rollout 3000 - --rollout-batch-size 32 + --rollout-batch-size 16 # Reduced from 32 for 8B model memory --n-samples-per-prompt 8 - --rollout-max-response-len 8192 + --rollout-max-response-len 16384 --rollout-temperature 1 - --global-batch-size 256 + --global-batch-size 128 # Reduced from 256 for 8B model memory --balance-data ) EVAL_ARGS=( --eval-interval 20 - --eval-prompt-data aime /root/data/aime-2024/aime-2024.jsonl + --eval-prompt-data aime /root/data/aime-2024.jsonl --n-samples-per-eval-prompt 16 --eval-max-response-len 16384 --eval-top-p 1 @@ -75,7 +78,7 @@ PERF_ARGS=( # --micro-batch-size 1 --use-dynamic-batch-size - --max-tokens-per-gpu 9216 + --max-tokens-per-gpu 18432 ) GRPO_ARGS=( @@ -100,14 +103,15 @@ OPTIMIZER_ARGS=( WANDB_ARGS=( --use-wandb --wandb-project strands-miles - --wandb-group Qwen3-4B-Instruct-2507-strands-dapo + --wandb-group Qwen3-8B-strands-dapo --wandb-key ${WANDB_KEY} ) SGLANG_ARGS=( --rollout-num-gpus-per-engine 2 - --sglang-mem-fraction-static 0.7 - --sglang-tool-call-parser qwen # Enable tool call parsing for Strands Agent + --sglang-mem-fraction-static 0.4 + # Note: strands-sglang handles tool parsing internally (HermesToolCallParser) + # No need for --sglang-tool-call-parser ) MISC_ARGS=( @@ -122,8 +126,8 @@ MISC_ARGS=( ) CUSTOM_ARGS=( - --custom-generate-function-path examples.strands-agents.generate_with_strands.generate - --custom-rm-path examples.strands-agents.generate_with_strands.reward_func + --custom-generate-function-path examples.strands_sglang.generate_with_strands.generate + --custom-rm-path examples.strands_sglang.generate_with_strands.reward_func ) # launch the master node of ray in container @@ -155,4 +159,4 @@ ray job submit --address="http://127.0.0.1:8265" \ ${EVAL_ARGS[@]} \ ${SGLANG_ARGS[@]} \ ${MISC_ARGS[@]} \ - ${CUSTOM_ARGS[@]} \ No newline at end of file + ${CUSTOM_ARGS[@]} diff --git a/examples/tau-bench/README.md b/examples/tau-bench/README.md index 524157b75..417275c1b 100644 --- a/examples/tau-bench/README.md +++ b/examples/tau-bench/README.md @@ -9,13 +9,13 @@ Use the `zhuzilin/miles:latest` image and initialize the environment required fo cd /root/ git clone https://github.com/radixark/miles.git cd miles -pip install -e . +pip install -e . --no-deps # for tau bench cd /root/ git clone https://github.com/JD-ETH/tau-bench.git cd tau-bench git checkout feature/litellm-retry -pip install -e . +pip install -e . --no-deps ``` Use the following script to generate mock data for miles training. diff --git a/examples/train_infer_mismatch_helper/mis.py b/examples/train_infer_mismatch_helper/mis.py index 19c666bf1..1e5f99275 100644 --- a/examples/train_infer_mismatch_helper/mis.py +++ b/examples/train_infer_mismatch_helper/mis.py @@ -2,7 +2,13 @@ import torch -from miles.backends.megatron_utils.cp_utils import all_gather_with_cp, slice_log_prob_with_cp +from miles.backends.training_utils.parallel import ParallelState + +# NOTE: +# - `compute_mis_weights` is a lightweight, standalone function that is useful to unit-test on CPU. +# - `compute_mis_weights_with_cp` depends on Megatron context-parallel utilities, which are heavy and may not be +# available in minimal environments. +# To keep `mis.py` importable for unit tests, we lazily import CP utilities inside `compute_mis_weights_with_cp`. def masked_sum(x: torch.Tensor, loss_mask: torch.Tensor, expand: bool = False) -> torch.Tensor: @@ -15,6 +21,26 @@ def masked_mean(x: torch.Tensor, loss_mask: torch.Tensor, expand: bool = False) return result.expand_as(x) if expand else result +def masked_min(x: torch.Tensor, loss_mask: torch.Tensor, expand: bool = False) -> torch.Tensor: + """Masked min over valid tokens (loss_mask == 1). Returns 0 when mask is empty.""" + mask = loss_mask.bool() + if mask.any(): + result = x[mask].min() + else: + result = torch.tensor(0.0, device=x.device, dtype=x.dtype) + return result.expand_as(x) if expand else result + + +def masked_max(x: torch.Tensor, loss_mask: torch.Tensor, expand: bool = False) -> torch.Tensor: + """Masked max over valid tokens (loss_mask == 1). Returns 0 when mask is empty.""" + mask = loss_mask.bool() + if mask.any(): + result = x[mask].max() + else: + result = torch.tensor(0.0, device=x.device, dtype=x.dtype) + return result.expand_as(x) if expand else result + + def metrics_append(metrics: dict[str, list[torch.Tensor]], key: str, value: torch.Tensor) -> None: """ @@ -60,6 +86,8 @@ def calculate_veto_mask( loss_mask: torch.Tensor, veto_threshold: float | None, metrics: dict[str, list[torch.Tensor]], + *, + metric_prefix: str = "", ) -> torch.Tensor: if veto_threshold is None: return torch.ones_like(log_ratio) @@ -69,16 +97,21 @@ def calculate_veto_mask( has_catastrophic = catastrophic_tokens.any() veto_mask = (~has_catastrophic).float().expand_as(log_ratio) - metrics_append(metrics, "catastrophic_token_fraction", catastrophic_tokens.int()) - metrics_append(metrics, "catastrophic_seq_fraction", has_catastrophic.int().expand_as(loss_mask)) + metrics_append(metrics, f"{metric_prefix}catastrophic_token_fraction", catastrophic_tokens.int()) + metrics_append(metrics, f"{metric_prefix}catastrophic_seq_fraction", has_catastrophic.int().expand_as(loss_mask)) return veto_mask def truncate( - weights: torch.Tensor, loss_mask: torch.Tensor, metrics: dict[str, list[torch.Tensor]], upper_bound: float + weights: torch.Tensor, + loss_mask: torch.Tensor, + metrics: dict[str, list[torch.Tensor]], + upper_bound: float, + *, + metric_prefix: str = "", ) -> torch.Tensor: assert upper_bound is not None - metrics_append(metrics, "truncate_fraction", (weights > upper_bound).int()) + metrics_append(metrics, f"{metric_prefix}truncate_fraction", (weights > upper_bound).int()) return weights.clamp(0, upper_bound) * loss_mask @@ -88,10 +121,12 @@ def clip( metrics: dict[str, list[torch.Tensor]], lower_bound: float, upper_bound: float, + *, + metric_prefix: str = "", ) -> torch.Tensor: assert lower_bound is not None and upper_bound is not None and lower_bound < upper_bound - metrics_append(metrics, "clip_fraction_low", (weights < lower_bound).int()) - metrics_append(metrics, "clip_fraction_high", (weights > upper_bound).int()) + metrics_append(metrics, f"{metric_prefix}clip_fraction_low", (weights < lower_bound).int()) + metrics_append(metrics, f"{metric_prefix}clip_fraction_high", (weights > upper_bound).int()) return weights.clamp(lower_bound, upper_bound) * loss_mask @@ -101,10 +136,12 @@ def mask( metrics: dict[str, list[torch.Tensor]], lower_bound: float, upper_bound: float, + *, + metric_prefix: str = "", ) -> tuple[torch.Tensor, torch.Tensor]: assert lower_bound is not None and upper_bound is not None and lower_bound < upper_bound - metrics_append(metrics, "mask_fraction_low", (weights < lower_bound).int()) - metrics_append(metrics, "mask_fraction_high", (weights > upper_bound).int()) + metrics_append(metrics, f"{metric_prefix}mask_fraction_low", (weights < lower_bound).int()) + metrics_append(metrics, f"{metric_prefix}mask_fraction_high", (weights > upper_bound).int()) in_range = (weights >= lower_bound) & (weights <= upper_bound) modified_mask = loss_mask * in_range.float() # Zero out padding in weights but preserve values at non-rejected positions @@ -189,11 +226,15 @@ def compute_log_ratio(raw_log_diff: torch.Tensor, mask: torch.Tensor, level: str metrics_append(metrics, "tis_weight_before_bound", weights) if args.tis_mode == "truncate": - weights = truncate(weights, loss_mask, metrics, args.tis_upper_bound) + weights = truncate(weights, loss_mask, metrics, args.tis_upper_bound, metric_prefix="tis_") elif args.tis_mode == "clip": - weights = clip(weights, loss_mask, metrics, tis_lower_bound, args.tis_upper_bound) + weights = clip( + weights, loss_mask, metrics, tis_lower_bound, args.tis_upper_bound, metric_prefix="tis_" + ) elif args.tis_mode == "mask": - weights, modified_mask = mask(weights, loss_mask, metrics, tis_lower_bound, args.tis_upper_bound) + weights, modified_mask = mask( + weights, loss_mask, metrics, tis_lower_bound, args.tis_upper_bound, metric_prefix="tis_" + ) else: raise ValueError(f"Unsupported tis_mode: {args.tis_mode}") @@ -212,14 +253,18 @@ def compute_log_ratio(raw_log_diff: torch.Tensor, mask: torch.Tensor, level: str rs_weights = torch.exp(log_ratio_safe_rs) # Apply mask-based rejection sampling - _, modified_mask = mask(rs_weights, modified_mask, metrics, rs_lower_bound, rs_upper_bound) + _, modified_mask = mask( + rs_weights, modified_mask, metrics, rs_lower_bound, rs_upper_bound, metric_prefix="rs_" + ) # Veto on raw per-token ratios (sequence-wise rejection) if args.rs_veto_threshold is not None: - veto_mask = calculate_veto_mask(raw_log_ratio_diff, loss_mask, args.rs_veto_threshold, metrics) + veto_mask = calculate_veto_mask( + raw_log_ratio_diff, loss_mask, args.rs_veto_threshold, metrics, metric_prefix="rs_" + ) modified_mask = modified_mask * veto_mask - metrics_append(metrics, "ratio_mean_after_tis", weights) + metrics_append(metrics, "is_ratio_mean_after_tis_rs", weights) weights = weights.detach() modified_mask = modified_mask.detach() @@ -253,6 +298,14 @@ def compute_log_ratio(raw_log_diff: torch.Tensor, mask: torch.Tensor, level: str for w in all_weights: metrics_append(metrics, "batch_norm_factor", torch.ones_like(w)) + # Final weight stats (after optional batch normalization). + # NOTE: These are expanded to token-shape so that the existing mean-reducer can aggregate them. + for w, m in zip(all_weights, loss_masks, strict=False): + m = m.float() + metrics_append(metrics, "is_ratio_mean_final", masked_mean(w, m, expand=True)) + metrics_append(metrics, "is_ratio_min_final", masked_min(w, m, expand=True)) + metrics_append(metrics, "is_ratio_max_final", masked_max(w, m, expand=True)) + return all_weights, all_modified_masks, metrics @@ -265,6 +318,7 @@ def compute_mis_weights_with_cp( loss_masks: list[torch.Tensor], total_lengths: list[int], response_lengths: list[int], + parallel_state: ParallelState, **kwargs: Any, ) -> tuple[torch.Tensor, list[torch.Tensor], dict[str, torch.Tensor]]: """ @@ -280,15 +334,18 @@ def compute_mis_weights_with_cp( modified_masks: List of modified response masks with rejection applied (one per sequence). is_metrics: The metrics for the importance sampling weights, a dict of flattened tensors. """ + # Lazy import to avoid importing Megatron dependencies when only `compute_mis_weights` is used. + from miles.backends.training_utils.cp_utils import all_gather_with_cp, slice_log_prob_with_cp + # Gather cp slice from other cp ranks full_rollout_log_probs = [ - all_gather_with_cp(log_prob, total_length, response_length) + all_gather_with_cp(log_prob, total_length, response_length, parallel_state) for log_prob, total_length, response_length in zip( rollout_log_probs, total_lengths, response_lengths, strict=False ) ] full_old_log_probs = [ - all_gather_with_cp(old_log_prob, total_length, response_length) + all_gather_with_cp(old_log_prob, total_length, response_length, parallel_state) for old_log_prob, total_length, response_length in zip( train_log_probs, total_lengths, response_lengths, strict=False ) @@ -308,7 +365,7 @@ def slice_cp_and_concat( ) -> torch.Tensor: values = [ # TODO: A rename of this function? - slice_log_prob_with_cp(values[i], total_lengths[i], response_lengths[i]) + slice_log_prob_with_cp(values[i], total_lengths[i], response_lengths[i], parallel_state) for i in range(len(values)) ] return torch.cat(values, dim=0) @@ -395,3 +452,45 @@ def add_ppl_metrics( rho_squared_seq = torch.exp(2.0 * log_ratio_sum_safe) # (ฮ  ฯ_t)ยฒ chi2_seq = rho_squared_seq - 1.0 metrics_append(metrics, "chi2_seq", chi2_seq) + + +def compute_mis_weights_fsdp( + args, + *, + pg_loss: torch.Tensor, + train_log_probs: list[torch.Tensor], + rollout_log_probs: list[torch.Tensor], + loss_masks: list[torch.Tensor], + **kwargs: Any, +) -> tuple[torch.Tensor, list[torch.Tensor], dict[str, torch.Tensor]]: + """Compute masked importance sampling weights for FSDP. No context parallelism. + + Args: + args: Arguments containing MIS settings (use_tis, tis_mode, etc.) + pg_loss: Policy gradient loss, flattened tensor [total_tokens] + train_log_probs: Training log probs, list of 1D tensors per sequence + rollout_log_probs: Rollout log probs, list of 1D tensors per sequence + loss_masks: Loss masks, list of 1D tensors per sequence + **kwargs: Additional arguments (cp_rank, cp_size, etc.) for compatibility + + Returns: + pg_loss: Policy gradient loss with IS weights applied + modified_masks: Modified loss masks after rejection sampling + mis_metrics: Metrics dict with flattened tensors + """ + is_weights, modified_masks, is_metrics = compute_mis_weights( + args=args, + train_log_probs=train_log_probs, + rollout_log_probs=rollout_log_probs, + loss_masks=loss_masks, + ) + + result_metrics = {} + if is_weights is not None: + is_weights_flat = torch.cat(is_weights, dim=0) + pg_loss = pg_loss * is_weights_flat + + for key, values in is_metrics.items(): + result_metrics[f"mis_{key}"] = torch.cat(values, dim=0) + + return pg_loss, modified_masks, result_metrics diff --git a/examples/train_infer_mismatch_helper/run-qwen3-4b-fsdp-mis.sh b/examples/train_infer_mismatch_helper/run-qwen3-4b-fsdp-mis.sh new file mode 100644 index 000000000..df3848038 --- /dev/null +++ b/examples/train_infer_mismatch_helper/run-qwen3-4b-fsdp-mis.sh @@ -0,0 +1,148 @@ +#!/bin/bash + +# for rerun the task +pkill -9 sglang +sleep 3 +ray stop --force +pkill -9 ray +pkill -9 python +sleep 3 +pkill -9 ray +pkill -9 python + + + + +set -ex + +# will prevent ray from buffering stdout/stderr +export PYTHONBUFFERED=16 +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +NVLINK_COUNT=$(nvidia-smi | grep -o "NVLink" | wc -l) +if [ "$NVLINK_COUNT" -gt 0 ]; then + HAS_NVLINK=1 +else + HAS_NVLINK=0 +fi +echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" + + + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" + +RUN_ID=${RUN_ID:-"run_$(date +%Y%m%d_%H%M%S)"} +LOAD_SAVE_PATH="/root/shared_data/${RUN_ID}/checkpoints" + +CKPT_ARGS=( + --hf-checkpoint /root/Qwen3-4B + --load /root/Qwen3-4B + --ref-load /root/Qwen3-4B +) + +ROLLOUT_ARGS=( + --prompt-data /root/dapo-math-17k/dapo-math-17k.jsonl + --input-key prompt + --label-key label + --apply-chat-template + --rollout-shuffle + --balance-data + --rm-type deepscaler + --num-rollout 100 + --rollout-batch-size 8 + --n-samples-per-prompt 8 + --rollout-max-response-len 4096 + --rollout-temperature 0.8 + --global-batch-size 64 +) + +GRPO_ARGS=( + --use-kl-loss + --advantage-estimator grpo + --kl-loss-coef 0.00 + --kl-loss-type low_var_kl + --kl-coef 0.00 + --entropy-coef 0.00 + --eps-clip 0.2 + --eps-clip-high 0.28 + --use-tis +) + +OPTIMIZER_ARGS=( + --optimizer adam + --lr 1e-6 + --lr-decay-style constant + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.98 +) + +WANDB_ARGS=( + --use-wandb + --wandb-project miles-dev-mcore-fsdp + --wandb-group qwen3-4B-fsdp-1130-ref + --wandb-key ${WANDB_API_KEY} +) + +SGLANG_ARGS=( + --rollout-num-gpus-per-engine 1 + --sglang-mem-fraction-static 0.75 + --sglang-decode-log-interval 1000 + --sglang-chunked-prefill-size 4096 + --sglang-attention-backend fa3 +) + +TRAIN_BACKEND_ARGS=( + --train-backend fsdp + --update-weight-buffer-size 536870912 + --gradient-checkpointing + --attn-implementation flash_attention_3 + --train-env-vars '{"PYTORCH_CUDA_ALLOC_CONF":"expandable_segments:True"}' +) + +PERF_ARGS=( + --use-dynamic-batch-size + --max-tokens-per-gpu 9216 +) + +MISC_ARGS=( + --actor-num-nodes 1 + --actor-num-gpus-per-node 8 + --colocate + --use-fault-tolerance + --dump-details /root/shared_data/qwen3-4B-fsdp-1116-noref/dump_details + # --fsdp-cpu-offload +) + +CUSTOM_ARGS=( + --custom-config-path examples/train_infer_mismatch_helper/mis.yaml + --custom-tis-function-path examples.train_infer_mismatch_helper.mis.compute_mis_weights_fsdp +) + +# launch the master node of ray in container - 8 GPUs for training +export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} +ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 --disable-usage-stats + + +RUNTIME_ENV_JSON="{ + \"env_vars\": { + \"PYTHONPATH\": \"/root/Megatron-LM/:${SCRIPT_DIR}\", + \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\" + } +}" + + +ray job submit --address="http://127.0.0.1:8265" \ + --runtime-env-json="${RUNTIME_ENV_JSON}" \ + -- python3 train.py \ + ${CKPT_ARGS[@]} \ + ${ROLLOUT_ARGS[@]} \ + ${OPTIMIZER_ARGS[@]} \ + ${GRPO_ARGS[@]} \ + ${WANDB_ARGS[@]} \ + ${SGLANG_ARGS[@]} \ + ${TRAIN_BACKEND_ARGS[@]} \ + ${PERF_ARGS[@]} \ + ${MISC_ARGS[@]} \ + ${CUSTOM_ARGS[@]} + + diff --git a/examples/true_on_policy/README.md b/examples/true_on_policy/README.md index 918f5f3a8..620564d41 100644 --- a/examples/true_on_policy/README.md +++ b/examples/true_on_policy/README.md @@ -1,6 +1,6 @@ # True On-Policy between Training and Inference -True on-policy ensures that the log probs generated by inference engine (SGLang) is strictly equal to the one generated by the training Engine. +True on-policy ensures that the log probs generated by inference engine (SGLang) is strictly equal to the one generated by the training Engine. Here's our [blog](https://lmsys.org/blog/2025-12-03-miles-fsdp/) for more details. ## Examples @@ -17,7 +17,7 @@ python examples/true_on_policy/run_simple.py This script contains more features for various use cases, and one flag is about the true on policy feature. ```bash -python scripts/run_qwen3_4b_fsdp.py --true-on-policy +python scripts/run_qwen3_4b.py --train-backend fsdp --true-on-policy ``` In order to quickly see the curve, you may use `--mode debug_minimal`, which will skip evaluation and run generation with a very short output sequence length. Since true on policy is unrelated to OSL or answer correctness, this can be used for quick experiments. @@ -45,7 +45,7 @@ Detailed reproduction refers to [this](https://gist.github.com/fzyzcjy/46f9fc096 ## How it is Implemented -The core idea is to make each and every operation in training and inference be bitwise equal. The main code is implemented in [#566](https://github.com/radixark/miles/pull/566) and [SGLang#12058](https://github.com/sgl-project/sglang/pull/12058). +The core idea is to make each and every operation in training and inference be bitwise equal. The main code is implemented in [#566](https://github.com/THUDM/slime/pull/566) and [SGLang#12058](https://github.com/sgl-project/sglang/pull/12058). Briefly speaking, we handled the following components to make them aligned: @@ -53,7 +53,7 @@ Briefly speaking, we handled the following components to make them aligned: * GEMM: We use [DeepGEMM](https://github.com/deepseek-ai/DeepGEMM) for fast matrix multiplication while preserving true-on-policy, thanks to its algorithm to pick things like tensor core instructions ([SGLang#12142](https://github.com/sgl-project/sglang/pull/12142)). * Batch invariant kernels: This is a prerequisite for true on-policy, and we use [the ones](https://github.com/thinking-machines-lab/batch_invariant_ops) from the Thinking Machines Lab. * Torch compile: We also utilize [`torch.compile`](https://docs.pytorch.org/docs/stable/generated/torch.compile.html) to speed up by avoiding many tiny kernels. -* We align numeric operation details between the two systems for simplicity, such as op dtype, detailed kernels, etc. Some operations can also be compiled to speedup ([#603](https://github.com/radixark/miles/pull/603), [SGLang#12161](https://github.com/sgl-project/sglang/pull/12161)). +* We align numeric operation details between the two systems for simplicity, such as op dtype, detailed kernels, etc. Some operations can also be compiled to speedup ([#603](https://github.com/THUDM/slime/pull/603), [SGLang#12161](https://github.com/sgl-project/sglang/pull/12161)). In order to more easily align the two parts, we use SGLang's [dumper](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/debug_utils/dumper.py) tool for quick comparisons. (Need [#12622](https://github.com/sgl-project/sglang/pull/12622) and [#12623](https://github.com/sgl-project/sglang/pull/12623) for most convenience.) diff --git a/examples/true_on_policy/run_simple.py b/examples/true_on_policy/run_simple.py index 1b472b806..7e317195d 100644 --- a/examples/true_on_policy/run_simple.py +++ b/examples/true_on_policy/run_simple.py @@ -14,7 +14,7 @@ def prepare(): U.exec_command("mkdir -p /root/models /root/datasets") - U.exec_command(f"huggingface-cli download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") + U.exec_command(f"hf download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") U.hf_download_dataset("zhuzilin/gsm8k") diff --git a/miles/backends/fsdp_utils/actor.py b/miles/backends/fsdp_utils/actor.py index 1e3e5b3ae..536fdf970 100644 --- a/miles/backends/fsdp_utils/actor.py +++ b/miles/backends/fsdp_utils/actor.py @@ -1,34 +1,38 @@ import logging import os +import random from argparse import Namespace -from itertools import accumulate import ray import torch import torch.distributed as dist -import torch.nn.functional as F -from ring_flash_attn import substitute_hf_flash_attn, update_ring_flash_attn_params +from ring_flash_attn import update_ring_flash_attn_params from tqdm import tqdm from transformers import AutoConfig from miles.ray.train_actor import TrainRayActor from miles.utils import train_dump_utils, train_metric_utils from miles.utils.context_utils import with_defer -from miles.utils.data import get_minimum_num_micro_batch_size, process_rollout_data from miles.utils.distributed_utils import get_gloo_group from miles.utils.memory_utils import clear_memory, print_memory -from miles.utils.metric_utils import compute_rollout_step -from miles.utils.ppo_utils import compute_approx_kl, compute_gspo_kl, compute_opsm_mask, compute_policy_loss from miles.utils.processing_utils import load_processor, load_tokenizer from miles.utils.ray_utils import Box from miles.utils.timer import Timer, inverse_timer, timer from miles.utils.tracking_utils import init_tracking -from ...utils import tracking_utils from ...utils.profile_utils import TrainProfiler +from ..training_utils.ci_utils import check_grad_norm +from ..training_utils.data import DataIterator, get_batch, get_data_iterator, get_rollout_data +from ..training_utils.log_utils import ( + aggregate_forward_results, + aggregate_train_losses, + log_rollout_data, + log_train_step, +) +from ..training_utils.loss import compute_advantages_and_returns, get_log_probs_and_entropy, loss_function from . import checkpoint -from .data_packing import pack_sequences, pad_packed_sequence_with_cp, unpack_sequences from .lr_scheduler import get_lr_scheduler +from .parallel import create_fsdp_parallel_state from .update_weight_utils import UpdateWeightFromDistributed, UpdateWeightFromTensor logger = logging.getLogger(__name__) @@ -51,12 +55,13 @@ class FSDPTrainRayActor(TrainRayActor): def init(self, args: Namespace, role: str, with_ref: bool = False) -> int: # type: ignore[override] super().init(args, role, with_ref) - # Setup device mesh for parallelism (handles both CP and non-CP cases) - self._setup_device_mesh() + # Setup ParallelState for both CP and non-CP cases + self.parallel_state = create_fsdp_parallel_state(args) + torch.manual_seed(args.seed) self.train_parallel_config = { - "dp_size": self.dp_size, + "dp_size": self.parallel_state.dp_size, } if self.args.debug_rollout_only: @@ -98,10 +103,10 @@ def init(self, args: Namespace, role: str, with_ref: bool = False) -> int: # ty full_state = model.state_dict() - model = apply_fsdp2(model, mesh=self.dp_mesh, cpu_offload=self.fsdp_cpu_offload, args=self.args) + model = apply_fsdp2(model, mesh=self.parallel_state.dp_mesh, cpu_offload=self.fsdp_cpu_offload, args=self.args) model = self._fsdp2_load_full_state_dict( - model, full_state, self.dp_mesh, cpu_offload=True if self.fsdp_cpu_offload else None + model, full_state, self.parallel_state.dp_mesh, cpu_offload=True if self.fsdp_cpu_offload else None ) self.model = model @@ -181,53 +186,6 @@ def _enable_true_on_policy_optimizations(self, args): apply_fsdp_moe_patch() - def _setup_device_mesh(self) -> None: - """Setup device mesh for parallelism (always called, handles both CP and non-CP cases). - - Creates 2D mesh (dp_size, cp_size) for all cases: - - When context_parallel_size > 1: hybrid CP + DP - - When context_parallel_size = 1: pure DP (equivalent to 1D mesh) - - This ensures consistent group management across all parallelism modes. - """ - from torch.distributed.device_mesh import init_device_mesh - - world_size = dist.get_world_size() - rank = dist.get_rank() - - # Use context_parallel_size directly (defaults to 1 for pure DP) - self.cp_size = self.args.context_parallel_size - self.dp_size = world_size // self.cp_size - - # Create 2D device mesh: (dp_size, cp_size) - # Ranks laid out in row-major: mesh[dp_idx, cp_idx] = dp_idx * cp_size + cp_idx - # - CP groups: consecutive ranks along dim 1, e.g., [0,1], [2,3], [4,5], [6,7] - # - DP groups: striped ranks along dim 0, e.g., [0,2,4,6], [1,3,5,7] - # When cp_size=1, this degenerates to pure DP - self.mesh = init_device_mesh("cuda", mesh_shape=(self.dp_size, self.cp_size), mesh_dim_names=("dp", "cp")) - - # Extract process groups from mesh - self.dp_group = self.mesh.get_group("dp") # For FSDP gradient sync, metric reduction - self.cp_group = self.mesh.get_group("cp") # For Ring Flash Attention, logit gathering - self.dp_mesh = self.mesh["dp"] # For FSDP - - # Compute local ranks within each dimension - self.dp_rank = rank // self.cp_size - self.cp_rank = rank % self.cp_size - - logger.info( - f"[Rank {rank}] Device mesh (2D): world_size={world_size}, " - f"cp_size={self.cp_size}, dp_size={self.dp_size}" - ) - logger.info(f"[Rank {rank}] Mesh shape: {self.mesh.shape}, " f"dp_rank={self.dp_rank}, cp_rank={self.cp_rank}") - - # Setup Ring Flash Attention with CP group from mesh (only when cp_size > 1) - if self.cp_size > 1: - substitute_hf_flash_attn(self.cp_group, heads_k_stride=1) - logger.info(f"[Rank {rank}] CP initialized via device mesh") - else: - logger.info(f"[Rank {rank}] Pure DP mode (cp_size=1)") - def _get_init_weight_context_manager(self): """Get context manager for model initialization. @@ -317,32 +275,31 @@ def wake_up(self) -> None: dist.barrier(group=get_gloo_group()) print_memory("after wake_up model") - def save_model(self, iteration: int) -> None: + def save_model(self, rollout_id: int, force_sync: bool = False) -> None: """Delegate checkpoint saving to the shared checkpoint utilities.""" if self.args.debug_rollout_only or self.args.save is None: return - checkpoint.save(self, iteration) + assert not self.args.async_save, "FSDPTrainRayActor does not support async_save yet." + checkpoint.save(self, rollout_id) def _compute_log_prob( self, model_tag: str, - packed_batches: list[dict[str, torch.Tensor]], + data_iterator: DataIterator, + num_microbatches: list[int], store_prefix: str = "", ) -> dict[str, list[torch.Tensor]]: - """Compute token log-probabilities for a list of packed batches. + """Compute token log-probabilities using data iterator. Parameters: model_tag: Which parameters to use, e.g. "actor" or "ref". - packed_batches: A list of packed batch dictionaries produced by - `pack_sequences`, each containing at least `tokens` and - `position_ids`; may also include multimodal keys like `pixel_values`. + data_iterator: DataIterator providing micro-batches. + num_microbatches: List of number of microbatches per step. store_prefix: Prefix to use for keys in outputs (e.g., "ref_"). Returns: - A lightweight dictionary keyed by f"{store_prefix}log_probs". The - actual per-sequence results are written in-place into each element of - `packed_batches` under the same key and can be read back by callers. + A lightweight dictionary keyed by f"{store_prefix}log_probs". Note: Uses separate ref model when model_tag == "ref". The ref model is @@ -361,26 +318,59 @@ def _compute_log_prob( active_model = self.model try: - rollout_data = {f"{store_prefix}log_probs": []} + forward_data_store = [] + data_iterator.reset() + with timer(f"{store_prefix}log_probs"), torch.no_grad(): - for batch in self.prof.iterate_train_log_probs( - tqdm(packed_batches, desc=f"{store_prefix}log_probs", disable=dist.get_rank() != 0) - ): - model_args = self._get_model_inputs_args(batch) - logits = active_model(**model_args).logits.squeeze(0).float() - log_probs_result, entropy_result = get_logprob_and_entropy_with_cp( - logits=logits, - target_tokens=batch["tokens"], - cp_rank=self.cp_rank, - cp_size=self.cp_size, - cp_group=self.cp_group, - model_input_ids=model_args["input_ids"], - allow_compile=not self.args.true_on_policy_mode, - temperature=self.args.rollout_temperature, - ) - batch[f"{store_prefix}log_probs"] = log_probs_result - if store_prefix == "": - batch["entropy"] = entropy_result + num_steps_per_rollout = len(num_microbatches) + for step_id in range(num_steps_per_rollout): + for _ in self.prof.iterate_train_log_probs( + tqdm( + range(num_microbatches[step_id]), + desc=f"{store_prefix}log_probs", + disable=dist.get_rank() != 0, + ) + ): + forward_only_keys = [ + "tokens", + "loss_masks", + "multimodal_train_inputs", + "total_lengths", + "response_lengths", + "max_seq_lens", + ] + batch = get_batch( + data_iterator, + forward_only_keys, + self.parallel_state, + self.args.data_pad_size_multiplier, + self.args.qkv_format, + get_position_ids=True, + ) + + model_args = self._get_model_inputs_args(batch) + logits = active_model(**model_args).logits.float() + + result = get_log_probs_and_entropy( + logits=logits, + args=self.args, + parallel_state=self.parallel_state, + unconcat_tokens=batch["unconcat_tokens"], + total_lengths=batch["total_lengths"], + response_lengths=batch["response_lengths"], + with_entropy=(store_prefix == ""), + max_seq_lens=batch.get("max_seq_lens", None), + ) + + batch_result = { + f"{store_prefix}log_probs": result["log_probs"], + } + if store_prefix == "" and "entropy" in result: + batch_result["entropy"] = result["entropy"] + forward_data_store.append(batch_result) + + rollout_data = aggregate_forward_results(forward_data_store, data_iterator, self.args, store_prefix) + return rollout_data finally: @@ -393,78 +383,6 @@ def _compute_log_prob( self.model.cuda() dist.barrier(group=get_gloo_group()) - def _packed_data( - self, rollout_data: dict[str, list[torch.Tensor]] - ) -> tuple[list[dict[str, torch.Tensor]], list[int]]: - """Pack variable-length sequences for efficient processing. - - Parameters: - rollout_data: Dictionary of lists containing sequence-level tensors - such as `tokens`, `loss_masks`, `rewards`, `response_lengths`, - `advantages`, `returns`, and optional `rollout_log_probs`. - - Returns: - A pair `(packed_batches, grad_accum)` where `packed_batches` is a list - of packed batch dictionaries and `grad_accum` lists the micro-batch - indices at which to perform optimizer steps. - """ - # Pack sequences efficiently - tokens = rollout_data["tokens"] - - packed_batches = [] - mbs_size_list = [] - local_batch_size = self.args.global_batch_size // self.dp_size - assert ( - self.args.global_batch_size % self.dp_size == 0 - ), f"global_batch_size {self.args.global_batch_size} is not divisible by dp_world_size {self.dp_size}" - # Use global_batch_size for splitting when max_tokens_per_gpu is enabled - if self.args.use_dynamic_batch_size: - # In CP mode, CP group shares sequences, so total capacity is max_tokens_per_gpu * cp_size - max_tokens = self.args.max_tokens_per_gpu - if self.cp_size > 1: - max_tokens = max_tokens * self.cp_size - - for i in range(0, len(tokens), local_batch_size): - mbs_size_list.append( - get_minimum_num_micro_batch_size( - [len(t) for t in rollout_data["tokens"][i : i + local_batch_size]], - max_tokens, - ) - ) - num_microbatches = torch.tensor(mbs_size_list, dtype=torch.int, device=torch.cuda.current_device()) - dist.all_reduce(num_microbatches, op=dist.ReduceOp.MAX, group=self.dp_group) - num_microbatches = num_microbatches.tolist() - else: - num_microbatches = [self.args.global_batch_size // (self.args.micro_batch_size * self.dp_size)] * ( - len(tokens) // local_batch_size - ) - - start = 0 - for mbs_size in num_microbatches: - end = start + local_batch_size - packed_batches.extend( - pack_sequences( - rollout_data["tokens"][start:end], - rollout_data["loss_masks"][start:end], - rollout_data["rewards"][start:end], - rollout_data["raw_reward"][start:end], - rollout_data["response_lengths"][start:end], - rollout_data["advantages"][start:end], - rollout_data["returns"][start:end], - rollout_log_probs=( - rollout_data["rollout_log_probs"][start:end] if "rollout_log_probs" in rollout_data else None - ), - multimodal_inputs=( - rollout_data["multimodal_inputs"][start:end] if "multimodal_inputs" in rollout_data else None - ), - num_packs=mbs_size, - ) - ) - start = end - grad_accum = list(accumulate(num_microbatches)) - - return packed_batches, grad_accum - def train(self, rollout_id: int, rollout_data_ref: Box) -> None: """Run one training update over a rollout batch. @@ -480,7 +398,7 @@ def train(self, rollout_id: int, rollout_data_ref: Box) -> None: self.wake_up() with inverse_timer("train_wait"), timer("train"): - rollout_data = process_rollout_data(self.args, rollout_data_ref, self.dp_rank, self.dp_size) + rollout_data = get_rollout_data(self.args, rollout_data_ref, self.parallel_state) if self.args.debug_rollout_only: return self._train_core(rollout_id=rollout_id, rollout_data=rollout_data) @@ -492,79 +410,101 @@ def train(self, rollout_id: int, rollout_data_ref: Box) -> None: compute_total_fwd_flops=None, ) - def _log_rollout_data(self, rollout_id: int, rollout_data, packed_batches): - log_dict = {} - if "raw_reward" in rollout_data and dist.get_rank() == 0: - raw_reward_list = rollout_data["raw_reward"] - if raw_reward_list: - log_dict["rollout/raw_reward"] = sum(raw_reward_list) / len(raw_reward_list) - - for metric_key in ["log_probs", "rollout_log_probs", "ref_log_probs", "advantages", "returns"]: - if metric_key not in packed_batches[0]: - continue - val = torch.tensor([0.0], device=torch.cuda.current_device()) - for _mbs_id, batches in enumerate(packed_batches): - unpacked_batches = unpack_sequences(batches) - for unpacked_batch in unpacked_batches: - if isinstance(unpacked_batch[metric_key], torch.Tensor): - loss_masks_tensor = unpacked_batch["loss_masks"].to(device=torch.cuda.current_device()) - metric_tensor = unpacked_batch[metric_key].to(device=torch.cuda.current_device()) - val += (metric_tensor * loss_masks_tensor).sum() / loss_masks_tensor.sum().clamp_min(1) - else: - val += unpacked_batch[metric_key] - dist.all_reduce(val, op=dist.ReduceOp.SUM, group=self.dp_group) - log_dict[f"rollout/{metric_key}"] = ( - val / (self.args.n_samples_per_prompt * self.args.rollout_batch_size) - ).item() - if dist.get_rank() == 0: - logger.info(f"rollout {rollout_id}: {log_dict}") - log_dict["rollout/step"] = compute_rollout_step(self.args, rollout_id) - tracking_utils.log(self.args, log_dict, step_key="rollout/step") - - if self.args.ci_test and self.args.true_on_policy_mode: - assert log_dict["rollout/log_probs"] == log_dict["rollout/rollout_log_probs"], ( - f"CI check failed: true_on_policy_mode is enabled, but log_probs " - f"({log_dict['rollout/log_probs']}) != rollout_log_probs " - f"({log_dict['rollout/rollout_log_probs']})" - ) - def _train_core(self, rollout_id: int, rollout_data) -> None: - if self.args.advantage_estimator in ["grpo", "gspo"]: - rollout_data["advantages"] = rollout_data["returns"] = [ - torch.tensor([rollout_data["rewards"][i]] * rollout_data["response_lengths"][i]) - for i in range(len(rollout_data["rewards"])) - ] - else: - raise NotImplementedError(f"Unsupported advantage_estimator {self.args.advantage_estimator}") - - packed_batches, grad_accum = self._packed_data(rollout_data) + data_iterator, num_microbatches = get_data_iterator(self.args, self.model, self.parallel_state, rollout_data) + data_iterator = data_iterator[0] assert ( - len(grad_accum) > 0 - ), f"Invalid grad_accum {grad_accum} for micro_batch_size {self.args.micro_batch_size} and global_batch_size {self.args.global_batch_size}" + len(num_microbatches) > 0 + ), f"Invalid num_microbatches {num_microbatches} for micro_batch_size {self.args.micro_batch_size} and global_batch_size {self.args.global_batch_size}" if self.ref_model is not None: - self._compute_log_prob("ref", packed_batches, store_prefix="ref_") + ref_results = self._compute_log_prob("ref", data_iterator, num_microbatches, store_prefix="ref_") + rollout_data.update(ref_results) - self._compute_log_prob("actor", packed_batches) - self._log_rollout_data(rollout_id, rollout_data, packed_batches) + actor_results = self._compute_log_prob("actor", data_iterator, num_microbatches) + rollout_data.update(actor_results) + + compute_advantages_and_returns(self.args, self.parallel_state, rollout_data) + + log_rollout_data(rollout_id, self.args, rollout_data, self.parallel_state) with timer("actor_train"): - reported_accum: dict[str, list[torch.Tensor]] = {} - self.optimizer.zero_grad(set_to_none=True) - for mbs_id, packed_batch in self.prof.iterate_train_actor( - enumerate(tqdm(packed_batches, desc="actor_train", disable=dist.get_rank() != 0)) - ): - self._train_step( - packed_batch=packed_batch, - reported_accum=reported_accum, - mbs_id=mbs_id, - grad_accum=grad_accum, + data_iterator.reset() + num_steps_per_rollout = len(num_microbatches) + + for step_id in range(num_steps_per_rollout): + self.optimizer.zero_grad(set_to_none=True) + + losses_reduced = [] + for _ in self.prof.iterate_train_actor( + tqdm(range(num_microbatches[step_id]), desc="actor_train", disable=dist.get_rank() != 0) + ): + batch = get_batch( + data_iterator, + [ + "tokens", + "loss_masks", + "multimodal_train_inputs", + "total_lengths", + "response_lengths", + "max_seq_lens", + "log_probs", + "advantages", + "returns", + "ref_log_probs", + "rollout_log_probs", + ], + self.parallel_state, + self.args.data_pad_size_multiplier, + self.args.qkv_format, + get_position_ids=True, + ) + + log_dict = self._train_step( + batch=batch, + step_id=step_id, + num_microbatches=num_microbatches[step_id], + ) + losses_reduced.append(log_dict) + + grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip_grad) + grad_norm = grad_norm.full_tensor().item() + + self.optimizer.step() + self.lr_scheduler.step() + + if self.args.ci_test: + check_grad_norm( + args=self.args, + grad_norm=grad_norm, + rollout_id=rollout_id, + step_id=step_id, + role="actor", + rank=self.parallel_state.dp_cp_rank, + ) + + loss_dict = aggregate_train_losses(losses_reduced, self.parallel_state) + + extra_metrics = {} + for param_group_id, param_group in enumerate(self.optimizer.param_groups): + extra_metrics[f"lr-pg_{param_group_id}"] = param_group["lr"] + + log_train_step( + args=self.args, + loss_dict=loss_dict, + grad_norm=grad_norm, + rollout_id=rollout_id, + step_id=step_id, + num_steps_per_rollout=num_steps_per_rollout, + role="actor", + extra_metrics=extra_metrics, ) self.prof.step(rollout_id=rollout_id) - train_dump_utils.save_debug_train_data(self.args, rollout_id=rollout_id, rollout_data=rollout_data) + if self.args.save_debug_train_data is not None: + train_dump_utils.save_debug_train_data(self.args, rollout_id=rollout_id, rollout_data=rollout_data) # Update ref model if needed (copy actor weights to ref) if ( @@ -579,199 +519,23 @@ def _train_core(self, rollout_id: int, rollout_data) -> None: self.ref_model.load_state_dict(actor_state) self.ref_model.cpu() - def _train_step(self, packed_batch, reported_accum, mbs_id, grad_accum): + def _train_step(self, batch, step_id, num_microbatches): # Prepare model inputs - model_args = self._get_model_inputs_args(packed_batch) - logits = self.model(**model_args).logits.squeeze(0).float() + model_args = self._get_model_inputs_args(batch) + logits = self.model(**model_args).logits.float() - # Compute log probs and entropy (unified for both CP and non-CP modes) - log_probs, entropy_result = get_logprob_and_entropy_with_cp( + loss, normalizer, log_dict = loss_function( + args=self.args, + parallel_state=self.parallel_state, + batch=batch, + num_microbatches=num_microbatches, logits=logits, - target_tokens=packed_batch["tokens"], - cp_rank=self.cp_rank, - cp_size=self.cp_size, - cp_group=self.cp_group, - model_input_ids=model_args["input_ids"], - allow_compile=not self.args.true_on_policy_mode, - temperature=self.args.rollout_temperature, + apply_megatron_loss_scaling=False, ) - packed_batch["cur_log_probs"] = log_probs - packed_batch["entropy"] = entropy_result - - unpacked_batches = unpack_sequences(packed_batch) - - old_log_prob_key = "rollout_log_probs" if self.args.use_rollout_logprobs else "log_probs" - missing_old_log_probs = [ - idx - for idx, batch in enumerate(unpacked_batches) - if old_log_prob_key not in batch or not isinstance(batch[old_log_prob_key], torch.Tensor) - ] - if missing_old_log_probs: - raise KeyError( - f"{old_log_prob_key} must be provided as torch.Tensor for all microbatches when " - f"use_rollout_logprobs is set to {self.args.use_rollout_logprobs}. Missing in batches: {missing_old_log_probs}" - ) - old_log_probs = torch.cat([batch[old_log_prob_key] for batch in unpacked_batches], dim=0) - log_probs = torch.cat([batch["cur_log_probs"] for batch in unpacked_batches], dim=0) - advantages = torch.cat([batch["advantages"] for batch in unpacked_batches], dim=0) - loss_masks = [batch["loss_masks"].to(device=log_probs.device) for batch in unpacked_batches] - response_lengths = [batch["response_lengths"] for batch in unpacked_batches] - - advantages = advantages.to(device=log_probs.device) - old_log_probs = old_log_probs.to(device=log_probs.device) - ppo_kl = old_log_probs - log_probs - - if self.args.use_opsm: - opsm_mask, opsm_clipfrac = compute_opsm_mask( - args=self.args, - full_log_probs=[batch["cur_log_probs"] for batch in unpacked_batches], - full_old_log_probs=[batch[old_log_prob_key] for batch in unpacked_batches], - advantages=[batch["advantages"] for batch in unpacked_batches], - loss_masks=loss_masks, - ) - - if self.args.advantage_estimator == "gspo": - ppo_kl = compute_gspo_kl( - full_log_probs=[batch["cur_log_probs"] for batch in unpacked_batches], - full_old_log_probs=[batch[old_log_prob_key] for batch in unpacked_batches], - local_log_probs=[batch["cur_log_probs"] for batch in unpacked_batches], - loss_masks=loss_masks, - ) - - pg_loss, pg_clipfrac = compute_policy_loss(ppo_kl, advantages, self.args.eps_clip, self.args.eps_clip_high) - - if self.args.use_opsm: - pg_loss = pg_loss * opsm_mask - - def _has_rollout_log_probs(batch) -> bool: - rollout_tensor = batch.get("rollout_log_probs") - return isinstance(rollout_tensor, torch.Tensor) and rollout_tensor.numel() > 0 - - has_rollout_log_probs = all(_has_rollout_log_probs(batch) for batch in unpacked_batches) - rollout_log_probs = ( - torch.cat([batch["rollout_log_probs"] for batch in unpacked_batches], dim=0) - if has_rollout_log_probs - else None - ) - - # Apply TIS before sample mean calculation - if self.args.use_tis: - # Apply TIS off-policy correction using importance sampling - assert ( - has_rollout_log_probs and rollout_log_probs is not None - ), "rollout_log_probs must be provided as non-empty torch.Tensor for TIS" - - tis = torch.exp(old_log_probs - rollout_log_probs) - ois = (-ppo_kl).exp() - tis_clip = torch.clamp( - tis, min=getattr(self.args, "tis_clip_low", 0.1), max=getattr(self.args, "tis_clip", 2.0) - ) - tis_clipfrac = tis_clip != tis - - pg_loss = pg_loss * tis_clip - - assert not self.args.calculate_per_token_loss, "calculate_per_token_loss not yet implemented" - pg_loss = sum_of_sample_mean(pg_loss, response_lengths, loss_masks) - pg_clipfrac = sum_of_sample_mean(pg_clipfrac, response_lengths, loss_masks) - ppo_kl = sum_of_sample_mean(ppo_kl.abs(), response_lengths, loss_masks) - - # Only compare rollout vs. train log probs when they originate from different stages. - train_rollout_logprob_abs_diff = None - if not self.args.use_rollout_logprobs and rollout_log_probs is not None: - train_rollout_logprob_abs_diff = (old_log_probs - rollout_log_probs).abs() - train_rollout_logprob_abs_diff = sum_of_sample_mean( - train_rollout_logprob_abs_diff, response_lengths, loss_masks - ).detach() - - entropy = torch.cat([batch["entropy"] for batch in unpacked_batches], dim=0) - entropy_loss = sum_of_sample_mean(entropy, response_lengths, loss_masks) - - loss = pg_loss - self.args.entropy_coef * entropy_loss - - if self.args.use_kl_loss: - ref_log_probs = torch.cat([batch["ref_log_probs"] for batch in unpacked_batches], dim=0) - importance_ratio = None - if self.args.use_unbiased_kl: - importance_ratio = torch.exp(log_probs - old_log_probs) - kl = compute_approx_kl( - log_probs, - ref_log_probs, - kl_loss_type=self.args.kl_loss_type, - importance_ratio=importance_ratio, - ) - kl_loss = sum_of_sample_mean(kl, response_lengths, loss_masks) - - loss = loss + self.args.kl_loss_coef * kl_loss - - reported = { - "loss": loss.detach(), - "pg_loss": pg_loss.detach(), - "pg_clipfrac": pg_clipfrac.detach(), - "ppo_kl": ppo_kl.detach(), - "entropy_loss": entropy_loss.detach(), - } - - if train_rollout_logprob_abs_diff is not None: - reported["train_rollout_logprob_abs_diff"] = train_rollout_logprob_abs_diff - if self.args.use_kl_loss: - reported["kl_loss"] = kl_loss.detach() - - if self.args.use_opsm: - reported["opsm_clipfrac"] = opsm_clipfrac - - if self.args.use_tis and tis is not None: - reported["tis"] = sum_of_sample_mean(tis, response_lengths, loss_masks).detach() - reported["ois"] = sum_of_sample_mean(ois, response_lengths, loss_masks).detach() - reported["tis_clipfrac"] = sum_of_sample_mean(tis_clipfrac.float(), response_lengths, loss_masks).detach() - - # Scale loss for gradient accumulation - loss = loss * self.dp_size / self.args.global_batch_size loss.backward() - # Accumulate reported metrics (store tensors for later mean) - for k, v in reported.items(): - reported_accum.setdefault(k, []).append(v) - - if (mbs_id + 1) in grad_accum: - # TODO: check if the grad norm is global grad norm. - grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip_grad) - # the grad norm used to be of DTensor - grad_norm = float(grad_norm) - - self.optimizer.step() - # Update learning rate - self.lr_scheduler.step() - self.optimizer.zero_grad(set_to_none=True) - # Aggregate logs - aggregated = {k: torch.stack(v).sum().item() for k, v in reported_accum.items()} - # TODO: change this, this is slow. - reduced_aggregated = [None] * self.dp_size - dist.all_gather_object(reduced_aggregated, aggregated, group=self.dp_group) - aggregated = {} - for k in reported_accum.keys(): - aggregated[k] = sum([r[k] for r in reduced_aggregated]) / (self.args.global_batch_size) - reported_accum.clear() - if dist.get_rank() == 0: - log_dict = { - f"train/{k}": (val.item() if torch.is_tensor(val) else val) for k, val in aggregated.items() - } - log_dict["train/grad_norm"] = grad_norm - - # Log learning rate per parameter group; use scheduler's last computed LRs - lr_values = self.lr_scheduler.get_last_lr() - for gid, _group in enumerate(self.optimizer.param_groups): - log_dict[f"train/lr-pg_{gid}"] = lr_values[gid] - - kl_info = "" - if self.args.use_kl_loss and "kl_loss" in aggregated: - kl_info = f", kl_loss: {aggregated['kl_loss']:.4f}, kl_penalty: {aggregated['kl_loss'] * self.args.kl_loss_coef:.4f}" - logger.info(kl_info) - logger.info(f"step {self.global_step}: {log_dict}") - - log_dict["train/step"] = self.global_step - tracking_utils.log(self.args, log_dict, step_key="train/step") - self.global_step += 1 + return log_dict @timer def update_weights(self) -> None: # type: ignore[override] @@ -789,8 +553,19 @@ def update_weights(self) -> None: # type: ignore[override] if num_new_engines > 0: self.weight_updater.connect_rollout_engines(rollout_engines, rollout_engine_lock) dist.barrier(group=get_gloo_group()) + if dist.get_rank() == 0: + ray.get(self.rollout_manager.clear_num_new_engines.remote()) self.weight_updater.update_weights() + + if self.args.ci_test and len(rollout_engines) > 0: + engine = random.choice(rollout_engines) + engine_version = ray.get(engine.get_weight_version.remote()) + if str(engine_version) != str(self.weight_updater.weight_version): + raise RuntimeError( + f"Weight version mismatch! Engine: {engine_version}, Updater: {self.weight_updater.weight_version}" + ) + clear_memory() def _create_ref_model(self, ref_load_path: str | None): @@ -826,203 +601,40 @@ def _create_ref_model(self, ref_load_path: str | None): full_state = ref_model.state_dict() # Always use CPUOffloadPolicy for reference, let FSDP2 handle the offload. It is faster than model.cpu(). - ref_model = apply_fsdp2(ref_model, mesh=self.dp_mesh, cpu_offload=True, args=self.args) - ref_model = self._fsdp2_load_full_state_dict(ref_model, full_state, self.dp_mesh, cpu_offload=True) + ref_model = apply_fsdp2(ref_model, mesh=self.parallel_state.dp_mesh, cpu_offload=True, args=self.args) + ref_model = self._fsdp2_load_full_state_dict( + ref_model, full_state, self.parallel_state.dp_mesh, cpu_offload=True + ) logger.info(f"[Rank {dist.get_rank()}] Reference model created with FSDP2 CPUOffloadPolicy") return ref_model else: raise NotImplementedError(f"Loading from checkpoint file {ref_load_path} not yet implemented") - def _get_model_inputs_args(self, packed_sequence: dict) -> dict: - input_ids = packed_sequence["tokens"].unsqueeze(0) - position_ids = packed_sequence["position_ids"].unsqueeze(0) - if self.cp_size > 1: - - packed_sequence = pad_packed_sequence_with_cp(packed_sequence, self.cp_size) + def _get_model_inputs_args(self, batch: dict) -> dict: + input_ids = batch["tokens"] + position_ids = batch["position_ids"] - if not packed_sequence["cu_seqlens"].is_cuda: - packed_sequence["cu_seqlens"] = packed_sequence["cu_seqlens"].cuda() - cu_seqlens = packed_sequence["cu_seqlens"] - update_ring_flash_attn_params(cu_seqlens, self.cp_group) + if self.parallel_state.cp_size > 1: + if "cu_seqlens" in batch: + cu_seqlens = batch["cu_seqlens"] + if not cu_seqlens.is_cuda: + cu_seqlens = cu_seqlens.cuda() + update_ring_flash_attn_params(cu_seqlens, self.cp_group) - input_ids = torch.chunk(packed_sequence["tokens"].unsqueeze(0), self.cp_size, dim=1)[self.cp_rank] - position_ids = torch.chunk(packed_sequence["position_ids"].unsqueeze(0), self.cp_size, dim=1)[self.cp_rank] + input_ids = torch.chunk(input_ids, self.parallel_state.cp_size, dim=1)[self.parallel_state.cp_rank] + position_ids = torch.chunk(position_ids, self.parallel_state.cp_size, dim=1)[self.parallel_state.cp_rank] model_args = { "input_ids": input_ids, "position_ids": position_ids, "attention_mask": None, } - if packed_sequence.get("multimodal_inputs"): - model_args.update(packed_sequence["multimodal_inputs"]) - return model_args + if batch.get("multimodal_train_inputs"): + model_args.update(batch["multimodal_train_inputs"]) -def selective_log_softmax_raw(logits: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor: - """Fused version of the common `log_softmax -> gather` operation. - - The fused version of this operation avoids the (potentially large) memory overhead - of allocating a new tensor to store the full logprobs. - - Parameters: - logits: Tensor of shape [..., V] containing model logits. - input_ids: Tensor of shape [...] of token indices whose log-probabilities are gathered. - - Returns: - Tensor of shape [...] containing the log-probabilities corresponding to `input_ids`. - """ - logprobs = logits.log_softmax(dim=-1) - return torch.gather(logprobs, dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1) - - -selective_log_softmax_compiled = torch.compile(dynamic=True)(selective_log_softmax_raw) - - -def gather_log_probs_packed( - shifted_logits: torch.Tensor, - input_ids: torch.Tensor, - allow_compile: bool, - cu_seqlens: torch.Tensor | float | None = None, - temperature: torch.Tensor | None = None, -) -> torch.Tensor: - """Gather next-token log probabilities for packed sequences. - - Parameters: - logits: Model logits of shape [B, T, V] or [T, V]. - input_ids: Token ids of shape [B, T] or [T]. - cu_seqlens: Optional cumulative sequence lengths (unused here). Present - for API compatibility with callers. - - Returns: - A tensor of shape [T-1] (or [B, T-1]) with log-probabilities of targets. - """ - # Handle batch dimension - logits should be [batch_size, seq_len, vocab_size] - if shifted_logits.dim() == 3: - # Remove batch dimension for packed sequences - shifted_logits = shifted_logits.squeeze(0) - input_ids = input_ids.squeeze(0) - - if temperature is not None: - shifted_logits = shifted_logits.div(temperature) - - targets = input_ids[1:].to(device=shifted_logits.device) - - # Gather log probs for targets - selective_log_softmax = selective_log_softmax_compiled if allow_compile else selective_log_softmax_raw - return selective_log_softmax(shifted_logits, targets) - - -def get_logprob_and_entropy_with_cp( - logits: torch.Tensor, - target_tokens: torch.Tensor, - cp_rank: int, - cp_size: int, - cp_group, - model_input_ids: torch.Tensor, - allow_compile: bool, - temperature: float | None = None, -) -> tuple[torch.Tensor, torch.Tensor]: - """Compute log probabilities and entropy in Context Parallel mode. - - Parameters: - logits: Model output logits with shape [chunk_size, vocab_size] - target_tokens: Target tokens with shape [total_seq_len] - cp_rank: Current CP rank - cp_size: CP world size - cp_group: CP communication group - model_input_ids: Model input_ids (used for the last rank) - allow_compile: Whether to allow compilation - temperature: Temperature parameter (optional) - - Returns: - log_probs: Aggregated log probabilities with shape [total_seq_len - 1] - entropy: Aggregated entropy with shape [total_seq_len - 1] - """ - # Fast path for non-CP mode (cp_size=1): avoid unnecessary communication - if cp_size == 1: - shifted_logits = logits[:-1, :] - local_log_probs = gather_log_probs_packed( - shifted_logits, target_tokens, allow_compile=allow_compile, temperature=temperature - ) - log_probs_full = torch.log_softmax(shifted_logits, dim=-1) - probs = torch.softmax(shifted_logits, dim=-1) - entropy = -(probs * log_probs_full).sum(dim=-1) - return local_log_probs, entropy - - chunk_size = logits.shape[0] - tokens_start_index = chunk_size * cp_rank - tokens_end_index = ( - tokens_start_index + chunk_size + 1 if cp_rank < cp_size - 1 else tokens_start_index + chunk_size - ) - - # For the last rank, remove the last logit - logits = logits if cp_rank < cp_size - 1 else logits[:-1, :] - - # Get local tokens for current rank - local_tokens = ( - target_tokens[tokens_start_index:tokens_end_index] if cp_rank < cp_size - 1 else model_input_ids.squeeze(0) - ) - - # Compute local log probs - local_log_probs = gather_log_probs_packed( - logits, local_tokens, allow_compile=allow_compile, temperature=temperature - ) - - # Pad for the last rank - if cp_rank == cp_size - 1: - local_log_probs = F.pad(local_log_probs, (0, chunk_size - local_log_probs.shape[0]), value=0) - - # Compute entropy - shifted_logits = logits[:-1, :] if cp_rank == cp_size - 1 else logits - log_probs_full = torch.log_softmax(shifted_logits, dim=-1) - probs = torch.softmax(shifted_logits, dim=-1) - entropy = -(probs * log_probs_full).sum(dim=-1) - - # Pad entropy for the last rank - if cp_rank == cp_size - 1: - entropy = F.pad(entropy, (0, chunk_size - entropy.shape[0]), value=0) - - # Merge with a single all_gather: stack as [2, chunk_size] - stacked_local = torch.stack([local_log_probs, entropy], dim=0) - gathered_stacked = torch.distributed.nn.functional.all_gather(stacked_local, group=cp_group) - - # Concatenate by effective length (non-last rank=chunk_size, last rank=chunk_size-1) - lp_parts, ent_parts = [], [] - for r in range(cp_size): - eff_len = chunk_size if r < cp_size - 1 else max(0, chunk_size - 1) - if eff_len > 0: - lp_parts.append(gathered_stacked[r][0][:eff_len]) - ent_parts.append(gathered_stacked[r][1][:eff_len]) - - log_probs = torch.cat(lp_parts, dim=0) if lp_parts else local_log_probs.new_zeros((0,)) - entropy_result = torch.cat(ent_parts, dim=0) if ent_parts else entropy.new_zeros((0,)) - - # Truncate to global effective length T-1 (packed tokens length is T) - log_probs = log_probs[: len(target_tokens) - 1] - entropy_result = entropy_result[: len(target_tokens) - 1] - - return log_probs, entropy_result - - -def sum_of_sample_mean(x: torch.Tensor, response_lengths: list[int], loss_masks: list[torch.Tensor]) -> torch.Tensor: - """Compute sum of per-sample means across variable-length responses. - - Parameters: - x: Flat tensor containing concatenated per-token values across samples. - response_lengths: Lengths of each sample's response segment in `x`. - loss_masks: Per-sample masks aligned with `response_lengths`. - - Returns: - A scalar tensor equal to the sum over samples of the mean value within - each sample's response segment. - """ - return sum( - [ - (x_i * loss_mask_i).sum() / torch.clamp_min(loss_mask_i.sum(), 1) - for x_i, loss_mask_i in zip(x.split(response_lengths, dim=0), loss_masks, strict=False) - ] - ) + return model_args @torch.no_grad() diff --git a/miles/backends/fsdp_utils/checkpoint.py b/miles/backends/fsdp_utils/checkpoint.py index 3c49a10f8..6daf7f982 100644 --- a/miles/backends/fsdp_utils/checkpoint.py +++ b/miles/backends/fsdp_utils/checkpoint.py @@ -214,14 +214,15 @@ def save(actor: Any, iteration: int) -> None: state_dict = {"model_state": model_state} dcp.save(state_dict, checkpoint_id=str(model_dir)) - # Save optimizer state - if hasattr(actor, "optimizer") and actor.optimizer is not None: + # Save optimizer state (skip if --no-save-optim is set) + save_optimizer_state = not getattr(actor.args, "no_save_optim", False) + if save_optimizer_state and hasattr(actor, "optimizer") and actor.optimizer is not None: optimizer_state = OptimizerState(actor.model, actor.optimizer) optim_state_dict = {"optim_state": optimizer_state} dcp.save(optim_state_dict, checkpoint_id=str(optimizer_dir)) - # Save LR scheduler state - if hasattr(actor, "lr_scheduler") and actor.lr_scheduler is not None: + # Save LR scheduler state (skip if --no-save-optim is set) + if save_optimizer_state and hasattr(actor, "lr_scheduler") and actor.lr_scheduler is not None: lr_scheduler_state = LRSchedulerState(actor.lr_scheduler) lr_scheduler_state_dict = {"lr_scheduler_state": lr_scheduler_state} dcp.save(lr_scheduler_state_dict, checkpoint_id=str(lr_scheduler_dir)) diff --git a/miles/backends/fsdp_utils/data_packing.py b/miles/backends/fsdp_utils/data_packing.py deleted file mode 100644 index 8318f0c53..000000000 --- a/miles/backends/fsdp_utils/data_packing.py +++ /dev/null @@ -1,218 +0,0 @@ -"""Data packing utilities for FSDP backend to reduce padding overhead.""" - -import math - -import torch -import torch.nn.functional as F - -from miles.utils.seqlen_balancing import get_seqlen_balanced_partitions - - -def pack_sequences( - tokens: list[list[int]], - loss_masks: list[list[int]], - rewards: list[float], - raw_rewards: list, - response_lengths: list[int], - advantages: list[float], - returns: list[float], - rollout_log_probs: list[list[float]] | None = None, - multimodal_inputs: list[dict] | None = None, - max_tokens_per_gpu: int | None = None, - num_packs: int | None = None, -) -> list[dict]: - """ - Pack sequences into dense batches with cumulative sequence lengths. - - Args: - tokens: List of token sequences - loss_masks: List of loss masks - rewards: List of rewards per sequence - raw_rewards: List of raw rewards per sequence - response_lengths: List of response lengths per sequence - advantages: List of advantages per sequence - returns: List of returns per sequence - rollout_log_probs: List of rollout log probabilities per sequence - multimodal_inputs: List of dict of multimodal tokens per sequence - max_tokens_per_gpu: Maximum tokens per GPU pack - num_packs: Explicit number of packs to create - - Returns: - List of packed batches with tokens, masks, cu_seqlens, rewards, raw_rewards, response_lengths, advantages, returns - """ - if not tokens: - return [] - - seq_lengths = [len(t) for t in tokens] - - # Determine number of packs and use balanced partitioning - if num_packs: - k_partitions = num_packs - elif max_tokens_per_gpu: - total_tokens = sum(seq_lengths) - k_partitions = max(1, math.ceil(total_tokens / max_tokens_per_gpu)) - else: - k_partitions = 1 - - # Use balanced partitioning for optimal load distribution - partitions = get_seqlen_balanced_partitions( - seq_lengths, k_partitions=k_partitions, equal_size=False # Allow variable sizes for better balance - ) - - # Pack each partition - result = [] - for indices in partitions: - # Build cumulative sequence lengths - cu_seqlens = [0] - flat_tokens = [] - flat_masks = [] - flat_positionids = [] - flat_advantages = [] - flat_returns = [] - flat_rollout_log_probs = [] - - for i in indices: - seq_tokens = tokens[i] - seq_mask = loss_masks[i] - seq_positionids = list(range(len(seq_tokens))) - - flat_tokens.extend(seq_tokens) - flat_positionids.extend(seq_positionids) - flat_masks.extend(seq_mask) - flat_advantages.extend(advantages[i]) - flat_returns.extend(returns[i]) - if rollout_log_probs: - flat_rollout_log_probs.extend(rollout_log_probs[i]) - cu_seqlens.append(cu_seqlens[-1] + len(seq_tokens)) - - packed_batch = { - "tokens": torch.tensor(flat_tokens, dtype=torch.long), - "loss_masks": torch.tensor(flat_masks, dtype=torch.int), - "position_ids": torch.tensor(flat_positionids, dtype=torch.int), - "cu_seqlens": torch.tensor(cu_seqlens, dtype=torch.int32), - "rewards": torch.tensor([rewards[i] for i in indices], dtype=torch.float32), - "raw_reward": [raw_rewards[i] for i in indices], - "response_lengths": [response_lengths[i] for i in indices], - "advantages": torch.tensor(flat_advantages, dtype=torch.float32), - "returns": torch.tensor(flat_returns, dtype=torch.float32), - "rollout_log_probs": torch.tensor( - flat_rollout_log_probs, dtype=torch.float32, device=torch.cuda.current_device() - ), - } - - # Collect and add multimodal inputs for this partition - if multimodal_inputs: - multimodal_data = {} # key -> concatenated tensor - multimodal_num_items = {} # key -> list of item counts per sequence - for i in indices: - for key, mm_tensor in multimodal_inputs[i].items(): - if key not in multimodal_data: - multimodal_data[key] = mm_tensor - multimodal_num_items[key] = [mm_tensor.size(0)] - else: - multimodal_data[key] = torch.cat([multimodal_data[key], mm_tensor], dim=0) - multimodal_num_items[key].append(mm_tensor.size(0)) - packed_batch["multimodal_inputs"] = multimodal_data - packed_batch["multimodal_num_items"] = multimodal_num_items - - result.append(packed_batch) - - return result - - -def unpack_sequences(packed_batch: dict) -> list[dict]: - """ - Unpack sequences from a packed batch. - - Args: - packed_batch: Packed batch - - Returns: - List of unpacked batches - """ - - cu_seqlens = packed_batch["cu_seqlens"] - num_sequences = len(cu_seqlens) - 1 - response_lengths = packed_batch["response_lengths"] - multimodal_num_items = packed_batch.get("multimodal_num_items", {}) - - instances = [] - - # Calculate pad_length by counting trailing zeros - tokens = packed_batch["tokens"] - nonzero_indices = (tokens != 0).nonzero(as_tuple=True)[0] - if len(nonzero_indices) > 0: - # Last non-zero index, pad_length is everything after it - pad_length = len(tokens) - nonzero_indices[-1].item() - 1 - else: - pad_length = 0 # No padding if no non-zero tokens (or all zeros) - for i in range(num_sequences): - start_idx = cu_seqlens[i].item() - end_idx = cu_seqlens[i + 1].item() - instance = {} - - # Copy any additional attributes that might exist in the packed batch - for key, value in packed_batch.items(): - if key not in instance: - # Skip multimodal_num_items - it's metadata - if key == "multimodal_num_items": - continue - # Handle multimodal_inputs dict: split each tensor using multimodal_num_items - elif key == "multimodal_inputs" and isinstance(value, dict): - instance[key] = {} - for mm_key, mm_tensor in value.items(): - if mm_key in multimodal_num_items: - num_items_list = multimodal_num_items[mm_key] - start_mm_idx = sum(num_items_list[:i]) - end_mm_idx = start_mm_idx + num_items_list[i] - if num_items_list[i] > 0: - instance[key][mm_key] = mm_tensor[start_mm_idx:end_mm_idx] - # For tensor attributes, we need to slice them appropriately - elif isinstance(value, torch.Tensor): - if key in ["log_probs", "ref_log_probs", "cur_log_probs", "entropy"]: - # These are computed from logits[:-1] so they have length seq_len-1 - instance[key] = value[ - end_idx - 1 - response_lengths[i] - pad_length : end_idx - 1 - pad_length - ] - elif key == "rollout_log_probs": - # rollout_log_probs is packed based on response_lengths, so slice differently - instance[key] = value[sum(response_lengths[:i]) : sum(response_lengths[: i + 1])] - elif key in ["tokens", "position_ids"]: - # For other tensor attributes, try to slice them - if len(value) > start_idx: - instance[key] = value[start_idx:end_idx] - else: - raise ValueError(f"Attribute {key} is not found in the packed batch") - elif key in ["loss_masks", "advantages", "returns"]: - instance[key] = value[sum(response_lengths[:i]) : sum(response_lengths[: i + 1])] - elif isinstance(value, list): - instance[key] = value[i] - else: - raise ValueError(f"Attribute {key} is not found in the packed batch") - - instances.append(instance) - - return instances - - -def pad_packed_sequence_with_cp(packed_sequence: dict, cp_size: int) -> dict: - """Pad packed sequence to make total length divisible by cp_size. - - Args: - packed_sequence: Packed sequence dict containing tokens, position_ids, cu_seqlens, etc. - cp_size: Context parallelism world size - - Returns: - Padded packed sequence - """ - seq_length = len(packed_sequence["tokens"]) - # Calculate padding needed: (cp_size - seq_length % cp_size) % cp_size - remainder = seq_length % cp_size - pad_length = (cp_size - remainder) % cp_size - - if pad_length > 0: - packed_sequence["tokens"] = F.pad(packed_sequence["tokens"], (0, pad_length), value=0) - packed_sequence["position_ids"] = F.pad(packed_sequence["position_ids"], (0, pad_length), value=0) - packed_sequence["loss_masks"] = F.pad(packed_sequence["loss_masks"], (0, pad_length), value=0) - packed_sequence["cu_seqlens"][-1] += pad_length - return packed_sequence diff --git a/miles/backends/fsdp_utils/kernels/fused_experts.py b/miles/backends/fsdp_utils/kernels/fused_experts.py index a7970994b..d1c02aae8 100644 --- a/miles/backends/fsdp_utils/kernels/fused_experts.py +++ b/miles/backends/fsdp_utils/kernels/fused_experts.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import torch import triton.language as tl from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( @@ -7,6 +9,8 @@ silu_and_mul, ) +from .fused_moe_triton_backward_kernels import invoke_fused_moe_backward_kernel + class GateUpProjFunction(torch.autograd.Function): @staticmethod @@ -81,14 +85,89 @@ def forward( filter_expert=True, ) - ctx.save_for_backward(hidden_states, w1, topk_weights) + ctx.save_for_backward(hidden_states, w1, topk_weights, topk_ids) + ctx.config = config + ctx.num_tokens = num_tokens + ctx.topk = topk return intermediate_cache1 @staticmethod def backward(ctx, grad_output): - hidden_states, w1, topk_weights = ctx.saved_tensors - return torch.zeros_like(hidden_states), torch.zeros_like(w1), torch.zeros_like(topk_weights), None + """ + Backward pass for GateUpProjFunction using Triton kernels. + + Args: + grad_output: shape (num_tokens * topk, N) + + Returns: + (grad_hidden_states, grad_w1, grad_topk_weights, None) + """ + + hidden_states, w1, topk_weights, topk_ids = ctx.saved_tensors + config = ctx.config + num_tokens = ctx.num_tokens + topk = ctx.topk + + E, N, D_in = w1.shape + CHUNK_SIZE = 64 * 1024 + + # Initialize gradient tensors + grad_hidden_states = torch.zeros_like(hidden_states) + grad_w1 = torch.zeros_like(w1) + # GateUpProj stage doesn't need topk_weights gradient + grad_topk_weights = torch.zeros_like(topk_weights) + + # Process in chunks to match forward pass + for chunk in range((num_tokens // CHUNK_SIZE) + 1): + begin_chunk_idx, end_chunk_idx = ( + chunk * CHUNK_SIZE, + min((chunk + 1) * CHUNK_SIZE, num_tokens), + ) + + curr_num_tokens = end_chunk_idx - begin_chunk_idx + if curr_num_tokens == 0: + continue + + curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx] + curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] + curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] + curr_grad_output = grad_output[begin_chunk_idx * topk : end_chunk_idx * topk] + + # Get aligned metadata + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( + curr_topk_ids, config["BLOCK_SIZE_M"], E + ) + + # Prepare gradient buffer for this chunk + curr_grad_hidden_states = torch.zeros_like(curr_hidden_states) + curr_grad_w1 = torch.zeros_like(w1) + + # Call Triton backward kernel with MUL_ROUTED_WEIGHT=False + # Use chunk of hidden_states to match sorted_token_ids indices + invoke_fused_moe_backward_kernel( + grad_output=curr_grad_output, + input=curr_hidden_states, # Use chunk of hidden_states to match sorted_token_ids + weight=w1, + grad_input=curr_grad_hidden_states, + grad_weight=curr_grad_w1, + grad_topk_weights=None, # Not needed for GateUpProj + topk_weights=curr_topk_weights, + topk_ids=curr_topk_ids, + sorted_token_ids=sorted_token_ids, + expert_ids=expert_ids, + num_tokens_post_padded=num_tokens_post_padded, + mul_routed_weight=False, + top_k=topk, + config=config, + compute_type=tl.bfloat16, + ) + + # Accumulate gradients + grad_hidden_states[begin_chunk_idx:end_chunk_idx] += curr_grad_hidden_states + grad_w1 += curr_grad_w1 + + return grad_hidden_states, grad_w1, grad_topk_weights, None class SiluAndMulFunction(torch.autograd.Function): @@ -193,15 +272,89 @@ def forward( b_use_tma=False, ) - ctx.save_for_backward(intermediate_cache2, w2, topk_weights) + ctx.save_for_backward(intermediate_cache2, w2, topk_weights, topk_ids) + ctx.config = config + ctx.num_tokens = num_tokens + ctx.topk = topk return intermediate_cache3 @staticmethod def backward(ctx, grad_output): - intermediate_cache2, w2, topk_weights = ctx.saved_tensors + """ + Backward pass for DownProjFunction using Triton kernels. + + Args: + grad_output: shape (num_tokens, topk, hidden_size) + + Returns: + (grad_intermediate_cache2, grad_w2, grad_topk_weights, None) + """ + intermediate_cache2, w2, topk_weights, topk_ids = ctx.saved_tensors + config = ctx.config + num_tokens = ctx.num_tokens + topk = ctx.topk + + E, hidden_size, intermediate_size = w2.shape + CHUNK_SIZE = 64 * 1024 + + # Initialize gradient tensors + grad_intermediate_cache2 = torch.zeros_like(intermediate_cache2) + grad_w2 = torch.zeros_like(w2) + grad_topk_weights = torch.zeros_like(topk_weights) + + # Process in chunks to match forward pass + for chunk in range((num_tokens // CHUNK_SIZE) + 1): + begin_chunk_idx, end_chunk_idx = ( + chunk * CHUNK_SIZE, + min((chunk + 1) * CHUNK_SIZE, num_tokens), + ) + + curr_num_tokens = end_chunk_idx - begin_chunk_idx + if curr_num_tokens == 0: + continue + + curr_intermediate_cache2 = intermediate_cache2[begin_chunk_idx * topk : end_chunk_idx * topk] + curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] + curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] + curr_grad_output = grad_output[begin_chunk_idx:end_chunk_idx] + + # Get aligned metadata + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( + curr_topk_ids, config["BLOCK_SIZE_M"], E + ) + + # Prepare gradient buffers for this chunk + curr_grad_intermediate_cache2 = torch.zeros_like(curr_intermediate_cache2) + curr_grad_w2 = torch.zeros_like(w2) + curr_grad_topk_weights = torch.zeros_like(curr_topk_weights) + + # Call Triton backward kernel with MUL_ROUTED_WEIGHT=True + # Note: Use top_k=1 to match forward pass indexing + invoke_fused_moe_backward_kernel( + grad_output=curr_grad_output, + input=curr_intermediate_cache2, + weight=w2, + grad_input=curr_grad_intermediate_cache2, + grad_weight=curr_grad_w2, + grad_topk_weights=curr_grad_topk_weights, + topk_weights=curr_topk_weights, + topk_ids=curr_topk_ids, + sorted_token_ids=sorted_token_ids, + expert_ids=expert_ids, + num_tokens_post_padded=num_tokens_post_padded, + mul_routed_weight=True, + top_k=1, + config=config, + compute_type=tl.bfloat16, + ) + + # Accumulate gradients + grad_intermediate_cache2[begin_chunk_idx * topk : end_chunk_idx * topk] = curr_grad_intermediate_cache2 + grad_w2 += curr_grad_w2 + grad_topk_weights[begin_chunk_idx:end_chunk_idx] = curr_grad_topk_weights - return torch.zeros_like(intermediate_cache2), torch.zeros_like(w2), torch.zeros_like(topk_weights), None + return grad_intermediate_cache2, grad_w2, grad_topk_weights, None class MoeSumReduceFunction(torch.autograd.Function): diff --git a/miles/backends/fsdp_utils/kernels/fused_moe_triton_backward_kernels.py b/miles/backends/fsdp_utils/kernels/fused_moe_triton_backward_kernels.py new file mode 100644 index 000000000..333347849 --- /dev/null +++ b/miles/backends/fsdp_utils/kernels/fused_moe_triton_backward_kernels.py @@ -0,0 +1,540 @@ +from __future__ import annotations + +from typing import Any + +import torch +import triton +import triton.language as tl + + +@triton.jit +def fused_moe_backward_input_kernel( + # Pointers to matrices + grad_output_ptr, + weight_ptr, + grad_input_ptr, + grad_topk_weights_ptr, + topk_weights_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + num_tokens_post_padded_ptr, + # Matrix dimensions + N, + K, + EM, + num_valid_tokens, + # Strides + stride_gom, + stride_gon, + stride_we, + stride_wn, + stride_wk, + stride_gim, + stride_gik, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + MUL_ROUTED_WEIGHT: tl.constexpr, + top_k: tl.constexpr, + compute_type: tl.constexpr, +): + """ + Backward kernel for computing grad_input. + + Forward: output = input @ weight.T (optionally multiplied by topk_weights) + Backward: grad_input = grad_output @ weight (optionally multiplied by topk_weights) + + This kernel computes: grad_input[token] = sum_over_N(grad_output[token, n] * weight[expert, n, :]) + If MUL_ROUTED_WEIGHT: grad_input[token] *= topk_weights[token] + + Parallelization: Similar to forward, parallel over M and N dimensions, loop over K. + """ + # Map program ids to blocks (parallel over M and N, similar to forward) + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # Check bounds + num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) + + # Only process if this block is valid + if pid_m * BLOCK_SIZE_M < num_tokens_post_padded: + # Load token information + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) + offs_token = offs_token.to(tl.int64) + token_mask = offs_token < num_valid_tokens + + # Get expert ID for this block + off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) + + # Only process if expert is valid + if off_experts != -1: + # Initialize offsets for N dimension (current block) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + # Load grad_output block: shape (BLOCK_SIZE_M, BLOCK_SIZE_N) + grad_output_ptrs = grad_output_ptr + (offs_token[:, None] * stride_gom + offs_n[None, :] * stride_gon) + grad_out = tl.load( + grad_output_ptrs, + mask=token_mask[:, None] & (offs_n[None, :] < N), + other=0.0, + ) + + # Apply topk_weights to grad_output if needed + if MUL_ROUTED_WEIGHT: + moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) + grad_out = grad_out * moe_weight[:, None] + + # Iterate over K dimension + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Current K offsets + curr_offs_k = k * BLOCK_SIZE_K + offs_k + + # Load weight block: shape (BLOCK_SIZE_N, BLOCK_SIZE_K) + # weight: shape (E, N, K) + weight_ptrs = ( + weight_ptr + + off_experts * stride_we + + offs_n[:, None] * stride_wn + + curr_offs_k[None, :] * stride_wk + ) + w = tl.load( + weight_ptrs, + mask=(offs_n[:, None] < N) & (curr_offs_k[None, :] < K), + other=0.0, + ) + + # Compute contribution: grad_out @ weight + # grad_out: (BLOCK_SIZE_M, BLOCK_SIZE_N) + # w: (BLOCK_SIZE_N, BLOCK_SIZE_K) + # result: (BLOCK_SIZE_M, BLOCK_SIZE_K) + contribution = tl.dot(grad_out, w) + + # Atomic add to grad_input because different N blocks contribute to same K + grad_input_ptrs = grad_input_ptr + ( + (offs_token[:, None] // top_k) * stride_gim + curr_offs_k[None, :] * stride_gik + ) + grad_input_mask = token_mask[:, None] & (curr_offs_k[None, :] < K) + tl.atomic_add(grad_input_ptrs, contribution.to(compute_type), mask=grad_input_mask) + + +@triton.jit +def fused_moe_backward_weight_kernel( + # Pointers to matrices + grad_output_ptr, + input_ptr, + grad_weight_ptr, + topk_weights_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + num_tokens_post_padded_ptr, + # Matrix dimensions + N, + K, + EM, + num_valid_tokens, + # Strides + stride_gom, + stride_gon, + stride_im, + stride_ik, + stride_gwe, + stride_gwn, + stride_gwk, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + MUL_ROUTED_WEIGHT: tl.constexpr, + top_k: tl.constexpr, + compute_type: tl.constexpr, +): + """ + Backward kernel for computing grad_weight. + + Forward: output = input @ weight.T (optionally multiplied by topk_weights) + Backward: grad_weight = input.T @ grad_output (optionally multiplied by topk_weights) + + This kernel computes: grad_weight[expert, n, k] = sum_over_tokens(input[token, k] * grad_output[token, n]) + If MUL_ROUTED_WEIGHT: the accumulation is weighted by topk_weights[token] + + Parallelization: Parallel over M and N dimensions with grouping, loop over K. + """ + # Map program ids to blocks (parallel over M and N with grouping, similar to forward and backward_input) + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # Check bounds + num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) + + # Only process if this block is valid + if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: + return + + # Get expert ID for this M block + expert_id = tl.load(expert_ids_ptr + pid_m).to(tl.int64) + + # Only process if expert is valid + if expert_id == -1: + return + + # Load token information for this M block + offs_m = tl.arange(0, BLOCK_SIZE_M) + offs_token_id = pid_m * BLOCK_SIZE_M + offs_m.to(tl.int64) + offs_token = tl.load( + sorted_token_ids_ptr + offs_token_id, mask=offs_token_id < num_tokens_post_padded, other=num_valid_tokens + ) + offs_token = offs_token.to(tl.int64) + token_mask = (offs_token_id < num_tokens_post_padded) & (offs_token < num_valid_tokens) + + # Clamp offs_token to valid range + offs_token_clamped = tl.where(token_mask, offs_token, 0) + + # Determine input token indices based on MUL_ROUTED_WEIGHT + if MUL_ROUTED_WEIGHT: + input_token_idx = offs_token_clamped + input_mask = token_mask + else: + input_token_idx = offs_token_clamped // top_k + num_input_tokens = num_valid_tokens // top_k + input_mask = token_mask & (input_token_idx < num_input_tokens) + + # Load topk_weights if needed + if MUL_ROUTED_WEIGHT: + moe_weight = tl.load(topk_weights_ptr + offs_token_clamped, mask=token_mask, other=0.0) + + # Current N offset for this program + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64) + + # Load grad_output for this N block: shape (M, BLOCK_SIZE_N) + # grad_output is always indexed by sorted_token_ids (offs_token_clamped) + # because it has shape (num_tokens * topk, N) + grad_output_ptrs = grad_output_ptr + (offs_token_clamped[:, None] * stride_gom + offs_n[None, :] * stride_gon) + grad_out = tl.load( + grad_output_ptrs, + mask=token_mask[:, None] & (offs_n[None, :] < N), + other=0.0, + ) + + # Apply topk_weights if needed + if MUL_ROUTED_WEIGHT: + grad_out = grad_out * moe_weight[:, None] + + # Zero out padding tokens + token_mask_col = token_mask[:, None] + grad_out = grad_out * token_mask_col + + # Iterate over K blocks and accumulate + for k_block in range(tl.cdiv(K, BLOCK_SIZE_K)): + offs_k = k_block * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K).to(tl.int64) + + # Load input for this K block + input_ptrs = input_ptr + (input_token_idx[:, None] * stride_im + offs_k[None, :] * stride_ik) + inp = tl.load( + input_ptrs, + mask=input_mask[:, None] & (offs_k[None, :] < K), + other=0.0, + ) + + # Zero out padding tokens - use input_mask for input, token_mask for grad_output + input_mask_col = input_mask[:, None] + inp = inp * input_mask_col + + # Compute grad_weight contribution: grad_out.T @ inp + grad_w_contribution = tl.dot(grad_out.T, inp) + + # Write back using atomic add + grad_weight_ptrs = ( + grad_weight_ptr + expert_id * stride_gwe + offs_n[:, None] * stride_gwn + offs_k[None, :] * stride_gwk + ) + grad_weight_mask = (offs_n[:, None] < N) & (offs_k[None, :] < K) + tl.atomic_add(grad_weight_ptrs, grad_w_contribution.to(compute_type), mask=grad_weight_mask) + + +@triton.jit +def fused_moe_backward_topk_weights_kernel( + # Pointers to matrices + grad_output_ptr, + input_ptr, + weight_ptr, + grad_topk_weights_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + num_tokens_post_padded_ptr, + # Matrix dimensions + N, + K, + EM, + num_valid_tokens, + # Strides + stride_gom, + stride_gon, + stride_im, + stride_ik, + stride_we, + stride_wn, + stride_wk, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + top_k: tl.constexpr, + compute_type: tl.constexpr, +): + """ + Backward kernel for computing grad_topk_weights. + + Forward: output = topk_weights * (input @ weight.T) + Backward: grad_topk_weights = sum(grad_output * (input @ weight.T)) + + This kernel computes the gradient of topk_weights by computing the dot product + of grad_output with the forward output before weight multiplication. + """ + # Map program id to token block + pid = tl.program_id(axis=0) + + # Check bounds + num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) + + # Only process if this block is valid + if pid * BLOCK_SIZE_M < num_tokens_post_padded: + # Load token information + offs_token_id = pid * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + offs_token = tl.load( + sorted_token_ids_ptr + offs_token_id, mask=offs_token_id < num_tokens_post_padded, other=num_valid_tokens + ) + offs_token = offs_token.to(tl.int64) + token_mask = (offs_token_id < num_tokens_post_padded) & (offs_token < num_valid_tokens) + + # Clamp offs_token to valid range for safe pointer arithmetic + offs_token_clamped = tl.where(token_mask, offs_token, 0) + + # Get expert ID for this block + off_experts = tl.load(expert_ids_ptr + pid).to(tl.int64) + + # Only process if expert is valid + if off_experts != -1: + # Initialize offsets + offs_n = tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + # Accumulator for grad_topk_weights + accumulator = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + + # Iterate over N and K dimensions to compute forward output and gradient + for n in range(0, tl.cdiv(N, BLOCK_SIZE_N)): + # Current N offset + curr_offs_n = n * BLOCK_SIZE_N + offs_n + + # Load grad_output block: (M, N) + grad_output_ptrs = grad_output_ptr + ( + offs_token_clamped[:, None] * stride_gom + curr_offs_n[None, :] * stride_gon + ) + grad_out = tl.load( + grad_output_ptrs, + mask=token_mask[:, None] & (curr_offs_n[None, :] < N), + other=0.0, + ) + + # Compute forward output for this N block: input @ weight[:, n, :].T + forward_output_n = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Current K offset + curr_offs_k = k * BLOCK_SIZE_K + offs_k + + # Load input block: (M, K) + input_ptrs = input_ptr + ( + (offs_token_clamped[:, None] // top_k) * stride_im + curr_offs_k[None, :] * stride_ik + ) + inp = tl.load( + input_ptrs, + mask=token_mask[:, None] & (curr_offs_k[None, :] < K), + other=0.0, + ) + + # Load weight block: (N, K) + weight_ptrs = ( + weight_ptr + + off_experts * stride_we + + curr_offs_n[:, None] * stride_wn + + curr_offs_k[None, :] * stride_wk + ) + w = tl.load( + weight_ptrs, + mask=(curr_offs_n[:, None] < N) & (curr_offs_k[None, :] < K), + other=0.0, + ) + + # Accumulate forward output: input @ weight.T + # inp: (M, K), w.T: (K, N) -> (M, N) + forward_output_n += tl.dot(inp, w.T) + + # Compute contribution to grad_topk_weights: sum(grad_out * forward_output) + # Sum over N dimension + accumulator += tl.sum(grad_out * forward_output_n, axis=1) + + # Write back grad_topk_weights using atomic add with clamped token indices + tl.atomic_add(grad_topk_weights_ptr + offs_token_clamped, accumulator.to(compute_type), mask=token_mask) + + +def invoke_fused_moe_backward_kernel( + grad_output: torch.Tensor, + input: torch.Tensor, + weight: torch.Tensor, + grad_input: torch.Tensor, + grad_weight: torch.Tensor, + grad_topk_weights: torch.Tensor | None, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + mul_routed_weight: bool, + top_k: int, + config: dict[str, Any], + compute_type: tl.dtype, +) -> None: + """ + Invoke the fused MOE backward kernels to compute gradients. + + Args: + grad_output: Gradient of output, shape (num_tokens * topk, N) or (num_tokens, topk, N) + input: Input tensor, shape (num_tokens, K) + weight: Weight tensor, shape (E, N, K) + grad_input: Output gradient for input, shape (num_tokens, K) + grad_weight: Output gradient for weight, shape (E, N, K) + grad_topk_weights: Output gradient for topk_weights, shape (num_tokens, topk) or None + topk_weights: Top-K routing weights, shape (num_tokens, topk) + topk_ids: Top-K expert IDs, shape (num_tokens, topk) + sorted_token_ids: Sorted token IDs + expert_ids: Expert IDs for each block + num_tokens_post_padded: Number of tokens after padding + mul_routed_weight: Whether to multiply by routing weights + top_k: Number of experts per token + config: Kernel configuration + compute_type: Computation data type + """ + assert topk_weights.stride(1) == 1 + assert sorted_token_ids.stride(0) == 1 + + # Flatten grad_output if needed + # Before: (num_tokens, topk, hidden_size) + # After: (num_tokens * topk, hidden_size) + if grad_output.ndim == 3: + grad_output = grad_output.reshape(-1, grad_output.shape[-1]) + + E, N, K = weight.shape + + # ===================== Compute grad_input ===================== + def grid_input(META): + return (triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),) + + fused_moe_backward_input_kernel[grid_input]( + grad_output, + weight, + grad_input, + grad_topk_weights if grad_topk_weights is not None else grad_input, # dummy pointer + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + N, + K, + sorted_token_ids.shape[0], + grad_output.shape[0], + grad_output.stride(0), + grad_output.stride(1), + weight.stride(0), + weight.stride(1), + weight.stride(2), + grad_input.stride(0), + grad_input.stride(1), + MUL_ROUTED_WEIGHT=mul_routed_weight, + top_k=top_k, + compute_type=compute_type, + **config, + ) + + # ===================== Compute grad_weight ===================== + # Initialize grad_weight to zero + grad_weight.zero_() + + # Use same grid configuration as forward kernel: encode both M and N dimensions + def grid_weight(META): + return (triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),) + + fused_moe_backward_weight_kernel[grid_weight]( + grad_output, + input, + grad_weight, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + N, + K, + sorted_token_ids.shape[0], + grad_output.shape[0], + grad_output.stride(0), + grad_output.stride(1), + input.stride(0), + input.stride(1), + grad_weight.stride(0), + grad_weight.stride(1), + grad_weight.stride(2), + MUL_ROUTED_WEIGHT=mul_routed_weight, + top_k=top_k, + compute_type=compute_type, + **config, + ) + + # ===================== Compute grad_topk_weights (if needed) ===================== + if mul_routed_weight and grad_topk_weights is not None: + + def grid_topk(META): + return (triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"]),) + + fused_moe_backward_topk_weights_kernel[grid_topk]( + grad_output, + input, + weight, + grad_topk_weights.view(-1), + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + N, + K, + sorted_token_ids.shape[0], + grad_output.shape[0], + grad_output.stride(0), + grad_output.stride(1), + input.stride(0), + input.stride(1), + weight.stride(0), + weight.stride(1), + weight.stride(2), + top_k=top_k, + compute_type=compute_type, + BLOCK_SIZE_M=config["BLOCK_SIZE_M"], + BLOCK_SIZE_N=config["BLOCK_SIZE_N"], + BLOCK_SIZE_K=config["BLOCK_SIZE_K"], + ) diff --git a/miles/backends/fsdp_utils/parallel.py b/miles/backends/fsdp_utils/parallel.py new file mode 100644 index 000000000..d87aa2b6e --- /dev/null +++ b/miles/backends/fsdp_utils/parallel.py @@ -0,0 +1,58 @@ +import logging +from argparse import Namespace + +import torch.distributed as dist +from ring_flash_attn import substitute_hf_flash_attn +from torch.distributed.device_mesh import init_device_mesh + +from miles.utils.distributed_utils import get_gloo_group + +from ..training_utils.parallel import ParallelState + +logger = logging.getLogger(__name__) + + +def create_fsdp_parallel_state(args: Namespace) -> ParallelState: + """Create a ParallelState instance for FSDP configuration.""" + world_size = dist.get_world_size() + rank = dist.get_rank() + + cp_size = args.context_parallel_size + dp_rank = rank // cp_size + cp_rank = rank % cp_size + + mesh = init_device_mesh("cuda", mesh_shape=(world_size // cp_size, cp_size), mesh_dim_names=("dp", "cp")) + + logger.info( + f"[Rank {rank}] Device mesh (2D): world_size={world_size}, " + f"cp_size={cp_size}, dp_size={world_size // cp_size}" + ) + logger.info(f"[Rank {rank}] Mesh shape: {mesh.shape}, " f"dp_rank={dp_rank}, cp_rank={cp_rank}") + + # Setup Ring Flash Attention with CP group from mesh (only when cp_size > 1) + if cp_size > 1: + substitute_hf_flash_attn(mesh.get_group("cp"), heads_k_stride=1) + logger.info(f"[Rank {rank}] CP initialized via device mesh") + else: + logger.info(f"[Rank {rank}] Pure DP mode (cp_size=1)") + + parallel_state = ParallelState( + dp_rank=dp_rank, + dp_src_rank=dp_rank // world_size, + dp_size=world_size // cp_size, + cp_rank=cp_rank, + cp_size=cp_size, + dp_cp_rank=rank, + dp_cp_size=world_size, + dp_group=mesh.get_group("dp"), + dp_cp_group=dist.group.WORLD, + dp_cp_group_gloo=get_gloo_group(), + cp_group=mesh.get_group("cp"), + tp_size=1, + tp_rank=0, + tp_group=dist.new_group([rank]), + ) + + parallel_state.dp_mesh = mesh["dp"] + + return parallel_state diff --git a/miles/backends/fsdp_utils/update_weight_utils.py b/miles/backends/fsdp_utils/update_weight_utils.py index c8dcbd810..d0f2360ab 100644 --- a/miles/backends/fsdp_utils/update_weight_utils.py +++ b/miles/backends/fsdp_utils/update_weight_utils.py @@ -33,6 +33,7 @@ class UpdateWeight(abc.ABC): def __init__(self, args: Namespace, model: torch.nn.Module) -> None: self.args = args self.model = model + self.weight_version = 0 @abc.abstractmethod def connect_rollout_engines( @@ -43,6 +44,7 @@ def connect_rollout_engines( pass def update_weights(self) -> None: + self.weight_version += 1 bucket = [] bucket_size = 0 for name, param in self.model.state_dict().items(): @@ -71,10 +73,10 @@ def update_weights(self) -> None: def wait_and_update_bucket_weights(self, bucket): bucket = [(name, param.wait()) if hasattr(param, "wait") else (name, param) for name, param in bucket] - self.update_bucket_weights(bucket) + self.update_bucket_weights(bucket, weight_version=self.weight_version) @abc.abstractmethod - def update_bucket_weights(self, named_tensors) -> None: + def update_bucket_weights(self, named_tensors, weight_version=None) -> None: pass @@ -114,7 +116,7 @@ def connect_rollout_engines( # Calculate TP rank within this SGLang engine group self.tp_rank = dist.get_rank() - start_rank - def update_bucket_weights(self, named_tensors) -> None: + def update_bucket_weights(self, named_tensors, weight_version=None) -> None: monkey_patch_torch_reductions() # Use flattened bucket approach similar to Megatron logger.info("Using flattened tensor bucket") @@ -162,6 +164,7 @@ def update_bucket_weights(self, named_tensors) -> None: "serialized_named_tensors": [tensors[i] for tensors in gathered_serialized_batches], "load_format": "flattened_bucket", "flush_cache": False, + "weight_version": str(weight_version), } ref = self._ipc_engine.update_weights_from_tensor.remote(**kwargs) ray.get(ref) @@ -174,10 +177,6 @@ def update_bucket_weights(self, named_tensors) -> None: class UpdateWeightFromDistributed(UpdateWeight): """Broadcast weights via a temporary NCCL group to rollout engines.""" - def __init__(self, args: Namespace, model: torch.nn.Module) -> None: - self.args = args - self.model = model - def connect_rollout_engines( self, rollout_engines: Sequence[ActorHandle], @@ -220,7 +219,7 @@ def connect_rollout_engines( ) ray.get(refs) - def update_bucket_weights(self, named_tensors) -> None: + def update_bucket_weights(self, named_tensors, weight_version=None) -> None: """Send names/dtypes/shapes metadata to engines, then broadcast tensors. Ensures tensors are contiguous; when `world_size == 1`, converts DTensors @@ -235,6 +234,7 @@ def update_bucket_weights(self, named_tensors) -> None: dtypes=[param.dtype for _, param in named_tensors], shapes=[param.shape for _, param in named_tensors], group_name=self._group_name, + weight_version=str(weight_version), ) for engine in self.rollout_engines ] diff --git a/miles/backends/megatron_utils/__init__.py b/miles/backends/megatron_utils/__init__.py index d67804568..a4666fbeb 100644 --- a/miles/backends/megatron_utils/__init__.py +++ b/miles/backends/megatron_utils/__init__.py @@ -20,23 +20,23 @@ def new_init(self, *args, **kwargs): except ImportError: logging.warning("deep_ep is not installed, some functionalities may be limited.") +try: + from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.text_model import ( + Qwen3VLMoETextRotaryEmbedding, + Qwen3VLTextRotaryEmbedding, + ) + + def patch_rotary_embedding(cls): + _original_forward = cls.forward + + def _patched_forward(self, *args, packed_seq_params=None, **kwargs): + return _original_forward(self, *args, **kwargs) + + cls.forward = _patched_forward + + patch_rotary_embedding(Qwen3VLTextRotaryEmbedding) + patch_rotary_embedding(Qwen3VLMoETextRotaryEmbedding) +except ImportError: + pass -from .actor import MegatronTrainRayActor -from .arguments import parse_args, set_default_megatron_args, validate_args -from .checkpoint import load_checkpoint, save_checkpoint -from .initialize import init -from .model import initialize_model_and_optimizer - -logging.getLogger().setLevel(logging.WARNING) - - -__all__ = [ - "parse_args", - "validate_args", - "load_checkpoint", - "save_checkpoint", - "set_default_megatron_args", - "MegatronTrainRayActor", - "init", - "initialize_model_and_optimizer", -] +logging.getLogger("megatron").setLevel(logging.WARNING) diff --git a/miles/backends/megatron_utils/actor.py b/miles/backends/megatron_utils/actor.py index bcbdd1a42..95803c73f 100644 --- a/miles/backends/megatron_utils/actor.py +++ b/miles/backends/megatron_utils/actor.py @@ -1,5 +1,6 @@ import logging import os +import random import socket from argparse import Namespace from contextlib import nullcontext @@ -15,7 +16,6 @@ from miles.ray.train_actor import TrainRayActor from miles.utils import train_dump_utils from miles.utils.context_utils import with_defer -from miles.utils.data import process_rollout_data from miles.utils.distributed_utils import get_gloo_group, init_process_group from miles.utils.memory_utils import clear_memory, print_memory from miles.utils.ray_utils import Box @@ -27,12 +27,14 @@ from ...utils.profile_utils import TrainProfiler from ...utils.tensor_backper import TensorBackuper +from ..training_utils.cp_utils import slice_with_cp +from ..training_utils.data import DataIterator, get_data_iterator, get_rollout_data, sync_actor_critic_data +from ..training_utils.log_utils import log_perf_data, log_rollout_data +from ..training_utils.loss import compute_advantages_and_returns, get_log_probs_and_entropy, get_values from .checkpoint import load_checkpoint -from .cp_utils import slice_log_prob_with_cp, slice_with_cp -from .data import DataIterator, get_data_iterator, log_perf_data, log_rollout_data, sync_actor_critic_data from .initialize import init, is_megatron_main_rank -from .loss import compute_advantages_and_returns, get_log_probs_and_entropy, get_values from .model import forward_only, initialize_model_and_optimizer, save, train +from .parallel import create_megatron_parallel_state from .update_weight.common import named_params_and_buffers from .update_weight.update_weight_from_distributed import UpdateWeightFromDistributed from .update_weight.update_weight_from_tensor import UpdateWeightFromTensor @@ -91,6 +93,8 @@ def init( args, role ) + self.parallel_state = create_megatron_parallel_state(model=self.model) + if role == "critic": if self.args.offload_train: self.sleep() @@ -175,62 +179,6 @@ def wake_up(self) -> None: reload_process_groups() print_memory("after wake_up model") - def _get_rollout_data(self, rollout_data_ref: Box) -> RolloutBatch: - # Fetch data through ray on CPU, not sure if this will be performance bottleneck. - # Both first pp stage and the last pp stage will receive the data. - rollout_data = process_rollout_data( - self.args, - rollout_data_ref, - mpu.get_data_parallel_rank(with_context_parallel=False), - mpu.get_data_parallel_world_size(with_context_parallel=False), - ) - # TODO: this is ugly, move to somewhere else? - # move tokens to GPU in advance - rollout_data["tokens"] = [ - torch.tensor(t, dtype=torch.long, device=torch.cuda.current_device()) for t in rollout_data["tokens"] - ] - rollout_data["loss_masks"] = [ - torch.tensor(t, dtype=torch.int, device=torch.cuda.current_device()) for t in rollout_data["loss_masks"] - ] - - if self.args.qkv_format == "bshd": - # TODO: micro-batch wise dynamic, possibly move to @data.py:get_data_iterator - max_seq_len = max(rollout_data["total_lengths"]) - - # pad to reduce memory fragmentation and maybe make the computation faster - pad_size = mpu.get_tensor_model_parallel_world_size() * self.args.data_pad_size_multiplier - max_seq_len = (max_seq_len + pad_size - 1) // pad_size * pad_size - - rollout_data["max_seq_lens"] = [max_seq_len] * len(rollout_data["tokens"]) - - if "rollout_log_probs" in rollout_data: - rollout_data["rollout_log_probs"] = [ - torch.tensor( - slice_log_prob_with_cp( - log_prob, - total_length, - response_length, - self.args.qkv_format, - rollout_data["max_seq_lens"][i] if self.args.qkv_format == "bshd" else None, - ), - device=torch.cuda.current_device(), - dtype=torch.float32, - ) - for i, (log_prob, total_length, response_length) in enumerate( - zip( - rollout_data["rollout_log_probs"], - rollout_data["total_lengths"], - rollout_data["response_lengths"], - strict=False, - ) - ) - ] - if "rollout_routed_experts" in rollout_data: - rollout_data["rollout_routed_experts"] = [ - torch.from_numpy(r) for r in rollout_data["rollout_routed_experts"] - ] - return rollout_data - def _switch_model(self, target_tag: str) -> None: if target_tag not in self.weights_backuper.backup_tags: raise ValueError(f"Cannot switch to unknown model tag: {target_tag}") @@ -251,8 +199,8 @@ def fill_routing_replay(self, data_iterator, num_microbatches, rollout_data): for iterator in data_iterator: iterator.reset() - tp_rank = mpu.get_tensor_model_parallel_rank() - tp_size = mpu.get_tensor_model_parallel_world_size() + tp_rank = self.parallel_state.tp_rank + tp_size = self.parallel_state.tp_size def pad_func(experts, pad): _, num_layers, topk = experts.shape @@ -278,9 +226,9 @@ def pad_func(experts, pad): # TODO: fuse this padding with the following slice_with_cp to reduce memory copy. rollout_routed_experts = [pad_func(r, 1) for r in rollout_routed_experts] # TODO: maybe extract a common process function for here and get_batch? - rollout_routed_experts = [slice_with_cp(r, pad_func) for r in rollout_routed_experts] + rollout_routed_experts = [slice_with_cp(r, pad_func, self.parallel_state) for r in rollout_routed_experts] rollout_routed_experts = torch.cat(rollout_routed_experts, dim=0) - pad_size = mpu.get_tensor_model_parallel_world_size() * self.args.data_pad_size_multiplier + pad_size = self.parallel_state.dp_size * self.args.data_pad_size_multiplier pad = (pad_size - rollout_routed_experts.size(0) % pad_size) % pad_size if pad != 0: rollout_routed_experts = pad_func(rollout_routed_experts, pad) @@ -329,6 +277,7 @@ def compute_log_prob( self.model, data_iterator, num_microbatches, + self.parallel_state, store_prefix=store_prefix, ) @@ -337,9 +286,9 @@ def train(self, rollout_id: int, rollout_data_ref: Box) -> None: self.wake_up() with timer("data_preprocess"): - rollout_data = self._get_rollout_data(rollout_data_ref) + rollout_data = get_rollout_data(self.args, rollout_data_ref, self.parallel_state) if self.args.debug_rollout_only: - log_rollout_data(rollout_id, self.args, rollout_data) + log_rollout_data(rollout_id, self.args, rollout_data, self.parallel_state) return if self.role == "critic": @@ -349,7 +298,7 @@ def train(self, rollout_id: int, rollout_data_ref: Box) -> None: def train_critic(self, rollout_id: int, rollout_data: RolloutBatch) -> None: # Create data iterator for log_probs and train. - data_iterator, num_microbatches = get_data_iterator(self.args, self.model, rollout_data) + data_iterator, num_microbatches = get_data_iterator(self.args, self.model, self.parallel_state, rollout_data) rollout_data.update( forward_only( get_values, @@ -357,13 +306,14 @@ def train_critic(self, rollout_id: int, rollout_data: RolloutBatch) -> None: self.model, data_iterator, num_microbatches, + self.parallel_state, ) ) if rollout_id >= self.args.num_critic_only_steps: sync_actor_critic_data(self.args, rollout_data, self._actor_critic_groups) - compute_advantages_and_returns(self.args, rollout_data) + compute_advantages_and_returns(self.args, self.parallel_state, rollout_data) self.args.loss_type = "value_loss" train( @@ -373,11 +323,12 @@ def train_critic(self, rollout_id: int, rollout_data: RolloutBatch) -> None: self.opt_param_scheduler, data_iterator, num_microbatches, + self.parallel_state, ) def train_actor(self, rollout_id: int, rollout_data: RolloutBatch) -> None: # Create data iterator for log_probs and train. - data_iterator, num_microbatches = get_data_iterator(self.args, self.model, rollout_data) + data_iterator, num_microbatches = get_data_iterator(self.args, self.model, self.parallel_state, rollout_data) if self.args.use_rollout_routing_replay: self.fill_routing_replay(data_iterator, num_microbatches, rollout_data) @@ -423,12 +374,12 @@ def train_actor(self, rollout_id: int, rollout_data: RolloutBatch) -> None: # Calculate adv and returns. Need to performed before training (instead of on the fly), # because we may need normalize the whole rollout. - compute_advantages_and_returns(self.args, rollout_data) + compute_advantages_and_returns(self.args, self.parallel_state, rollout_data) if self.rollout_data_postprocess is not None: self.rollout_data_postprocess(self.args) - log_rollout_data(rollout_id, self.args, rollout_data) + log_rollout_data(rollout_id, self.args, rollout_data, self.parallel_state) # Train if self.args.use_routing_replay: @@ -441,6 +392,7 @@ def train_actor(self, rollout_id: int, rollout_data: RolloutBatch) -> None: self.opt_param_scheduler, data_iterator, num_microbatches, + self.parallel_state, ) self.prof.step(rollout_id=rollout_id) @@ -464,35 +416,74 @@ def train_actor(self, rollout_id: int, rollout_data: RolloutBatch) -> None: logger.info(f"Updating ref model at rollout_id {rollout_id}") self.weights_backuper.backup("ref") - log_perf_data(rollout_id, self.args) + log_perf_data(rollout_id, self.args, self.parallel_state) @timer - def save_model(self, iteration: int) -> None: + def save_model(self, rollout_id: int, force_sync: bool = False) -> None: if self.args.debug_rollout_only: return - save(iteration, self.model, self.optimizer, self.opt_param_scheduler) + # torch dist may trigger nccl communication during saving. + if self.args.offload_train: + reload_process_groups() + + if self.args.async_save: + from megatron.training.async_utils import maybe_finalize_async_save + + maybe_finalize_async_save(blocking=True) + + save(rollout_id, self.model, self.optimizer, self.opt_param_scheduler) + + if force_sync and self.args.async_save: + maybe_finalize_async_save(blocking=True) + + if self.args.save_hf is not None and self.role == "actor": + from miles.backends.megatron_utils.model import save_hf_model + + save_hf_model(self.args, rollout_id, self.model) + + if self.args.offload_train: + destroy_process_groups() @timer def update_weights(self) -> None: if self.args.debug_train_only or self.args.debug_rollout_only: return - if self.args.offload_train: - reload_process_groups() + if self.args.use_fault_tolerance: + if dist.get_rank() == 0: + ray.get(self.rollout_manager.recover_rollout_engines.remote()) + dist.barrier(group=get_gloo_group()) rollout_engines, rollout_engine_lock, num_new_engines = ray.get( self.rollout_manager.get_rollout_engines_and_lock.remote() ) + + if self.args.offload_train: + reload_process_groups() + + if isinstance(num_new_engines, tuple): + num_new_engines = num_new_engines[0] + if num_new_engines > 0: self.weight_updater.connect_rollout_engines(rollout_engines, rollout_engine_lock) dist.barrier(group=get_gloo_group()) + if dist.get_rank() == 0: + ray.get(self.rollout_manager.clear_num_new_engines.remote()) with torch_memory_saver.disable() if self.args.offload_train else nullcontext(): print_memory("before update_weights") self.weight_updater.update_weights() print_memory("after update_weights") + if self.args.ci_test and len(rollout_engines) > 0: + engine = random.choice(rollout_engines) + engine_version = ray.get(engine.get_weight_version.remote()) + if str(engine_version) != str(self.weight_updater.weight_version): + raise RuntimeError( + f"Weight version mismatch! Engine: {engine_version}, Updater: {self.weight_updater.weight_version}" + ) + if getattr(self.args, "keep_old_actor", False): if self.args.update_weights_interval == 1: logger.info("updating model queue: rollout_actor -> old_actor, actor -> rollout_actor") diff --git a/miles/backends/megatron_utils/arguments.py b/miles/backends/megatron_utils/arguments.py index 5d5090116..0eb2bcd44 100644 --- a/miles/backends/megatron_utils/arguments.py +++ b/miles/backends/megatron_utils/arguments.py @@ -16,6 +16,8 @@ def set_default_megatron_args(args): # placeholders args.seq_length = 4096 args.max_position_embeddings = args.seq_length + # TODO: revisit this when megatron(dev) have solved the optimizer-cpu-offload ckpt saving bug + args.dist_ckpt_save_pre_mcore_014 = True # compatible for megatron if hasattr(args, "rope_type") and args.rope_type is None: args.rope_type = "yarn" if args.multi_latent_attention else "rope" diff --git a/miles/backends/megatron_utils/checkpoint.py b/miles/backends/megatron_utils/checkpoint.py index 6bd77d4a4..87495b0d0 100644 --- a/miles/backends/megatron_utils/checkpoint.py +++ b/miles/backends/megatron_utils/checkpoint.py @@ -7,8 +7,88 @@ from megatron.training.checkpointing import load_checkpoint as _load_checkpoint_megatron from megatron.training.checkpointing import save_checkpoint from megatron.training.global_vars import get_args + from miles.utils import megatron_bridge_utils +try: + # Here we patch out the `validate_non_overlapping_shards_metadata` in both functions + # because it is really slow for large models with many shards. + # TODO: find a less hacky way to do this. + import torch.distributed as dist + import torch.distributed._shard.sharding_spec as shard_spec + from torch.distributed._shard.sharded_tensor import ShardedTensor + from torch.distributed._shard.sharded_tensor.metadata import ShardedTensorMetadata + from torch.distributed._shard.sharded_tensor.shard import Shard + from torch.distributed._shard.sharded_tensor.utils import _parse_and_validate_remote_device + from torch.distributed._shard.sharding_spec.api import EnumerableShardingSpec + + def __post_init__(self): + pass + + EnumerableShardingSpec.__post_init__ = __post_init__ + + @classmethod + def _init_from_local_shards_and_global_metadata( # type: ignore[override] + cls, + local_shards: list[Shard], + sharded_tensor_metadata: ShardedTensorMetadata, + process_group=None, + init_rrefs=False, + sharding_spec=None, + ) -> ShardedTensor: + """ + Initialize a ShardedTensor with local shards and a global + ShardedTensorMetadata built on each rank. + + Warning: This API is experimental and subject to change. It does + not do cross rank validations, and fully rely on the user + for the correctness of sharded_tensor_metadata on each rank + """ + process_group = cls._normalize_pg(process_group) + current_rank = dist.get_rank() # intentional to get global rank + + shards_metadata = sharded_tensor_metadata.shards_metadata + + local_shard_metadatas = [] + + # collect local shard metadatas from the global sharded_tensor_metadata + for shard_metadata in shards_metadata: # type: ignore[attr-defined] + rank, local_device = _parse_and_validate_remote_device(process_group, shard_metadata.placement) + + if current_rank == rank: + local_shard_metadatas.append(shard_metadata) + + shards_metadata = sharded_tensor_metadata.shards_metadata + tensor_properties = sharded_tensor_metadata.tensor_properties + + if sharding_spec is None: + spec = shard_spec._infer_sharding_spec_from_shards_metadata(shards_metadata) + else: + spec = sharding_spec + + sharded_tensor = ShardedTensor.__new__( + ShardedTensor, + spec, + sharded_tensor_metadata.size, + dtype=tensor_properties.dtype, + layout=tensor_properties.layout, + pin_memory=tensor_properties.pin_memory, + requires_grad=tensor_properties.requires_grad, + ) + + # done validation, add local_shards + sharded_tensor._local_shards = local_shards + sharded_tensor._prepare_init(process_group=process_group, init_rrefs=init_rrefs) + + # run post initialization, i.e. map registration, rpc initialization + sharded_tensor._post_init() + return sharded_tensor + + ShardedTensor._init_from_local_shards_and_global_metadata = _init_from_local_shards_and_global_metadata + +except ImportError: + pass + logger = logging.getLogger(__name__) __all__ = ["save_checkpoint"] @@ -47,13 +127,15 @@ def _is_megatron_checkpoint(path: str | Path) -> bool: def _load_checkpoint_hf(ddp_model, optimizer, args, load_path: str): + assert args.megatron_to_hf_mode == "bridge", "Only bridge mode is supported for loading HF checkpoint" from megatron.bridge import AutoBridge + import miles_plugins.megatron_bridge # noqa: F401 logger.info(f"Load checkpoint from HuggingFace model into Megatron (path={load_path})") - bridge = AutoBridge.from_hf_pretrained(load_path, trust_remote_code=True) with megatron_bridge_utils.patch_megatron_model(ddp_model): + bridge = AutoBridge.from_hf_pretrained(args.hf_checkpoint, trust_remote_code=True) bridge.load_hf_weights(ddp_model) # Copied from Megatron-core :: load_checkpoint (with simplifications) diff --git a/miles/backends/megatron_utils/ci_utils.py b/miles/backends/megatron_utils/ci_utils.py new file mode 100644 index 000000000..e6ce784ca --- /dev/null +++ b/miles/backends/megatron_utils/ci_utils.py @@ -0,0 +1,84 @@ +"""CI utilities for Megatron backend testing.""" + +import logging +from collections.abc import Sequence + +from megatron.core.distributed import DistributedDataParallel as DDP + +logger = logging.getLogger(__name__) + + +def check_mtp_only_grad(model: Sequence[DDP], step_id: int) -> None: + """Check that only MTP parameters have non-zero gradients. + + This is used for CI testing to verify that when all outputs are truncated, + only the MTP layers receive gradients (since only mtp_loss contributes). + + Args: + model: Sequence of DDP-wrapped model chunks. + step_id: Current step index for logging. + + Raises: + AssertionError: If any non-MTP parameter has a non-zero gradient. + """ + non_mtp_nonzero_grads = [] + mtp_nonzero_grads = [] + + for model_chunk in model: + for name, param in model_chunk.named_parameters(): + # Get the main_grad from the distributed optimizer if available + grad = getattr(param, "main_grad", None) + if grad is None: + grad = param.grad + if grad is None: + continue + + grad_norm = grad.abs().max().item() + is_mtp = ".mtp." in name + + if is_mtp: + if grad_norm > 0: + mtp_nonzero_grads.append((name, grad_norm)) + else: + if grad_norm > 0: + non_mtp_nonzero_grads.append((name, grad_norm)) + + # Log the results + logger.info( + f"[CI MTP Grad Check] Step {step_id}: " + f"MTP params with non-zero grad: {len(mtp_nonzero_grads)}, " + f"non-MTP params with non-zero grad: {len(non_mtp_nonzero_grads)}" + ) + + if non_mtp_nonzero_grads: + # Log the first few non-MTP params with non-zero gradients for debugging + for name, grad_norm in non_mtp_nonzero_grads[:5]: + logger.error(f"[CI MTP Grad Check] Non-MTP param with non-zero grad: {name}, max_grad={grad_norm}") + + assert len(non_mtp_nonzero_grads) == 0, ( + f"Expected all non-MTP parameters to have zero gradients, " + f"but found {len(non_mtp_nonzero_grads)} with non-zero gradients. " + f"First few: {non_mtp_nonzero_grads[:5]}" + ) + + # Also verify that MTP params do have gradients (otherwise the test is not valid) + assert len(mtp_nonzero_grads) > 0, ( + "Expected MTP parameters to have non-zero gradients, but all were zero. " + "This may indicate the MTP loss is not being computed." + ) + + +def check_mtp_loss(mtp_loss: float, max_mtp_loss: float = 1.0) -> None: + """Check that MTP loss is within expected bounds. + + Args: + mtp_loss: The computed MTP loss value. + max_mtp_loss: Maximum allowed MTP loss (default: 1.0). + + Raises: + AssertionError: If MTP loss exceeds the maximum allowed value. + """ + assert mtp_loss < max_mtp_loss, ( + f"MTP loss {mtp_loss} exceeds maximum allowed value {max_mtp_loss}. " + "This may indicate an issue with MTP training." + ) diff --git a/miles/backends/megatron_utils/config_mapping/__init__.py b/miles/backends/megatron_utils/config_mapping/__init__.py deleted file mode 100644 index cc8ebc132..000000000 --- a/miles/backends/megatron_utils/config_mapping/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -from .registry import mapper_registry, register_mapper - - -def get_mapper(name: str): - return mapper_registry.get_mapper(name) - - -__all__ = [ - "register_mapper", - "mapper_registry", - "get_mapper", -] diff --git a/miles/backends/megatron_utils/config_mapping/predefined_config_mappers.py b/miles/backends/megatron_utils/config_mapping/predefined_config_mappers.py deleted file mode 100644 index 8f092d9ae..000000000 --- a/miles/backends/megatron_utils/config_mapping/predefined_config_mappers.py +++ /dev/null @@ -1,128 +0,0 @@ -from collections import namedtuple -import torch.nn.functional as F -from transformers import PretrainedConfig -from .registry import register_mapper - - -MegatronModelConfig = namedtuple("MegatronModelConfig", ["transformer_config", "gpt_model_args"]) - - -def _get_activation_func(name: str): - if name == "silu": - return F.silu - elif name == "gelu": - return F.gelu - else: - raise ValueError(f"Unsupported activation function: {name}") - - -def _to_n_args(value): - if isinstance(value, list): - return value - return [value] - - -def _map_common_configs(hf_config: PretrainedConfig) -> MegatronModelConfig: - rope_scaling_args = {} - if "rope_scaling" in hf_config and hf_config.rope_scaling is not None: - rope_scaling_args["seq_len_interpolation_factor"] = hf_config.rope_scaling["factor"] - return MegatronModelConfig( - transformer_config={ - # Model architecture parameters - "num_layers": hf_config.num_hidden_layers, - "hidden_size": hf_config.hidden_size, - "num_attention_heads": hf_config.num_attention_heads, - "num_query_groups": hf_config.num_key_value_heads, - "ffn_hidden_size": hf_config.intermediate_size, - "kv_channels": getattr(hf_config, "head_dim", None), - "layernorm_epsilon": hf_config.rms_norm_eps, - # Activation and normalization - "activation_func": _get_activation_func(hf_config.hidden_act), - "normalization": "RMSNorm", - "gated_linear_unit": True, - }, - gpt_model_args={ - "vocab_size": hf_config.vocab_size, - "rotary_base": hf_config.rope_theta, - "position_embedding_type": "rope", - "untie_embeddings_and_output_weights": not hf_config.tie_word_embeddings, - }, - ) - - -@register_mapper("qwen2") -def qwen2_config_mapper(hf_config: PretrainedConfig) -> MegatronModelConfig: - mapped_config = _map_common_configs(hf_config) - mapped_config.transformer_config.update( - { - "add_bias_linear": False, - "add_qkv_bias": hf_config.attention_bias, - } - ) - - return mapped_config - - -@register_mapper("qwen3") -def qwen3_config_mapper(hf_config: PretrainedConfig) -> MegatronModelConfig: - mapped_config = _map_common_configs(hf_config) - mapped_config.transformer_config.update( - { - "add_bias_linear": False, - "add_qkv_bias": hf_config.attention_bias, - "qk_layernorm": True, - } - ) - - return mapped_config - - -@register_mapper("qwen3_moe") -def qwen3_moe_config_mapper(hf_config: PretrainedConfig) -> MegatronModelConfig: - mapped_config = _map_common_configs(hf_config) - mapped_config.transformer_config.update( - { - "add_bias_linear": False, - "add_qkv_bias": hf_config.attention_bias, - "moe_ffn_hidden_size": hf_config.moe_intermediate_size, - "moe_router_topk": hf_config.num_experts_per_tok, - "num_moe_experts": hf_config.num_experts, - "moe_aux_loss_coeff": _to_n_args(hf_config.router_aux_loss_coef), - "moe_router_load_balancing_type": _to_n_args("none"), # turn off aux_loss as it hurts perf in RL - "moe_router_score_function": "softmax", - "moe_router_pre_softmax": False, - "qk_layernorm": True, - } - ) - - return mapped_config - - -@register_mapper("glm4_moe") -def glm4_moe_config_mapper(hf_config: PretrainedConfig) -> MegatronModelConfig: - moe_layer_freq = [1] * hf_config.num_hidden_layers - for i in range(min(hf_config.first_k_dense_replace, hf_config.num_hidden_layers)): - moe_layer_freq[i] = 0 - - mapped_config = _map_common_configs(hf_config) - mapped_config.transformer_config.update( - { - "add_bias_linear": False, - "qk_layernorm": hf_config.use_qk_norm, - "add_qkv_bias": hf_config.attention_bias, - "moe_ffn_hidden_size": hf_config.moe_intermediate_size, - "moe_router_topk": hf_config.num_experts_per_tok, - "moe_router_topk_scaling_factor": hf_config.routed_scaling_factor, - "moe_router_dtype": "fp32", - "num_moe_experts": hf_config.num_experts, - "moe_router_enable_expert_bias": True, - "moe_layer_freq": moe_layer_freq, - "moe_router_bias_update_rate": 0.0, - "moe_aux_loss_coeff": _to_n_args(hf_config.router_aux_loss_coef), - "moe_router_load_balancing_type": _to_n_args("seq_aux_loss"), - "moe_router_score_function": "sigmoid", - "rotary_percent": hf_config.partial_rotary_factor, - } - ) - - return mapped_config diff --git a/miles/backends/megatron_utils/config_mapping/registry.py b/miles/backends/megatron_utils/config_mapping/registry.py deleted file mode 100644 index ebc2677f2..000000000 --- a/miles/backends/megatron_utils/config_mapping/registry.py +++ /dev/null @@ -1,55 +0,0 @@ -import logging -from collections.abc import Callable - -logger = logging.getLogger(__name__) - - -class MapperRegistry: - """ - Registry for config mappers. - """ - - def __init__(self): - self._mappers: dict[str, Callable] = {} - - def register(self, model_types: list[str], mapper_func: Callable): - if not callable(mapper_func): - raise TypeError(f"Mapper for {model_types} must be callable") - - for name in model_types: - if name in self._mappers: - logger.warning(f"Mapper for {name} is being overridden") - self._mappers[name] = mapper_func - logger.info(f"Registered config mapper for model type: {name}") - - def get_mapper(self, name: str) -> Callable: - """ - Get the mapper by model_type. - """ - if name not in self._mappers: - raise ValueError(f"Mapper for {name} is not registered.") - return self._mappers[name] - - def list_registered_mappers(self) -> list[str]: - return list(self._mappers.keys()) - - -# Global registry instance -mapper_registry = MapperRegistry() - - -def register_mapper(*args): - """ - Decorator: register config mapper. - - Args: suppotred model_types. - """ - - def decorator(func: Callable): - mapper_registry.register( - model_types=list(args), - mapper_func=func, - ) - return func - - return decorator diff --git a/miles/backends/megatron_utils/initialize.py b/miles/backends/megatron_utils/initialize.py index e9f062c11..cbe981632 100644 --- a/miles/backends/megatron_utils/initialize.py +++ b/miles/backends/megatron_utils/initialize.py @@ -4,6 +4,7 @@ import numpy as np import torch from megatron.core import mpu, tensor_parallel +from megatron.core.config import set_experimental_flag from megatron.core.num_microbatches_calculator import init_num_microbatches_calculator from megatron.training.global_vars import _build_tokenizer, set_args @@ -54,6 +55,10 @@ def _initialize_distributed(args, get_embedding_ranks=None, get_position_embeddi def init(args): set_args(args) + if args.enable_experimental: + logger.info("Enable megatron experimental") + set_experimental_flag(True) + # Pytorch distributed. _initialize_distributed(args) diff --git a/miles/backends/megatron_utils/kernels/int4_qat/fake_int4_quant_cuda.cu b/miles/backends/megatron_utils/kernels/int4_qat/fake_int4_quant_cuda.cu new file mode 100644 index 000000000..a6e955490 --- /dev/null +++ b/miles/backends/megatron_utils/kernels/int4_qat/fake_int4_quant_cuda.cu @@ -0,0 +1,368 @@ +#include +#include + +#define FINAL_MASK 0xFFFFFFFF + +__device__ __host__ __forceinline__ +int ceil_div(int a, int b) { + return (a + b - 1) / b; +} + +__device__ __forceinline__ +float warpReduceMax(float val) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + val = fmaxf(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32)); + return val; +} + + +__device__ __forceinline__ +float warpReduceMin(float val) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + val = fminf(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32)); + return val; +} + +// almost all int4 use blocksize = [1, 32] +template +__global__ +void int4_quant_1x32_kernel( + const scalar_t* __restrict__ x, + scalar_t* __restrict__ out, + scalar_t* out_scale, + scalar_t* out_zero, + const int M, const int N, + const int stride_xm, const int stride_xn, + const int stride_om, const int stride_on, + const int stride_osm, const int stride_osn, + const int stride_ozm, const int stride_ozn, + bool sym +) { + constexpr int WARPS_PER_BLOCK = 8; + const int needed_warps = ceil_div(N, 32); + + const int tid = threadIdx.x; + const int warp_id = tid >> 5; + const int lane_id = tid & 0x1F; + constexpr float SYM_CONS = 1.0f / 7.0f; + constexpr float ASYM_CONS = 1.0f / 15.0f; + + const int row = blockIdx.x; + + for (int item = warp_id; item < needed_warps; item += WARPS_PER_BLOCK) { + const int col = item * 32 + lane_id; + float val = 0.0f; + + if (col < N) { + val = static_cast(x[row * stride_xm + col * stride_xn]); + } + + float scale = 0.0f; + float zero = 0.0f; + + if (sym) { + float abs_val = fabsf(val); + + float block_max = warpReduceMax(abs_val); + + scale = fmaxf(block_max * SYM_CONS, 1e-5f); + + val = rintf(val / scale); + } else { + float block_min = warpReduceMin(val); + float block_max = warpReduceMax(val); + + scale = fmaxf((block_max - block_min) * ASYM_CONS, 1e-5f); + zero = fminf(fmaxf(-rintf(block_min / scale), 0.0f), 15.0f); + + val = rintf(val / scale) + zero; + } + + if (col < N) { + out[row * stride_om + col * stride_on] = static_cast(val); + out_scale[row * stride_osm + item * stride_osn] = static_cast(scale); + if(!sym) { + out_zero[row * stride_ozm + item * stride_ozn] = static_cast(zero); + } + } + } +} + +// for some transpose case, blocksize = [32, 1] +template +__global__ +void int4_quant_32x1_kernel( + const scalar_t* __restrict__ x, + scalar_t* __restrict__ out, + scalar_t* out_scale, + scalar_t* out_zero, + const int M, const int N, + const int stride_xm, const int stride_xn, + const int stride_om, const int stride_on, + const int stride_osm, const int stride_osn, + const int stride_ozm, const int stride_ozn, + bool sym +) { + constexpr int WARPS_PER_BLOCK = 8; + const int start_row = blockIdx.x * 32; + const int end_row = min((blockIdx.x + 1) * 32, M); + + const int tid = threadIdx.x; + const int warp_id = tid >> 5; + const int lane_id = tid & 0x1F; + constexpr float SYM_CONS = 1.0f / 7.0f; + constexpr float ASYM_CONS = 1.0f / 15.0f; + + for (int item = warp_id; item < N; item += WARPS_PER_BLOCK) { + const int col = item; + const int row = start_row + lane_id; + + float val = 0.0f; + + if (row < end_row) { + val = static_cast(x[row * stride_xm + col * stride_xn]); + } + + float scale = 0.0f; + float zero = 0.0f; + + if (sym) { + float abs_val = fabsf(val); + + float block_max = warpReduceMax(abs_val); + + scale = fmaxf(block_max * SYM_CONS, 1e-5f); + + val = rintf(val / scale); + } else { + float block_min = warpReduceMin(val); + float block_max = warpReduceMax(val); + + scale = fmaxf((block_max - block_min) * ASYM_CONS, 1e-5f); + zero = fminf(fmaxf(-rintf(block_min / scale), 0.0f), 15.0f); + + val = rintf(val / scale) + zero; + } + + if (row < end_row) { + out[row * stride_om + col * stride_on] = static_cast(val); + out_scale[blockIdx.x * stride_osm + item * stride_osn] = static_cast(scale); + if (!sym) { + out_zero[blockIdx.x * stride_ozm + item * stride_ozn] = static_cast(zero); + } + } + } +} + +template +__global__ void int4_quant_common_kernel( + const scalar_t* __restrict__ x, + scalar_t* __restrict__ out, + scalar_t* out_scale, + scalar_t* out_zero, + const int M, const int N, + const int stride_xm, const int stride_xn, + const int stride_om, const int stride_on, + const int stride_osm, const int stride_osn, + const int stride_ozm, const int stride_ozn, + const int BLOCK_M, const int BLOCK_N, + bool sym +) { + const int start_row = blockIdx.x * BLOCK_M; + const int WARPS_PER_BLOCK = blockDim.x >> 5; + + const int warp_id = threadIdx.x >> 5; + const int lane_id = threadIdx.x & 0x1F; + constexpr float SYM_CONS = 1.0f / 7.0f; + constexpr float ASYM_CONS = 1.0f / 15.0f; + constexpr int WARP_SIZE = 32; + + const int needed_warps = ceil_div(N, BLOCK_N); + const int iters = ceil_div(BLOCK_M * BLOCK_N, 32); + int warp_rows = 1; + + if (BLOCK_N <= WARP_SIZE) { + warp_rows = WARP_SIZE / BLOCK_N; + } + + for (int item = warp_id; item < needed_warps; item += WARPS_PER_BLOCK) { + float local_max = -INFINITY; + float local_min = INFINITY; + + float val = 0.0f; + float scale, zero = 0.0f; + + const int row_off = lane_id / BLOCK_N; + const int col_off = lane_id % BLOCK_N; + int row, col = 0; + + for (int i = 0; i < iters; ++i) { + if (BLOCK_N <= WARP_SIZE) { + row = start_row + i * warp_rows + row_off; + col = item * BLOCK_N + col_off; + } else { + row = start_row; + col = item * BLOCK_N + i * WARP_SIZE + col_off; + } + + if (row < M && col < N) { + val = static_cast(x[row * stride_xm + col * stride_xn]); + } else { + val = 0.0f; + } + + if (sym) { + local_max = fmaxf(local_max, fabsf(val)); + } else { + local_max = fmaxf(local_max, val); + local_min = fminf(local_min, val); + } + } + + if (sym) { + float block_max = warpReduceMax(local_max); + scale = fmaxf(block_max * SYM_CONS, 1e-5f); + } else { + float block_max = warpReduceMax(local_max); + float block_min = warpReduceMin(local_min); + scale = fmaxf((block_max - block_min) * ASYM_CONS, 1e-5f); + zero = fminf(fmaxf(-rintf(block_min / scale), 0.0f), 15.0f); + } + + for (int i = 0; i < iters; ++i) { + if (BLOCK_N <= WARP_SIZE) { + row = start_row + i * warp_rows + row_off; + col = item * BLOCK_N + col_off; + } else { + row = start_row; + col = item * BLOCK_N + i * WARP_SIZE + col_off; + } + + if (row < M && col < N) { + float val = static_cast(x[row * stride_xm + col * stride_xn]); + if (sym) { + val = rintf(val / scale); + } else { + val = rintf(val / scale) + zero; + } + + out[row * stride_om + col * stride_on] = static_cast(val); + out_scale[blockIdx.x * stride_osm + item * stride_osn] = static_cast(scale); + if (!sym) { + out_zero[blockIdx.x * stride_ozm + item * stride_ozn] = static_cast(zero); + } + } + } + } +} + +// dispatch +template +void launch_int4_quant_kernel( + const scalar_t* x, + scalar_t* out, + scalar_t* out_scale, + scalar_t* out_zero, + int M, int N, + const int stride_xm, const int stride_xn, + const int stride_om, const int stride_on, + const int stride_osm, const int stride_osn, + const int stride_ozm, const int stride_ozn, + int block_m, int block_n, + bool sym, + cudaStream_t stream +) { + constexpr int WARPS_PER_BLOCK = 8; + constexpr int THREADS_PER_BLOCK = WARPS_PER_BLOCK * 32; // 256 + + if (block_m == 1 && block_n == 32) { + dim3 grid(M); + dim3 block(THREADS_PER_BLOCK); + + int4_quant_1x32_kernel<<>>( + x, out, out_scale, out_zero, M, N, + stride_xm, stride_xn, + stride_om, stride_on, + stride_osm, stride_osn, + stride_ozm, stride_ozn, + sym + ); + } else if (block_m == 32 && block_n == 1) { + dim3 grid(ceil_div(M, block_m)); + dim3 block(THREADS_PER_BLOCK); + + int4_quant_32x1_kernel<<>>( + x, out, out_scale, out_zero, M, N, + stride_xm, stride_xn, + stride_om, stride_on, + stride_osm, stride_osn, + stride_ozm, stride_ozn, + sym + ); + } else { + dim3 grid(ceil_div(M, block_m)); + dim3 block(THREADS_PER_BLOCK); + int4_quant_common_kernel<<>>( + x, out, out_scale, out_zero, M, N, + stride_xm, stride_xn, + stride_om, stride_on, + stride_osm, stride_osn, + stride_ozm, stride_ozn, + block_m, block_n, + sym + ); + } +} + +std::tuple +fake_int4_quant_cuda( + torch::Tensor& x, + std::vector& block_size, + bool sym +) { + TORCH_CHECK(x.dim() == 2, "Input must be 2D"); + TORCH_CHECK(x.is_cuda(), "Input must be on CUDA"); + + int M = x.size(0); + int N = x.size(1); + int block_m = block_size[0]; + int block_n = block_size[1]; + + TORCH_CHECK(block_m > 0 && block_n > 0, "Block sizes must be positive, got block_m=", block_m, ", block_n=", block_n); + TORCH_CHECK((block_m * block_n) % 32 == 0, + "block_m * block_n (", block_m * block_n, ") must be divisible by 32. " + "But got a ", block_m, "x", block_n, " block."); + + auto out = torch::empty_like(x); + auto out_scale = torch::empty({ceil_div(M, block_m), ceil_div(N, block_n)}, x.options()); + auto out_zero = torch::empty_like(out_scale); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + AT_DISPATCH_FLOATING_TYPES_AND( + at::ScalarType::BFloat16, + x.scalar_type(), "int4_quant_cuda", [&] { + launch_int4_quant_kernel( + x.const_data_ptr(), + out.data_ptr(), + out_scale.data_ptr(), + out_zero.data_ptr(), + M, N, + x.stride(0), x.stride(1), + out.stride(0), out.stride(1), + out_scale.stride(0), out_scale.stride(1), + out_zero.stride(0), out_zero.stride(1), + block_m, block_n, + sym, + stream + ); + }); + + return std::make_tuple(out, out_scale, out_zero); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("fake_int4_quant_cuda", &fake_int4_quant_cuda, "fake INT4 quantization cuda"); +} diff --git a/miles/backends/megatron_utils/kernels/int4_qat/setup.py b/miles/backends/megatron_utils/kernels/int4_qat/setup.py new file mode 100644 index 000000000..b27967bc9 --- /dev/null +++ b/miles/backends/megatron_utils/kernels/int4_qat/setup.py @@ -0,0 +1,39 @@ +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension +import torch + +# Get CUDA arch list +arch_list = [] +if torch.cuda.is_available(): + for i in range(torch.cuda.device_count()): + major, minor = torch.cuda.get_device_capability(i) + arch_list.append(f"{major}.{minor}") + arch_list = sorted(set(arch_list)) + +setup( + name="fake_int4_quant_cuda", + ext_modules=[ + CUDAExtension( + name="fake_int4_quant_cuda", + sources=["fake_int4_quant_cuda.cu"], + extra_compile_args={ + "cxx": [ + "-O3", + "-std=c++17", + ], + "nvcc": [ + "-O3", + "-std=c++17", + "--expt-relaxed-constexpr", + "-Xcompiler", + "-fPIC", + ] + + [ + f'-gencode=arch=compute_{arch.replace(".", "")},code=sm_{arch.replace(".", "")}' + for arch in arch_list + ], + }, + ) + ], + cmdclass={"build_ext": BuildExtension}, +) diff --git a/miles/backends/megatron_utils/megatron_to_hf/__init__.py b/miles/backends/megatron_utils/megatron_to_hf/__init__.py index ba5a286a3..b9b394cbc 100644 --- a/miles/backends/megatron_utils/megatron_to_hf/__init__.py +++ b/miles/backends/megatron_utils/megatron_to_hf/__init__.py @@ -1,10 +1,10 @@ from .deepseekv3 import convert_deepseekv3_to_hf +from .deepseekv32 import convert_deepseekv32_to_hf from .glm4 import convert_glm4_to_hf from .glm4moe import convert_glm4moe_to_hf from .llama import convert_llama_to_hf from .mimo import convert_mimo_to_hf -from .processors.padding_remover import remove_padding -from .processors.quantizer import quantize_params +from .processors import quantize_params, remove_padding from .qwen2 import convert_qwen2_to_hf from .qwen3_next import convert_qwen3_next_to_hf from .qwen3moe import convert_qwen3moe_to_hf @@ -23,9 +23,6 @@ def convert_to_hf(args, model_name, name, param, quantization_config=None): converted_named_tensors = _convert_to_hf_core(args, model_name, name, param) - if not quantization_config: - return converted_named_tensors - return quantize_params(args, name, converted_named_tensors, quantization_config) @@ -45,6 +42,8 @@ def _convert_to_hf_core(args, model_name, name, param): converted_named_tensors = convert_qwen3_next_to_hf(args, name, param) elif "qwen2" in model_name or "qwen3" in model_name: converted_named_tensors = convert_qwen2_to_hf(args, name, param) + elif "deepseekv32" in model_name: + converted_named_tensors = convert_deepseekv32_to_hf(args, name, param) elif "deepseekv3" in model_name: converted_named_tensors = convert_deepseekv3_to_hf(args, name, param) diff --git a/miles/backends/megatron_utils/megatron_to_hf/deepseekv3.py b/miles/backends/megatron_utils/megatron_to_hf/deepseekv3.py index 1200b18d9..3f29f197e 100644 --- a/miles/backends/megatron_utils/megatron_to_hf/deepseekv3.py +++ b/miles/backends/megatron_utils/megatron_to_hf/deepseekv3.py @@ -108,4 +108,22 @@ def convert_deepseekv3_to_hf(args, name, param): elif rest == "mlp.router.expert_bias": return [(f"model.layers.{layer_idx}.mlp.gate.e_score_correction_bias", param)] + mtp_layer_pattern = r"module\.module\.mtp\.layers\.(\d+)\.(.+)" + match = re.match(mtp_layer_pattern, name) + if match: + layer_idx, rest = match.groups() + layer_idx = int(layer_idx) + args.num_layers + if rest == "eh_proj.weight": + return [(f"model.layers.{layer_idx}.eh_proj.weight", param)] + elif rest == "enorm.weight": + return [(f"model.layers.{layer_idx}.enorm.weight", param)] + elif rest == "hnorm.weight": + return [(f"model.layers.{layer_idx}.hnorm.weight", param)] + elif rest == "final_layernorm.weight": + return [(f"model.layers.{layer_idx}.shared_head.norm.weight", param)] + else: + name = f"module.module.decoder.layers.{layer_idx}.{rest}" + name = name.replace("transformer_layer.", "") + return convert_deepseekv3_to_hf(args, name, param) + raise ValueError(f"Unknown parameter name: {name}") diff --git a/miles/backends/megatron_utils/megatron_to_hf/deepseekv32.py b/miles/backends/megatron_utils/megatron_to_hf/deepseekv32.py new file mode 100644 index 000000000..3b6519ada --- /dev/null +++ b/miles/backends/megatron_utils/megatron_to_hf/deepseekv32.py @@ -0,0 +1,135 @@ +import re + +import sglang +import torch +from packaging.version import parse + + +def convert_deepseekv32_to_hf(args, name, param): + if name == "module.module.embedding.word_embeddings.weight": + return [("model.embed_tokens.weight", param)] + if name == "module.module.output_layer.weight": + return [("lm_head.weight", param)] + if name == "module.module.decoder.final_layernorm.weight": + return [("model.norm.weight", param)] + + try: + head_dim = args.kv_channels if args.kv_channels is not None else args.hidden_size // args.num_attention_heads + except AttributeError: + head_dim = args.hidden_size // args.num_attention_heads + value_num_per_group = args.num_attention_heads // args.num_query_groups + + decoder_layers_pattern = r"module\.module\.decoder\.layers\.(\d+)\.(.+)" + match = re.match(decoder_layers_pattern, name) + if match: + layer_idx, rest = match.groups() + + # experts + expert_pattern = r"mlp.experts\.(.+)\.weight(\d+)" + match = re.match(expert_pattern, rest) + if match: + rest, expert_idx = match.groups() + if rest == "linear_fc1": + gate_weight, up_weight = param.chunk(2, dim=0) + outputs = [ + (f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.gate_proj.weight", gate_weight), + (f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.up_proj.weight", up_weight), + ] + return outputs + elif rest == "linear_fc2": + outputs = [ + (f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.down_proj.weight", param), + ] + if parse(sglang.__version__) < parse("0.4.9.post5") and args.sglang_enable_ep_moe: + outputs += [ + ( + f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.down_proj.input_scale", + torch.tensor(1.0, dtype=torch.float32, device=param.device), + ), + ( + f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.down_proj.weight_scale", + torch.tensor(1.0, dtype=torch.float32, device=param.device), + ), + ] + return outputs + else: + raise ValueError(f"Unknown expert parameter name: {name}") + + # shared expert + shared_expert_pattern = r"mlp.shared_experts\.(.+)" + match = re.match(shared_expert_pattern, rest) + if match: + rest = match.groups()[0] + if rest == "linear_fc1.weight": + gate_weight, up_weight = param.chunk(2, dim=0) + return [ + (f"model.layers.{layer_idx}.mlp.shared_experts.gate_proj.weight", gate_weight), + (f"model.layers.{layer_idx}.mlp.shared_experts.up_proj.weight", up_weight), + ] + elif rest == "linear_fc2.weight": + return [(f"model.layers.{layer_idx}.mlp.shared_experts.down_proj.weight", param)] + else: + raise ValueError(f"Unknown shared expert parameter name: {name}") + + if rest == "self_attention.linear_proj.weight": + return [(f"model.layers.{layer_idx}.self_attn.o_proj.weight", param)] + elif rest == "self_attention.linear_q_proj.weight": + return [(f"model.layers.{layer_idx}.self_attn.q_proj.weight", param)] + elif rest == "self_attention.linear_q_down_proj.weight": + return [(f"model.layers.{layer_idx}.self_attn.q_a_proj.weight", param)] + elif rest == "self_attention.q_layernorm.weight": + return [(f"model.layers.{layer_idx}.self_attn.q_a_layernorm.weight", param)] + elif rest == "self_attention.linear_q_up_proj.weight": + return [(f"model.layers.{layer_idx}.self_attn.q_b_proj.weight", param)] + elif rest == "self_attention.linear_qkv.bias": + param = param.view(args.num_query_groups, -1) + q_bias, k_bias, v_bias = torch.split( + param, + split_size_or_sections=[value_num_per_group * head_dim, head_dim, head_dim], + dim=1, + ) + q_bias = q_bias.contiguous().flatten() + k_bias = k_bias.contiguous().flatten() + v_bias = v_bias.contiguous().flatten() + return [ + (f"model.layers.{layer_idx}.self_attn.q_proj.bias", q_bias), + (f"model.layers.{layer_idx}.self_attn.k_proj.bias", k_bias), + (f"model.layers.{layer_idx}.self_attn.v_proj.bias", v_bias), + ] + elif rest == "mlp.linear_fc1.weight": + gate_weight, up_weight = param.chunk(2, dim=0) + return [ + (f"model.layers.{layer_idx}.mlp.gate_proj.weight", gate_weight), + (f"model.layers.{layer_idx}.mlp.up_proj.weight", up_weight), + ] + elif rest == "mlp.linear_fc2.weight": + return [(f"model.layers.{layer_idx}.mlp.down_proj.weight", param)] + elif rest == "self_attention.linear_qkv.layer_norm_weight" or rest == "input_layernorm.weight": + return [(f"model.layers.{layer_idx}.input_layernorm.weight", param)] + elif rest == "mlp.linear_fc1.layer_norm_weight": + return [(f"model.layers.{layer_idx}.post_attention_layernorm.weight", param)] + elif rest == "self_attention.linear_kv_down_proj.weight": + return [(f"model.layers.{layer_idx}.self_attn.kv_a_proj_with_mqa.weight", param)] + elif rest == "self_attention.kv_layernorm.weight": + return [(f"model.layers.{layer_idx}.self_attn.kv_a_layernorm.weight", param)] + elif rest == "self_attention.linear_kv_up_proj.weight": + return [(f"model.layers.{layer_idx}.self_attn.kv_b_proj.weight", param)] + elif rest == "pre_mlp_layernorm.weight": + return [(f"model.layers.{layer_idx}.post_attention_layernorm.weight", param)] + # DSA Indexer parameters + elif rest == "self_attention.core_attention.indexer.linear_wq_b.weight": + return [(f"model.layers.{layer_idx}.self_attn.indexer.wq_b.weight", param)] + elif rest == "self_attention.core_attention.indexer.linear_wk.weight": + return [(f"model.layers.{layer_idx}.self_attn.indexer.wk.weight", param)] + elif rest == "self_attention.core_attention.indexer.k_norm.weight": + return [(f"model.layers.{layer_idx}.self_attn.indexer.k_norm.weight", param)] + elif rest == "self_attention.core_attention.indexer.k_norm.bias": + return [(f"model.layers.{layer_idx}.self_attn.indexer.k_norm.bias", param)] + elif rest == "self_attention.core_attention.indexer.linear_weights_proj.weight": + return [(f"model.layers.{layer_idx}.self_attn.indexer.weights_proj.weight", param)] + elif rest == "mlp.router.weight": + return [(f"model.layers.{layer_idx}.mlp.gate.weight", param)] + elif rest == "mlp.router.expert_bias": + return [(f"model.layers.{layer_idx}.mlp.gate.e_score_correction_bias", param)] + + raise ValueError(f"Unknown parameter name: {name}") diff --git a/miles/backends/megatron_utils/megatron_to_hf/processors/__init__.py b/miles/backends/megatron_utils/megatron_to_hf/processors/__init__.py index e69de29bb..0141c3548 100644 --- a/miles/backends/megatron_utils/megatron_to_hf/processors/__init__.py +++ b/miles/backends/megatron_utils/megatron_to_hf/processors/__init__.py @@ -0,0 +1,15 @@ +from .padding_remover import remove_padding +from .quantizer_compressed_tensors import quantize_params_compressed_tensors +from .quantizer_fp8 import quantize_params_fp8 + +__all__ = ["remove_padding", "quantize_param", "quantize_params_fp8", "quantize_params_compressed_tensors"] + + +def quantize_params(args, megatron_name, converted_named_params, quantization_config): + if quantization_config is None: + return converted_named_params + elif quantization_config["quant_method"] == "fp8": + return quantize_params_fp8(args, megatron_name, converted_named_params, quantization_config) + elif quantization_config["quant_method"] == "compressed-tensors": + # only int4 at the moment. + return quantize_params_compressed_tensors(converted_named_params, quantization_config) diff --git a/miles/backends/megatron_utils/megatron_to_hf/processors/quantizer_compressed_tensors.py b/miles/backends/megatron_utils/megatron_to_hf/processors/quantizer_compressed_tensors.py new file mode 100644 index 000000000..41712b33a --- /dev/null +++ b/miles/backends/megatron_utils/megatron_to_hf/processors/quantizer_compressed_tensors.py @@ -0,0 +1,189 @@ +import logging +import math +import re +from typing import Literal + +import torch + +logger = logging.getLogger(__name__) + + +__all__ = ["quantize_params_compressed_tensors"] + + +def pack_to_int32( + value: torch.Tensor, + num_bits: int, + packed_dim: Literal[0] | Literal[1] = 1, +) -> torch.Tensor: + """ + Packs a tensor of quantized weights stored in int8 into int32s with padding + + Pseudocode: + 1. Shift wrt num_bits to convert to unsigned. num_bits=8 + [1,2] -> [129, 130] + 2. Pad to fill in 32 bits + [129, 130] -> [129, 130, 0, 0] + 3. convert to binary align in order + [129, 130, 0, 0] -> 00000000 00000000 10000010 10000001 + 4. convert aligned binary to number + 00000000000000001000001010000001 -> 33409 + 5. covert back to uint32 + 33409 -> 33409 + + :param value: tensor to pack + :param num_bits: number of bits used to store underlying data, must be at least 1 + :returns: packed int32 tensor + """ + if value.dtype is not torch.int8: + raise ValueError("Tensor must be quantized to torch.int8 before packing") + + if num_bits > 8: + raise ValueError("Packing is only supported for less than 8 bits") + + if num_bits < 1: + raise ValueError(f"num_bits must be at least 1, got {num_bits}") + + # Convert to unsigned range for packing, matching quantization offset + offset = 1 << (num_bits - 1) + value = (value + offset).to(torch.uint8) + device = value.device + + pack_factor = 32 // num_bits + + if packed_dim == 0: + value = value.transpose(0, 1) + + rows, cols = value.shape + padded_cols = math.ceil(cols / pack_factor) * pack_factor + pad_len = padded_cols - cols + + if pad_len > 0: + value = torch.nn.functional.pad(value, (0, pad_len)) + + num_groups = padded_cols // pack_factor + + # Use int32 here + reshaped = value.view(rows, num_groups, pack_factor).to(torch.int32) + bit_shifts = torch.arange(pack_factor, device=device, dtype=torch.int32) * num_bits + packed = (reshaped << bit_shifts).sum(dim=2, dtype=torch.int32) + + if packed_dim == 0: + packed = packed.transpose(0, 1) + + return packed + + +def pack_int4_to_int32(q_weight: torch.Tensor) -> torch.Tensor: + """ + pack int4 to int32 + Args: + q_weight: [N, K] tensor, dtype=int8 or uint8 + Returns: + packed: [N, K // 8] tensor, dtype=int32 + """ + return pack_to_int32(q_weight, 4, -1) + + +def int4_block_quantize(x: torch.Tensor, group_size: int = 128) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + De-quantized = Scale * Quantized (Zero Point is always 0) + """ + N, K = x.shape + if group_size == -1: + group_size = K + + # Padding + if K % group_size != 0: + import torch.nn.functional as F + + x = F.pad(x, (0, group_size - (K % group_size))) + N, K = x.shape + + num_groups = K // group_size + x_reshaped = x.float().view(N, num_groups, group_size) + + # ========================================================= + # 1. Scale + # Range: [-7, 7] -> dividing by 7.0 + # ========================================================= + x_abs_max = x_reshaped.abs().amax(dim=-1, keepdim=True) + scale = x_abs_max / 7.0 + scale = scale.clamp(min=1e-5) + + # ========================================================= + # 2. Quantize + # ========================================================= + x_int_sym = (x_reshaped / scale).round().clamp(-8, 7) + + out = x_int_sym.to(torch.int8) + + # ========================================================= + # 3. Zero Point + # ========================================================= + zero_point = torch.zeros_like(scale) + out = out.view(N, K) + + scale_out = scale.squeeze(-1).contiguous() + zero_out = zero_point.squeeze(-1).contiguous() + + return out, scale_out, zero_out + + +def quantize_params_compressed_tensors(converted_named_params, quantization_config): + w_cfg = quantization_config["config_groups"]["group_0"]["weights"] + group_size = w_cfg["group_size"] + is_symmetric = w_cfg["symmetric"] + ignore_rules = quantization_config.get("ignore", []) + + results = [] + + for name, param in converted_named_params: + is_ignored = any((r.startswith("re:") and re.match(r[3:], name)) or r == name for r in ignore_rules) + + if is_ignored or not name.endswith(".weight") or param.dim() < 2: + results.append((name, param)) + continue + + input_tensor = param.view(-1, param.shape[-1]) if param.dim() > 2 else param + + if group_size != -1 and input_tensor.shape[-1] < group_size: + logger.warning(f"Skipping {name}, K-dim {input_tensor.shape[-1]} < group_size") + results.append((name, param)) + continue + + results.extend(_quantize_param_int4(name, input_tensor, group_size, param.shape, is_symmetric)) # origin shape + + return results + + +def _quantize_param_int4(name: str, weight: torch.Tensor, group_size: int, shape: torch.Tensor, is_symmetric: bool): + """ + Wraps the quantization function, handles renaming and packing. + """ + base_name = name.replace(".weight", "") + + new_base_name = base_name + + original_dtype = weight.dtype + + if group_size == -1: + group_size = weight.shape[1] + elif weight.shape[1] % group_size != 0: + logger.warning( + f"Weight {name} with shape {weight.shape} has K-dimension " + f"not divisible by group_size {group_size}. Skipping." + ) + return [(name, weight.to(original_dtype))] + + q_weight, scales, zeros = int4_block_quantize(weight, group_size) + + packed_q_weight = pack_int4_to_int32(q_weight) + + qweight_name = f"{new_base_name}.weight_packed" + scales_name = f"{new_base_name}.weight_scale" + qweight_shape = f"{new_base_name}.weight_shape" + + q_shape = torch.tensor(shape, dtype=torch.int32, device="cuda") + + return [(qweight_name, packed_q_weight), (scales_name, scales.to(original_dtype)), (qweight_shape, q_shape)] diff --git a/miles/backends/megatron_utils/megatron_to_hf/processors/quantizer.py b/miles/backends/megatron_utils/megatron_to_hf/processors/quantizer_fp8.py similarity index 82% rename from miles/backends/megatron_utils/megatron_to_hf/processors/quantizer.py rename to miles/backends/megatron_utils/megatron_to_hf/processors/quantizer_fp8.py index 9876f6020..da5b2b55a 100644 --- a/miles/backends/megatron_utils/megatron_to_hf/processors/quantizer.py +++ b/miles/backends/megatron_utils/megatron_to_hf/processors/quantizer_fp8.py @@ -7,9 +7,7 @@ from ...sglang import quant_weight_ue8m0, should_deepgemm_weight_requant_ue8m0, transform_scale_ue8m0 -def quantize_params(args, megatron_name, converted_named_params, quantization_config): - if quantization_config is None: - return converted_named_params +def quantize_params_fp8(args, megatron_name, converted_named_params, quantization_config): assert quantization_config["quant_method"] == "fp8" assert quantization_config["fmt"] == "e4m3" assert quantization_config["activation_scheme"] == "dynamic" @@ -44,7 +42,10 @@ def quantize_params(args, megatron_name, converted_named_params, quantization_co # TODO: find a clearer way. if converted_name.endswith("_scale"): continue - quantize_named_params.extend(_quantize_param(converted_name, param, weight_block_size)) + if_use_ue8m0_in_moe = True if args.sglang_moe_runner_backend == "deep_gemm" else False + quantize_named_params.extend( + _quantize_param(converted_name, param, weight_block_size, if_use_ue8m0_in_moe=if_use_ue8m0_in_moe) + ) return quantize_named_params @@ -74,6 +75,9 @@ def quantize_params(args, megatron_name, converted_named_params, quantization_co "self_attention.linear_q_up_proj.weight", "self_attention.linear_kv_down_proj.weight", "self_attention.linear_kv_up_proj.weight", + # dsa indexer + "self_attention.core_attention.indexer.linear_wq_b.weight", + "self_attention.core_attention.indexer.linear_wk.weight", ]: quantize_named_params = [] for converted_name, param in converted_named_params: @@ -85,13 +89,15 @@ def quantize_params(args, megatron_name, converted_named_params, quantization_co return converted_named_params -def _quantize_param(name, weight, weight_block_size): +def _quantize_param(name, weight, weight_block_size, if_use_ue8m0_in_moe=True): assert name.endswith(".weight"), f"Expected weight parameter, got {name}" FP8_MIN = torch.finfo(torch.float8_e4m3fn).min FP8_MAX = torch.finfo(torch.float8_e4m3fn).max if weight_block_size is not None: - if should_deepgemm_weight_requant_ue8m0 and should_deepgemm_weight_requant_ue8m0( - weight_block_size=weight_block_size + if ( + should_deepgemm_weight_requant_ue8m0 + and should_deepgemm_weight_requant_ue8m0(weight_block_size=weight_block_size) + and if_use_ue8m0_in_moe ): qweight, scale = quant_weight_ue8m0(weight, weight_block_size=weight_block_size) scale = transform_scale_ue8m0(scale, mn=qweight.shape[-2]) diff --git a/miles/backends/megatron_utils/megatron_to_hf/qwen3_next.py b/miles/backends/megatron_utils/megatron_to_hf/qwen3_next.py index 91e9ba88a..6cb7a2169 100644 --- a/miles/backends/megatron_utils/megatron_to_hf/qwen3_next.py +++ b/miles/backends/megatron_utils/megatron_to_hf/qwen3_next.py @@ -62,8 +62,7 @@ def convert_qwen3_next_to_hf(args, name, param): if rest == "self_attention.linear_proj.weight": return [(f"model.layers.{layer_idx}.self_attn.o_proj.weight", param)] - elif rest == "self_attention.linear_qgkv.weight": - + elif rest == "self_attention.linear_qkv.weight": param = param.view(args.num_query_groups, -1, head_dim, args.hidden_size) q_param, k_param, v_param = torch.split( param, split_size_or_sections=[2 * value_num_per_group, 1, 1], dim=1 @@ -80,7 +79,7 @@ def convert_qwen3_next_to_hf(args, name, param): (f"model.layers.{layer_idx}.self_attn.k_proj.weight", k_param), (f"model.layers.{layer_idx}.self_attn.v_proj.weight", v_param), ] - elif rest == "self_attention.linear_qgkv.bias": + elif rest == "self_attention.linear_qkv.bias": param = param.view(args.num_query_groups, -1) q_bias, k_bias, v_bias = torch.split( param, @@ -103,7 +102,7 @@ def convert_qwen3_next_to_hf(args, name, param): ] elif rest == "mlp.linear_fc2.weight": return [(f"model.layers.{layer_idx}.mlp.down_proj.weight", param)] - elif rest == "self_attention.linear_qgkv.layer_norm_weight": + elif rest == "self_attention.linear_qkv.layer_norm_weight": return [(f"model.layers.{layer_idx}.input_layernorm.weight", param)] elif rest == "mlp.linear_fc1.layer_norm_weight": return [(f"model.layers.{layer_idx}.post_attention_layernorm.weight", param)] diff --git a/miles/backends/megatron_utils/model.py b/miles/backends/megatron_utils/model.py index c4e182797..fea58feec 100644 --- a/miles/backends/megatron_utils/model.py +++ b/miles/backends/megatron_utils/model.py @@ -6,6 +6,7 @@ from argparse import Namespace from collections.abc import Callable, Sequence from functools import partial +from pathlib import Path import torch from megatron.core import mpu @@ -21,14 +22,16 @@ from megatron.training.global_vars import get_args from megatron.training.training import get_model -from miles.utils import tracking_utils from miles.utils.memory_utils import clear_memory +from ..training_utils.ci_utils import check_grad_norm, check_kl +from ..training_utils.data import DataIterator, get_batch +from ..training_utils.log_utils import aggregate_forward_results, aggregate_train_losses, log_train_step +from ..training_utils.loss import loss_function +from ..training_utils.parallel import ParallelState from .checkpoint import load_checkpoint, save_checkpoint -from .cp_utils import slice_with_cp -from .data import DataIterator, get_batch -from .loss import loss_function from .model_provider import get_model_provider_func +from .parallel import get_packed_seq_params logger = logging.getLogger(__name__) @@ -84,9 +87,6 @@ def get_optimizer_param_scheduler(args: Namespace, optimizer: MegatronOptimizer) def setup_model_and_optimizer( args: Namespace, role: str = "actor", - no_wd_decay_cond: Callable[..., bool] | None = None, - scale_lr_cond: Callable[..., bool] | None = None, - lr_mult: float = 1.0, ) -> tuple[list[DDP], MegatronOptimizer, OptimizerParamScheduler]: """Build model(s), wrap with DDP, and construct optimizer and scheduler. @@ -119,11 +119,8 @@ def setup_model_and_optimizer( config.timers = None optimizer = get_megatron_optimizer( - config, - model, - no_wd_decay_cond, - scale_lr_cond, - lr_mult, + config=config, + model_chunks=model, use_gloo_process_groups=args.enable_gloo_process_groups, ) opt_param_scheduler = get_optimizer_param_scheduler(args, optimizer) @@ -160,6 +157,7 @@ def forward_only( model: Sequence[DDP], data_iterator: Sequence[DataIterator], num_microbatches: Sequence[int], + parallel_state: ParallelState, store_prefix: str = "", ) -> dict[str, list[torch.Tensor]]: """Run forward passes only and collect non-loss outputs (e.g., logprobs). @@ -211,13 +209,21 @@ def forward_step( # Get the batch. batch = get_batch( data_iterator, - ["tokens", "total_lengths", "response_lengths", "max_seq_lens"], + [ + "tokens", + "loss_masks", + "multimodal_train_inputs", + "total_lengths", + "response_lengths", + "max_seq_lens", + ], + parallel_state, args.data_pad_size_multiplier, args.qkv_format, ) unconcat_tokens = batch["unconcat_tokens"] tokens = batch["tokens"] - packed_seq_params = batch["packed_seq_params"] + packed_seq_params = get_packed_seq_params(batch, args) total_lengths = batch["total_lengths"] response_lengths = batch["response_lengths"] output_tensor = model( @@ -226,11 +232,14 @@ def forward_step( attention_mask=None, labels=None, packed_seq_params=packed_seq_params, + loss_mask=batch["full_loss_masks"], + **(batch["multimodal_train_inputs"] if batch["multimodal_train_inputs"] is not None else {}), ) return output_tensor, partial( f, args=args, + parallel_state=parallel_state, unconcat_tokens=unconcat_tokens, total_lengths=total_lengths, response_lengths=response_lengths, @@ -273,22 +282,9 @@ def forward_step( rollout_data = {} # Store the results on the last stage if mpu.is_pipeline_last_stage(): - keys = forward_data_store[0].keys() - for key in keys: - values = [] - for value in forward_data_store: - assert isinstance(value[key], list) - values += value[key] - - if args.use_dynamic_batch_size: - # TODO: This is ugly... Find a better way to make the data have the same order. - # TODO: move this out of the loop. - origin_values = [None] * len(values) - origin_indices = sum(data_iterator[0].micro_batch_indices, []) - for value, origin_index in zip(values, origin_indices, strict=False): - origin_values[origin_index] = value - values = origin_values - rollout_data[f"{store_prefix}{key}"] = values + aggregated = aggregate_forward_results(forward_data_store, data_iterator[0], args, store_prefix="") + for key, value in aggregated.items(): + rollout_data[f"{store_prefix}{key}"] = value return rollout_data @@ -301,6 +297,7 @@ def train_one_step( optimizer: MegatronOptimizer, opt_param_scheduler: OptimizerParamScheduler, num_microbatches: int, + parallel_state: ParallelState, ) -> tuple[dict[str, float], float]: """Execute a single pipeline-parallel training step. @@ -355,6 +352,7 @@ def forward_step(data_iterator: DataIterator, model: GPTModel, return_schedule_p data_iterator, [ "tokens", + "multimodal_train_inputs", "packed_seq_params", "total_lengths", "response_lengths", @@ -367,6 +365,7 @@ def forward_step(data_iterator: DataIterator, model: GPTModel, return_schedule_p "rollout_log_probs", "max_seq_lens", ], + parallel_state, args.data_pad_size_multiplier, args.qkv_format, ) @@ -375,36 +374,6 @@ def forward_step(data_iterator: DataIterator, model: GPTModel, return_schedule_p old_stage = os.environ["ROUTING_REPLAY_STAGE"] os.environ["ROUTING_REPLAY_STAGE"] = "replay_forward" - def build_loss_mask_for_mtp(batch: dict[str, object]) -> torch.Tensor | None: - tokens_tensor: torch.Tensor = batch["tokens"] - - mask_chunks: list[torch.Tensor] = [] - for total_len, response_len, resp_mask in zip( - batch["total_lengths"], batch["response_lengths"], batch["loss_masks"], strict=False - ): - assert ( - resp_mask.numel() == response_len - ), f"Unexpected loss mask size {resp_mask.numel()} (expected {response_len} or {total_len})." - prompt_len = total_len - response_len - full_mask = resp_mask.new_zeros(total_len) - full_mask[prompt_len:] = resp_mask - - mask_chunks.append(slice_with_cp(full_mask, 0.0)) - - flattened_mask = torch.cat(mask_chunks, dim=0) - seq_len = tokens_tensor.size(-1) - assert ( - flattened_mask.numel() <= seq_len - ), f"MTP loss mask ({flattened_mask.numel()}) exceeds token length ({seq_len})." - - # token tensor may be padded by 128, so pad loss mask to the same length - loss_mask_tensor = flattened_mask.new_zeros(seq_len) - loss_mask_tensor[: flattened_mask.numel()] = flattened_mask - return loss_mask_tensor.unsqueeze(0) - - loss_mask = None - mtp_kwargs = None - if return_schedule_plan: assert not args.enable_mtp_training, "MTP training should not be enabled when using combined 1f1b" output_tensor = model.build_schedule_plan( @@ -412,35 +381,33 @@ def build_loss_mask_for_mtp(batch: dict[str, object]) -> torch.Tensor | None: position_ids=None, attention_mask=None, labels=None, - packed_seq_params=batch["packed_seq_params"], + packed_seq_params=get_packed_seq_params(batch, args), + loss_mask=batch["full_loss_masks"], ) else: - # If enabling MTP training: trigger MTP loss inside Megatron while returning logits - # for the target model's loss. + forward_kwargs = { + "input_ids": batch["tokens"], + "position_ids": None, + "attention_mask": None, + "labels": None, + "packed_seq_params": get_packed_seq_params(batch, args), + "loss_mask": batch["full_loss_masks"], + } + if args.enable_mtp_training: - loss_mask = build_loss_mask_for_mtp(batch) - assert ( - loss_mask.shape == batch["tokens"].shape - ), f"loss_mask shape {loss_mask.shape} mismatches token shape {batch['tokens'].shape}" - mtp_kwargs = { - # We have to set labels to tokens for MTP training, to point out samples to train. - "mtp_labels": batch["tokens"], - } - - output_tensor = model( - input_ids=batch["tokens"], - position_ids=None, - attention_mask=None, - labels=None, - packed_seq_params=batch["packed_seq_params"], - loss_mask=loss_mask, - **(dict(mtp_kwargs=mtp_kwargs) if mtp_kwargs is not None else {}), - ) + forward_kwargs["mtp_kwargs"] = {"mtp_labels": batch["tokens"]} + + if batch["multimodal_train_inputs"] is not None: + forward_kwargs.update(batch["multimodal_train_inputs"]) + + output_tensor = model(**forward_kwargs) if os.environ.get("ENABLE_ROUTING_REPLAY", "0") == "1": os.environ["ROUTING_REPLAY_STAGE"] = old_stage - return output_tensor, partial(loss_function, args, batch, num_microbatches) + return output_tensor, partial( + loss_function, args, parallel_state, batch, num_microbatches, apply_megatron_loss_scaling=True + ) # Forward pass. forward_backward_func = get_forward_backward_func() @@ -467,6 +434,13 @@ def build_loss_mask_for_mtp(batch: dict[str, object]) -> torch.Tensor | None: else: valid_step = not (math.isnan(grad_norm) or math.isinf(grad_norm)) + # CI check: verify only MTP parameters have non-zero gradients when truncation happens + # This check must happen before optimizer.step() as gradients may be modified during step + if args.ci_test and args.enable_mtp_training: + from miles.backends.megatron_utils.ci_utils import check_mtp_only_grad + + check_mtp_only_grad(model, step_id) + if valid_step: # Update parameters. update_successful, grad_norm, num_zeros_in_grad = optimizer.step() @@ -481,22 +455,7 @@ def build_loss_mask_for_mtp(batch: dict[str, object]) -> torch.Tensor | None: optimizer.zero_grad() if mpu.is_pipeline_last_stage(ignore_virtual=True): - # Average loss across microbatches. - keys = losses_reduced[0]["keys"] - values = None - for x in losses_reduced: - if values is None: - values = x["values"] - else: - values += x["values"] - assert len(keys) + 1 == values.numel() - torch.distributed.all_reduce(values, group=mpu.get_data_parallel_group(with_context_parallel=True)) - - loss_reduced = {} - values = values.tolist() - num_samples_or_tokens = values[0] - for key, value in zip(keys, values[1:], strict=False): - loss_reduced[key] = value * mpu.get_context_parallel_world_size() / num_samples_or_tokens + loss_reduced = aggregate_train_losses(losses_reduced, parallel_state) return loss_reduced, grad_norm return {}, grad_norm @@ -506,6 +465,16 @@ def should_disable_forward_pre_hook(args: Namespace) -> bool: return args.use_distributed_optimizer and args.overlap_param_gather +def finalize_model_grads_with_empty_cache(*args, **kwargs): + # trigger empty cache when there are less than 10% free memory before the final reduce scatter. + # TODO: this is an ad-hoc method and we should figure out why the oom happens in the first place. + device = torch.cuda.current_device() + free, total = torch.cuda.mem_get_info(device) + if free / total < 0.1: + clear_memory() + return finalize_model_grads(*args, **kwargs) + + def train( rollout_id: int, model: Sequence[DDP], @@ -513,6 +482,7 @@ def train( opt_param_scheduler: OptimizerParamScheduler, data_iterator: Sequence[DataIterator], num_microbatches: Sequence[int], + parallel_state: ParallelState, ) -> None: """Run training over a rollout consisting of multiple steps. @@ -556,10 +526,27 @@ def train( config.param_sync_func = [model_chunk.start_param_sync for model_chunk in model] if len(model) == 1: config.param_sync_func = config.param_sync_func[0] - config.finalize_model_grads_func = finalize_model_grads + config.finalize_model_grads_func = finalize_model_grads_with_empty_cache pre_hook_enabled = False + if args.reset_optimizer_states: + if ( + mpu.get_data_parallel_rank(with_context_parallel=True) == 0 + and mpu.get_tensor_model_parallel_rank() == 0 + and mpu.get_pipeline_model_parallel_rank() == mpu.get_pipeline_model_parallel_world_size() - 1 + ): + print("Reset optimizer states") + for chained_optimizer in optimizer.chained_optimizers: + for group in chained_optimizer.optimizer.param_groups: + if "step" in group: + group["step"] = 0 + for state in chained_optimizer.optimizer.state.values(): + if "exp_avg" in state: + state["exp_avg"].zero_() + if "exp_avg_sq" in state: + state["exp_avg_sq"].zero_() + if args.manual_gc: # Disable the default garbage collector and perform the collection manually. # This is to align the timing of garbage collection across ranks. @@ -593,6 +580,7 @@ def train( optimizer, opt_param_scheduler, num_microbatches[step_id], + parallel_state, ) if step_id == 0: @@ -619,6 +607,12 @@ def train( mtp_losses = (tracker["values"] * mtp_loss_scale).item() MTPLossLoggingHelper.clean_loss_in_tracker() + # CI check: verify MTP loss is within expected bounds + if args.ci_test: + from miles.backends.megatron_utils.ci_utils import check_mtp_loss + + check_mtp_loss(mtp_losses) + # per train step log. if ( mpu.get_data_parallel_rank(with_context_parallel=True) == 0 @@ -628,52 +622,41 @@ def train( accumulated_step_id = rollout_id * num_steps_per_rollout + step_id role = getattr(model[0], "role", "actor") role_tag = "" if role == "actor" else f"{role}-" - log_dict = { - f"train/{role_tag}{key}": val.mean().item() if isinstance(val, torch.Tensor) else val - for key, val in loss_dict.items() - } - log_dict[f"train/{role_tag}grad_norm"] = grad_norm + + extra_metrics = {} if args.enable_mtp_training: - log_dict[f"train/{role_tag}mtp_loss"] = mtp_losses + extra_metrics["mtp_loss"] = mtp_losses for param_group_id, param_group in enumerate(optimizer.param_groups): - log_dict[f"train/{role_tag}lr-pg_{param_group_id}"] = opt_param_scheduler.get_lr(param_group) - - log_dict["train/step"] = accumulated_step_id - tracking_utils.log(args, log_dict, step_key="train/step") + extra_metrics[f"lr-pg_{param_group_id}"] = opt_param_scheduler.get_lr(param_group) + + log_dict = log_train_step( + args=args, + loss_dict=loss_dict, + grad_norm=grad_norm, + rollout_id=rollout_id, + step_id=step_id, + num_steps_per_rollout=num_steps_per_rollout, + role=role, + extra_metrics=extra_metrics, + should_log=True, + ) if args.ci_test and not args.ci_disable_kl_checker: - if step_id == 0 and "train/ppo_kl" in log_dict and "train/pg_clipfrac" in log_dict: - if args.multi_latent_attention: - # TODO: mla currently have non-zero kl, need further investigation - assert log_dict["train/ppo_kl"] < 1e-8, f"{log_dict=}" - else: - assert log_dict["train/ppo_kl"] == 0.0 and log_dict["train/pg_clipfrac"] == 0.0, f"{log_dict=}" - if accumulated_step_id == 0 and "train/kl_loss" in log_dict: - assert log_dict["train/kl_loss"] == 0.0, f"{log_dict=}" + check_kl(args, log_dict, step_id, accumulated_step_id) logger.info(f"{role_tag}step {accumulated_step_id}: {log_dict}") - if args.ci_save_grad_norm is not None: - ci_save_grad_norm_path = args.ci_save_grad_norm.format( - role=role, + if args.ci_test: + check_grad_norm( + args=args, + grad_norm=grad_norm, rollout_id=rollout_id, step_id=step_id, - ) - torch.save(grad_norm, ci_save_grad_norm_path) - elif args.ci_load_grad_norm is not None: - ci_load_grad_norm_path = args.ci_load_grad_norm.format( role=role, - rollout_id=rollout_id, - step_id=step_id, + rank=mpu.get_data_parallel_rank(), ) - expected_grad_norm = torch.load(ci_load_grad_norm_path) - assert math.isclose( - grad_norm, - expected_grad_norm, - rel_tol=0.01, - abs_tol=0.01, - ), f"grad norm mismatch: {grad_norm} != {expected_grad_norm}" + # Close out pre-hooks if using distributed optimizer and overlapped param gather. if pre_hook_enabled: disable_forward_pre_hook(model) @@ -707,6 +690,44 @@ def save( enable_forward_pre_hook(model) +def save_hf_model(args, rollout_id: int, model: Sequence[DDP]) -> None: + """Save Megatron model in HuggingFace format. + + Args: + model (Sequence[DDP]): Sequence of DDP-wrapped model chunks. + rollout_id (int): Rollout ID for path formatting. + """ + should_log = ( + mpu.get_data_parallel_rank(with_context_parallel=True) == 0 and mpu.get_tensor_model_parallel_rank() == 0 + ) + + try: + from megatron.bridge import AutoBridge + + from miles.utils.megatron_bridge_utils import patch_megatron_model + + path = Path(args.save_hf.format(rollout_id=rollout_id)) + + if should_log: + logger.info(f"Saving model in HuggingFace format to {path}") + + bridge = AutoBridge.from_hf_pretrained(args.hf_checkpoint, trust_remote_code=True) + + path.mkdir(parents=True, exist_ok=True) + + with patch_megatron_model(model): + bridge.save_hf_pretrained( + model, + path=path, + ) + + if should_log: + logger.info(f"Successfully saved HuggingFace model to {path}") + except Exception as e: + if should_log: + logger.error(f"Failed to save HuggingFace format: {e}") + + def initialize_model_and_optimizer( args: Namespace, role: str = "actor" ) -> tuple[list[DDP], MegatronOptimizer, OptimizerParamScheduler, int]: @@ -723,6 +744,7 @@ def initialize_model_and_optimizer( if torch.version.hip: import megatron.core.dist_checkpointing.strategies.filesystem_async as filesystem_async_module + from miles.utils.rocm_checkpoint_writer import ROCmFileSystemWriterAsync filesystem_async_module.FileSystemWriterAsync = ROCmFileSystemWriterAsync @@ -740,4 +762,6 @@ def initialize_model_and_optimizer( ) clear_memory() + opt_param_scheduler.step(increment=iteration * args.global_batch_size) + return model, optimizer, opt_param_scheduler, iteration diff --git a/miles/backends/megatron_utils/model_provider.py b/miles/backends/megatron_utils/model_provider.py index 5b7b3dd74..7834f1101 100644 --- a/miles/backends/megatron_utils/model_provider.py +++ b/miles/backends/megatron_utils/model_provider.py @@ -16,6 +16,8 @@ from megatron.core.transformer.transformer_config import TransformerConfig from megatron.training.arguments import core_transformer_config_from_args +from miles.utils.misc import load_function + # Adapt from https://github.com/volcengine/verl/blob/c3b20575d2bc815fcccd84bddb4c0401fc4b632b/verl/models/llama/megatron/layers/parallel_linear.py#L82 class LinearForLastLayer(torch.nn.Linear): @@ -53,6 +55,42 @@ def get_model_provider_func( args: argparse.Namespace, role: Literal["actor", "critic"] = "actor", ): + # Support custom model provider path (similar to --custom-rm-path for reward models) + if getattr(args, "custom_model_provider_path", None): + + def wrapped_model_provider( + pre_process: bool = True, post_process: bool = True, vp_stage: int | None = None + ) -> GPTModel: + custom_model_provider = load_function(args.custom_model_provider_path) + # Check if the custom provider supports vp_stage parameter + has_vp_stage = "vp_stage" in inspect.signature(custom_model_provider).parameters + if has_vp_stage: + model = custom_model_provider(pre_process=pre_process, post_process=post_process, vp_stage=vp_stage) + else: + model = custom_model_provider(pre_process=pre_process, post_process=post_process) + # Apply critic output layer if needed + if post_process and role == "critic": + model.output_layer = LinearForLastLayer( + input_size=model.config.hidden_size, output_size=1, config=model.config + ) + return model + + return wrapped_model_provider + + if args.megatron_to_hf_mode == "bridge": + from megatron.bridge import AutoBridge + + bridge = AutoBridge.from_hf_pretrained(args.hf_checkpoint, trust_remote_code=True) + provider = bridge.to_megatron_provider(load_weights=False) + # TODO: we should not manually set this... + provider.tensor_model_parallel_size = args.tensor_model_parallel_size + provider.pipeline_model_parallel_size = args.pipeline_model_parallel_size + provider.expert_model_parallel_size = args.expert_model_parallel_size + provider.expert_tensor_parallel_size = args.expert_tensor_parallel_size + provider.sequence_parallel = args.sequence_parallel + provider.finalize() + return provider.provide + def model_provider(pre_process: bool = True, post_process: bool = True, vp_stage: int | None = None) -> GPTModel: """Builds the model. diff --git a/miles/backends/megatron_utils/parallel.py b/miles/backends/megatron_utils/parallel.py new file mode 100644 index 000000000..e3d99fc46 --- /dev/null +++ b/miles/backends/megatron_utils/parallel.py @@ -0,0 +1,67 @@ +import logging +from argparse import Namespace +from collections.abc import Sequence + +import torch +from megatron.core import mpu +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.utils import get_model_config + +from ..training_utils.parallel import ParallelState + +logger = logging.getLogger(__name__) + + +def create_megatron_parallel_state( + model: torch.nn.Module | Sequence[torch.nn.Module] | None = None, +) -> ParallelState: + vpp_size_value = mpu.get_virtual_pipeline_model_parallel_world_size() + if vpp_size_value is None: + vpp_size = 1 + microbatch_group_size_per_vp_stage = None + elif vpp_size_value > 1: + assert model is not None + model_to_check = model[0] if isinstance(model, Sequence) else model + config = get_model_config(model_to_check) + vpp_size = vpp_size_value + microbatch_group_size_per_vp_stage = config.microbatch_group_size_per_vp_stage + else: + vpp_size = 1 + microbatch_group_size_per_vp_stage = None + + parallel_state = ParallelState( + dp_rank=mpu.get_data_parallel_rank(with_context_parallel=False), + dp_src_rank=mpu.get_data_parallel_src_rank(with_context_parallel=True), + dp_size=mpu.get_data_parallel_world_size(with_context_parallel=False), + cp_rank=mpu.get_context_parallel_rank(), + cp_size=mpu.get_context_parallel_world_size(), + dp_cp_rank=mpu.get_data_parallel_rank(with_context_parallel=True), + dp_cp_size=mpu.get_data_parallel_world_size(with_context_parallel=True), + dp_group=mpu.get_data_parallel_group(with_context_parallel=False), + dp_cp_group=mpu.get_data_parallel_group(with_context_parallel=True), + dp_cp_group_gloo=mpu.get_data_parallel_group_gloo(with_context_parallel=True), + cp_group=mpu.get_context_parallel_group(), + tp_size=mpu.get_tensor_model_parallel_world_size(), + tp_rank=mpu.get_tensor_model_parallel_rank(), + tp_group=mpu.get_tensor_model_parallel_group(), + is_pp_last_stage=mpu.is_pipeline_last_stage(), + vpp_size=vpp_size, + microbatch_group_size_per_vp_stage=microbatch_group_size_per_vp_stage, + ) + + return parallel_state + + +def get_packed_seq_params(batch: dict[str, torch.Tensor], args: Namespace) -> PackedSeqParams: + if args.qkv_format == "thd": + packed_seq_params = PackedSeqParams( + cu_seqlens_q=batch["cu_seqlens"], + cu_seqlens_kv=batch["cu_seqlens"], + max_seqlen_q=batch["max_seqlen"], + max_seqlen_kv=batch["max_seqlen"], + qkv_format="thd", + ) + batch["packed_seq_params"] = packed_seq_params + return packed_seq_params + else: + return None diff --git a/miles/backends/megatron_utils/update_weight/common.py b/miles/backends/megatron_utils/update_weight/common.py index 85fe76a1b..e958566dc 100644 --- a/miles/backends/megatron_utils/update_weight/common.py +++ b/miles/backends/megatron_utils/update_weight/common.py @@ -203,6 +203,17 @@ def _named_params_and_buffers_global( yield f"module.module.mtp.layers.{layer_idx}.transformer_layer.mlp.experts.{rest}.weight{expert_idx}", param continue + # TODO: a hacking here, need to be cleaner + duplicated = [ + "indexer.linear_weights_proj", + "indexer.linear_wk", + "indexer.linear_wq_b", + "linear_q_down_proj", + "linear_kv_down_proj", + ] + if any(dup in name for dup in duplicated): + param.parallel_mode = "duplicated" + layer_idx, rest = match.groups() layer_idx = int(layer_idx) + layer_offset diff --git a/miles/backends/megatron_utils/update_weight/hf_weight_iterator_bridge.py b/miles/backends/megatron_utils/update_weight/hf_weight_iterator_bridge.py index a88d18b36..7e0a4817e 100644 --- a/miles/backends/megatron_utils/update_weight/hf_weight_iterator_bridge.py +++ b/miles/backends/megatron_utils/update_weight/hf_weight_iterator_bridge.py @@ -13,9 +13,10 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) from megatron.bridge import AutoBridge + import miles_plugins.megatron_bridge # noqa: F401 - self._bridge = AutoBridge.from_hf_pretrained(self.args.hf_checkpoint) + self._bridge = AutoBridge.from_hf_pretrained(self.args.hf_checkpoint, trust_remote_code=True) def get_hf_weight_chunks(self, megatron_local_weights): # TODO support quantization (e.g. modify megatron-bridge to provide megatron param name) diff --git a/miles/backends/megatron_utils/update_weight/update_weight_from_distributed.py b/miles/backends/megatron_utils/update_weight/update_weight_from_distributed.py index 801074553..f9c90bb1b 100644 --- a/miles/backends/megatron_utils/update_weight/update_weight_from_distributed.py +++ b/miles/backends/megatron_utils/update_weight/update_weight_from_distributed.py @@ -80,6 +80,13 @@ def update_weights(self) -> None: if dist.get_rank() == 0: ray.get([engine.pause_generation.remote() for engine in self.rollout_engines]) ray.get([engine.flush_cache.remote() for engine in self.rollout_engines]) + # int4/fp4 pre_process + if self.quantization_config and self.quantization_config["quant_method"] in ["compressed-tensors"]: + post_process_weights( + restore_weights_before_load=True, + post_process_quantization=False, + rollout_engines=self.rollout_engines, + ) dist.barrier(group=get_gloo_group()) buffer_size = 0 @@ -111,9 +118,15 @@ def update_weights(self) -> None: if named_tensors: self._update_expert_bucket_weights_from_distributed(named_tensors, pbar=pbar) - dist.barrier(group=get_gloo_group()) if dist.get_rank() == 0: ray.get([engine.continue_generation.remote() for engine in self.rollout_engines]) + # int4/fp4 post_process + if self.quantization_config and self.quantization_config["quant_method"] in ["compressed-tensors"]: + post_process_weights( + restore_weights_before_load=False, + post_process_quantization=True, + rollout_engines=self.rollout_engines, + ) dist.barrier(group=get_gloo_group()) def _update_weight_from_distributed( @@ -297,3 +310,22 @@ def update_weights_from_distributed( handle.wait() return refs + + +def post_process_weights( + restore_weights_before_load: bool, + post_process_quantization: bool, + rollout_engines: Sequence[ActorHandle], +): + """ + Trigger post-process for int4/fp4 quantization on all rollout engines. + """ + ray.get( + [ + engine.post_process_weights.remote( + restore_weights_before_load=restore_weights_before_load, + post_process_quantization=post_process_quantization, + ) + for engine in rollout_engines + ] + ) diff --git a/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py b/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py index 527d3cfe9..1acfabba3 100644 --- a/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py +++ b/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py @@ -16,6 +16,7 @@ from .update_weight_from_distributed import ( connect_rollout_engines_from_distributed, disconnect_rollout_engines_from_distributed, + post_process_weights, update_weights_from_distributed, ) @@ -112,6 +113,12 @@ def update_weights(self) -> None: rank = dist.get_rank() if rank == 0: ray.get([engine.flush_cache.remote() for engine in self.rollout_engines]) + if self.quantization_config and self.quantization_config["quant_method"] in ["compressed-tensors"]: + post_process_weights( + restore_weights_before_load=True, + post_process_quantization=False, + rollout_engines=self.rollout_engines, + ) dist.barrier(group=get_gloo_group()) megatron_local_weights = self.weights_getter() @@ -121,6 +128,14 @@ def update_weights(self) -> None: ray.get(refs) del long_lived_tensors + # int4/fp4 post_process + if rank == 0: + if self.quantization_config and self.quantization_config["quant_method"] in ["compressed-tensors"]: + post_process_weights( + restore_weights_before_load=False, + post_process_quantization=True, + rollout_engines=self.rollout_engines, + ) dist.barrier(group=get_gloo_group()) def _send_hf_params(self, hf_named_tensors) -> tuple[list[ObjectRef], Any]: diff --git a/miles/backends/sglang_utils/arguments.py b/miles/backends/sglang_utils/arguments.py index 1311350cb..3b8697545 100644 --- a/miles/backends/sglang_utils/arguments.py +++ b/miles/backends/sglang_utils/arguments.py @@ -1,5 +1,3 @@ -import sglang -from packaging.version import parse from sglang.srt.server_args import ServerArgs from miles.utils.http_utils import _wrap_ipv6 @@ -41,7 +39,6 @@ def add_sglang_arguments(parser): skipped_args = [ "model_path", - "dtype", "trust_remote_code", "random_seed", # memory @@ -115,9 +112,6 @@ def new_add_argument_wrapper(*name_or_flags, **kwargs): def validate_args(args): - if parse(sglang.__version__) == parse("0.4.10") and getattr(args, "sglang_enable_ep_moe", False): - args.sglang_expert_parallel_size = args.rollout_num_gpus_per_engine - args.sglang_tp_size = args.rollout_num_gpus_per_engine args.sglang_dp_size = args.sglang_data_parallel_size args.sglang_pp_size = args.sglang_pipeline_parallel_size diff --git a/miles/backends/sglang_utils/sglang_engine.py b/miles/backends/sglang_utils/sglang_engine.py index 2e1afe625..f736cf97a 100644 --- a/miles/backends/sglang_utils/sglang_engine.py +++ b/miles/backends/sglang_utils/sglang_engine.py @@ -1,7 +1,10 @@ import dataclasses +import ipaddress import logging import multiprocessing +import os import time +from urllib.parse import quote import requests import sglang_router @@ -29,6 +32,24 @@ def get_base_gpu_id(args, rank): return start_index +def _to_local_gpu_id(physical_gpu_id: int) -> int: + cvd = os.environ.get("CUDA_VISIBLE_DEVICES") + if not cvd: + return physical_gpu_id # no remapping + # CUDA_VISIBLE_DEVICES can be like "4,5,6,7" + visible = [int(x) for x in cvd.split(",") if x.strip() != ""] + # In a remapped process, valid torch device indices are 0..len(visible)-1 + if physical_gpu_id in visible: + return visible.index(physical_gpu_id) + # If we're already getting local IDs, allow them + if 0 <= physical_gpu_id < len(visible): + return physical_gpu_id + raise RuntimeError( + f"GPU id {physical_gpu_id} is not valid under CUDA_VISIBLE_DEVICES={cvd}. " + f"Expected one of {visible} (physical) or 0..{len(visible)-1} (local)." + ) + + def launch_server_process(server_args: ServerArgs) -> multiprocessing.Process: from sglang.srt.entrypoints.http_server import launch_server @@ -86,33 +107,46 @@ def _wait_server_healthy(base_url, api_key, is_process_alive): class SGLangEngine(RayActor): - def __init__(self, args, rank: int, worker_type: str = "regular"): + def __init__(self, args, rank: int, worker_type: str = "regular", base_gpu_id: int | None = None): self.args = args self.rank = rank self.worker_type = worker_type + self.base_gpu_id = base_gpu_id - def init(self, dist_init_addr, port, nccl_port, host=None): + def init(self, dist_init_addr, port, nccl_port, host=None, disaggregation_bootstrap_port=None): self.router_ip = self.args.sglang_router_ip self.router_port = self.args.sglang_router_port host = host or get_host_info()[1] - # support ipv6 address - if ":" in host and not host.startswith("["): - host = f"[{host}]" + def _format_v6_uri(addr): + if not addr or addr.startswith("["): + return addr + try: + if ipaddress.ip_address(addr).version == 6: + return f"[{addr}]" + except ValueError: + pass + return addr - # dist_init_addr may be 2605:...:10163, should split port - *addr_parts, port_str = dist_init_addr.split(":") - ipv6_addr = ":".join(addr_parts) - if ":" in ipv6_addr and not ipv6_addr.startswith("["): - dist_init_addr = f"[{ipv6_addr}]:{port_str}" + host = _format_v6_uri(host) + ip_part, port_part = dist_init_addr.rsplit(":", 1) + dist_init_addr = f"{_format_v6_uri(ip_part)}:{port_part}" server_args_dict, external_engine_need_check_fields = _compute_server_args( - self.args, self.rank, dist_init_addr, nccl_port, host, port, self.worker_type + self.args, + self.rank, + dist_init_addr, + nccl_port, + host, + port, + self.worker_type, + disaggregation_bootstrap_port, + base_gpu_id=self.base_gpu_id, ) self.node_rank = server_args_dict["node_rank"] - self.server_host = server_args_dict["host"] + self.server_host = server_args_dict["host"] # with [] if ipv6 self.server_port = server_args_dict["port"] if self.args.rollout_external: @@ -157,12 +191,15 @@ def _init_normal(self, server_args_dict): f"http://{self.router_ip}:{self.router_port}/add_worker?url=http://{self.server_host}:{self.server_port}" ) else: + payload = { + "url": f"http://{self.server_host}:{self.server_port}", + "worker_type": self.worker_type, + } + if self.worker_type == "prefill": + payload["bootstrap_port"] = server_args_dict["disaggregation_bootstrap_port"] response = requests.post( f"http://{self.router_ip}:{self.router_port}/workers", - json={ - "url": f"http://{self.server_host}:{self.server_port}", - "worker_type": self.worker_type, - }, + json=payload, ) response.raise_for_status() @@ -261,13 +298,31 @@ def shutdown(self): logger.info(f"Shutdown engine {self.server_host}:{self.server_port}...") if self.node_rank == 0: worker_url = f"http://{self.server_host}:{self.server_port}" + response = None if parse(sglang_router.__version__) <= parse("0.2.1") or self.args.use_miles_router: response = requests.post( f"http://{self.router_ip}:{self.router_port}/remove_worker?url=http://{self.server_host}:{self.server_port}" ) - else: + elif parse(sglang_router.__version__) < parse("0.3.0"): + worker_url = quote(worker_url, safe="") response = requests.delete(f"http://{self.router_ip}:{self.router_port}/workers/{worker_url}") - response.raise_for_status() + else: + try: + all_workers = requests.get(f"http://{self.router_ip}:{self.router_port}/workers").json()["workers"] + for worker in all_workers: + if worker["url"] == worker_url: + worker_id = worker["id"] + response = requests.delete( + f"http://{self.router_ip}:{self.router_port}/workers/{worker_id}" + ) + break + else: + logger.warning(f"Worker {worker_url} not found in router during shutdown.") + except Exception as e: + logger.warning(f"Failed to fetch workers list or remove worker: {e}") + + if response is not None: + response.raise_for_status() kill_process_tree(self.process.pid) def get_weight_version(self): @@ -292,7 +347,7 @@ def resume_memory_occupation(self, tags: list[str] = None): ) def check_weights(self, action: str): - return self._make_request("check_weights", {"action": action}) + return self._make_request("weights_checker", {"action": action}) def init_weights_update_group(self, master_address, master_port, rank_offset, world_size, group_name, backend): return self._make_request( @@ -346,6 +401,25 @@ def continue_generation(self): response.raise_for_status() return response + def post_process_weights( + self, + restore_weights_before_load: bool = False, + post_process_quantization: bool = False, + ): + """ + Update model weights from tensor data. The HTTP server will only post meta data, and the real weights will be copied directly from GPUs. + Note: The model should be on GPUs rather than CPU for this functionality to work properly. + If you encounter issues, ensure your model is loaded on GPU devices rather than CPU. + """ + + return self._make_request( + "post_process_weights", + { + "restore_weights_before_load": restore_weights_before_load, + "post_process_quantization": post_process_quantization, + }, + ) + def start_profile( self, # The output directory @@ -380,10 +454,33 @@ def stop_profile(self): response.raise_for_status() return response + def simulate_crash(self): + if self.args.rollout_external or not getattr(self, "process", None): + logger.info( + "simulate_crash called but no local engine process exists (rollout_external=%s); skip kill", + self.args.rollout_external, + ) + return -def _compute_server_args(args, rank, dist_init_addr, nccl_port, host, port, worker_type: str = "regular"): + logger.info(f"Simulating crash on engine {self.server_host}:{self.server_port}...") + self.shutdown() + + +def _compute_server_args( + args, + rank, + dist_init_addr, + nccl_port, + host, + port, + worker_type: str = "regular", + disaggregation_bootstrap_port: int | None = None, + base_gpu_id: int | None = None, +): nnodes = max(1, args.rollout_num_gpus_per_engine // args.num_gpus_per_node) node_rank = rank % nnodes + base = base_gpu_id if base_gpu_id is not None else get_base_gpu_id(args, rank) + base = _to_local_gpu_id(base) kwargs = { "model_path": args.hf_checkpoint, "trust_remote_code": True, @@ -398,7 +495,7 @@ def _compute_server_args(args, rank, dist_init_addr, nccl_port, host, port, work "node_rank": node_rank, "dist_init_addr": dist_init_addr, "gpu_id_step": 1, - "base_gpu_id": get_base_gpu_id(args, rank), + "base_gpu_id": base, # parallel "tp_size": args.rollout_num_gpus_per_engine, "dp_size": args.sglang_dp_size, @@ -406,11 +503,17 @@ def _compute_server_args(args, rank, dist_init_addr, nccl_port, host, port, work "ep_size": args.sglang_ep_size, # always skip warmup to prevent warmup timeout. "skip_server_warmup": True, + # always enable draft weights cpu backup so that we run training without mtp weights. + "enable_draft_weights_cpu_backup": True, } if worker_type == "prefill": kwargs["disaggregation_mode"] = "prefill" kwargs["load_balance_method"] = "round_robin" + assert ( + disaggregation_bootstrap_port is not None + ), "disaggregation_bootstrap_port must be set for prefill worker" + kwargs["disaggregation_bootstrap_port"] = disaggregation_bootstrap_port elif worker_type == "decode": kwargs["disaggregation_mode"] = "decode" kwargs["prefill_round_robin_balance"] = True @@ -419,11 +522,12 @@ def _compute_server_args(args, rank, dist_init_addr, nccl_port, host, port, work kwargs["enable_return_routed_experts"] = True if args.fp16: kwargs["dtype"] = "float16" - external_engine_need_check_fields = [k for k in kwargs.keys() if k not in _EXTERNAL_ENGINE_SKIP_CHECK_FIELDS] unused_keys = set(kwargs.keys()) for attr in dataclasses.fields(ServerArgs): + if worker_type == "decode" and attr.name == "enable_hierarchical_cache": + continue if hasattr(args, f"sglang_{attr.name}") and attr.name not in kwargs: kwargs[attr.name] = getattr(args, f"sglang_{attr.name}") unused_keys.discard(attr.name) @@ -444,4 +548,6 @@ def _compute_server_args(args, rank, dist_init_addr, nccl_port, host, port, work "nccl_port", "dist_init_addr", "skip_server_warmup", + "enable_draft_weights_cpu_backup", + "mem_fraction_static", ] diff --git a/miles/backends/training_utils/__init__.py b/miles/backends/training_utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/miles/backends/training_utils/ci_utils.py b/miles/backends/training_utils/ci_utils.py new file mode 100644 index 000000000..0080afba4 --- /dev/null +++ b/miles/backends/training_utils/ci_utils.py @@ -0,0 +1,55 @@ +"""CI utilities for training backend testing.""" + +import logging +import math +from argparse import Namespace + +import torch + +logger = logging.getLogger(__name__) + + +def check_kl(args: Namespace, log_dict: dict[str, float], step_id: int, accumulated_step_id: int) -> None: + if step_id == 0 and "train/ppo_kl" in log_dict and "train/pg_clipfrac" in log_dict: + if args.multi_latent_attention: + # TODO: mla currently have non-zero kl, need further investigation + assert log_dict["train/ppo_kl"] < 1e-8, f"{log_dict=}" + else: + assert log_dict["train/ppo_kl"] == 0.0 and log_dict["train/pg_clipfrac"] == 0.0, f"{log_dict=}" + if accumulated_step_id == 0 and "train/kl_loss" in log_dict and not args.use_rollout_routing_replay: + assert log_dict["train/kl_loss"] == 0.0, f"{log_dict=}" + + +def check_grad_norm( + args: Namespace, + grad_norm: float, + rollout_id: int, + step_id: int, + role: str = "actor", + rank: int = 0, +) -> None: + + if rank != 0: + return + + if args.ci_save_grad_norm is not None: + ci_save_grad_norm_path = args.ci_save_grad_norm.format( + role=role, + rollout_id=rollout_id, + step_id=step_id, + ) + torch.save(grad_norm, ci_save_grad_norm_path) + + elif args.ci_load_grad_norm is not None: + ci_load_grad_norm_path = args.ci_load_grad_norm.format( + role=role, + rollout_id=rollout_id, + step_id=step_id, + ) + expected_grad_norm = torch.load(ci_load_grad_norm_path, weights_only=False) + assert math.isclose( + grad_norm, + expected_grad_norm, + rel_tol=0.03, + abs_tol=0.03, + ), f"grad norm mismatch: {grad_norm} != {expected_grad_norm}" diff --git a/miles/backends/megatron_utils/cp_utils.py b/miles/backends/training_utils/cp_utils.py similarity index 90% rename from miles/backends/megatron_utils/cp_utils.py rename to miles/backends/training_utils/cp_utils.py index 2e795d3d3..7d3f4b3e1 100644 --- a/miles/backends/megatron_utils/cp_utils.py +++ b/miles/backends/training_utils/cp_utils.py @@ -3,20 +3,22 @@ import torch import torch.distributed as dist import torch.nn.functional as F -from megatron.core import mpu + +from .parallel import ParallelState def get_logits_and_tokens_offset_with_cp( total_length: int, response_length: int, + parallel_state: ParallelState, qkv_format: str = "thd", max_seq_len: int | None = None, ): """ All offsets start from the begining of the prompt. """ - cp_rank = mpu.get_context_parallel_rank() - cp_size = mpu.get_context_parallel_world_size() + cp_rank = parallel_state.cp_rank + cp_size = parallel_state.cp_size assert cp_size > 1 prompt_length = total_length - response_length @@ -54,6 +56,7 @@ def get_sum_of_sample_mean( total_lengths: list[int], response_lengths: list[int], loss_masks: list[torch.Tensor], + parallel_state: ParallelState, calculate_per_token_loss: bool = False, qkv_format: str = "thd", max_seq_lens: list[int] | None = None, @@ -61,7 +64,7 @@ def get_sum_of_sample_mean( """ Calculate correct sample mean for CP """ - cp_size = mpu.get_context_parallel_world_size() + cp_size = parallel_state.cp_size if cp_size == 1: def sum_of_sample_mean(x: torch.Tensor) -> torch.Tensor: @@ -89,7 +92,7 @@ def sum_of_token(x: torch.Tensor) -> torch.Tensor: max_seq_len = max_seq_lens[i] if max_seq_lens is not None else None prompt_length = total_length - response_length _, _, _, tokens_offset = get_logits_and_tokens_offset_with_cp( - total_length, response_length, qkv_format, max_seq_len + total_length, response_length, parallel_state, qkv_format, max_seq_len ) loss_mask_0 = loss_mask[tokens_offset[0][0] - prompt_length : tokens_offset[0][1] - prompt_length] loss_mask_1 = loss_mask[tokens_offset[1][0] - prompt_length : tokens_offset[1][1] - prompt_length] @@ -119,18 +122,20 @@ def sum_of_token(x: torch.Tensor) -> torch.Tensor: return sum_of_sample_mean if not calculate_per_token_loss else sum_of_token -def all_gather_with_cp(tensor: torch.Tensor, total_length: int, response_length: int) -> torch.Tensor: +def all_gather_with_cp( + tensor: torch.Tensor, total_length: int, response_length: int, parallel_state: ParallelState +) -> torch.Tensor: """ Gather tensors across all ranks in the context parallel group. The first dimension of the output tensor will be the `response_length`. """ - cp_group = mpu.get_context_parallel_group() - cp_size = mpu.get_context_parallel_world_size() + cp_group = parallel_state.cp_group + cp_size = parallel_state.cp_size if cp_size == 1: return tensor - _, _, logits_offset, _ = get_logits_and_tokens_offset_with_cp(total_length, response_length) + _, _, logits_offset, _ = get_logits_and_tokens_offset_with_cp(total_length, response_length, parallel_state) prompt_length = total_length - response_length @@ -174,11 +179,12 @@ def zero(len: int) -> torch.Tensor: def slice_with_cp( tokens: torch.Tensor, pad_value: tuple[int, float, Callable], + parallel_state: ParallelState, qkv_format: str = "thd", max_seq_len: int | None = None, ) -> torch.Tensor: - cp_rank = mpu.get_context_parallel_rank() - cp_size = mpu.get_context_parallel_world_size() + cp_rank = parallel_state.cp_rank + cp_size = parallel_state.cp_size if qkv_format == "bshd": assert max_seq_len is not None @@ -219,19 +225,20 @@ def slice_log_prob_with_cp( log_prob: list[float] | torch.Tensor, total_length: int, response_length: int, + parallel_state: ParallelState, qkv_format: str = "thd", max_token_len: int | None = None, ) -> list[float] | torch.Tensor: assert len(log_prob) == response_length - cp_size = mpu.get_context_parallel_world_size() + cp_size = parallel_state.cp_size if cp_size == 1: return log_prob prompt_length = total_length - response_length _, _, logits_offset, _ = get_logits_and_tokens_offset_with_cp( - total_length, response_length, qkv_format, max_token_len + total_length, response_length, parallel_state, qkv_format, max_token_len ) chunk_1 = log_prob[logits_offset[0][0] - (prompt_length - 1) : logits_offset[0][1] - (prompt_length - 1)] diff --git a/miles/backends/megatron_utils/data.py b/miles/backends/training_utils/data.py similarity index 50% rename from miles/backends/megatron_utils/data.py rename to miles/backends/training_utils/data.py index f94d1b7e0..67bb30108 100644 --- a/miles/backends/megatron_utils/data.py +++ b/miles/backends/training_utils/data.py @@ -2,29 +2,95 @@ from argparse import Namespace from collections.abc import Sequence -import numpy as np import torch import torch.distributed as dist import torch.nn.functional as F -from megatron.core import mpu -from megatron.core.packed_seq_params import PackedSeqParams -from miles.utils import train_metric_utils from miles.utils.data import get_minimum_num_micro_batch_size -from miles.utils.flops_utils import calculate_fwd_flops -from miles.utils.metric_utils import compute_pass_rate, compute_rollout_step from miles.utils.seqlen_balancing import get_seqlen_balanced_partitions from miles.utils.types import RolloutBatch -from ...utils import tracking_utils -from .cp_utils import get_sum_of_sample_mean, slice_with_cp +from ...utils.data import process_rollout_data +from ...utils.ray_utils import Box +from .cp_utils import slice_log_prob_with_cp, slice_with_cp +from .parallel import ParallelState logger = logging.getLogger(__name__) +def get_rollout_data(args: Namespace, rollout_data_ref: Box, parallel_state: ParallelState) -> RolloutBatch: + # Fetch data through ray on CPU, not sure if this will be performance bottleneck. + # Both first pp stage and the last pp stage will receive the data. + rollout_data = process_rollout_data( + args, + rollout_data_ref, + parallel_state.dp_rank, + parallel_state.dp_size, + ) + # move tokens to GPU in advance + rollout_data["tokens"] = [ + torch.tensor(t, dtype=torch.long, device=torch.cuda.current_device()) for t in rollout_data["tokens"] + ] + rollout_data["loss_masks"] = [ + torch.tensor(t, dtype=torch.int, device=torch.cuda.current_device()) for t in rollout_data["loss_masks"] + ] + if "multimodal_train_inputs" in rollout_data: + # Move multimodal training tensors to GPU in advance + rollout_data["multimodal_train_inputs"] = [ + ( + {key: tensor.to(device=torch.cuda.current_device()) for key, tensor in mm_dict.items()} + if mm_dict is not None + else None + ) + for mm_dict in rollout_data["multimodal_train_inputs"] + ] + + if args.qkv_format == "bshd": + # TODO: micro-batch wise dynamic, possibly move to @data.py:get_data_iterator + max_seq_len = max(rollout_data["total_lengths"]) + + # pad to reduce memory fragmentation and maybe make the computation faster + pad_size = parallel_state.tp_size * args.data_pad_size_multiplier + max_seq_len = (max_seq_len + pad_size - 1) // pad_size * pad_size + + rollout_data["max_seq_lens"] = [max_seq_len] * len(rollout_data["tokens"]) + + if "rollout_log_probs" in rollout_data: + rollout_data["rollout_log_probs"] = [ + torch.tensor( + slice_log_prob_with_cp( + log_prob, + total_length, + response_length, + parallel_state, + args.qkv_format, + rollout_data["max_seq_lens"][i] if args.qkv_format == "bshd" else None, + ), + device=torch.cuda.current_device(), + dtype=torch.float32, + ) + for i, (log_prob, total_length, response_length) in enumerate( + zip( + rollout_data["rollout_log_probs"], + rollout_data["total_lengths"], + rollout_data["response_lengths"], + strict=False, + ) + ) + ] + if "rollout_routed_experts" in rollout_data: + rollout_data["rollout_routed_experts"] = [torch.from_numpy(r) for r in rollout_data["rollout_routed_experts"]] + return rollout_data + + def get_batch( - data_iterator: "DataIterator", keys: Sequence[str], pad_multiplier: int = 128, qkv_format: str = "thd" -) -> dict[str, torch.Tensor | PackedSeqParams | list[torch.Tensor] | None]: + data_iterator: "DataIterator", + keys: Sequence[str], + parallel_state: ParallelState, + pad_multiplier: int = 128, + qkv_format: str = "thd", + get_position_ids: bool = False, +) -> dict[str, torch.Tensor | list[torch.Tensor] | None]: """ Generate a CP-ready micro-batch with packed sequence parameters. @@ -49,25 +115,27 @@ def get_batch( assert "tokens" in keys batch = data_iterator.get_next(keys) - packed_seq_params = None + if "dynamic_global_batch_size" in data_iterator.rollout_data: + batch["dynamic_global_batch_size"] = data_iterator.rollout_data["dynamic_global_batch_size"] + tokens = batch["tokens"] # use 0 as the pad token id should be fine? pad_token_id = 0 - pad_size = mpu.get_tensor_model_parallel_world_size() * pad_multiplier + pad_size = parallel_state.dp_size * pad_multiplier # for cp, we need all tokens to calculate logprob batch["unconcat_tokens"] = tokens - cp_size = mpu.get_context_parallel_world_size() + cp_size = parallel_state.cp_size if qkv_format == "bshd": max_seqlen = batch["max_seq_lens"][0] assert max([t.size(0) for t in tokens]) <= max_seqlen - tokens = [slice_with_cp(t, pad_token_id, qkv_format, max_seqlen) for t in tokens] + tokens = [slice_with_cp(t, pad_token_id, parallel_state, qkv_format, max_seqlen) for t in tokens] tokens = torch.stack(tokens) elif qkv_format == "thd": - tokens = [slice_with_cp(t, pad_token_id, qkv_format) for t in tokens] + tokens = [slice_with_cp(t, pad_token_id, parallel_state, qkv_format) for t in tokens] cu_seqlens = [0] for t in tokens: @@ -85,68 +153,74 @@ def get_batch( cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int).cuda() * cp_size max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() - packed_seq_params = PackedSeqParams( - cu_seqlens_q=cu_seqlens, - cu_seqlens_kv=cu_seqlens, - max_seqlen_q=max_seqlen, - max_seqlen_kv=max_seqlen, - qkv_format="thd", - ) - tokens = tokens.unsqueeze(0) + + batch["cu_seqlens"] = cu_seqlens + batch["max_seqlen"] = max_seqlen else: raise ValueError(f"Unsupported qkv_format: {qkv_format}") batch["tokens"] = tokens - batch["packed_seq_params"] = packed_seq_params - return batch - -def gather_log_data( - metric_name: str, - args: Namespace, - rollout_id: int, - log_dict: dict[str, float], -) -> dict[str, float] | None: - """ - Gather per-rank metrics, reduce by mean on the DP source rank, and log. - - Expects `log_dict` to contain plain scalars. The DP source rank prints and - optionally logs to WandB/TensorBoard with a step derived from `rollout_id` and - batch sizes. Returns the reduced dict on the DP source rank; returns None on others. - """ - - if mpu.get_data_parallel_rank(with_context_parallel=True) == 0: - dp_size = mpu.get_data_parallel_world_size(with_context_parallel=True) - - gathered_log_dict = [None] * dp_size - # Not sure if this will be a performance bottleneck. - dist.gather_object( - log_dict, - gathered_log_dict, - dst=mpu.get_data_parallel_src_rank(with_context_parallel=True), - group=mpu.get_data_parallel_group_gloo(with_context_parallel=True), - ) - - reduced_log_dict = { - f"{metric_name}/{key}": sum([d[key] for d in gathered_log_dict]) / dp_size for key in log_dict - } - logger.info(f"{metric_name} {rollout_id}: {reduced_log_dict}") + if get_position_ids: + position_ids_list = [] + for t in batch["unconcat_tokens"]: + seq_len = t.size(0) + pos_ids = torch.arange(seq_len, device=t.device, dtype=torch.long) + position_ids_list.append(pos_ids) + + if qkv_format == "bshd": + position_ids = [slice_with_cp(p, 0, parallel_state, qkv_format, max_seqlen) for p in position_ids_list] + position_ids = torch.stack(position_ids) + elif qkv_format == "thd": + position_ids = [slice_with_cp(p, 0, parallel_state, qkv_format) for p in position_ids_list] + position_ids = torch.cat(position_ids) + if pad != 0: + position_ids = F.pad(position_ids, (0, pad), value=0) + position_ids = position_ids.unsqueeze(0) + + batch["position_ids"] = position_ids + + # loss masks + loss_masks = [] + for loss_mask, total_length, response_length in zip( + batch["loss_masks"], + batch["total_lengths"], + batch["response_lengths"], + strict=True, + ): + prompt_length = total_length - response_length + loss_mask = F.pad(loss_mask, (prompt_length - 1, 1), value=0) + loss_mask = slice_with_cp(loss_mask, 0, parallel_state, qkv_format, max_seqlen) + loss_masks.append(loss_mask) - # Calculate step once to avoid duplication - step = compute_rollout_step(args, rollout_id) - reduced_log_dict["rollout/step"] = step - tracking_utils.log(args, reduced_log_dict, step_key="rollout/step") + if qkv_format == "bshd": + loss_masks = torch.stack(loss_masks) + elif qkv_format == "thd": + loss_masks = torch.cat(loss_masks) + loss_masks = F.pad(loss_masks, (0, pad), value=0).unsqueeze(0) + + assert loss_masks.shape == tokens.shape, f"loss_masks.shape: {loss_masks.shape}, tokens.shape: {tokens.shape}" + batch["full_loss_masks"] = loss_masks + + # Process multimodal training tensors if present + multimodal_train_inputs = batch.get("multimodal_train_inputs", None) + if multimodal_train_inputs is not None: + multimodal_data = {} # key -> concatenated tensor + multimodal_num_items = {} # key -> list of item counts per sequence + for mm_input_dict in multimodal_train_inputs: + if mm_input_dict is not None: + for key, mm_tensor in mm_input_dict.items(): + if key not in multimodal_data: + multimodal_data[key] = mm_tensor + multimodal_num_items[key] = [mm_tensor.size(0)] + else: + multimodal_data[key] = torch.cat([multimodal_data[key], mm_tensor], dim=0) + multimodal_num_items[key].append(mm_tensor.size(0)) + batch["multimodal_train_inputs"] = multimodal_data + batch["multimodal_num_items"] = multimodal_num_items - return reduced_log_dict - else: - dist.gather_object( - log_dict, - None, - dst=mpu.get_data_parallel_src_rank(with_context_parallel=True), - group=mpu.get_data_parallel_group_gloo(with_context_parallel=True), - ) - return None + return batch class DataIterator: @@ -216,6 +290,7 @@ def reset(self) -> "DataIterator": def get_data_iterator( args: Namespace, model: torch.nn.Module | Sequence[torch.nn.Module], + parallel_state: ParallelState, rollout_data: RolloutBatch, ) -> tuple[list[DataIterator], list[int]]: """ @@ -232,22 +307,24 @@ def get_data_iterator( - `data_iterators`: list of `DataIterator`, one per VPP stage (size 1 if VPP disabled) - `num_microbatches`: list[int], one per local step in the rollout (length = steps) """ - dp_size = mpu.get_data_parallel_world_size(with_context_parallel=False) - dp_group = mpu.get_data_parallel_group() - vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size() - if vpp_size is None: - vpp_size = 1 - if vpp_size > 1: - from megatron.core.utils import get_model_config - - config = get_model_config(model[0]) - microbatch_group_size_per_vp_stage = config.microbatch_group_size_per_vp_stage - cp_size = mpu.get_context_parallel_world_size() + dp_size = parallel_state.dp_size + dp_group = parallel_state.dp_group + vpp_size = parallel_state.vpp_size + microbatch_group_size_per_vp_stage = parallel_state.microbatch_group_size_per_vp_stage + + cp_size = parallel_state.cp_size num_local_samples = len(rollout_data["total_lengths"]) - num_local_gbs = args.global_batch_size // dp_size + global_batch_size = rollout_data.get("dynamic_global_batch_size", args.global_batch_size) + num_local_gbs = global_batch_size // dp_size num_steps_per_rollout = num_local_samples // num_local_gbs + if global_batch_size != args.global_batch_size: + logger.info( + f"Using dynamic global_batch_size={global_batch_size} (original={args.global_batch_size}), " + f"num_local_samples={num_local_samples}, num_steps_per_rollout={num_steps_per_rollout}" + ) + def _generate_data_iterator(rollout_data, micro_batch_size, micro_batch_indices=None): data_iterator = [] for _ in range(vpp_size): @@ -304,155 +381,6 @@ def _generate_data_iterator(rollout_data, micro_batch_size, micro_batch_indices= ) -def log_rollout_data(rollout_id: int, args: Namespace, rollout_data: RolloutBatch) -> None: - """ - Summarize rollout fields and log reduced metrics on PP last stage, TP rank 0. - - - Tensor-valued lists are concatenated and averaged. For token-level metrics - like log-probs/returns/advantages/values, computes a CP-correct sample mean - using `loss_masks` and total/response lengths. - - Non-tensor lists are averaged elementwise. - - Scalars are converted to Python numbers. - """ - if mpu.get_tensor_model_parallel_rank() == 0 and mpu.is_pipeline_last_stage(): - cp_size = mpu.get_context_parallel_world_size() - log_dict = {} - response_lengths = rollout_data["response_lengths"] - loss_masks = rollout_data["loss_masks"] - total_lengths = rollout_data["total_lengths"] - max_seq_lens = rollout_data.get("max_seq_lens", None) - - for key, val in rollout_data.items(): - if key in [ - "tokens", - "loss_masks", - "sample_indices", - "rollout_routed_experts", - "max_seq_lens", - ]: - continue - # Upload per sample mean for each rollout value - # There are the following assumptions: - # - Each dp rank has the same number of samples - if isinstance(val, (list, tuple)): - if isinstance(val[0], torch.Tensor): - # NOTE: Here we have to do the clone().detach(), otherwise the tensor will be - # modified in place and will cause problem for the next rollout. - val = torch.cat(val).clone().detach() - if key in ["log_probs", "ref_log_probs", "rollout_log_probs", "returns", "advantages", "values"]: - sum_of_sample_mean = get_sum_of_sample_mean( - total_lengths, - response_lengths, - loss_masks, - qkv_format=args.qkv_format, - max_seq_lens=max_seq_lens, - ) - val = cp_size * sum_of_sample_mean(val) / len(loss_masks) - else: - val = val.mean() * cp_size - else: - val = sum(val) / len(val) - elif isinstance(val, torch.Tensor): - val = val.float().mean() - else: - raise ValueError(f"Unsupported type: {type(val)} for key: {key}") - log_dict[key] = val.item() if isinstance(val, torch.Tensor) else val - - reduced_log_dict = gather_log_data("rollout", args, rollout_id, log_dict) - if args.ci_test and reduced_log_dict is not None: - if ( - rollout_id == 0 - and "rollout/log_probs" in reduced_log_dict - and "rollout/ref_log_probs" in reduced_log_dict - ): - assert reduced_log_dict["rollout/log_probs"] == reduced_log_dict["rollout/ref_log_probs"] - if "rollout/log_probs" in reduced_log_dict: - assert -0.5 < reduced_log_dict["rollout/log_probs"] < 0 - if "rollout/entropy" in reduced_log_dict: - assert 0 < reduced_log_dict["rollout/entropy"] < 0.5 - - if args.log_multi_turn: - log_multi_turn_data(rollout_id, args, rollout_data) - if args.log_passrate: - log_passrate(rollout_id, args, rollout_data) - - -def log_multi_turn_data(rollout_id: int, args: Namespace, rollout_data: RolloutBatch) -> None: - """ - Log multi-turn auxiliary metrics such as raw/observed response lengths and rounds. - - Operates only on PP last stage and TP rank 0. Uses GPU tensors when available - to compute statistics without host transfers. - """ - if mpu.get_tensor_model_parallel_rank() == 0 and mpu.is_pipeline_last_stage(): - log_dict = {} - for key, val in rollout_data.items(): - if key == "loss_masks": - if val: # Check if val is not empty - device = val[0].device # Get device from first tensor - - # Vectorized length calculation using torch - raw_response_lengths = torch.tensor([v.shape[0] for v in val], dtype=torch.float32, device=device) - log_dict["raw_response_length/response_length_mean"] = raw_response_lengths.mean().item() - log_dict["raw_response_length/response_length_max"] = raw_response_lengths.max().item() - log_dict["raw_response_length/response_length_min"] = raw_response_lengths.min().item() - log_dict["raw_response_length/response_length_clip_ratio"] = ( - (raw_response_lengths >= args.rollout_max_response_len).float().mean().item() - ) - - # Vectorized sum calculation using torch - stay on GPU - wo_obs_response_lengths = torch.tensor( - [v.sum().item() for v in val], dtype=torch.float32, device=device - ) - log_dict["wo_obs_response_length/response_length_mean"] = wo_obs_response_lengths.mean().item() - log_dict["wo_obs_response_length/response_length_max"] = wo_obs_response_lengths.max().item() - log_dict["wo_obs_response_length/response_length_min"] = wo_obs_response_lengths.min().item() - if key == "round_number": - # Use numpy for vectorized round number statistics - round_number_array = np.array(val) - log_dict["multi_turn_metric/round_number_mean"] = np.mean(round_number_array) - log_dict["multi_turn_metric/round_number_max"] = np.max(round_number_array) - log_dict["multi_turn_metric/round_number_min"] = np.min(round_number_array) - gather_log_data("multi_turn", args, rollout_id, log_dict) - - -def log_passrate(rollout_id: int, args: Namespace, rollout_data: RolloutBatch) -> None: - """ - Compute pass@k metrics from `raw_reward` groups and log the results. - - `raw_reward` is reshaped to `[group_number, group_size]`, then pass@k is - estimated per problem and averaged. - """ - if mpu.get_tensor_model_parallel_rank() == 0 and mpu.is_pipeline_last_stage(): - log_dict = {} - for key, val in rollout_data.items(): - if key != "raw_reward": - continue - - log_dict |= compute_pass_rate( - flat_rewards=val, - group_size=args.n_samples_per_prompt, - num_groups=args.rollout_batch_size, - ) - - gather_log_data("passrate", args, rollout_id, log_dict) - - -def log_perf_data(rollout_id: int, args: Namespace) -> None: - train_metric_utils.log_perf_data_raw( - rollout_id=rollout_id, - args=args, - is_primary_rank=( - mpu.get_tensor_model_parallel_rank() == 0 - and mpu.is_pipeline_last_stage() - and mpu.get_data_parallel_rank(with_context_parallel=True) == 0 - ), - compute_total_fwd_flops=lambda seq_lens: calculate_fwd_flops(seqlens=seq_lens, args=args) - / dist.get_world_size() - / 1e12, - ) - - def sync_actor_critic_data( args: Namespace, rollout_data: RolloutBatch | None = None, diff --git a/miles/backends/training_utils/log_utils.py b/miles/backends/training_utils/log_utils.py new file mode 100644 index 000000000..1a2f17602 --- /dev/null +++ b/miles/backends/training_utils/log_utils.py @@ -0,0 +1,408 @@ +import logging +from argparse import Namespace +from math import isclose + +import numpy as np +import torch +import torch.distributed as dist + +from miles.utils import train_metric_utils +from miles.utils.flops_utils import calculate_fwd_flops +from miles.utils.metric_utils import compute_pass_rate, compute_rollout_step +from miles.utils.types import RolloutBatch + +from ...utils import tracking_utils +from .cp_utils import get_sum_of_sample_mean +from .data import DataIterator +from .parallel import ParallelState + +logger = logging.getLogger(__name__) + + +def gather_log_data( + metric_name: str, + args: Namespace, + rollout_id: int, + log_dict: dict[str, float], + parallel_state: ParallelState, +) -> dict[str, float] | None: + """ + Gather per-rank metrics, reduce by mean on the DP source rank, and log. + + Expects `log_dict` to contain plain scalars. The DP source rank prints and + optionally logs to WandB/TensorBoard with a step derived from `rollout_id` and + batch sizes. Returns the reduced dict on the DP source rank; returns None on others. + """ + + if parallel_state.dp_cp_rank == 0: + dp_size = parallel_state.dp_cp_size + + gathered_log_dict = [None] * dp_size + # Not sure if this will be a performance bottleneck. + dist.gather_object( + log_dict, + gathered_log_dict, + dst=parallel_state.dp_src_rank, + group=parallel_state.dp_cp_group_gloo, + ) + + reduced_log_dict = { + f"{metric_name}/{key}": sum([d[key] for d in gathered_log_dict]) / dp_size for key in log_dict + } + logger.info(f"{metric_name} {rollout_id}: {reduced_log_dict}") + + # Calculate step once to avoid duplication + step = compute_rollout_step(args, rollout_id) + reduced_log_dict["rollout/step"] = step + tracking_utils.log(args, reduced_log_dict, step_key="rollout/step") + + return reduced_log_dict + else: + dist.gather_object( + log_dict, + None, + dst=parallel_state.dp_src_rank, + group=parallel_state.dp_cp_group_gloo, + ) + return None + + +def aggregate_forward_results( + forward_data_store: list[dict[str, list]], + data_iterator: DataIterator, + args: Namespace, + store_prefix: str = "", +) -> dict[str, list]: + rollout_data = {} + if not forward_data_store: + return rollout_data + + keys = forward_data_store[0].keys() + for key in keys: + values = [] + for batch_result in forward_data_store: + assert isinstance(batch_result[key], list), f"Expected list for key {key}, got {type(batch_result[key])}" + values += batch_result[key] + + # Handle dynamic batch size: restore original order + if args.use_dynamic_batch_size and hasattr(data_iterator, "micro_batch_indices"): + origin_values = [None] * len(values) + origin_indices = sum(data_iterator.micro_batch_indices, []) + for value, origin_index in zip(values, origin_indices, strict=False): + origin_values[origin_index] = value + values = origin_values + + rollout_data[key] = values + + return rollout_data + + +def log_rollout_data( + rollout_id: int, args: Namespace, rollout_data: RolloutBatch, parallel_state: ParallelState +) -> None: + """ + Summarize rollout fields and log reduced metrics on PP last stage, TP rank 0. + + - Tensor-valued lists are concatenated and averaged. For token-level metrics + like log-probs/returns/advantages/values, computes a CP-correct sample mean + using `loss_masks` and total/response lengths. + - Non-tensor lists are averaged elementwise. + - Scalars are converted to Python numbers. + """ + if parallel_state.tp_rank == 0 and parallel_state.is_pp_last_stage: + cp_size = parallel_state.cp_size + log_dict = {} + response_lengths = rollout_data["response_lengths"] + loss_masks = rollout_data["loss_masks"] + total_lengths = rollout_data["total_lengths"] + max_seq_lens = rollout_data.get("max_seq_lens", None) + + for key, val in rollout_data.items(): + if key in [ + "tokens", + "multimodal_train_inputs", + "loss_masks", + "sample_indices", + "rollout_routed_experts", + "max_seq_lens", + "dynamic_global_batch_size", + ]: + continue + # Upload per sample mean for each rollout value + # There are the following assumptions: + # - Each dp rank has the same number of samples + if isinstance(val, (list, tuple)): + if isinstance(val[0], torch.Tensor): + # NOTE: Here we have to do the clone().detach(), otherwise the tensor will be + # modified in place and will cause problem for the next rollout. + val = torch.cat(val).clone().detach() + if key in ["log_probs", "ref_log_probs", "rollout_log_probs", "returns", "advantages", "values"]: + sum_of_sample_mean = get_sum_of_sample_mean( + total_lengths, + response_lengths, + loss_masks, + parallel_state, + qkv_format=args.qkv_format, + max_seq_lens=max_seq_lens, + ) + val = cp_size * sum_of_sample_mean(val) / len(loss_masks) + else: + val = val.mean() * cp_size + else: + val = sum(val) / len(val) + elif isinstance(val, torch.Tensor): + val = val.float().mean() + else: + raise ValueError(f"Unsupported type: {type(val)} for key: {key}") + log_dict[key] = val.item() if isinstance(val, torch.Tensor) else val + + reduced_log_dict = gather_log_data("rollout", args, rollout_id, log_dict, parallel_state) + if args.ci_test and reduced_log_dict is not None: + if ( + rollout_id == 0 + and "rollout/log_probs" in reduced_log_dict + and "rollout/ref_log_probs" in reduced_log_dict + ): + assert reduced_log_dict["rollout/log_probs"] == reduced_log_dict["rollout/ref_log_probs"] + if "rollout/log_probs" in reduced_log_dict and "rollout/rollout_log_probs" in reduced_log_dict: + assert isclose( + reduced_log_dict["rollout/log_probs"], reduced_log_dict["rollout/rollout_log_probs"], abs_tol=0.03 + ) + if "rollout/entropy" in reduced_log_dict: + assert 0 < reduced_log_dict["rollout/entropy"] < 0.7 + + if args.log_multi_turn: + log_multi_turn_data(rollout_id, args, rollout_data, parallel_state) + if args.log_passrate: + log_passrate(rollout_id, args, rollout_data) + + if args.log_correct_samples: + if parallel_state.tp_rank == 0 and parallel_state.is_pp_last_stage: + cp_size = parallel_state.cp_size + log_dict = {} + response_lengths = rollout_data["response_lengths"] + loss_masks = rollout_data["loss_masks"] + total_lengths = rollout_data["total_lengths"] + + def quantile(total_value, n_quantiles, data) -> dict: + import math + + assert n_quantiles > 1, f"n_quantiles({n_quantiles}) must be greater than 1." + + quantiles = [((i + 1) / n_quantiles) for i in range(n_quantiles)] + cut_points = [total_value * q for q in quantiles] + cut_points[-1] = total_value + + count = [0] * n_quantiles + for d in data: + for i, point in enumerate(cut_points): + if d <= point: + count[i] += 1 + break + + total = sum(count) + 1e-9 + percentile = [c / total for c in count] + + percentile = {f"p{min(math.ceil(q*100),100)}": p for q, p in zip(quantiles, percentile, strict=True)} + return percentile + + raw_rewards = rollout_data["raw_reward"] + # Additional metrics for correct cases are calculated separately below. + correct_response_lengths = [] + correct_total_lengths = [] + correct_loss_masks = [] + correct_entropy = [] + for i, raw_reward in enumerate(raw_rewards): + if raw_reward == 1: + correct_response_lengths.append(response_lengths[i]) + correct_total_lengths.append(total_lengths[i]) + correct_loss_masks.append(loss_masks[i]) + correct_entropy.append(-rollout_data["log_probs"][i]) + num_correct_responses = len(correct_total_lengths) + rollout_data["correct_response_lengths"] = correct_response_lengths + correct_response_length_percentile = quantile( + args.rollout_max_response_len, 4, rollout_data["correct_response_lengths"] + ) + for p, val in correct_response_length_percentile.items(): + rollout_data[f"correct_length/{p}"] = [val] * num_correct_responses + if len(correct_entropy) > 0: + sum_of_sample_mean = get_sum_of_sample_mean( + correct_total_lengths, correct_response_lengths, correct_loss_masks, parallel_state + ) + correct_entropy = sum_of_sample_mean(torch.cat(correct_entropy, dim=0)) + rollout_data["correct_entropy"] = [correct_entropy.item()] * num_correct_responses + else: + rollout_data["correct_entropy"] = [0] * num_correct_responses + + +def log_multi_turn_data( + rollout_id: int, args: Namespace, rollout_data: RolloutBatch, parallel_state: ParallelState +) -> None: + """ + Log multi-turn auxiliary metrics such as raw/observed response lengths and rounds. + + Operates only on PP last stage and TP rank 0. Uses GPU tensors when available + to compute statistics without host transfers. + """ + if parallel_state.tp_rank == 0 and parallel_state.is_pp_last_stage: + log_dict = {} + for key, val in rollout_data.items(): + if key == "loss_masks": + if val: # Check if val is not empty + device = val[0].device # Get device from first tensor + + # Vectorized length calculation using torch + raw_response_lengths = torch.tensor([v.shape[0] for v in val], dtype=torch.float32, device=device) + log_dict["raw_response_length/response_length_mean"] = raw_response_lengths.mean().item() + log_dict["raw_response_length/response_length_max"] = raw_response_lengths.max().item() + log_dict["raw_response_length/response_length_min"] = raw_response_lengths.min().item() + log_dict["raw_response_length/response_length_clip_ratio"] = ( + (raw_response_lengths >= args.rollout_max_response_len).float().mean().item() + ) + + # Vectorized sum calculation using torch - stay on GPU + wo_obs_response_lengths = torch.tensor( + [v.sum().item() for v in val], dtype=torch.float32, device=device + ) + log_dict["wo_obs_response_length/response_length_mean"] = wo_obs_response_lengths.mean().item() + log_dict["wo_obs_response_length/response_length_max"] = wo_obs_response_lengths.max().item() + log_dict["wo_obs_response_length/response_length_min"] = wo_obs_response_lengths.min().item() + if key == "round_number": + # Use numpy for vectorized round number statistics + round_number_array = np.array(val) + log_dict["multi_turn_metric/round_number_mean"] = np.mean(round_number_array) + log_dict["multi_turn_metric/round_number_max"] = np.max(round_number_array) + log_dict["multi_turn_metric/round_number_min"] = np.min(round_number_array) + gather_log_data("multi_turn", args, rollout_id, log_dict, parallel_state) + + +def log_passrate(rollout_id: int, args: Namespace, rollout_data: RolloutBatch, parallel_state: ParallelState) -> None: + """ + Compute pass@k metrics from `raw_reward` groups and log the results. + + `raw_reward` is reshaped to `[group_number, group_size]`, then pass@k is + estimated per problem and averaged. + """ + if parallel_state.tp_rank == 0 and parallel_state.is_pp_last_stage: + log_dict = {} + for key, val in rollout_data.items(): + if key != "raw_reward": + continue + + log_dict |= compute_pass_rate( + flat_rewards=val, + group_size=args.n_samples_per_prompt, + num_groups=args.rollout_batch_size, + ) + + gather_log_data("passrate", args, rollout_id, log_dict, parallel_state) + + +def log_perf_data(rollout_id: int, args: Namespace, parallel_state: ParallelState) -> None: + train_metric_utils.log_perf_data_raw( + rollout_id=rollout_id, + args=args, + is_primary_rank=( + parallel_state.tp_rank == 0 and parallel_state.is_pp_last_stage and parallel_state.dp_cp_rank == 0 + ), + compute_total_fwd_flops=lambda seq_lens: calculate_fwd_flops(seqlens=seq_lens, args=args) + / dist.get_world_size() + / 1e12, + ) + + +def aggregate_train_losses( + losses_reduced: list[dict[str, list[str] | torch.Tensor]], + parallel_state: ParallelState, +) -> dict[str, float]: + """Aggregate loss metrics across micro-batches. + + Sums loss values across all micro-batches, performs all-reduce across + the data-parallel group, and computes per-sample/token averages. + + Args: + losses_reduced: List of log_dict from each micro-batch. + Each log_dict has format: {"keys": list[str], "values": torch.Tensor} + parallel_state: Parallel state containing dp_group and cp_size. + + Returns: + Dictionary mapping metric names to averaged values. + """ + if not losses_reduced: + return {} + + keys = losses_reduced[0]["keys"] + + values = None + for log_dict in losses_reduced: + if values is None: + values = log_dict["values"].clone() + else: + values += log_dict["values"] + + assert len(keys) + 1 == values.numel(), f"Expected {len(keys) + 1} values, got {values.numel()}" + + dist.all_reduce(values, op=dist.ReduceOp.SUM, group=parallel_state.dp_cp_group) + + loss_reduced = {} + values = values.tolist() + num_samples_or_tokens = values[0] + + for key, value in zip(keys, values[1:], strict=False): + loss_reduced[key] = value * parallel_state.cp_size / num_samples_or_tokens + + return loss_reduced + + +def log_train_step( + args: Namespace, + loss_dict: dict[str, float], + grad_norm: float, + rollout_id: int, + step_id: int, + num_steps_per_rollout: int, + role: str = "actor", + extra_metrics: dict[str, float] | None = None, + should_log: bool | None = None, +) -> dict[str, float]: + """Log training metrics for one step. + + Formats loss metrics, gradient norm, and extra metrics (e.g., learning rates, MTP loss) for tracking. + + Args: + args: Configuration. + loss_dict: Dictionary of loss metrics from aggregate_train_losses. + grad_norm: Gradient norm after clipping. + rollout_id: Rollout ID. + step_id: Step ID within the rollout. + num_steps_per_rollout: Total number of steps per rollout. + role: Role name (e.g., "actor", "critic"). + extra_metrics: Optional extra metrics to log (e.g., learning rates, MTP loss). + should_log: Optional override for logging condition. If None, uses rank == 0. + + Returns: + The formatted log_dict (for CI tests or other uses). + """ + accumulated_step_id = rollout_id * num_steps_per_rollout + step_id + role_tag = "" if role == "actor" else f"{role}-" + + log_dict_out = { + f"train/{role_tag}{key}": val.mean().item() if isinstance(val, torch.Tensor) else val + for key, val in loss_dict.items() + } + log_dict_out[f"train/{role_tag}grad_norm"] = float(grad_norm) + + if extra_metrics: + for key, val in extra_metrics.items(): + log_dict_out[f"train/{role_tag}{key}"] = val + + log_dict_out["train/step"] = accumulated_step_id + + if should_log is None: + should_log = dist.get_rank() == 0 + + if should_log: + tracking_utils.log(args, log_dict_out, step_key="train/step") + logger.info(f"{role_tag}step {accumulated_step_id}: {log_dict_out}") + + return log_dict_out diff --git a/miles/backends/megatron_utils/loss.py b/miles/backends/training_utils/loss.py similarity index 88% rename from miles/backends/megatron_utils/loss.py rename to miles/backends/training_utils/loss.py index d7b72a512..abc790761 100644 --- a/miles/backends/megatron_utils/loss.py +++ b/miles/backends/training_utils/loss.py @@ -3,7 +3,6 @@ from typing import Any import torch -from megatron.core import mpu from torch.utils.checkpoint import checkpoint from miles.utils.distributed_utils import distributed_masked_whiten @@ -22,12 +21,14 @@ from miles.utils.types import RolloutBatch from .cp_utils import all_gather_with_cp, get_logits_and_tokens_offset_with_cp, get_sum_of_sample_mean +from .parallel import ParallelState def get_responses( logits: torch.Tensor, *, args: Namespace, + parallel_state: ParallelState, unconcat_tokens: list[torch.Tensor], total_lengths: list[int], response_lengths: list[int], @@ -68,7 +69,7 @@ def get_responses( logits = logits.div(args.rollout_temperature) - cp_size = mpu.get_context_parallel_world_size() + cp_size = parallel_state.cp_size end = 0 for i, (tokens, total_length, response_length) in enumerate( zip(unconcat_tokens, total_lengths, response_lengths, strict=False) @@ -87,7 +88,7 @@ def get_responses( else: # TODO: this is super ugly... do better abstraction. chunk_size, chunks_offset, logits_offset, tokens_offset = get_logits_and_tokens_offset_with_cp( - total_length, response_length, qkv_format, max_seq_len + total_length, response_length, parallel_state, qkv_format, max_seq_len ) logits_0, logits_1 = logits[end : end + chunk_size], logits[end + chunk_size : end + 2 * chunk_size] @@ -112,6 +113,7 @@ def get_log_probs_and_entropy( logits: torch.Tensor, *, args: Namespace, + parallel_state: ParallelState, unconcat_tokens: list[torch.Tensor], total_lengths: list[int], response_lengths: list[int], @@ -147,13 +149,18 @@ def get_log_probs_and_entropy( for logits_chunk, tokens_chunk in get_responses( logits, args=args, + parallel_state=parallel_state, unconcat_tokens=unconcat_tokens, total_lengths=total_lengths, response_lengths=response_lengths, max_seq_lens=max_seq_lens, ): log_prob, entropy = calculate_log_probs_and_entropy( - logits_chunk, tokens_chunk, mpu.get_tensor_model_parallel_group(), with_entropy=with_entropy + logits_chunk, + tokens_chunk, + parallel_state.tp_group, + with_entropy=with_entropy, + chunk_size=args.log_probs_chunk_size, ) log_probs_list.append(log_prob.squeeze(-1)) @@ -171,6 +178,7 @@ def get_values( logits: torch.Tensor, *, args: Namespace, + parallel_state: ParallelState, unconcat_tokens: list[torch.Tensor], total_lengths: list[int], response_lengths: list[int], @@ -201,6 +209,7 @@ def get_values( for logits_chunk, _ in get_responses( logits, args=args, + parallel_state=parallel_state, unconcat_tokens=unconcat_tokens, total_lengths=total_lengths, response_lengths=response_lengths, @@ -214,7 +223,7 @@ def get_values( } -def compute_advantages_and_returns(args: Namespace, rollout_data: RolloutBatch) -> None: +def compute_advantages_and_returns(args: Namespace, parallel_state: ParallelState, rollout_data: RolloutBatch) -> None: """Compute advantages and returns in-place based on `args.advantage_estimator`. This function extracts rewards, log-probs, values, and masks from @@ -269,17 +278,17 @@ def compute_advantages_and_returns(args: Namespace, rollout_data: RolloutBatch) advantages = [r for r in returns] elif args.advantage_estimator == "ppo": - # TODO: optimize this old_rewards = rewards rewards = [] + kl_coef = -args.kl_coef + cp_rank = parallel_state.cp_rank for reward, k in zip(old_rewards, kl, strict=False): - k *= -args.kl_coef - cp_rank = mpu.get_context_parallel_rank() + k *= kl_coef if cp_rank == 0: k[-1] += reward rewards.append(k) advantages, returns = get_advantages_and_returns_batch( - total_lengths, response_lengths, values, rewards, args.gamma, args.lambd + total_lengths, response_lengths, values, rewards, args.gamma, args.lambd, parallel_state ) elif args.advantage_estimator == "reinforce_plus_plus": @@ -292,6 +301,7 @@ def compute_advantages_and_returns(args: Namespace, rollout_data: RolloutBatch) total_lengths=total_lengths, kl_coef=args.kl_coef, gamma=args.gamma, + parallel_state=parallel_state, ) advantages = [r for r in returns] @@ -327,7 +337,7 @@ def compute_advantages_and_returns(args: Namespace, rollout_data: RolloutBatch) # TODO: OpenRLHF always does advantages normalization but veRL doesn't seem to do it. if args.normalize_advantages: all_advs = torch.cat(advantages) - cp_size = mpu.get_context_parallel_world_size() + cp_size = parallel_state.cp_size if cp_size == 1: all_masks = torch.cat(loss_masks) else: @@ -339,7 +349,7 @@ def compute_advantages_and_returns(args: Namespace, rollout_data: RolloutBatch) max_seq_len = max_seq_lens[i] if max_seq_lens is not None else None _, _, _, token_offsets = get_logits_and_tokens_offset_with_cp( - total_len, response_len, args.qkv_format, max_seq_len + total_len, response_len, parallel_state, args.qkv_format, max_seq_len ) # Convert global offsets to response-space offsets @@ -369,7 +379,7 @@ def compute_advantages_and_returns(args: Namespace, rollout_data: RolloutBatch) assert ( all_advs.size() == all_masks.size() ), f"Shape mismatch before whitening: advantages {all_advs.size()}, masks {all_masks.size()}" - dp_group = mpu.get_data_parallel_group() + dp_group = parallel_state.dp_group whitened_advs_flat = distributed_masked_whiten( all_advs, @@ -436,6 +446,7 @@ def icepop_function( def policy_loss_function( args: Namespace, + parallel_state: ParallelState, batch: RolloutBatch, logits: torch.Tensor, sum_of_sample_mean: Callable[[torch.Tensor], torch.Tensor], @@ -445,7 +456,7 @@ def policy_loss_function( Computes current log-probabilities and entropy from model logits, then calculates PPO-style clipped policy gradient loss. For GSPO, gathers full sequences via context-parallel all-gather before computing per-sample - KL. Optionally applies TIS (Temporal Importance Sampling) correction and + KL. Optionally applies TIS (Truncated Importance Sampling) correction and adds KL loss term if configured. Args: @@ -474,6 +485,7 @@ def policy_loss_function( log_probs_and_entropy = get_log_probs_and_entropy( logits, args=args, + parallel_state=parallel_state, unconcat_tokens=batch["unconcat_tokens"], total_lengths=total_lengths, response_lengths=response_lengths, @@ -490,13 +502,13 @@ def policy_loss_function( full_old_log_probs = None if need_full_log_probs: full_log_probs = [ - all_gather_with_cp(log_prob, total_length, response_length) + all_gather_with_cp(log_prob, total_length, response_length, parallel_state) for log_prob, total_length, response_length in zip( log_probs, total_lengths, response_lengths, strict=False ) ] full_old_log_probs = [ - all_gather_with_cp(old_log_prob, total_length, response_length) + all_gather_with_cp(old_log_prob, total_length, response_length, parallel_state) for old_log_prob, total_length, response_length in zip( old_log_probs, total_lengths, response_lengths, strict=False ) @@ -534,6 +546,16 @@ def policy_loss_function( # Apply off-policy correction using importance sampling if enabled if args.get_mismatch_metrics or args.use_tis: + # NOTE: + # `tis_func` may apply rejection-sampling style masking (RS) and return `modified_response_masks`. + # We rebuild `sum_of_sample_mean` with those masks to correct denominators for loss/backprop. + # + # However, mismatch/TIS/RS metrics (e.g., "truncate_fraction") are often defined over the + # *pre-RS* valid tokens. If we aggregate metrics with `modified_response_masks`, the rejected + # tokens are excluded from the denominator and the metric can be artificially driven to 0. + # Keep a copy of the original reducer (based on `batch["loss_masks"]`) for metric aggregation. + sum_of_sample_mean_for_mismatch_metrics = sum_of_sample_mean + assert "rollout_log_probs" in batch, "rollout_log_probs must be provided for TIS" ois = (-ppo_kl).exp() @@ -545,6 +567,7 @@ def policy_loss_function( "loss_masks": batch["loss_masks"], "total_lengths": total_lengths, "response_lengths": response_lengths, + "parallel_state": parallel_state, } if args.custom_tis_function_path is not None: @@ -559,12 +582,24 @@ def policy_loss_function( total_lengths, response_lengths, modified_response_masks, + parallel_state, args.calculate_per_token_loss, args.qkv_format, - batch.get("max_seq_lens", None), + max_seq_lens, + ) + + # Determine pg_loss reducer: use custom if specified, otherwise default + if getattr(args, "custom_pg_loss_reducer_function_path", None) is not None: + custom_pg_loss_reducer_func = load_function(args.custom_pg_loss_reducer_function_path) + # Determine which loss_masks to use for pg_loss reducer + pg_loss_masks = modified_response_masks if (args.get_mismatch_metrics or args.use_tis) else batch["loss_masks"] + pg_loss_reducer = custom_pg_loss_reducer_func( + total_lengths, response_lengths, pg_loss_masks, args.calculate_per_token_loss ) + else: + pg_loss_reducer = sum_of_sample_mean - pg_loss = sum_of_sample_mean(pg_loss) + pg_loss = pg_loss_reducer(pg_loss) pg_clipfrac = sum_of_sample_mean(pg_clipfrac) ppo_kl = sum_of_sample_mean(ppo_kl) @@ -615,11 +650,13 @@ def policy_loss_function( reported_loss["kl_loss"] = kl_loss.clone().detach() if args.get_mismatch_metrics or args.use_tis: - reported_loss["ois"] = sum_of_sample_mean(ois).clone().detach() + # Aggregate mismatch/TIS/RS related metrics with the *pre-RS* masks. + # See comment above where `sum_of_sample_mean_for_mismatch_metrics` is defined. + reported_loss["ois"] = sum_of_sample_mean_for_mismatch_metrics(ois).clone().detach() # Assume all metrics are already cloned and detached for metric_key, metric_value in tis_metrics.items(): key_name = f"{metric_key}" - reported_loss[key_name] = sum_of_sample_mean(metric_value) + reported_loss[key_name] = sum_of_sample_mean_for_mismatch_metrics(metric_value) if args.use_opsm: reported_loss["opsm_clipfrac"] = opsm_clipfrac @@ -629,6 +666,7 @@ def policy_loss_function( def value_loss_function( args: Namespace, + parallel_state: ParallelState, batch: RolloutBatch, logits: torch.Tensor, sum_of_sample_mean: Callable[[torch.Tensor], torch.Tensor], @@ -655,6 +693,7 @@ def value_loss_function( values = get_values( logits, args=args, + parallel_state=parallel_state, unconcat_tokens=batch["unconcat_tokens"], total_lengths=batch["total_lengths"], response_lengths=batch["response_lengths"], @@ -687,6 +726,7 @@ def value_loss_function( def sft_loss_function( args: Namespace, + parallel_state: ParallelState, batch: RolloutBatch, logits: torch.Tensor, sum_of_sample_mean: Callable[[torch.Tensor], torch.Tensor], @@ -713,6 +753,7 @@ def sft_loss_function( log_probs_and_entropy = get_log_probs_and_entropy( logits, args=args, + parallel_state=parallel_state, unconcat_tokens=batch["unconcat_tokens"], total_lengths=total_lengths, response_lengths=response_lengths, @@ -738,9 +779,11 @@ def sft_loss_function( def loss_function( args: Namespace, + parallel_state: ParallelState, batch: RolloutBatch, num_microbatches: int, logits: torch.Tensor, + apply_megatron_loss_scaling: bool = False, ) -> tuple[torch.Tensor, int | torch.Tensor, dict[str, list[str] | torch.Tensor]]: """Dispatch to the configured loss and rescale for Megatron integration. @@ -772,6 +815,7 @@ def loss_function( batch["total_lengths"], batch["response_lengths"], batch["loss_masks"], + parallel_state, args.calculate_per_token_loss, args.qkv_format, batch.get("max_seq_lens", None), @@ -790,24 +834,31 @@ def loss_function( raise ValueError(f"Unknown loss type: {args.loss_type}") if args.recompute_loss_function: - loss, log = checkpoint(func, args, batch, logits, sum_of_sample_mean) + loss, log = checkpoint( + func, + args, + parallel_state, + batch, + logits, + sum_of_sample_mean, + ) else: - loss, log = func(args, batch, logits, sum_of_sample_mean) + loss, log = func(args, parallel_state, batch, logits, sum_of_sample_mean) # Here we need to divide by cp_size because to cancel the multiply in Megatron. + global_batch_size = batch.get("dynamic_global_batch_size", args.global_batch_size) if not args.calculate_per_token_loss: - loss = ( - loss - * num_microbatches - / args.global_batch_size - * mpu.get_data_parallel_world_size(with_context_parallel=True) - ) + if apply_megatron_loss_scaling: + loss = loss * num_microbatches / global_batch_size * parallel_state.dp_cp_size + else: + loss = loss / global_batch_size * parallel_state.dp_size else: - loss = loss * mpu.get_context_parallel_world_size() + if apply_megatron_loss_scaling: + loss = loss * parallel_state.cp_size return ( loss, - num_tokens if args.calculate_per_token_loss else 1, + torch.tensor(num_tokens if args.calculate_per_token_loss else 1, dtype=torch.int, device=logits.device), { "keys": list(log.keys()), "values": torch.tensor( diff --git a/miles/backends/training_utils/parallel.py b/miles/backends/training_utils/parallel.py new file mode 100644 index 000000000..4283e8731 --- /dev/null +++ b/miles/backends/training_utils/parallel.py @@ -0,0 +1,27 @@ +from dataclasses import dataclass +import torch.distributed as dist + + +@dataclass +class ParallelState: + """Core parallel state shared across all backends. + Required by the general training utils. + """ + + dp_rank: int + dp_src_rank: int + dp_size: int + cp_rank: int + cp_size: int + dp_cp_rank: int + dp_cp_size: int + dp_group: dist.ProcessGroup | None + dp_cp_group: dist.ProcessGroup | None + dp_cp_group_gloo: dist.ProcessGroup | None + cp_group: dist.ProcessGroup | None + tp_size: int + tp_rank: int + tp_group: dist.ProcessGroup | None + is_pp_last_stage: bool = True + vpp_size: int | None = 1 + microbatch_group_size_per_vp_stage: int | None = None diff --git a/miles/ray/actor_group.py b/miles/ray/actor_group.py index 21a26abbe..0b11df312 100644 --- a/miles/ray/actor_group.py +++ b/miles/ray/actor_group.py @@ -31,7 +31,7 @@ def __init__( args, num_nodes, num_gpus_per_node, - pg: tuple[PlacementGroup, list[int]], + pg: tuple[PlacementGroup, list[int], list[int]], num_gpus_per_actor: float = 1, role: str = "actor", ) -> None: @@ -48,7 +48,7 @@ def _allocate_gpus_for_actor(self, pg, num_gpus_per_actor): # Use placement group to lock resources for models of same type assert pg is not None - pg, reordered_bundle_indices = pg + pg, reordered_bundle_indices, _reordered_gpu_ids = pg env_vars = { # because sglang will always set NCCL_CUMEM_ENABLE to 0 @@ -72,12 +72,13 @@ def _allocate_gpus_for_actor(self, pg, num_gpus_per_actor): env_vars["TMS_INIT_ENABLE"] = "1" env_vars["TMS_INIT_ENABLE_CPU_BACKUP"] = "1" - if self.args.use_routing_replay: + # We cannot do routing replay for critic. + if self.args.use_routing_replay and self.role == "actor": env_vars["ENABLE_ROUTING_REPLAY"] = "1" backend = self.args.train_backend if backend == "megatron": - from miles.backends.megatron_utils import MegatronTrainRayActor + from miles.backends.megatron_utils.actor import MegatronTrainRayActor actor_impl = MegatronTrainRayActor @@ -115,9 +116,9 @@ def async_train(self, rollout_id, rollout_data_ref): """Do one rollout training""" return [actor.train.remote(rollout_id, rollout_data_ref) for actor in self._actor_handlers] - def save_model(self, step_id): - """Save actor model on rank 0.""" - return ray.get([actor.save_model.remote(step_id) for actor in self._actor_handlers]) + def save_model(self, rollout_id, force_sync=False): + """Save actor model""" + return ray.get([actor.save_model.remote(rollout_id, force_sync=force_sync) for actor in self._actor_handlers]) def update_weights(self): """Broadcast weights from rank 0 to all other ranks.""" diff --git a/miles/ray/placement_group.py b/miles/ray/placement_group.py index b6fb7a20b..eb232b161 100644 --- a/miles/ray/placement_group.py +++ b/miles/ray/placement_group.py @@ -1,5 +1,6 @@ import logging import socket + import ray from ray.util.placement_group import placement_group from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy @@ -60,7 +61,11 @@ def _create_placement_group(num_gpus): ray.kill(actor) bundle_infos = [(i, gpu_ids[i][0], gpu_ids[i][1]) for i in range(num_bundles)] - pg_reordered_bundle_indices = [bundle_info[0] for bundle_info in sorted(bundle_infos, key=sort_key)] + sorted_bundle_infos = sorted(bundle_infos, key=sort_key) + pg_reordered_bundle_indices = [info[0] for info in sorted_bundle_infos] + # Map from logical index -> physical GPU ID + pg_reordered_gpu_ids = [gpu_ids[info[0]][1] for info in sorted_bundle_infos] + for i in range(num_bundles): actual_bundle_index = pg_reordered_bundle_indices[i] logger.info( @@ -68,7 +73,7 @@ def _create_placement_group(num_gpus): f"node: {gpu_ids[actual_bundle_index][0]}, gpu: {gpu_ids[actual_bundle_index][1]}" ) - return pg, pg_reordered_bundle_indices + return pg, pg_reordered_bundle_indices, pg_reordered_gpu_ids def create_placement_groups(args): @@ -99,16 +104,18 @@ def create_placement_groups(args): rollout_offset += args.critic_num_nodes * args.critic_num_gpus_per_node logger.info(f"Creating placement group with {num_gpus} GPUs...") - pg, actor_pg_reordered_bundle_indices = _create_placement_group(num_gpus) + pg, actor_pg_reordered_bundle_indices, actor_pg_reordered_gpu_ids = _create_placement_group(num_gpus) rollout_pg_reordered_bundle_indices = actor_pg_reordered_bundle_indices[rollout_offset:] + rollout_pg_reordered_gpu_ids = actor_pg_reordered_gpu_ids[rollout_offset:] if args.use_critic: critic_pg_reordered_bundle_indices = actor_pg_reordered_bundle_indices[critic_offset:] + critic_pg_reordered_gpu_ids = actor_pg_reordered_gpu_ids[critic_offset:] return { - "actor": (pg, actor_pg_reordered_bundle_indices), - "critic": (pg, critic_pg_reordered_bundle_indices) if args.use_critic else None, - "rollout": (pg, rollout_pg_reordered_bundle_indices), + "actor": (pg, actor_pg_reordered_bundle_indices, actor_pg_reordered_gpu_ids), + "critic": (pg, critic_pg_reordered_bundle_indices, critic_pg_reordered_gpu_ids) if args.use_critic else None, + "rollout": (pg, rollout_pg_reordered_bundle_indices, rollout_pg_reordered_gpu_ids), } diff --git a/miles/ray/rollout.py b/miles/ray/rollout.py index 9ee0fbb8a..27211845d 100644 --- a/miles/ray/rollout.py +++ b/miles/ray/rollout.py @@ -1,9 +1,8 @@ +import itertools import logging import multiprocessing -import os import random import time -from glob import glob from pathlib import Path from typing import Any @@ -11,10 +10,18 @@ import ray import torch from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy +from sglang.srt.constants import GPU_MEMORY_TYPE_CUDA_GRAPH, GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS from miles.backends.sglang_utils.sglang_engine import SGLangEngine -from miles.rollout.base_types import call_rollout_fn +from miles.rollout.base_types import ( + RolloutFnConstructorInput, + RolloutFnEvalInput, + RolloutFnTrainInput, + call_rollout_fn, +) +from miles.rollout.inference_rollout.compatibility import call_rollout_function, load_rollout_function from miles.utils import tracking_utils +from miles.utils.environ import enable_experimental_rollout_refactor from miles.utils.health_monitor import RolloutHealthMonitor from miles.utils.http_utils import _wrap_ipv6, find_available_port, get_host_info, init_http_client from miles.utils.iter_utils import group_by @@ -53,11 +60,22 @@ def __init__(self, args, pg): data_source_cls = load_function(self.args.data_source_path) self.data_source = data_source_cls(args) - self.generate_rollout = load_function(self.args.rollout_function_path) - self.eval_generate_rollout = load_function(self.args.eval_function_path) + self.use_experimental_refactor = enable_experimental_rollout_refactor() + if self.use_experimental_refactor: + input = RolloutFnConstructorInput(args=args, data_source=self.data_source) + self.generate_rollout = load_rollout_function(input, self.args.rollout_function_path) + self.eval_generate_rollout = load_rollout_function(input, self.args.eval_function_path) + else: + self.generate_rollout = load_function(self.args.rollout_function_path) + self.eval_generate_rollout = load_function(self.args.eval_function_path) self.custom_reward_post_process_func = None if self.args.custom_reward_post_process_path is not None: self.custom_reward_post_process_func = load_function(self.args.custom_reward_post_process_path) + self.custom_convert_samples_to_train_data_func = None + if self.args.custom_convert_samples_to_train_data_path is not None: + self.custom_convert_samples_to_train_data_func = load_function( + self.args.custom_convert_samples_to_train_data_path + ) logger.info(f"import {self.args.rollout_function_path} as generate_rollout function.") logger.info(f"import {self.args.eval_function_path} as eval_generate_rollout function.") @@ -70,14 +88,41 @@ def __init__(self, args, pg): self.num_new_engines = init_rollout_engines(args, pg, self.all_rollout_engines) self.nodes_per_engine = max(1, args.rollout_num_gpus_per_engine // args.num_gpus_per_node) self.rollout_engine_lock = Lock.options(num_cpus=1, num_gpus=0).remote() + self.rollout_id = -1 self._metric_checker = MetricChecker.maybe_create(args) + self._health_monitor = None if self.args.use_fault_tolerance: self._health_monitor = RolloutHealthMonitor(self, args) + self._health_monitor.start() # Start the monitor thread (in paused state) + self._ci_fault_injection_pending = self.args.ci_test # Flag for CI fault injection + + def _try_ci_fault_injection(self): + """Try to inject fault during generate (when health monitor is running).""" + if not self._ci_fault_injection_pending: + return + + # Only inject fault once + self._ci_fault_injection_pending = False + + if self.all_rollout_engines and self.all_rollout_engines[0]: + logger.info("CI Fault Injection: Simulating crash on engine 0 during generate") + try: + # This will cause the ray actor to exit + self.all_rollout_engines[0].simulate_crash.remote() + # Wait for health monitor to detect the crash and mark engine as None + # health_check_interval + health_check_timeout + buffer + wait_time = self.args.rollout_health_check_interval + self.args.rollout_health_check_timeout + 5 + logger.info(f"CI Fault Injection: Waiting {wait_time}s for health monitor to detect crash") + time.sleep(wait_time) + except Exception as e: + logger.warning(f"CI Fault Injection failed: {e}") def dispose(self): if self._metric_checker is not None: self._metric_checker.dispose() + if self._health_monitor is not None: + self._health_monitor.stop() # TODO maybe rename "rollout_engines" and "all_rollout_engines" later @property @@ -93,28 +138,29 @@ def get_num_rollout_per_epoch(self): return len(self.data_source.dataset) // self.args.rollout_batch_size def generate(self, rollout_id): - monitor_started = self.args.use_fault_tolerance and self._health_monitor.start() start_time = time.time() - try: - data, metrics = self._get_rollout_data(rollout_id=rollout_id) - self._save_debug_rollout_data(data, rollout_id=rollout_id, evaluation=False) - _log_rollout_data(rollout_id, self.args, data, metrics, time.time() - start_time) - data = self._convert_samples_to_train_data(data) - return self._split_train_data_by_dp(data, self.train_parallel_config["dp_size"]) - finally: - if monitor_started: - self._health_monitor.stop() - self.num_new_engines = init_rollout_engines(self.args, self.pg, self.all_rollout_engines) - else: - self.num_new_engines = 0 + self.rollout_id = rollout_id + self.health_monitoring_resume() + if self.args.ci_test and self.args.use_fault_tolerance and rollout_id >= 2: + self._try_ci_fault_injection() + data, metrics = self._get_rollout_data(rollout_id=rollout_id) + self._save_debug_rollout_data(data, rollout_id=rollout_id, evaluation=False) + _log_rollout_data(rollout_id, self.args, data, metrics, time.time() - start_time) + data = self._convert_samples_to_train_data(data) + return self._split_train_data_by_dp(data, self.train_parallel_config["dp_size"]) def eval(self, rollout_id): if self.args.debug_train_only: # if debug train only, we don't generate evaluation data return + self.health_monitoring_resume() - # TODO: add fault tolerance to eval - result = call_rollout_fn(self.eval_generate_rollout, self.args, rollout_id, self.data_source, evaluation=True) + if self.use_experimental_refactor: + result = call_rollout_function(self.eval_generate_rollout, RolloutFnEvalInput(rollout_id=rollout_id)) + else: + result = call_rollout_fn( + self.eval_generate_rollout, self.args, rollout_id, self.data_source, evaluation=True + ) data = result.data self._save_debug_rollout_data(data, rollout_id=rollout_id, evaluation=True) metrics = _log_eval_rollout_data(rollout_id, self.args, data, result.metrics) @@ -128,10 +174,54 @@ def load(self, rollout_id=None): self.data_source.load(rollout_id) def offload(self): - return ray.get([engine.release_memory_occupation.remote() for engine in self.rollout_engines]) + self.health_monitoring_pause() + return ray.get( + [engine.release_memory_occupation.remote() for engine in self.rollout_engines if engine is not None] + ) + + def onload(self, tags: list[str] | None = None): + return ray.get( + [ + engine.resume_memory_occupation.remote(tags=tags) + for engine in self.rollout_engines + if engine is not None + ] + ) + + def onload_weights(self): + self.onload(tags=[GPU_MEMORY_TYPE_WEIGHTS]) + + def onload_kv(self): + self.onload(tags=[GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_CUDA_GRAPH]) + + def recover_rollout_engines(self): + """Restart any dead rollout engines and update num_new_engines for update_weights detection.""" + self.health_monitoring_pause() + if self.rollout_id == -1: + return self.rollout_engines, self.rollout_engine_lock, self.num_new_engines + + dead_indices = [i for i, engine in enumerate(self.all_rollout_engines) if engine is None] + self.num_new_engines = init_rollout_engines(self.args, self.pg, self.all_rollout_engines) + logger.info(f"Recovered {self.num_new_engines} dead rollout engines") + assert self.num_new_engines == len(dead_indices), "num_new_engines does not match dead_indices length" + if self.args.offload_rollout and dead_indices: + new_engines = [self.all_rollout_engines[i] for i in dead_indices] + ray.get([engine.release_memory_occupation.remote() for engine in new_engines]) + ray.get([engine.resume_memory_occupation.remote(tags=[GPU_MEMORY_TYPE_WEIGHTS]) for engine in new_engines]) + + return self.rollout_engines, self.rollout_engine_lock, self.num_new_engines - def onload(self, tags: list[str] = None): - return ray.get([engine.resume_memory_occupation.remote(tags=tags) for engine in self.rollout_engines]) + def clear_num_new_engines(self): + # when fault tolerance is not enabled, we need to manually clear num_new_engines after update_weights + self.num_new_engines = 0 + + def health_monitoring_pause(self) -> None: + if self._health_monitor is not None: + self._health_monitor.pause() + + def health_monitoring_resume(self) -> None: + if self._health_monitor is not None: + self._health_monitor.resume() def check_weights(self, action: str): return ray.get([engine.check_weights.remote(action=action) for engine in self.rollout_engines]) @@ -139,7 +229,7 @@ def check_weights(self, action: str): def _get_rollout_data(self, rollout_id): if self.args.load_debug_rollout_data: data = torch.load( - open(self.args.load_debug_rollout_data.format(rollout_id=rollout_id), "rb"), + self.args.load_debug_rollout_data.format(rollout_id=rollout_id), weights_only=False, )["samples"] data = [Sample.from_dict(sample) for sample in data] @@ -152,22 +242,66 @@ def _get_rollout_data(self, rollout_id): ) metrics = None else: - data = call_rollout_fn(self.generate_rollout, self.args, rollout_id, self.data_source, evaluation=False) + if self.use_experimental_refactor: + data = call_rollout_function(self.generate_rollout, RolloutFnTrainInput(rollout_id=rollout_id)) + else: + data = call_rollout_fn( + self.generate_rollout, self.args, rollout_id, self.data_source, evaluation=False + ) metrics = data.metrics data = data.samples # flatten the data if it is a list of lists while isinstance(data[0], list): - data = sum(data, []) - - if self.args.disable_rollout_trim_samples: - logger.info(f"Collectd {len(data)} samples from rollout to train") - elif len(data) % self.args.global_batch_size != 0: - trim_len = (len(data) // self.args.global_batch_size) * self.args.global_batch_size - origin_data_length = len(data) - data = data[:trim_len] - logger.info(f"trim number of samples from {origin_data_length} to {trim_len}") + data = list(itertools.chain.from_iterable(data)) + + if not self.args.disable_rollout_trim_samples: + global_batch_size = self.args.global_batch_size + if self.args.use_dynamic_global_batch_size: + logger.info(f"Collected {len(data)} samples from rollout to train with dynamic global batch size") + # TODO: this is a temporary solution, we should directly save dynamic_global_batch_size to rollout data + self._dynamic_global_batch_size = self._compute_dynamic_global_batch_size(len(data)) + global_batch_size = self._dynamic_global_batch_size + + if len(data) % global_batch_size != 0: + trim_len = (len(data) // global_batch_size) * global_batch_size + if trim_len == 0: + raise ValueError(f"Not enough samples {len(data)} for global_batch_size {global_batch_size}") + origin_data_length = len(data) + data = data[:trim_len] + logger.info(f"trim number of samples from {origin_data_length} to {trim_len}") + logger.info(f"Final collected {len(data)} samples from rollout to train") + return data, metrics + def _compute_dynamic_global_batch_size(self, num_samples: int) -> int: + """Calculate dynamic global_batch_size to ensure only one training step. + + Strategy: global_batch_size = num_samples rounded down to a multiple of dp_size + This ensures num_steps_per_rollout = num_samples // global_batch_size = 1 + """ + dp_size = self.train_parallel_config["dp_size"] + original_gbs = self.args.global_batch_size + + # Round down to a multiple of dp_size to ensure only one training step + dynamic_gbs = (num_samples // dp_size) * dp_size + + if dynamic_gbs == 0: + # Too few samples, use at least dp_size + dynamic_gbs = dp_size + logger.warning(f"num_samples={num_samples} < dp_size={dp_size}, using dp_size as global_batch_size") + + # Calculate how many samples will be discarded + wasted = num_samples - dynamic_gbs + + if dynamic_gbs != original_gbs or wasted > 0: + logger.info( + f"Dynamic global_batch_size: {original_gbs} -> {dynamic_gbs} " + f"(num_samples={num_samples}, dp_size={dp_size}, " + f"num_steps=1, wasted={wasted})" + ) + + return dynamic_gbs + def _save_debug_rollout_data(self, data, rollout_id, evaluation: bool): # TODO to be refactored (originally Buffer._set_data) if (path_template := self.args.save_debug_rollout_data) is not None: @@ -218,6 +352,9 @@ def _convert_samples_to_train_data(self, samples: list[Sample] | list[list[Sampl """ Convert inference generated samples to training data. """ + if self.custom_convert_samples_to_train_data_func is not None: + return self.custom_convert_samples_to_train_data_func(self.args, samples) + raw_rewards, rewards = self._post_process_rewards(samples) assert len(raw_rewards) == len(samples) @@ -268,8 +405,8 @@ def _convert_samples_to_train_data(self, samples: list[Sample] | list[list[Sampl if samples[0].train_metadata is not None: train_data["metadata"] = [sample.train_metadata for sample in samples] - if samples[0].multimodal_inputs is not None: - train_data["multimodal_inputs"] = [sample.multimodal_inputs for sample in samples] + if samples[0].multimodal_train_inputs is not None: + train_data["multimodal_train_inputs"] = [sample.multimodal_train_inputs for sample in samples] if "teacher_log_probs" in samples[0].__dict__: train_data["teacher_log_probs"] = [sample.teacher_log_probs for sample in samples] @@ -302,7 +439,7 @@ def _split_train_data_by_dp(self, data, dp_size): rollout_data["partition"] = partition for key in [ "tokens", - "multimodal_inputs", + "multimodal_train_inputs", "response_lengths", "rewards", "truncated", @@ -326,13 +463,16 @@ def _split_train_data_by_dp(self, data, dp_size): if key not in data: continue rollout_data[key] = data[key] + # Pass dynamic global_batch_size to training side + if hasattr(self, "_dynamic_global_batch_size"): + rollout_data["dynamic_global_batch_size"] = self._dynamic_global_batch_size rollout_data_refs.append(Box(ray.put(rollout_data))) return rollout_data_refs def init_rollout_engines(args, pg, all_rollout_engines): if args.debug_train_only: - return 0, None + return 0 num_gpu_per_engine = min(args.rollout_num_gpus_per_engine, args.num_gpus_per_node) num_engines = args.rollout_num_gpus // num_gpu_per_engine @@ -343,7 +483,7 @@ def init_rollout_engines(args, pg, all_rollout_engines): num_engines > prefill_num_servers ), f"num_engines {num_engines} should be larger than prefill_num_servers {prefill_num_servers}" - pg, reordered_bundle_indices = pg + pg, reordered_bundle_indices, reordered_gpu_ids = pg RolloutRayActor = ray.remote(SGLangEngine) @@ -355,6 +495,9 @@ def init_rollout_engines(args, pg, all_rollout_engines): num_gpus = 0.2 num_cpus = num_gpus + # Get the base GPU ID from placement group + base_gpu_id = int(reordered_gpu_ids[i * num_gpu_per_engine]) + scheduling_strategy = PlacementGroupSchedulingStrategy( placement_group=pg, placement_group_capture_child_tasks=True, @@ -369,27 +512,9 @@ def init_rollout_engines(args, pg, all_rollout_engines): "SGLANG_MEMORY_SAVER_CUDA_GRAPH": "true", "SGLANG_BATCH_INVARIANT_OPS_ENABLE_MM_FALLBACK_VARIANT": "true", "SGLANG_ENABLE_HEALTH_ENDPOINT_GENERATION": "false", + "SGLANG_ENABLE_STRICT_MEM_CHECK_DURING_IDLE": "false", } - # TODO: currently the amem position is hardcoded, change to a better way later. - # note that amem does not work with update weights from distributed. - if ( - args.offload_rollout - and args.actor_num_nodes * args.actor_num_gpus_per_node >= args.rollout_num_gpus - and len(glob("/usr/local/lib/python3.12/dist-packages/nvidia/nccl/lib/libamem_nccl.so*")) > 0 - ): - logger.info("Enable AMEM for rollout engine.") - ld_library_path = ( - os.environ.get("LD_LIBRARY_PATH", "") + ":/usr/local/lib/python3.12/dist-packages/nvidia/nccl/lib" - ) - env_vars |= { - "LD_LIBRARY_PATH": ld_library_path, - "NCCL_CUMEM_ENABLE": "1", - "AMEM_ENABLE": "1", - "AMEM_GROUPID": "0", - "GMM_LOG": "2", - } - worker_type = "regular" if args.prefill_num_servers is not None: if i < prefill_num_servers: @@ -404,7 +529,7 @@ def init_rollout_engines(args, pg, all_rollout_engines): runtime_env={ "env_vars": env_vars, }, - ).remote(args, rank=i, worker_type=worker_type) + ).remote(args, rank=i, worker_type=worker_type, base_gpu_id=base_gpu_id) rollout_engines.append((i, rollout_engine)) all_rollout_engines[i] = rollout_engine @@ -412,7 +537,7 @@ def init_rollout_engines(args, pg, all_rollout_engines): num_new_engines = len(rollout_engines) if num_new_engines == 0: - return num_new_engines, None + return num_new_engines if args.rollout_external: addr_and_ports = _allocate_rollout_engine_addr_and_ports_external(args=args, rollout_engines=rollout_engines) @@ -432,10 +557,11 @@ def init_rollout_engines(args, pg, all_rollout_engines): def _allocate_rollout_engine_addr_and_ports_external(args, rollout_engines): addr_and_ports = [] for rank, _ in rollout_engines: - [host, port] = args.rollout_external_engine_addrs[rank].split(":") + addr = args.rollout_external_engine_addrs[rank] + [host, port] = addr.split(":") addr_and_ports.append( dict( - dist_init_addr=None, + dist_init_addr=addr, nccl_port=None, host=host, port=int(port), @@ -456,6 +582,12 @@ def _allocate_rollout_engine_addr_and_ports_normal(*, args, num_engines, rollout ) addr_and_ports = [{} for _ in range(num_engines)] + # Calculate prefill limit to identify prefill engines + prefill_limit = 0 + if args.prefill_num_servers is not None: + num_gpu_per_engine = min(args.rollout_num_gpus_per_engine, args.num_gpus_per_node) + prefill_limit = args.prefill_num_servers * args.rollout_num_gpus_per_engine // num_gpu_per_engine + visited_nodes = set() for rank, engine in rollout_engines: if rank // num_engines_per_node in visited_nodes: @@ -490,20 +622,24 @@ def addr(): get_addr, get_port = get_addr_and_ports(engine) for i in range(num_engines_on_this_node): - addr_and_ports[rank + i]["host"] = get_addr() - addr_and_ports[rank + i]["port"] = get_port() - addr_and_ports[rank + i]["nccl_port"] = get_port() + current_rank = rank + i + addr_and_ports[current_rank]["host"] = get_addr() + addr_and_ports[current_rank]["port"] = get_port() + addr_and_ports[current_rank]["nccl_port"] = get_port() + + if args.prefill_num_servers is not None and current_rank < prefill_limit: + addr_and_ports[current_rank]["disaggregation_bootstrap_port"] = get_port() if args.rollout_num_gpus_per_engine > args.num_gpus_per_node: num_node_per_engine = args.rollout_num_gpus_per_engine // args.num_gpus_per_node if rank % num_node_per_engine == 0: # this is the first node in the engine, we need to allocate the dist_init_addr port - dist_init_addr = f"{get_addr()}:{get_port(6 + args.sglang_dp_size)}" + dist_init_addr = f"{get_addr()}:{get_port(30 + args.sglang_dp_size)}" for i in range(num_node_per_engine): addr_and_ports[rank + i]["dist_init_addr"] = dist_init_addr else: for i in range(num_engines_on_this_node): - addr_and_ports[rank + i]["dist_init_addr"] = f"{get_addr()}:{get_port(6 + args.sglang_dp_size)}" + addr_and_ports[rank + i]["dist_init_addr"] = f"{get_addr()}:{get_port(30 + args.sglang_dp_size)}" for i, _ in rollout_engines: for key in ["port", "nccl_port", "dist_init_addr"]: @@ -513,7 +649,7 @@ def addr(): return addr_and_ports -def _start_router(args, prefill_and_decode_urls=None): +def _start_router(args): """start sgl router and miles router""" if args.sglang_router_ip is not None: return @@ -538,13 +674,11 @@ def _start_router(args, prefill_and_decode_urls=None): router_args.port = args.sglang_router_port router_args.prometheus_port = find_available_port(random.randint(4000, 5000)) router_args.log_level = "warn" + router_args.request_timeout_secs = args.sglang_router_request_timeout_secs if args.prefill_num_servers is not None: router_args.pd_disaggregation = True - if hasattr(router_args, "request_timeout_secs"): - router_args.request_timeout_secs = args.sglang_router_request_timeout_secs - logger.info(f"Launch router with args: {router_args}") process = multiprocessing.Process( @@ -560,6 +694,11 @@ def _start_router(args, prefill_and_decode_urls=None): def _log_eval_rollout_data(rollout_id, args, data, extra_metrics: dict[str, Any] | None = None): + if args.custom_eval_rollout_log_function_path is not None: + custom_log_func = load_function(args.custom_eval_rollout_log_function_path) + if custom_log_func(rollout_id, args, data, extra_metrics): + return + log_dict = extra_metrics or {} for key in data.keys(): rewards = data[key]["rewards"] @@ -588,16 +727,17 @@ def _log_eval_rollout_data(rollout_id, args, data, extra_metrics: dict[str, Any] def _log_rollout_data(rollout_id, args, samples, rollout_extra_metrics, rollout_time): + if args.custom_rollout_log_function_path is not None: + custom_log_func = load_function(args.custom_rollout_log_function_path) + if custom_log_func(rollout_id, args, samples, rollout_extra_metrics, rollout_time): + return + if args.load_debug_rollout_data: return log_dict = {**(rollout_extra_metrics or {})} - response_lengths = [sample.effective_response_length for sample in samples] - log_dict["perf/rollout_time"] = rollout_time - if args.rollout_num_gpus: - log_dict["perf/tokens_per_gpu_per_sec"] = sum(response_lengths) / rollout_time / args.rollout_num_gpus - log_dict["perf/longest_sample_tokens_per_sec"] = max(response_lengths) / rollout_time log_dict |= dict_add_prefix(compute_metrics_from_samples(args, samples), "rollout/") + log_dict |= dict_add_prefix(compute_perf_metrics_from_samples(args, samples, rollout_time), "perf/") logger.info(f"perf {rollout_id}: {log_dict}") step = compute_rollout_step(args, rollout_id) log_dict["rollout/step"] = step @@ -610,13 +750,45 @@ def compute_metrics_from_samples(args, samples): log_dict = {} log_dict |= dict_add_prefix(compute_statistics(response_lengths), "response_len/") log_dict |= _compute_zero_std_metrics(args, samples) - log_dict |= _compute_spec_metrics(args, samples) log_dict |= _compute_reward_cat_metrics(args, samples) log_dict["repetition_frac"] = np.mean([int(has_repetition(s.response)) for s in samples]).item() log_dict["truncated_ratio"] = np.mean([int(s.status == Sample.Status.TRUNCATED) for s in samples]).item() return log_dict +def compute_perf_metrics_from_samples(args, samples, rollout_time): + non_generation_time = [sample.non_generation_time for sample in samples] + + log_dict = {} + log_dict["rollout_time"] = rollout_time + if max(non_generation_time) > 0: + log_dict |= dict_add_prefix(compute_statistics(non_generation_time), "non_generation_time/") + + def token_perf(response_lengths, non_generation_time, key=""): + max_response_length = max(response_lengths) + if args.rollout_num_gpus: + log_dict[f"{key}tokens_per_gpu_per_sec"] = sum(response_lengths) / rollout_time / args.rollout_num_gpus + log_dict[f"longest_{key}sample_tokens_per_sec"] = max_response_length / rollout_time + + if max(non_generation_time) == 0: + return + + non_generation_time = [ + t for t, length in zip(non_generation_time, response_lengths, strict=True) if length == max_response_length + ] + mean_non_generation_time = sum(non_generation_time) / len(non_generation_time) + + log_dict[f"longest_{key}sample_non_generation_time"] = mean_non_generation_time + log_dict[f"longest_{key}sample_tokens_per_sec_without_non_generation"] = max_response_length / ( + rollout_time - mean_non_generation_time + ) + + token_perf([sample.response_length for sample in samples], non_generation_time, key="") + token_perf([sample.effective_response_length for sample in samples], non_generation_time, key="effective_") + + return log_dict + + def _compute_zero_std_metrics(args, all_samples: list[Sample]): # only compute in GRPO-like algorithms where one prompt has multiple responses if args.advantage_estimator == "ppo": @@ -639,12 +811,19 @@ def _compute_spec_metrics(args, all_samples: list[Sample]): return {} num_samples = len(all_samples) metrics = {} - metrics["rollout/spec_accept_rate"] = ( - sum(sample.spec_info.spec_accept_rate for sample in all_samples) / num_samples - ) - metrics["rollout/spec_accept_length"] = ( - sum(sample.spec_info.spec_accept_length for sample in all_samples) / num_samples - ) + metrics["spec_accept_rate"] = sum(sample.spec_info.spec_accept_rate for sample in all_samples) / num_samples + metrics["spec_accept_length"] = sum(sample.spec_info.spec_accept_length for sample in all_samples) / num_samples + return metrics + + +def _compute_prefix_cache_metrics(args, all_samples: list[Sample]): + num_samples = len(all_samples) + metrics = {} + total_cached_tokens = sum(sample.prefix_cache_info.cached_tokens for sample in all_samples) + total_prompt_tokens = sum(sample.prefix_cache_info.total_prompt_tokens for sample in all_samples) + + metrics["prefix_cache_hit_rate"] = total_cached_tokens / total_prompt_tokens if total_prompt_tokens > 0 else 0.0 + metrics["avg_cached_tokens_per_sample"] = total_cached_tokens / num_samples return metrics diff --git a/miles/ray/rollout_data_source.py b/miles/ray/rollout_data_source.py deleted file mode 100644 index c9df08f4f..000000000 --- a/miles/ray/rollout_data_source.py +++ /dev/null @@ -1,186 +0,0 @@ -import copy -import logging -import os -from pathlib import Path - -import torch -from transformers import AutoTokenizer - -from miles.utils.data import Dataset -from miles.utils.misc import load_function -from miles.utils.types import Sample - -logger = logging.getLogger(__name__) - - -# TODO may further refactor data-loading part later -class RolloutDataSource: - def __init__(self, args): - self.args = args - - self.epoch_id = 0 - self.sample_group_index = 0 - self.sample_index = 0 - self.sample_offset = 0 - # TODO remove this - self.metadata = {} - - if args.rollout_global_dataset: - tokenizer = AutoTokenizer.from_pretrained(args.hf_checkpoint, trust_remote_code=True) - - # TODO move (during the refactor) - if (d := args.dump_details) is not None: - tokenizer.save_pretrained(Path(d) / "tokenizer") - - self.dataset = Dataset( - args.prompt_data, - tokenizer=tokenizer, - max_length=args.rollout_max_prompt_len, - prompt_key=args.input_key, - label_key=args.label_key, - metadata_key=args.metadata_key, - tool_key=args.tool_key, - apply_chat_template=args.apply_chat_template, - apply_chat_template_kwargs=args.apply_chat_template_kwargs, - seed=args.rollout_seed, - ) - if self.args.rollout_shuffle: - self.dataset.shuffle(self.epoch_id) - else: - self.dataset = None - - def get_samples(self, num_samples): - # TODO further improve code - if self.dataset is not None: - if self.sample_offset + num_samples <= len(self.dataset): - prompt_samples = self.dataset.samples[self.sample_offset : self.sample_offset + num_samples] - self.sample_offset += num_samples - else: - prompt_samples = self.dataset.samples[self.sample_offset :] - num_samples -= len(prompt_samples) - self.epoch_id += 1 - if self.args.rollout_shuffle: - self.dataset.shuffle(self.epoch_id) - prompt_samples += self.dataset.samples[:num_samples] - self.sample_offset = num_samples - else: - prompt_samples = [Sample() for _ in range(num_samples)] - - samples = [] - for prompt_sample in prompt_samples: - group = [] - for _ in range(self.args.n_samples_per_prompt): - sample = copy.deepcopy(prompt_sample) - sample.group_index = self.sample_group_index - sample.index = self.sample_index - self.sample_index += 1 - group.append(sample) - self.sample_group_index += 1 - samples.append(group) - return samples - - def add_samples(self, samples: list[list[Sample]]): - raise RuntimeError(f"Cannot add samples to {self.__class__.__name__}. This is a read-only data source.") - - def save(self, rollout_id): - if not self.args.rollout_global_dataset: - return - - state_dict = { - "sample_offset": self.sample_offset, - "epoch_id": self.epoch_id, - "sample_group_index": self.sample_group_index, - "sample_index": self.sample_index, - "metadata": self.metadata, - } - path = os.path.join(self.args.save, f"rollout/global_dataset_state_dict_{rollout_id}.pt") - os.makedirs(os.path.dirname(path), exist_ok=True) - torch.save(state_dict, path) - - def load(self, rollout_id=None): - if not self.args.rollout_global_dataset: - return - - if self.args.load is None: - return - - path = os.path.join(self.args.load, f"rollout/global_dataset_state_dict_{rollout_id}.pt") - if not os.path.exists(path): - logger.info(f"Checkpoint {path} does not exist.") - return - - logger.info(f"load metadata from {path}") - logger.info(f"load metadata: {self.metadata}") - state_dict = torch.load(path) - self.sample_offset = state_dict.get("sample_offset", 0) - self.epoch_id = state_dict.get("epoch_id", 0) - self.sample_group_index = state_dict.get("sample_group_index", 0) - self.sample_index = state_dict.get("sample_index", 0) - self.metadata = state_dict.get("metadata", {}) - - if self.args.rollout_global_dataset and self.args.rollout_shuffle: - self.dataset.shuffle(self.epoch_id) - - -class RolloutDataSourceWithBuffer(RolloutDataSource): - def __init__(self, args): - super().__init__(args) - self.buffer = [] - if self.args.buffer_filter_path is None: - self.buffer_filter = pop_first - else: - self.buffer_filter = load_function(self.args.buffer_filter_path) - - def get_samples(self, num_samples: int) -> list[list[Sample]]: - """ - Return num_samples samples - """ - - samples = self._get_samples_from_buffer(num_samples) - num_samples -= len(samples) - - if num_samples == 0: - return samples - - samples += super().get_samples(num_samples=num_samples) - return samples - - def _get_samples_from_buffer(self, num_samples: int) -> list[list[Sample]]: - if len(self.buffer) == 0 or num_samples == 0: - return [] - - samples = self.buffer_filter(self.args, None, self.buffer, num_samples) - return samples - - def add_samples(self, samples: list[list[Sample]]): - """ - Add a sample group to buffer. - """ - if not samples: - return - assert isinstance(samples, list), f"samples must be a list, got {type(samples)}" - assert isinstance(samples[0], list), f"the elements of samples must be list, got {type(samples[0])}" - for i in range(0, len(samples)): - assert ( - len(samples[i]) == self.args.n_samples_per_prompt - ), f"the length of the elements of samples must be equal to n_samples_per_prompt, got {len(samples[i])} != {self.args.n_samples_per_prompt}" - group = samples[i] # type: ignore - self.buffer.append(group) - - # TODO remove - def update_metadata(self, metadata: dict): - self.metadata.update(metadata) - - # TODO remove - def get_metadata(self): - return self.metadata - - def get_buffer_length(self): - return len(self.buffer) - - -def pop_first(args, rollout_id, buffer: list[list[Sample]], num_samples: int) -> list[list[Sample]]: - num_to_pop = min(len(buffer), num_samples) - samples = buffer[:num_to_pop] - del buffer[:num_to_pop] - return samples diff --git a/miles/ray/train_actor.py b/miles/ray/train_actor.py index 3d3923e9c..81e3fae00 100644 --- a/miles/ray/train_actor.py +++ b/miles/ray/train_actor.py @@ -113,7 +113,7 @@ def train(self, rollout_id, rollout_data_ref): raise NotImplementedError @abc.abstractmethod - def save_model(self, iteration): + def save_model(self, rollout_id, force_sync=False): raise NotImplementedError @abc.abstractmethod diff --git a/miles/rollout/base_types.py b/miles/rollout/base_types.py index faa85c726..c2644e87f 100644 --- a/miles/rollout/base_types.py +++ b/miles/rollout/base_types.py @@ -1,22 +1,86 @@ +from __future__ import annotations + +from argparse import Namespace from dataclasses import dataclass -from typing import Any +from typing import TYPE_CHECKING, Any +from miles.rollout.data_source import DataSource from miles.utils.types import Sample +if TYPE_CHECKING: + from miles.rollout.inference_rollout.inference_rollout_common import GenerateState + + +@dataclass(frozen=True) +class RolloutFnConstructorInput: + args: Namespace + # TODO may refactor DataSource API + data_source: DataSource + + +@dataclass(frozen=True) +class RolloutFnBaseInput: + rollout_id: int + + @property + def evaluation(self): + raise NotImplementedError + + +# subclassing for different data in the future +@dataclass(frozen=True) +class RolloutFnTrainInput(RolloutFnBaseInput): + @property + def evaluation(self): + return False + +@dataclass(frozen=True) +class RolloutFnEvalInput(RolloutFnBaseInput): + @property + def evaluation(self): + return True + + +# TODO make it frozen @dataclass class RolloutFnTrainOutput: samples: list[list[Sample]] metrics: dict[str, Any] = None +# TODO make it frozen @dataclass class RolloutFnEvalOutput: data: dict[str, dict[str, Any]] metrics: dict[str, Any] = None +RolloutFnInput = RolloutFnTrainInput | RolloutFnEvalInput +RolloutFnOutput = RolloutFnTrainOutput | RolloutFnEvalOutput + + +@dataclass(frozen=True) +class GenerateFnInput: + state: GenerateState + sample: Sample + sampling_params: dict[str, Any] + evaluation: bool + + @property + def args(self) -> Namespace: + return self.state.args + + +@dataclass(frozen=True) +class GenerateFnOutput: + # One generate may lead to multiple samples, such as multi-agent, tree-like exploration, or + # multi-turn with removing thinking tokens. + samples: Sample | list[Sample] + + def call_rollout_fn(fn, *args, evaluation: bool, **kwargs): + """Legacy rollout function call interface. Used when MILES_EXPERIMENTAL_ROLLOUT_REFACTOR is disabled.""" output = fn(*args, **kwargs, evaluation=evaluation) # compatibility for legacy version diff --git a/miles/rollout/filter_hub/base_types.py b/miles/rollout/filter_hub/base_types.py index ba1a4441c..2937273bd 100644 --- a/miles/rollout/filter_hub/base_types.py +++ b/miles/rollout/filter_hub/base_types.py @@ -1,3 +1,4 @@ +from collections import defaultdict from dataclasses import dataclass @@ -5,3 +6,32 @@ class DynamicFilterOutput: keep: bool reason: str | None = None + + +def call_dynamic_filter(fn, *args, **kwargs): + if fn is None: + return DynamicFilterOutput(keep=True) + + output = fn(*args, **kwargs) + + # compatibility for legacy version + if not isinstance(output, DynamicFilterOutput): + output = DynamicFilterOutput(keep=output) + + return output + + +class MetricGatherer: + def __init__(self): + self._dynamic_filter_drop_reason_count = defaultdict(lambda: 0) + + def on_dynamic_filter_drop(self, reason: str | None): + if not reason: + return + self._dynamic_filter_drop_reason_count[reason] += 1 + + def collect(self): + return { + f"rollout/dynamic_filter/drop_{reason}": count + for reason, count in self._dynamic_filter_drop_reason_count.items() + } diff --git a/miles/rollout/generate_hub/agentic_tool_call.py b/miles/rollout/generate_hub/agentic_tool_call.py new file mode 100644 index 000000000..05223a654 --- /dev/null +++ b/miles/rollout/generate_hub/agentic_tool_call.py @@ -0,0 +1,85 @@ +""" +Simple agentic demo with tool calling. +""" + +import argparse +from copy import deepcopy +from typing import Any + +from openai import AsyncOpenAI + +from miles.rollout.base_types import GenerateFnInput, GenerateFnOutput +from miles.rollout.generate_utils.openai_endpoint_utils import ( + OpenAIEndpointTracer, + compute_samples_from_openai_records, +) +from miles.rollout.generate_utils.sample_utils import merge_samples +from miles.rollout.generate_utils.tool_call_utils import execute_tool_calls +from miles.utils.misc import load_function + + +async def generate(input: GenerateFnInput) -> GenerateFnOutput: + tracer = await OpenAIEndpointTracer.create(input.args) + + await _run_blackbox_tool_call_agent( + base_url=tracer.base_url, + prompt=input.sample.prompt, + max_turns=input.args.generate_max_turns, + tool_specs_path=input.args.generate_tool_specs_path, + execute_tool_function_path=input.args.generate_execute_tool_function_path, + ) + + records = await tracer.collect_records() + samples = compute_samples_from_openai_records(input.sample, records, input.state.tokenizer) + if not input.args.generate_multi_samples: + samples = merge_samples(samples, input.state.tokenizer) + return GenerateFnOutput(samples=samples) + + +def _add_arguments(parser: argparse.ArgumentParser): + parser.add_argument("--generate-max-turns", type=int, default=16) + parser.add_argument("--generate-tool-specs-path", type=str) + parser.add_argument("--generate-execute-tool-function-path", type=str) + parser.add_argument("--generate-multi-samples", action="store_true") + + +generate.add_arguments = _add_arguments + + +async def _run_blackbox_tool_call_agent( + base_url: str, + prompt: list[dict[str, Any]], + max_turns: int, + tool_specs_path: str, + execute_tool_function_path: str, +): + """ + Imagine this is a black-box agent, e.g. SWE-agent, which does arbitrarily complex work, + only understands OpenAI compatible API, and never understands Miles or the Sample data structure. + """ + + # ----------------------- Setup ------------------------- + + client = AsyncOpenAI(base_url=base_url, api_key="empty") + execute_tool_function = load_function(execute_tool_function_path) + tool_specs = load_function(tool_specs_path) + + # ----------------------- Initial prompts ------------------------- + + messages = deepcopy(prompt) + + for _turn in range(max_turns): + # ----------------------- Call inference endpoint ------------------------- + + response = await client.chat.completions.create(model="default", messages=messages, tools=tool_specs) + + choice = response.choices[0] + messages.append(choice.message.model_dump()) + + if choice.finish_reason in ("stop", "length"): + break + + # ----------------------- Execute tools ------------------------- + + if x := choice.message.tool_calls: + messages += await execute_tool_calls(x, execute_tool_function) diff --git a/miles/rollout/generate_hub/multi_turn.py b/miles/rollout/generate_hub/multi_turn.py new file mode 100644 index 000000000..97814ecb3 --- /dev/null +++ b/miles/rollout/generate_hub/multi_turn.py @@ -0,0 +1,88 @@ +""" +Simple multi-turn generation with tool calling. +""" + +import argparse +from copy import deepcopy + +from miles.rollout.base_types import GenerateFnInput, GenerateFnOutput +from miles.rollout.generate_utils.generate_endpoint_utils import ( + compute_prompt_ids_from_sample, + compute_request_payload, + update_sample_from_response, +) +from miles.rollout.generate_utils.tool_call_utils import ( + create_tool_call_parser, + execute_tool_calls, + update_sample_with_tool_responses, +) +from miles.utils.http_utils import post +from miles.utils.misc import load_function + + +async def generate(input: GenerateFnInput) -> GenerateFnOutput: + # ----------------------- Setup ------------------------- + + args = input.args + sample = deepcopy(input.sample) + tokenizer = input.state.tokenizer + assert not args.partial_rollout, "Partial rollout is not supported" + + url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/generate" + + execute_tool_function = load_function(args.generate_execute_tool_function_path) + + tool_specs = load_function(args.generate_tool_specs_path) + tool_call_parser = create_tool_call_parser(tool_specs, args.generate_tool_call_parser) + + multi_samples = [] + + # ----------------------- Initial prompts ------------------------- + + prompt_tokens_ids = compute_prompt_ids_from_sample(input.state, sample, tools=tool_specs) + + sample.tokens = prompt_tokens_ids.copy() + + for _turn in range(args.generate_max_turns): + # ----------------------- Call inference endpoint ------------------------- + + payload, halt_status = compute_request_payload(args, sample.tokens, input.sampling_params) + if payload is None: + sample.status = halt_status + if args.generate_multi_samples and multi_samples: + multi_samples[-1].status = halt_status + break + + if args.generate_multi_samples: + sample = deepcopy(input.sample) + + output = await post(url, payload) + await update_sample_from_response(args, sample, payload=payload, output=output, update_loss_mask=True) + + if args.generate_multi_samples: + multi_samples.append(deepcopy(sample)) + + if output["meta_info"]["finish_reason"]["type"] in ("abort", "length"): + break + + # ----------------------- Execute tools ------------------------- + + _, tool_calls = tool_call_parser.parse_non_stream(output["text"]) + if len(tool_calls) == 0: + break + + tool_messages = await execute_tool_calls(tool_calls, execute_tool_function) + update_sample_with_tool_responses(sample, tool_messages, tokenizer=tokenizer) + + return GenerateFnOutput(samples=multi_samples if args.generate_multi_samples else sample) + + +def _add_arguments(parser: argparse.ArgumentParser): + parser.add_argument("--generate-max-turns", type=int, default=16) + parser.add_argument("--generate-tool-specs-path", type=str) + parser.add_argument("--generate-tool-call-parser", type=str) + parser.add_argument("--generate-execute-tool-function-path", type=str) + parser.add_argument("--generate-multi-samples", action="store_true") + + +generate.add_arguments = _add_arguments diff --git a/miles/rollout/generate_hub/single_turn.py b/miles/rollout/generate_hub/single_turn.py new file mode 100644 index 000000000..5c0a15b5b --- /dev/null +++ b/miles/rollout/generate_hub/single_turn.py @@ -0,0 +1,46 @@ +""" +Simple single-turn generation. +""" + +from miles.rollout.base_types import GenerateFnInput, GenerateFnOutput +from miles.rollout.generate_utils.generate_endpoint_utils import ( + compute_prompt_ids_from_sample, + compute_request_payload, + update_sample_from_response, +) +from miles.utils.http_utils import post +from miles.utils.types import Sample + + +async def generate(input: GenerateFnInput) -> GenerateFnOutput: + args = input.args + sample = input.sample + sampling_params = input.sampling_params + assert sample.status in {Sample.Status.PENDING, Sample.Status.ABORTED}, f"{sample.status=}" + url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/generate" + + prompt_ids = compute_prompt_ids_from_sample(input.state, sample) + + # Handle Partial Rollout resuming + if len(sample.response) > 0: + input_ids = sample.tokens + sampling_params["max_new_tokens"] -= len(sample.tokens) - len(prompt_ids) + + assert sampling_params["max_new_tokens"] >= 0 + if sampling_params["max_new_tokens"] == 0: + sample.status = Sample.Status.TRUNCATED + return GenerateFnOutput(samples=sample) + else: + input_ids = prompt_ids + + payload, halt_status = compute_request_payload( + args, input_ids=input_ids, sampling_params=sampling_params, multimodal_inputs=sample.multimodal_inputs + ) + if payload is None: + sample.status = halt_status + return GenerateFnOutput(samples=sample) + + output = await post(url, payload) + await update_sample_from_response(args, sample, payload=payload, output=output) + + return GenerateFnOutput(samples=sample) diff --git a/miles/rollout/generate_utils/__init__.py b/miles/rollout/generate_utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/miles/rollout/generate_utils/generate_endpoint_utils.py b/miles/rollout/generate_utils/generate_endpoint_utils.py new file mode 100644 index 000000000..a91d71f1d --- /dev/null +++ b/miles/rollout/generate_utils/generate_endpoint_utils.py @@ -0,0 +1,112 @@ +""" +Utils to integrate SGLang's `/generate` endpoint with RL things like Sample. +""" + +from copy import deepcopy +from typing import Any + +import numpy as np +import pybase64 + +from miles.utils.processing_utils import encode_image_for_rollout_engine +from miles.utils.types import Sample + + +# Make this an isolated function because users may want to compute their own +def compute_prompt_ids_from_sample(state, sample, tools=None): + prompt = sample.prompt + + if state.processor: + processor_output = state.processor(text=prompt, **sample.multimodal_inputs) + prompt_ids = processor_output["input_ids"][0] + + # TODO shall we move it to other places? then can make this function immutable + sample.multimodal_train_inputs = { + k: v for k, v in processor_output.items() if k not in ["input_ids", "attention_mask"] + } or None + + return prompt_ids + else: + if not isinstance(prompt, str): + prompt = state.tokenizer.apply_chat_template( + prompt, tokenize=False, add_generation_prompt=True, tools=tools + ) + + return state.tokenizer.encode(prompt, add_special_tokens=False) + + +def compute_request_payload( + args, + input_ids: list[int], + sampling_params: dict, + multimodal_inputs: dict | None = None, +) -> tuple[dict[str, Any] | None, Sample.Status | None]: + sampling_params = deepcopy(sampling_params) + max_new_tokens = sampling_params.pop("max_new_tokens", args.rollout_max_response_len) + if x := args.rollout_max_context_len: + max_new_tokens = min(max_new_tokens, x - len(input_ids)) + if max_new_tokens <= 0: + return None, Sample.Status.TRUNCATED + + payload = { + "input_ids": input_ids, + "sampling_params": {**sampling_params, "max_new_tokens": max_new_tokens}, + "return_logprob": True, + "return_routed_experts": args.use_rollout_routing_replay, + } + if image_data := (multimodal_inputs or {}).get("images"): + payload["image_data"] = [encode_image_for_rollout_engine(image) for image in image_data] + + return payload, None + + +async def update_sample_from_response( + args, sample: Sample, payload: dict, output: dict, update_loss_mask: bool = False +): + # Initialize sample.tokens for the first turn + if (len(sample.response) == 0) and not sample.tokens: + sample.tokens = payload["input_ids"] + + if args.use_miles_router and "RadixTreeMiddleware" in args.miles_router_middleware_paths: + from miles.router.middleware_hub.radix_tree_middleware import postprocess_sample_with_radix_tree + + # TODO may rename to match + await postprocess_sample_with_radix_tree(args, sample, output) + + assert not update_loss_mask, "This code branch has not implemented update_loss_mask" + else: + if x := output["meta_info"].get("output_token_logprobs"): + new_response_tokens = [item[1] for item in x] + new_response_log_probs = [item[0] for item in x] + else: + new_response_tokens, new_response_log_probs = [], [] + + # Update sample with tokens directly - avoiding re-tokenization + sample.tokens = sample.tokens + new_response_tokens + sample.response_length += len(new_response_tokens) + sample.response += output["text"] + + if sample.rollout_log_probs is None: + sample.rollout_log_probs = [] + sample.rollout_log_probs += new_response_log_probs + + if update_loss_mask: + if sample.loss_mask is None: + sample.loss_mask = [] + sample.loss_mask += [1] * len(new_response_tokens) + + # TODO handle multi-turn cases (may need concat instead of assignment) + sample.rollout_routed_experts = _get_rollout_routed_experts_from_response(args, sample, output) + + # TODO may unify (currently there are both methods inside Sample and separate functions) + sample.update_from_meta_info(args, output["meta_info"]) + + +def _get_rollout_routed_experts_from_response(args, sample, output): + info = output["meta_info"].get("routed_experts") + if info is None: + return None + + x = np.frombuffer(pybase64.b64decode(info.encode("ascii")), dtype=np.int32) + x = x.reshape(len(sample.tokens) - 1, args.num_layers, args.moe_router_topk) + return x diff --git a/miles/rollout/generate_utils/openai_endpoint_utils.py b/miles/rollout/generate_utils/openai_endpoint_utils.py new file mode 100644 index 000000000..73ba8198b --- /dev/null +++ b/miles/rollout/generate_utils/openai_endpoint_utils.py @@ -0,0 +1,67 @@ +""" +Utilities for the OpenAI endpoint +""" + +import logging +from argparse import Namespace +from copy import deepcopy + +from miles.router.sessions import GetSessionResponse, SessionRecord +from miles.utils.http_utils import post +from miles.utils.types import Sample + +logger = logging.getLogger(__name__) + + +class OpenAIEndpointTracer: + def __init__(self, router_url: str, session_id: str): + self.router_url = router_url + self.session_id = session_id + self.base_url = f"{router_url}/sessions/{session_id}/v1" + + @staticmethod + async def create(args: Namespace): + router_url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}" + session_id = (await post(f"{router_url}/sessions", {}))["session_id"] + return OpenAIEndpointTracer(router_url=router_url, session_id=session_id) + + async def collect_records(self) -> list[SessionRecord]: + response = await post(f"{self.router_url}/sessions/{self.session_id}", {}, action="get") + response = GetSessionResponse.model_validate(response) + records = response.records + + try: + await post(f"{self.router_url}/sessions/{self.session_id}", {}, action="delete") + except Exception as e: + logger.warning(f"Failed to delete session {self.session_id} after collecting records: {e}") + + return records + + +def compute_samples_from_openai_records(input_sample: Sample, records: list[SessionRecord], tokenizer) -> list[Sample]: + return [_compute_sample_from_openai_record(input_sample, record, tokenizer) for record in records] + + +def _compute_sample_from_openai_record(input_sample: Sample, record: SessionRecord, tokenizer) -> Sample: + # TODO may refine after @guapisolo's implementation + choice = record.response["choices"][0] + output_token_ids = [item["token_id"] for item in choice["logprobs"]["content"]] + output_log_probs = [item["logprob"] for item in choice["logprobs"]["content"]] + + sample = deepcopy(input_sample) + sample.tokens = record.request["input_ids"] + output_token_ids + sample.rollout_log_probs = output_log_probs + sample.response = tokenizer.decode(output_token_ids) + sample.response_length = len(output_token_ids) + sample.loss_mask = [1] * len(output_token_ids) + + # TODO unify with Sample.update_from_meta_info + match choice["finish_reason"]: + case "stop" | "tool_calls": + sample.status = Sample.Status.COMPLETED + case "length": + sample.status = Sample.Status.TRUNCATED + case "abort": + sample.status = Sample.Status.ABORTED + + return sample diff --git a/miles/rollout/generate_utils/sample_utils.py b/miles/rollout/generate_utils/sample_utils.py new file mode 100644 index 000000000..6a4e645be --- /dev/null +++ b/miles/rollout/generate_utils/sample_utils.py @@ -0,0 +1,115 @@ +from copy import deepcopy +from dataclasses import fields + +from miles.utils.types import Sample + + +def merge_samples(samples: list[Sample], tokenizer) -> Sample: + acc = samples[0] + for sample in samples[1:]: + acc = _merge_sample_pair(acc, sample, tokenizer=tokenizer) + return acc + + +def _merge_sample_pair(a: Sample, b: Sample, tokenizer) -> Sample: + """Merge two samples generated from sibling inference engine calls.""" + a, b = deepcopy(a), deepcopy(b) + + def _merge_equal_value(field): + x = getattr(a, field) + y = getattr(b, field) + assert x == y, f"{field} mismatch: a.{field}={x}, b.{field}={y}" + return x + + def _fill_defaults(sample: Sample): + if sample.loss_mask is None: + sample.loss_mask = [1] * sample.response_length + if sample.rollout_log_probs is None: + sample.rollout_log_probs = [0.0] * sample.response_length + + _fill_defaults(a) + _fill_defaults(b) + + obs_len = len(b.tokens) - len(a.tokens) - b.response_length + obs_tokens = b.tokens[len(a.tokens) : len(a.tokens) + obs_len] + # TODO: is this acceptable? + obs_text = tokenizer.decode(obs_tokens) + + try: + a.validate() + b.validate() + assert _startswith(short=a.prompt, long=b.prompt), "b.prompt must start with a.prompt" + assert _startswith(short=a.tokens, long=b.tokens), "b.tokens must start with a.tokens" + assert obs_len > 0, f"obs_len must be > 0, got {obs_len}" + if a.rollout_routed_experts is not None: + assert a.rollout_routed_experts.shape[0] <= b.rollout_routed_experts.shape[0] + assert a.status == Sample.Status.COMPLETED, f"a.status must be COMPLETED, got {a.status}" + + return _create_with_all_fields( + Sample, + group_index=_merge_equal_value("group_index"), + index=_merge_equal_value("index"), + prompt=b.prompt, + tokens=b.tokens, + multimodal_inputs=_merge_equal_value("multimodal_inputs"), + multimodal_train_inputs=_merge_equal_value("multimodal_train_inputs"), + response=a.response + obs_text + b.response, + response_length=a.response_length + obs_len + b.response_length, + label=_merge_equal_value("label"), + reward=_merge_equal_value("reward"), + loss_mask=a.loss_mask + [0] * obs_len + b.loss_mask, + weight_versions=a.weight_versions + b.weight_versions, + rollout_log_probs=a.rollout_log_probs + [0.0] * obs_len + b.rollout_log_probs, + rollout_routed_experts=b.rollout_routed_experts, + remove_sample=_merge_equal_value("remove_sample"), + status=b.status, + metadata=_merge_equal_value("metadata"), + train_metadata=_merge_equal_value("train_metadata"), + non_generation_time=_merge_equal_value("non_generation_time"), + spec_info=_merge_spec_info(a.spec_info, b.spec_info), + prefix_cache_info=_merge_prefix_cache_info(a.prefix_cache_info, b.prefix_cache_info), + ) + except AssertionError as e: + e.add_note(f"{a=} {b=}") + raise + + +def _merge_spec_info(a: Sample.SpecInfo, b: Sample.SpecInfo) -> Sample.SpecInfo: + def _merge_plus_value(field): + return getattr(a, field) + getattr(b, field) + + return _create_with_all_fields( + Sample.SpecInfo, + spec_accept_token_num=_merge_plus_value("spec_accept_token_num"), + spec_draft_token_num=_merge_plus_value("spec_draft_token_num"), + spec_verify_ct=_merge_plus_value("spec_verify_ct"), + completion_token_num=_merge_plus_value("completion_token_num"), + ) + + +def _merge_prefix_cache_info(a: Sample.PrefixCacheInfo, b: Sample.PrefixCacheInfo) -> Sample.PrefixCacheInfo: + def _merge_plus_value(field): + return getattr(a, field) + getattr(b, field) + + return _create_with_all_fields( + Sample.PrefixCacheInfo, + cached_tokens=_merge_plus_value("cached_tokens"), + total_prompt_tokens=_merge_plus_value("total_prompt_tokens"), + ) + + +def _create_with_all_fields(cls, **kwargs): + expected = {f.name for f in fields(cls)} + actual = set(kwargs.keys()) + assert ( + expected == actual + ), f"{cls.__name__} field mismatch. Missing: {expected - actual}, Extra: {actual - expected}" + return cls(**kwargs) + + +def _startswith(*, short, long) -> bool: + if isinstance(short, str) and isinstance(long, str): + return long.startswith(short) + if isinstance(short, list) and isinstance(long, list): + return (len(long) >= len(short)) and (long[: len(short)] == short) + raise NotImplementedError diff --git a/miles/rollout/generate_utils/tool_call_utils.py b/miles/rollout/generate_utils/tool_call_utils.py new file mode 100644 index 000000000..85ea87aea --- /dev/null +++ b/miles/rollout/generate_utils/tool_call_utils.py @@ -0,0 +1,115 @@ +""" +Utils to handle tool calls. +""" + +import json +import uuid +from collections.abc import Callable +from typing import Any + +from openai.types.chat import ChatCompletionMessageToolCall +from pydantic import TypeAdapter +from sglang.srt.entrypoints.openai.protocol import Tool +from sglang.srt.function_call.core_types import ToolCallItem +from sglang.srt.function_call.function_call_parser import FunctionCallParser + +from miles.utils.types import Sample + +_DUMMY_USER = {"role": "user", "content": "dummy"} + + +def create_tool_call_parser(tool_specs, tool_call_parser): + return FunctionCallParser( + tools=TypeAdapter(list[Tool]).validate_python(tool_specs), + tool_call_parser=tool_call_parser, + ) + + +async def execute_tool_calls( + tool_calls: list[ToolCallItem | ChatCompletionMessageToolCall], + execute_one: Callable, +) -> list[dict[str, Any]]: + tool_messages = [] + for call in tool_calls: + tool_messages.append(await _execute_tool_call(call, execute_one)) + return tool_messages + + +async def _execute_tool_call( + call: ToolCallItem | ChatCompletionMessageToolCall, execute_one: Callable +) -> dict[str, Any]: + if isinstance(call, ChatCompletionMessageToolCall): + name = call.function.name + params = json.loads(call.function.arguments) if call.function.arguments else {} + tool_call_id = call.id + elif isinstance(call, ToolCallItem): + name = call.name + params = json.loads(call.parameters) if call.parameters else {} + tool_call_id = f"call_{uuid.uuid4().hex[:24]}" + else: + raise TypeError(f"Unsupported tool call type: {type(call)}") + + result = await execute_one(name, params) + assert isinstance(result, str) + + return {"role": "tool", "tool_call_id": tool_call_id, "content": result, "name": name} + + +def update_sample_with_tool_responses(sample: Sample, tool_messages: list[dict[str, Any]], tokenizer): + next_obs_tokens_ids: list[int] = tokenize_tool_responses(tool_messages, tokenizer=tokenizer) + sample.response += tokenizer.decode(next_obs_tokens_ids) + sample.response_length += len(next_obs_tokens_ids) + sample.tokens += next_obs_tokens_ids + sample.loss_mask += [0] * len(next_obs_tokens_ids) + sample.rollout_log_probs += [0.0] * len(next_obs_tokens_ids) + + +# TODO: very naive implementation, need the to-be-implemented e2e test to validate. +def tokenize_tool_responses( + tool_messages: list[dict[str, Any]], + tokenizer, +) -> list[int]: + return _tokenize_postfix_messages(tool_messages, tokenizer) + + +def _tokenize_postfix_messages( + postfix_messages: list[dict[str, Any]], + tokenizer, +) -> list[int]: + dummy_assistant = _build_dummy_assistant(postfix_messages) + base_messages = [_DUMMY_USER, dummy_assistant] + + messages_without = base_messages + messages_with = base_messages + postfix_messages + + tokens_with = tokenizer.apply_chat_template(messages_with, tokenize=True, add_generation_prompt=True) + tokens_without = tokenizer.apply_chat_template(messages_without, tokenize=True, add_generation_prompt=False) + + assert tokens_with[: len(tokens_without)] == tokens_without, ( + f"Fail to tokenize_tool_responses caused by token prefix mismatch. " + f"This can happen for thinking model or models with special chat template, " + f"and this simple example does not support it yet, " + f"since this means we cannot have a append-only token id list. " + f"{tokens_with=} {tokens_without=} " + f"{tokenizer.decode(tokens_with)=} {tokenizer.decode(tokens_without)=} " + ) + return tokens_with[len(tokens_without) :] + + +def _build_dummy_assistant(tool_responses: list[dict[str, Any]]) -> dict[str, Any]: + return { + "role": "assistant", + "content": "", + "reasoning_content": " ", + "tool_calls": [ + { + "id": resp.get("tool_call_id", f"call0000{i}"), + "type": "function", + "function": { + "name": resp.get("name", "dummy_func"), + "arguments": {}, + }, + } + for i, resp in enumerate(tool_responses) + ], + } diff --git a/miles/rollout/inference_rollout/__init__.py b/miles/rollout/inference_rollout/__init__.py new file mode 100644 index 000000000..33ccf17bf --- /dev/null +++ b/miles/rollout/inference_rollout/__init__.py @@ -0,0 +1,2 @@ +# This is a refactor of the portions above generate-function in sglang_rollout.py, +# and is give a different name to ensure both code exist at the same time. diff --git a/miles/rollout/inference_rollout/compatibility.py b/miles/rollout/inference_rollout/compatibility.py new file mode 100644 index 000000000..7711e0dd3 --- /dev/null +++ b/miles/rollout/inference_rollout/compatibility.py @@ -0,0 +1,84 @@ +import inspect +from collections.abc import Callable + +from miles.rollout.base_types import ( + GenerateFnInput, + GenerateFnOutput, + RolloutFnConstructorInput, + RolloutFnEvalOutput, + RolloutFnInput, + RolloutFnOutput, + RolloutFnTrainOutput, +) +from miles.utils.async_utils import run +from miles.utils.misc import load_function + + +class LegacyRolloutFnAdapter: + def __init__(self, input: RolloutFnConstructorInput, fn: Callable): + self.args = input.args + self.data_source = input.data_source + self.fn = fn + + def __call__(self, input: RolloutFnInput) -> RolloutFnOutput: + output = self.fn(self.args, input.rollout_id, self.data_source, evaluation=input.evaluation) + + # compatibility for legacy version + if not isinstance(output, (RolloutFnTrainOutput, RolloutFnEvalOutput)): + output = RolloutFnEvalOutput(data=output) if input.evaluation else RolloutFnTrainOutput(samples=output) + + return output + + +def load_rollout_function(input: RolloutFnConstructorInput, path: str): + fn = load_function(path) + + if inspect.isclass(fn): + return fn(input) + else: + return LegacyRolloutFnAdapter(input, fn) + + +def call_rollout_function(fn, input: RolloutFnInput) -> RolloutFnOutput: + output = fn(input) + + if inspect.iscoroutine(output): + output = run(output) + + return output + + +class LegacyGenerateFnAdapter: + def __init__(self, fn: Callable): + self.fn = fn + self._has_evaluation_param = "evaluation" in inspect.signature(fn).parameters + + async def __call__(self, input: GenerateFnInput) -> GenerateFnOutput: + if self._has_evaluation_param: + output = await self.fn(input.args, input.sample, input.sampling_params, evaluation=input.evaluation) + else: + output = await self.fn(input.args, input.sample, input.sampling_params) + + if not isinstance(output, GenerateFnOutput): + output = GenerateFnOutput(samples=output) + + return output + + +def load_generate_function(path: str): + fn = load_function(path) + if fn is None: + return None + + if inspect.isclass(fn): + return fn() + elif _is_legacy_generate_fn(fn): + return LegacyGenerateFnAdapter(fn) + else: + return fn + + +def _is_legacy_generate_fn(fn: Callable) -> bool: + sig = inspect.signature(fn) + params = list(sig.parameters.keys()) + return len(params) >= 3 and params[0] != "input" diff --git a/miles/rollout/inference_rollout/inference_rollout_common.py b/miles/rollout/inference_rollout/inference_rollout_common.py new file mode 100644 index 000000000..8518c6e02 --- /dev/null +++ b/miles/rollout/inference_rollout/inference_rollout_common.py @@ -0,0 +1,192 @@ +import asyncio +import logging +from argparse import Namespace +from copy import deepcopy +from typing import Any + +from miles.rollout.base_types import ( + GenerateFnInput, + RolloutFnConstructorInput, + RolloutFnEvalInput, + RolloutFnEvalOutput, + RolloutFnInput, + RolloutFnOutput, + RolloutFnTrainInput, + RolloutFnTrainOutput, +) +from miles.rollout.generate_hub.single_turn import generate +from miles.rollout.inference_rollout.compatibility import load_generate_function +from miles.rollout.rm_hub import async_rm, batched_async_rm +from miles.utils.processing_utils import load_processor, load_tokenizer +from miles.utils.types import Sample + +logger = logging.getLogger(__name__) + + +class GenerateState: + def __init__(self, args: Namespace) -> None: + # persistent state for the generation process + self.args = args + self.tokenizer = load_tokenizer(args.hf_checkpoint, trust_remote_code=True) + self.processor = load_processor(args.hf_checkpoint, trust_remote_code=True) + + self.generate_fn_semaphore = asyncio.Semaphore( + args.sglang_server_concurrency * args.rollout_num_gpus // args.rollout_num_gpus_per_engine + ) + self.sampling_params: dict[str, Any] = compute_sampling_params( + args, + temperature=args.rollout_temperature, + top_p=args.rollout_top_p, + top_k=args.rollout_top_k, + max_new_tokens=args.rollout_max_response_len, + ) + + self.generate_function = load_generate_function(args.custom_generate_function_path) or generate + + self.reset() + + def reset(self) -> None: + self.aborted = False + + +async def generate_and_rm( + state: GenerateState, + sample: Sample | list[Sample], + sampling_params: dict[str, Any], + evaluation: bool = False, +) -> Sample | list[Sample]: + args = state.args + + # mask previous off-policy generation for partial rollout + if args.partial_rollout and args.mask_offpolicy_in_partial_rollout and sample.response_length > 0: + sample.loss_mask = [0] * sample.response_length + + # For samples with existing response, check if they're complete + if sample.status == Sample.Status.COMPLETED or sample.status == Sample.Status.TRUNCATED: + assert sample.response is not None + if not args.group_rm: + assert sample.reward is not None + return sample + + # generate + async with state.generate_fn_semaphore: + if state.aborted: + sample.status = Sample.Status.ABORTED + return sample + + output = await state.generate_function( + GenerateFnInput( + state=state, + sample=sample, + sampling_params=deepcopy(sampling_params), + evaluation=evaluation, + ) + ) + sample = output.samples + + # TODO change to `if not args.group_rm: do reward model` for more clarity after the refactor below + # for the rm that need the whole group, we will not do the rm here + if args.group_rm: + return sample + + # TODO: unify the two branches into one if we decide to use list as output type + # multi samples + if isinstance(sample, list): + samples = sample + if any([sample.status == Sample.Status.ABORTED for sample in samples]): + return samples + + # for multi agent system, the reward of some sample is calculated during generation. + samples_need_reward = [sample for sample in samples if sample.reward is None] + await batched_async_rm(args, samples_need_reward, inplace_set_reward_field=True) + return samples + else: + if sample.status == Sample.Status.ABORTED: + return sample + # for multi-turn environment, a reward could be assigned to the agent. + if sample.reward is None: + sample.reward = await async_rm(args, sample) + + return sample + + +async def generate_and_rm_group( + state: GenerateState, group: list[Sample], sampling_params: dict[str, Any], evaluation: bool = False +) -> list[Sample]: + args = state.args + + if state.aborted: + return group + + tasks = [] + for idx, sample in enumerate(group): + current_sampling_params = sampling_params.copy() + if getattr(args, "sglang_enable_deterministic_inference", False): + current_sampling_params["sampling_seed"] = args.rollout_seed + idx + tasks.append( + asyncio.create_task(generate_and_rm(state, sample, current_sampling_params, evaluation=evaluation)) + ) + + group = await asyncio.gather(*tasks) + if state.aborted: + return group + + if args.group_rm: + await batched_async_rm(args, group, inplace_set_reward_field=True) + + return group + + +def compute_sampling_params( + args, + *, + # after unifying configuration, this can be further refactored + temperature, + top_p, + top_k, + max_new_tokens, +): + return dict( + temperature=temperature, + top_p=top_p, + top_k=top_k, + max_new_tokens=max_new_tokens, + stop=args.rollout_stop, + stop_token_ids=args.rollout_stop_token_ids, + skip_special_tokens=args.rollout_skip_special_tokens, + no_stop_trim=True, + spaces_between_special_tokens=False, + ) + + +class InferenceRolloutFn: + def __init__(self, input: RolloutFnConstructorInput): + self.data_source = input.data_source + self.state = GenerateState(input.args) + self.eval_prompt_dataset_cache = {} + + async def __call__(self, input: RolloutFnInput) -> RolloutFnOutput: + if input.evaluation: + return await self._call_eval(input) + return await self._call_train(input) + + async def _call_train(self, input: RolloutFnTrainInput) -> RolloutFnTrainOutput: + from miles.rollout.inference_rollout.inference_rollout_train import generate_rollout_async + + output, aborted_samples = await generate_rollout_async( + self.state, input.rollout_id, self.data_source.get_samples + ) + self.data_source.add_samples(aborted_samples) + return output + + async def _call_eval(self, input: RolloutFnEvalInput) -> RolloutFnEvalOutput: + from miles.rollout.inference_rollout.inference_rollout_eval import eval_rollout_single_dataset + + assert not self.state.args.group_rm, "Group RM is not supported for eval rollout" + + coros = [] + for dataset_cfg in getattr(self.state.args, "eval_datasets", []) or []: + coros.append(eval_rollout_single_dataset(self.state, dataset_cfg, self.eval_prompt_dataset_cache)) + results_list = await asyncio.gather(*coros) + results = {k: v for r in results_list for k, v in r.items()} + return RolloutFnEvalOutput(data=results) diff --git a/miles/rollout/inference_rollout/inference_rollout_eval.py b/miles/rollout/inference_rollout/inference_rollout_eval.py new file mode 100644 index 000000000..2d052be0a --- /dev/null +++ b/miles/rollout/inference_rollout/inference_rollout_eval.py @@ -0,0 +1,112 @@ +import asyncio +import copy +import logging +from typing import Any + +from tqdm import tqdm + +from miles.rollout.inference_rollout.inference_rollout_common import ( + GenerateState, + compute_sampling_params, + generate_and_rm, +) +from miles.utils.data import Dataset +from miles.utils.eval_config import EvalDatasetConfig +from miles.utils.misc import as_completed_async +from miles.utils.processing_utils import load_processor, load_tokenizer +from miles.utils.types import Sample + +logger = logging.getLogger(__name__) + + +async def eval_rollout_single_dataset( + state: GenerateState, + dataset_cfg: EvalDatasetConfig, + prompt_dataset_cache: dict[Any, Dataset], +) -> dict[str, dict[str, list[Any]]]: + args = state.args + assert not args.group_rm, "Group RM is not supported for eval rollout" + + cache_key = dataset_cfg.cache_key + (args.hf_checkpoint, args.apply_chat_template) + if cache_key not in prompt_dataset_cache: + tokenizer = load_tokenizer(args.hf_checkpoint, trust_remote_code=True) + processor = load_processor(args.hf_checkpoint, trust_remote_code=True) + prompt_dataset_cache[cache_key] = Dataset( + path=dataset_cfg.path, + tokenizer=tokenizer, + processor=processor, + max_length=args.eval_max_prompt_len, + prompt_key=dataset_cfg.input_key, + label_key=dataset_cfg.label_key, + multimodal_keys=args.multimodal_keys, + metadata_key=dataset_cfg.metadata_key, + tool_key=dataset_cfg.tool_key, + apply_chat_template=args.apply_chat_template, + apply_chat_template_kwargs=args.apply_chat_template_kwargs, + ) + dataset = prompt_dataset_cache[cache_key] + + base_sampling_params = compute_sampling_params( + args, + temperature=dataset_cfg.temperature, + top_p=dataset_cfg.top_p, + top_k=dataset_cfg.top_k, + max_new_tokens=dataset_cfg.max_response_len, + ) + + tasks = [] + # do multiple samples for eval prompts + sample_index = 0 + for _i, prompt_sample in enumerate(dataset.samples): + for j in range(dataset_cfg.n_samples_per_eval_prompt): + # use the same prompt for multiple samples + sample = copy.deepcopy(prompt_sample) + sample.index = sample_index + sample_index += 1 + sample.metadata = dataset_cfg.inject_metadata(getattr(sample, "metadata", None)) + sampling_params = base_sampling_params + if getattr(args, "sglang_enable_deterministic_inference", False): + sampling_params = base_sampling_params.copy() + sampling_params["sampling_seed"] = args.rollout_seed + j + tasks.append( + asyncio.create_task( + generate_and_rm( + state, + sample, + sampling_params=sampling_params, + evaluation=True, + ) + ) + ) + + data = [] + do_print = True + pbar = tqdm(total=len(tasks), desc=f"Eval {dataset_cfg.name}", disable=not do_print) + async for sample in as_completed_async(tasks): + if do_print: + # TODO improve this after enhancing samples' type + s = (sample[0] if len(sample) > 0 else None) if isinstance(sample, list) else sample + if s is not None: + logger.info( + "eval_rollout_single_dataset example data: " + f"{[str(s.prompt) + s.response]} " + f"reward={s.reward}" + ) + do_print = False + if isinstance(sample, list): + data.extend(sample) + else: + data.append(sample) + pbar.update(1) + pbar.close() + + data.sort(key=lambda sample: sample.index) + + reward_key = args.eval_reward_key or args.reward_key + return { + dataset_cfg.name: { + "rewards": [sample.reward if not reward_key else sample.reward[reward_key] for sample in data], + "truncated": [sample.status == Sample.Status.TRUNCATED for sample in data], + "samples": data, + } + } diff --git a/miles/rollout/inference_rollout/inference_rollout_train.py b/miles/rollout/inference_rollout/inference_rollout_train.py new file mode 100644 index 000000000..bae94ec67 --- /dev/null +++ b/miles/rollout/inference_rollout/inference_rollout_train.py @@ -0,0 +1,146 @@ +import asyncio +import logging +from argparse import Namespace +from collections.abc import Callable + +import sglang_router +from packaging.version import parse +from tqdm import tqdm + +from miles.rollout.base_types import RolloutFnTrainOutput +from miles.rollout.filter_hub.base_types import MetricGatherer, call_dynamic_filter +from miles.rollout.inference_rollout.inference_rollout_common import GenerateState, generate_and_rm_group +from miles.utils.http_utils import get, post +from miles.utils.misc import as_completed_async, load_function +from miles.utils.types import Sample + +logger = logging.getLogger(__name__) + + +async def abort(state: GenerateState, pendings: set, rollout_id: int) -> list[list[Sample]]: + args = state.args + + assert not state.aborted + state.aborted = True + + urls = await get_worker_urls(args) + logger.info(f"Abort request for {urls}") + await asyncio.gather(*[post(f"{url}/abort_request", {"abort_all": True}) for url in urls]) + + # make sure all the pending tasks are finished + aborted_samples = [] + async for group in as_completed_async(pendings): + if not args.partial_rollout: + continue + + # for partial rollout, collect the partial samples into the data buffer + for sample in group: + if sample.response and "start_rollout_id" not in sample.metadata: + sample.metadata["start_rollout_id"] = rollout_id + aborted_samples.append(group) + + if args.partial_rollout: + logger.info(f"Collected {sum(len(x) for x in aborted_samples)} partial samples into the data buffer") + + return aborted_samples + + +async def get_worker_urls(args: Namespace): + if parse(sglang_router.__version__) <= parse("0.2.1") or args.use_miles_router: + response = await get(f"http://{args.sglang_router_ip}:{args.sglang_router_port}/list_workers") + return response["urls"] + else: + response = await get(f"http://{args.sglang_router_ip}:{args.sglang_router_port}/workers") + return [worker["url"] for worker in response["workers"]] + + +def submit_generate_tasks(state: GenerateState, samples: list[list[Sample]]): + return [ + asyncio.create_task( + # submit a group of samples as a single task. + generate_and_rm_group( + state, + group, + sampling_params=state.sampling_params.copy(), + evaluation=False, + ) + ) + for group in samples + ] + + +async def generate_rollout_async( + state: GenerateState, rollout_id: int, data_source: Callable[[int], list[list[Sample]]] +) -> tuple[RolloutFnTrainOutput, list[list[Sample]]]: + args = state.args + assert args.rollout_global_dataset + + # instantiate data filters + dynamic_filter = load_function(args.dynamic_sampling_filter_path) + + metric_gatherer = MetricGatherer() + + # target_data_size is the total number of valid samples to get + target_data_size = args.rollout_batch_size + + pendings = set() + data = [] + all_data = [] + do_print = True + pbar = tqdm(total=target_data_size * args.n_samples_per_prompt, desc="Rollout generation") + while len(data) < target_data_size: + while len(data) + len(pendings) < target_data_size: + # get samples from the buffer and submit the generation requests. + samples = data_source(args.over_sampling_batch_size) + pendings.update(submit_generate_tasks(state, samples)) + + # wait for the generation to finish + done, pendings = await asyncio.wait(pendings, return_when=asyncio.FIRST_COMPLETED) + for task in done: + group: list[Sample] = task.result() + + if do_print: + sample = group[0][0] if isinstance(group[0], list) else group[0] + logger.info( + f"First rollout sample: {[str(sample.prompt) + sample.response]}, label: {sample.label}, reward: {sample.reward}", + ) + do_print = False + + assert len(group) == args.n_samples_per_prompt + all_data.append(group) + dynamic_filter_output = call_dynamic_filter(dynamic_filter, args, group) + if not dynamic_filter_output.keep: + metric_gatherer.on_dynamic_filter_drop(reason=dynamic_filter_output.reason) + continue + + # add the samples to the data + # NOTE: here we have not stored all the unused samples back to the data buffer. + if len(data) < target_data_size: + data.append(group) + pbar.update(args.n_samples_per_prompt) + + pbar.close() + sample = data[-1][0][0] if isinstance(data[-1][0], list) else data[-1][0] + logger.info( + f"Finish rollout: {[str(sample.prompt) + sample.response]}, label: {sample.label}, reward: {sample.reward}", + ) + + # there are still some unfinished requests, abort them + aborted_samples = await abort(state, pendings, rollout_id) + + assert len(data) == args.rollout_batch_size, f"Got {len(data)} samples, expected {args.rollout_batch_size}" + data = sorted(data, key=lambda group: group[0][0].index if isinstance(group[0], list) else group[0].index) + all_samples = sorted( + all_data, key=lambda group: group[0][0].index if isinstance(group[0], list) else group[0].index + ) + + # reset the global state to prevent effects on the next rollout or eval. + state.reset() + + if f := load_function(args.rollout_sample_filter_path): + f(args, data) + # There can be circumstances where users want to process all samples including filtered ones. + if f := load_function(args.rollout_all_samples_process_path): + f(args, all_samples, data_source) + + return RolloutFnTrainOutput(samples=data, metrics=metric_gatherer.collect()), aborted_samples diff --git a/miles/rollout/rm_hub/__init__.py b/miles/rollout/rm_hub/__init__.py index 62b253dde..e9ee29db4 100644 --- a/miles/rollout/rm_hub/__init__.py +++ b/miles/rollout/rm_hub/__init__.py @@ -69,8 +69,18 @@ async def async_rm(args, sample: Sample, **kwargs): async def batched_async_rm( args, samples: list[Sample], + inplace_set_reward_field: bool = False, **kwargs, -) -> list[int | float]: +) -> list[int | float] | None: + if inplace_set_reward_field: + rewards = await batched_async_rm(args, samples, **kwargs) + for sample, reward in zip(samples, rewards, strict=True): + assert ( + sample.reward is None + ), f"Overriding sample.reward from {sample.reward} to {reward}, is this intended?" + sample.reward = reward + return None + if args.custom_rm_path is not None: # Ensure the custom reward function is implemented in batch mode rm_function = load_function(args.custom_rm_path) diff --git a/miles/rollout/rm_hub/math_utils.py b/miles/rollout/rm_hub/math_utils.py index cab786797..94ec98b92 100644 --- a/miles/rollout/rm_hub/math_utils.py +++ b/miles/rollout/rm_hub/math_utils.py @@ -18,7 +18,7 @@ def mathd_normalize_answer(answer: str | None) -> str | None: answer = answer.strip() try: # Remove enclosing `\text{}`. - m = re.search("^\\\\text\{(?P.+?)\}$", answer) + m = re.search(r"^\\text\{(?P.+?)\}$", answer) if m is not None: answer = m.group("text").strip() return _strip_string(answer) @@ -124,7 +124,7 @@ def _fix_sqrt(string): # remove percentage string = string.replace("\\%", "") - string = string.replace("\%", "") + string = string.replace(r"\%", "") # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string string = string.replace(" .", " 0.") @@ -161,7 +161,7 @@ def _fix_sqrt(string): # sympy might hang -- we don't care about trying to be lenient in these cases BAD_SUBSTRINGS = ["^{", "^("] -BAD_REGEXES = ["\^[0-9]+\^", "\^[0-9][0-9]+"] +BAD_REGEXES = [r"\^[0-9]+\^", r"\^[0-9][0-9]+"] TUPLE_CHARS = "()[]" @@ -220,7 +220,7 @@ def _str_is_int(x: str) -> bool: return False -def _str_to_int(x: str) -> bool: +def _str_to_int(x: str) -> int: x = x.replace(",", "") x = float(x) return int(x) @@ -238,7 +238,7 @@ def _inject_implicit_mixed_number(step: str): def _strip_properly_formatted_commas(expr: str): # We want to be careful because we don't want to strip tuple commas - p1 = re.compile("(\d)(,)(\d\d\d)($|\D)") + p1 = re.compile(r"(\d)(,)(\d\d\d)($|\D)") while True: next_expr = p1.sub("\\1\\3\\4", expr) if next_expr == expr: @@ -253,7 +253,7 @@ def _normalize(expr: str) -> str: return None # Remove enclosing `\text{}`. - m = re.search("^\\\\text\{(?P.+?)\}$", expr) + m = re.search(r"^\\text\{(?P.+?)\}$", expr) if m is not None: expr = m.group("text") @@ -286,8 +286,8 @@ def _normalize(expr: str) -> str: "inch", "yard", ]: - expr = re.sub(f"{unit}(es)?(s)? *(\^[0-9]+)?", "", expr) - expr = re.sub("\^ *\\\\circ", "", expr) + expr = re.sub(rf"{unit}(es)?(s)? *(\^[0-9]+)?", "", expr) + expr = re.sub(r"\^ *\\circ", "", expr) if len(expr) > 0 and expr[0] == "{" and expr[-1] == "}": expr = expr[1:-1] diff --git a/miles/rollout/sft_rollout.py b/miles/rollout/sft_rollout.py index 1e8a96c85..8669e380f 100644 --- a/miles/rollout/sft_rollout.py +++ b/miles/rollout/sft_rollout.py @@ -1,8 +1,7 @@ import logging -from transformers import AutoTokenizer - from miles.utils.mask_utils import MultiTurnLossMaskGenerator +from miles.utils.processing_utils import load_processor, load_tokenizer __all__ = ["generate_rollout"] @@ -10,6 +9,7 @@ TOKENIZER = None +PROCESSOR = None MASK_GENERATOR = None SAMPLE_PRINTED = False @@ -29,9 +29,12 @@ def generate_rollout(args, rollout_id, data_buffer, evaluation=False): assert not evaluation assert args.rollout_global_dataset - global TOKENIZER, MASK_GENERATOR, SAMPLE_PRINTED + global TOKENIZER, PROCESSOR, MASK_GENERATOR, SAMPLE_PRINTED if TOKENIZER is None: - TOKENIZER = AutoTokenizer.from_pretrained(args.hf_checkpoint, trust_remote_code=True) + TOKENIZER = load_tokenizer(args.hf_checkpoint, trust_remote_code=True) + + if PROCESSOR is None: + PROCESSOR = load_processor(args.hf_checkpoint, trust_remote_code=True) if MASK_GENERATOR is None: MASK_GENERATOR = MultiTurnLossMaskGenerator(TOKENIZER, tokenizer_type=args.loss_mask_type) @@ -41,7 +44,10 @@ def generate_rollout(args, rollout_id, data_buffer, evaluation=False): for i, sample in enumerate(samples): (sample,) = sample messages = sample.prompt - token_ids, loss_mask = MASK_GENERATOR.get_loss_mask(messages) + tools = sample.metadata.get("tools", None) + + token_ids, loss_mask = MASK_GENERATOR.get_loss_mask(messages, tools=tools) + response_length = MASK_GENERATOR.get_response_lengths([loss_mask])[0] sample.tokens = token_ids diff --git a/miles/rollout/sglang_rollout.py b/miles/rollout/sglang_rollout.py index 2e33542a5..91918340a 100644 --- a/miles/rollout/sglang_rollout.py +++ b/miles/rollout/sglang_rollout.py @@ -1,30 +1,26 @@ import asyncio import copy +import inspect import logging from argparse import Namespace -from collections import defaultdict from collections.abc import Callable +from contextlib import contextmanager from typing import Any import numpy as np +import pybase64 import sglang_router from packaging.version import parse from tqdm import tqdm from miles.rollout.base_types import RolloutFnEvalOutput, RolloutFnTrainOutput -from miles.rollout.filter_hub.base_types import DynamicFilterOutput +from miles.rollout.filter_hub.base_types import MetricGatherer, call_dynamic_filter from miles.utils.async_utils import run from miles.utils.data import Dataset from miles.utils.eval_config import EvalDatasetConfig from miles.utils.http_utils import get, post -from miles.utils.mask_utils import get_response_lengths from miles.utils.misc import SingletonMeta, load_function -from miles.utils.processing_utils import ( - encode_image_for_rollout_engine, - load_processor, - load_tokenizer, - prepare_model_inputs, -) +from miles.utils.processing_utils import encode_image_for_rollout_engine, load_processor, load_tokenizer from miles.utils.types import Sample from .rm_hub import async_rm, batched_async_rm @@ -64,8 +60,24 @@ def __init__(self, args: Namespace) -> None: sampling_seed_base = args.rollout_seed self.group_sampling_seeds = [sampling_seed_base + i for i in range(args.n_samples_per_prompt)] + # dp rank balancing + self.dp_counts = [0] * (args.sglang_dp_size or 1) + self.dp_rank = 0 + self.reset() + @contextmanager + def dp_rank_context(self): + candidates = [i for i, count in enumerate(self.dp_counts) if count == min(self.dp_counts)] + dp_rank = int(np.random.choice(candidates)) + self.dp_counts[dp_rank] += 1 + self.dp_rank = dp_rank + try: + yield dp_rank + finally: + self.dp_counts[dp_rank] -= 1 + assert self.dp_counts[dp_rank] >= 0 + def reset(self) -> None: self.remaining_batch_size = 0 self.pendings = set() @@ -89,6 +101,9 @@ def submit_generate_tasks(self, samples: list[list[Sample]]) -> None: async def generate(args: Namespace, sample: Sample, sampling_params: dict[str, Any]) -> Sample: """Generate using traditional SGLang router with token-based workflow""" + if args.ci_test: + assert isinstance(sample.prompt, str) + state = GenerateState(args) url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/generate" @@ -96,17 +111,14 @@ async def generate(args: Namespace, sample: Sample, sampling_params: dict[str, A sample.status == Sample.Status.PENDING or sample.status == Sample.Status.ABORTED ), f"Sample status is {sample.status}" - prompt_ids, extra_info = prepare_model_inputs( - sample.prompt, - state.tokenizer, - state.processor, - sample.metadata, - args.apply_chat_template_kwargs, - ) - - image_data = extra_info.get("images", []) - video_data = extra_info.get("videos", []) - multimodal_inputs = extra_info.get("multimodal_inputs", None) + if state.processor: + processor_output = state.processor(text=sample.prompt, **sample.multimodal_inputs) + prompt_ids = processor_output["input_ids"][0] + sample.multimodal_train_inputs = { + k: v for k, v in processor_output.items() if k not in ["input_ids", "attention_mask"] + } or None + else: + prompt_ids = state.tokenizer.encode(sample.prompt, add_special_tokens=False) if len(sample.response) > 0: sampling_params["max_new_tokens"] -= len(sample.tokens) - len(prompt_ids) @@ -127,12 +139,9 @@ async def generate(args: Namespace, sample: Sample, sampling_params: dict[str, A if args.use_rollout_routing_replay: payload["return_routed_experts"] = True - if image_data: + if sample.multimodal_inputs and sample.multimodal_inputs["images"]: + image_data = sample.multimodal_inputs["images"] payload["image_data"] = [encode_image_for_rollout_engine(image) for image in image_data] - sample.multimodal_inputs = multimodal_inputs - - if video_data: - raise NotImplementedError("Video data is not supported yet") # Use existing tokens for multi-turn or tokenize the new prompt if len(sample.response) > 0: @@ -144,20 +153,10 @@ async def generate(args: Namespace, sample: Sample, sampling_params: dict[str, A output = await post(url, payload) - # Extract new response tokens - if args.use_miles_router and "RadixTreeMiddleware" in args.miles_router_middleware_paths: - assert not args.partial_rollout, "Currently parital rollout is not suppurted when using miles router" - retrieve_url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/retrieve_from_text" - retrieve_payload = {"text": sample.prompt + output["text"], "return_logp": True} - retrieve_output = await post(retrieve_url, retrieve_payload) - sample.tokens = retrieve_output["tokens"] - sample.response += output["text"] - sample.loss_mask = retrieve_output["loss_mask"] - sample.response_length = get_response_lengths([sample.loss_mask])[0] - sample.loss_mask = sample.loss_mask[-sample.response_length :] - sample.rollout_log_probs = retrieve_output["rollout_logp"][-sample.response_length :] - # Notice: currently cannot get the spec info from radix router output. + from miles.router.middleware_hub.radix_tree_middleware import postprocess_sample_with_radix_tree + + sample = await postprocess_sample_with_radix_tree(args, sample, output) else: if "output_token_logprobs" in output["meta_info"]: new_response_tokens = [item[1] for item in output["meta_info"]["output_token_logprobs"]] @@ -174,27 +173,17 @@ async def generate(args: Namespace, sample: Sample, sampling_params: dict[str, A sample.rollout_log_probs = [] sample.rollout_log_probs += new_response_log_probs - if args.sglang_speculative_algorithm: - # cannot directly use spec info from sglang because of partial rollout. - sample.spec_info.add( - meta_info=output["meta_info"], - response_length=sample.response_length, - ) - - if "weight_version" in output["meta_info"]: - sample.weight_versions.append(output["meta_info"]["weight_version"]) - if "routed_experts" in output["meta_info"]: - assert len(output["meta_info"]["routed_experts"]) == len(sample.tokens) - 1 - sample.rollout_routed_experts = np.array(output["meta_info"]["routed_experts"]) + sample.rollout_routed_experts = np.frombuffer( + pybase64.b64decode(output["meta_info"]["routed_experts"].encode("ascii")), + dtype=np.int32, + ).reshape( + len(sample.tokens) - 1, + args.num_layers, + args.moe_router_topk, + ) - match output["meta_info"]["finish_reason"]["type"]: - case "length": - sample.status = Sample.Status.TRUNCATED - case "abort": - sample.status = Sample.Status.ABORTED - case "stop": - sample.status = Sample.Status.COMPLETED + sample.update_from_meta_info(args, output["meta_info"]) return sample @@ -205,6 +194,10 @@ async def generate_and_rm( sampling_params: dict[str, Any], evaluation: bool = False, ) -> Sample | list[Sample]: + # mask previous off-policy generation for partial rollout + if args.partial_rollout and args.mask_offpolicy_in_partial_rollout and sample.response_length > 0: + sample.loss_mask = [0] * sample.response_length + # For samples with existing response, check if they're complete if sample.status == Sample.Status.COMPLETED or sample.status == Sample.Status.TRUNCATED: assert sample.response is not None @@ -220,11 +213,16 @@ async def generate_and_rm( sample.status = Sample.Status.ABORTED return sample - if args.custom_generate_function_path is not None: - custom_generate_func = load_function(args.custom_generate_function_path) - sample = await custom_generate_func(args, sample, sampling_params) - else: - sample = await generate(args, sample, sampling_params) + with state.dp_rank_context() as _: + if args.custom_generate_function_path is not None: + custom_generate_func = load_function(args.custom_generate_function_path) + # if signature has evaluation, pass evaluation + if "evaluation" in inspect.signature(custom_generate_func).parameters: + sample = await custom_generate_func(args, sample, sampling_params, evaluation=evaluation) + else: + sample = await custom_generate_func(args, sample, sampling_params) + else: + sample = await generate(args, sample, sampling_params) # for the rm that need the whole group, we will not do the rm here if args.group_rm: @@ -266,7 +264,9 @@ async def generate_and_rm_group( if getattr(args, "sglang_enable_deterministic_inference", False): seed = state.group_sampling_seeds[idx] current_sampling_params["sampling_seed"] = seed - tasks.append(generate_and_rm(args, sample, current_sampling_params, evaluation=evaluation)) + tasks.append( + asyncio.create_task(generate_and_rm(args, sample, current_sampling_params, evaluation=evaluation)) + ) group = await asyncio.gather(*tasks) @@ -343,12 +343,13 @@ async def generate_rollout_async( load_function(args.dynamic_sampling_filter_path) if args.dynamic_sampling_filter_path is not None else None ) - metric_gatherer = _MetricGatherer() + metric_gatherer = MetricGatherer() # target_data_size is the total number of valid samples to get target_data_size = args.rollout_batch_size data = [] + all_data = [] do_print = True pbar = tqdm(total=target_data_size * args.n_samples_per_prompt, desc="Rollout generation") while len(data) < target_data_size: @@ -370,7 +371,8 @@ async def generate_rollout_async( do_print = False assert len(group) == args.n_samples_per_prompt - dynamic_filter_output = _call_dynamic_filter(dynamic_filter, args, group) + all_data.append(group) + dynamic_filter_output = call_dynamic_filter(dynamic_filter, args, group) if not dynamic_filter_output.keep: metric_gatherer.on_dynamic_filter_drop(reason=dynamic_filter_output.reason) state.remaining_batch_size -= 1 @@ -393,6 +395,9 @@ async def generate_rollout_async( assert len(data) == args.rollout_batch_size, f"Got {len(data)} samples, expected {args.rollout_batch_size}" data = sorted(data, key=lambda group: group[0][0].index if isinstance(group[0], list) else group[0].index) + all_samples = sorted( + all_data, key=lambda group: group[0][0].index if isinstance(group[0], list) else group[0].index + ) # reset the global state to prevent effects on the next rollout or eval. state.reset() @@ -400,36 +405,12 @@ async def generate_rollout_async( filter_func = load_function(args.rollout_sample_filter_path) filter_func(args, data) - return RolloutFnTrainOutput(samples=data, metrics=metric_gatherer.collect()), aborted_samples - + # There can be circumstances where users want to process all samples including filtered ones. + if args.rollout_all_samples_process_path is not None: + process_func = load_function(args.rollout_all_samples_process_path) + process_func(args, all_samples, data_source) -def _call_dynamic_filter(fn, *args, **kwargs): - if fn is None: - return DynamicFilterOutput(keep=True) - - output = fn(*args, **kwargs) - - # compatibility for legacy version - if not isinstance(output, DynamicFilterOutput): - output = DynamicFilterOutput(keep=output) - - return output - - -class _MetricGatherer: - def __init__(self): - self._dynamic_filter_drop_reason_count = defaultdict(lambda: 0) - - def on_dynamic_filter_drop(self, reason: str | None): - if not reason: - return - self._dynamic_filter_drop_reason_count[reason] += 1 - - def collect(self): - return { - f"rollout/dynamic_filter/drop_{reason}": count - for reason, count in self._dynamic_filter_drop_reason_count.items() - } + return RolloutFnTrainOutput(samples=data, metrics=metric_gatherer.collect()), aborted_samples EVAL_PROMPT_DATASET = {} @@ -508,17 +489,19 @@ async def eval_rollout_single_dataset( sampling_params = base_sampling_params.copy() sampling_params["sampling_seed"] = args.rollout_seed + j tasks.append( - generate_and_rm( - args, - sample, - sampling_params=sampling_params, - evaluation=True, + asyncio.create_task( + generate_and_rm( + args, + sample, + sampling_params=sampling_params, + evaluation=True, + ) ) ) data = [] do_print = True - pbar = tqdm(total=len(tasks), desc="Rollout generation", disable=not do_print) + pbar = tqdm(total=len(tasks), desc=f"Eval {dataset_cfg.name}", disable=not do_print) for coro in asyncio.as_completed(tasks): sample = await coro if do_print: @@ -547,9 +530,8 @@ async def eval_rollout_single_dataset( } -# TODO remove this temp function def generate_rollout( - args: Namespace, rollout_id: int, data_buffer: Any, evaluation: bool = False + args: Namespace, rollout_id: int, data_source: Any, evaluation: bool = False ) -> RolloutFnTrainOutput | RolloutFnEvalOutput: """An example to implement the generate_rollout function for an rule based rm rollout generation. @@ -562,20 +544,11 @@ def generate_rollout( Returns: list[list[Sample]]: a list of list of samples generated by the rollout """ - output, aborted_samples = generate_abortable_samples( - args, rollout_id, data_buffer.get_samples, evaluation=evaluation - ) - data_buffer.add_samples(aborted_samples) - return output - - -def generate_abortable_samples( - args: Namespace, - rollout_id: int, - data_source: Callable[[int], list[list[Sample]]], - evaluation: bool = False, -) -> tuple[Any, list[list[Sample]]]: assert args.rollout_global_dataset if evaluation: - return run(eval_rollout(args, rollout_id)) - return run(generate_rollout_async(args, rollout_id, data_source)) + output, _ = run(eval_rollout(args, rollout_id)) + return output + + output, aborted_samples = run(generate_rollout_async(args, rollout_id, data_source.get_samples)) + data_source.add_samples(aborted_samples) + return output diff --git a/miles/router/middleware_hub/radix_tree.py b/miles/router/middleware_hub/radix_tree.py index 3c5dc769a..6e722f1e2 100644 --- a/miles/router/middleware_hub/radix_tree.py +++ b/miles/router/middleware_hub/radix_tree.py @@ -621,11 +621,6 @@ def retrieve_from_text(self, text: str, return_logprob: bool = True): # Create trie instance for testing trie = StringRadixTrie(max_cache_size=100, verbose=True) - # Test token retrieval - print("\nTesting token retrieval:") - test_tokens = trie.retrieve_from_text("Hello world") - print(f"Tokens for 'Hello world': {test_tokens}") - # Example usage with simplified insert test_cases = [ ("Hello world", [1, 2, 3], [-0.1, -0.2, -0.3]), diff --git a/miles/router/middleware_hub/radix_tree_middleware.py b/miles/router/middleware_hub/radix_tree_middleware.py index a173cd0d6..db57f6456 100644 --- a/miles/router/middleware_hub/radix_tree_middleware.py +++ b/miles/router/middleware_hub/radix_tree_middleware.py @@ -1,5 +1,5 @@ +import asyncio import json -from time import sleep from fastapi.responses import JSONResponse from starlette.middleware.base import BaseHTTPMiddleware @@ -7,6 +7,10 @@ from starlette.responses import Response from transformers import AutoTokenizer +from miles.utils.http_utils import post +from miles.utils.mask_utils import get_response_lengths +from miles.utils.types import Sample + from .radix_tree import StringRadixTrie # Hop-by-hop headers that should not be forwarded @@ -108,7 +112,7 @@ async def dispatch(self, request: Request, call_next): ): break # await 30 seconds for aborted responses - sleep(30) + await asyncio.sleep(30) if isinstance(response_data, dict) and "text" in response_data and "output_ids" in response_data: generated_text = response_data["text"] @@ -149,3 +153,17 @@ async def dispatch(self, request: Request, call_next): if getattr(self.router, "verbose", False): print(f"[miles-router] Warning: Failed to cache trajectory: {e}") return response + + +async def postprocess_sample_with_radix_tree(args, sample: Sample, output: dict): + assert not args.partial_rollout, "Currently partial rollout is not supported when using miles router" + retrieve_url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/retrieve_from_text" + retrieve_payload = {"text": sample.prompt + output["text"], "return_logp": True} + retrieve_output = await post(retrieve_url, retrieve_payload) + sample.tokens = retrieve_output["tokens"] + sample.response += output["text"] + sample.loss_mask = retrieve_output["loss_mask"] + sample.response_length = get_response_lengths([sample.loss_mask])[0] + sample.loss_mask = sample.loss_mask[-sample.response_length :] + sample.rollout_log_probs = retrieve_output["rollout_logp"][-sample.response_length :] + return sample diff --git a/miles/router/router.py b/miles/router/router.py index 88179a293..7d3ecd980 100644 --- a/miles/router/router.py +++ b/miles/router/router.py @@ -1,5 +1,7 @@ import argparse +import asyncio import json +import logging import httpx import uvicorn @@ -7,8 +9,11 @@ from fastapi.responses import JSONResponse from starlette.responses import Response +from miles.router.sessions import setup_session_routes from miles.utils.misc import load_function +logger = logging.getLogger(__name__) + def run_router(args): """ @@ -28,9 +33,14 @@ def __init__(self, args, verbose=False): self.verbose = verbose self.app = FastAPI() - - # Worker information - self.worker_urls: dict[str, int] = {} + self.app.add_event_handler("startup", self._start_background_health_check) + + # URL -> Active Request Count (load state) + self.worker_request_counts: dict[str, int] = {} + # URL -> Consecutive Failures + self.worker_failure_counts: dict[str, int] = {} + # Quarantined workers excluded from routing pool + self.dead_workers: set[str] = set() self.max_weight_version = None max_connections = getattr(args, "miles_router_max_connections", None) @@ -60,48 +70,104 @@ def _setup_routes(self): self.app.post("/add_worker")(self.add_worker) self.app.get("/list_workers")(self.list_workers) self.app.post("/retrieve_from_text")(self.retrieve_from_text) + # Session routes - must be registered before catch-all + setup_session_routes(self.app, self) # Catch-all route for proxying to SGLang - must be registered LAST self.app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])(self.proxy) - async def health_check(self, request: Request): - # TODO: do health check in background - pass + async def _start_background_health_check(self): + asyncio.create_task(self._health_check_loop()) + + async def _check_worker_health(self, url): + """Encapsulated health check logic for better maintainability.""" + try: + response = await self.client.get(f"{url}/health", timeout=5.0) + if response.status_code == 200: + return url, True + logger.debug(f"[miles-router] Worker {url} is unhealthy (Status: {response.status_code})") + except Exception as e: + logger.debug(f"[miles-router] Worker {url} health check failed: {e}") + return url, False + + async def _health_check_loop(self): + """Background loop to monitor worker health and adjust routing pool.""" + interval = self.args.rollout_health_check_interval + threshold = self.args.miles_router_health_check_failure_threshold + + while True: + try: + await asyncio.sleep(interval) + + urls = [u for u in self.worker_request_counts if u not in self.dead_workers] + if not urls: + continue + + results = await asyncio.gather(*(self._check_worker_health(url) for url in urls)) + + for url, is_healthy in results: + if not is_healthy: + failures = self.worker_failure_counts.get(url, 0) + 1 + self.worker_failure_counts[url] = failures + + if failures >= threshold: + logger.warning( + f"[miles-router] Worker {url} failed {threshold} consecutive health checks. Marking as DEAD." + ) + self.dead_workers.add(url) + # TODO (chenyang): Connect back 'dead' workers requires a mechanism to sync + # model versions to avoid off-policy issues from stale weights, since these + # dead workers' parameters may not be refitted. + else: + self.worker_failure_counts[url] = 0 + + logger.debug( + f"[miles-router] Health check complete. {len(self.worker_request_counts) - len(self.dead_workers)} workers healthy." + ) + + except asyncio.CancelledError: + logger.warning("[miles-router] Background health check loop is being cancelled.") + raise + except Exception as e: + logger.error(f"[miles-router] Unexpected error in health check loop: {e}", exc_info=True) + await asyncio.sleep(5) async def proxy(self, request: Request, path: str): """Proxy all other requests to the SGLang router""" - # Forward all other paths to SGLang router + result = await self._do_proxy(request, path) + return self._build_proxy_response(result) + + async def _do_proxy(self, request: Request, path: str) -> dict: + """Core proxy logic. Returns dict with request_body, response_body, status_code, headers.""" worker_url = self._use_url() url = f"{worker_url}/{path}" - # Get request body and headers body = await request.body() headers = dict(request.headers) try: response = await self.client.request(request.method, url, content=body, headers=headers) - # Eagerly read content so we can return JSON (not streaming) content = await response.aread() - content_type = response.headers.get("content-type", "") - try: - # Prefer parsing JSON if possible - data = json.loads(content) - return JSONResponse( - content=data, - status_code=response.status_code, - headers=dict(response.headers), - ) - except Exception: - # Fall back to raw body with original content type - return Response( - content=content, - status_code=response.status_code, - headers=dict(response.headers), - media_type=content_type or None, - ) - + return { + "request_body": body, + "response_body": content, + "status_code": response.status_code, + "headers": dict(response.headers), + } finally: self._finish_url(worker_url) + def _build_proxy_response(self, result: dict) -> Response: + """Build HTTP response from proxy result.""" + content = result["response_body"] + status_code = result["status_code"] + headers = result["headers"] + content_type = headers.get("content-type", "") + try: + data = json.loads(content) + return JSONResponse(content=data, status_code=status_code, headers=headers) + except Exception: + return Response(content=content, status_code=status_code, headers=headers, media_type=content_type) + async def add_worker(self, request: Request): """Add a new worker to the router. Supports providing the URL via query string or JSON body. @@ -124,16 +190,17 @@ async def add_worker(self, request: Request): ) # Add if new, keep a simple request count per worker - if worker_url not in self.worker_urls: - self.worker_urls[worker_url] = 0 + if worker_url not in self.worker_request_counts: + self.worker_request_counts[worker_url] = 0 + self.worker_failure_counts[worker_url] = 0 if self.verbose: print(f"[miles-router] Added new worker: {worker_url}") - return {"status": "success", "worker_urls": self.worker_urls} + return {"status": "success", "worker_urls": self.worker_request_counts} async def list_workers(self, request: Request): """List all registered workers""" - return {"urls": list(self.worker_urls.keys())} + return {"urls": list(self.worker_request_counts.keys())} async def retrieve_from_text(self, request: Request): """Get token information from text input""" @@ -158,19 +225,27 @@ async def retrieve_from_text(self, request: Request): return result def _use_url(self): - """Select a worker URL using round-robin strategy""" - assert len(self.worker_urls) > 0, "No workers available" + """Select worker URL with minimal active requests.""" + + if not self.dead_workers: + # Healthy path: select from all workers + url = min(self.worker_request_counts, key=self.worker_request_counts.get) + else: + # Degraded path: select from workers not in dead_workers + valid_workers = (w for w in self.worker_request_counts if w not in self.dead_workers) + try: + url = min(valid_workers, key=self.worker_request_counts.get) + except ValueError: + raise RuntimeError("No healthy workers available in the pool") from None - # get the url with mininal count - url = min(self.worker_urls, key=self.worker_urls.get) - self.worker_urls[url] += 1 + self.worker_request_counts[url] += 1 return url def _finish_url(self, url): """Mark the request to the given URL as finished""" - assert url in self.worker_urls, f"URL {url} not recognized" - self.worker_urls[url] -= 1 - assert self.worker_urls[url] >= 0, f"URL {url} count went negative" + assert url in self.worker_request_counts, f"URL {url} not recognized" + self.worker_request_counts[url] -= 1 + assert self.worker_request_counts[url] >= 0, f"URL {url} count went negative" if __name__ == "__main__": diff --git a/miles/router/sessions.py b/miles/router/sessions.py new file mode 100644 index 000000000..9d753e597 --- /dev/null +++ b/miles/router/sessions.py @@ -0,0 +1,124 @@ +import json +import time +import uuid +from typing import TYPE_CHECKING + +from fastapi import Request +from fastapi.responses import JSONResponse, Response +from pydantic import BaseModel +from transformers import AutoTokenizer + +if TYPE_CHECKING: + from miles.router.router import MilesRouter + + +class SessionRecord(BaseModel): + timestamp: float + method: str + path: str + request: dict + response: dict + status_code: int + + +class GetSessionResponse(BaseModel): + session_id: str + records: list[SessionRecord] + + +class SessionManager: + def __init__(self): + self.sessions: dict[str, list[SessionRecord]] = {} + + def create_session(self) -> str: + session_id = uuid.uuid4().hex + self.sessions[session_id] = [] + return session_id + + def get_session(self, session_id: str) -> list[SessionRecord] | None: + return self.sessions.get(session_id) + + def delete_session(self, session_id: str) -> list[SessionRecord]: + assert session_id in self.sessions + return self.sessions.pop(session_id) + + def add_record(self, session_id: str, record: SessionRecord): + assert session_id in self.sessions + self.sessions[session_id].append(record) + + +def setup_session_routes(app, router: "MilesRouter"): + manager = SessionManager() + + # TODO temporary hack before @guapisolo implements TITO + # ============================= HACK START =============================== + # Lazy load tokenizer only when needed (for tests that don't have hf_checkpoint) + tokenizer = None + + def get_tokenizer(): + nonlocal tokenizer + if tokenizer is None: + tokenizer = AutoTokenizer.from_pretrained(router.args.hf_checkpoint, trust_remote_code=True) + return tokenizer + + # ============================= HACK END =============================== + + @app.post("/sessions") + async def create_session(): + session_id = manager.create_session() + return {"session_id": session_id} + + @app.get("/sessions/{session_id}") + async def get_session(session_id: str): + records = manager.get_session(session_id) + if records is None: + return JSONResponse(status_code=404, content={"error": "session not found"}) + return GetSessionResponse(session_id=session_id, records=records) + + @app.delete("/sessions/{session_id}") + async def delete_session(session_id: str): + if session_id not in manager.sessions: + return JSONResponse(status_code=404, content={"error": "session not found"}) + manager.delete_session(session_id) + return Response(status_code=204) + + @app.api_route("/sessions/{session_id}/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"]) + async def session_proxy(request: Request, session_id: str, path: str): + if session_id not in manager.sessions: + return JSONResponse(status_code=404, content={"error": "session not found"}) + + result = await router._do_proxy(request, path) + + request_body = json.loads(result["request_body"]) + response_body = json.loads(result["response_body"]) + + # TODO: remove this hack when @guapisolo implements the real TITO + # ============================= HACK START =============================== + if "messages" in request_body and "input_ids" not in request_body: + request_body["input_ids"] = get_tokenizer().apply_chat_template( + request_body["messages"], + add_generation_prompt=True, + add_special_tokens=False, + tools=request_body.get("tools"), + ) + if ( + "logprobs" in response_body.get("choices", [{}])[0] + and "content" in response_body["choices"][0]["logprobs"] + ): + logprobs_content = response_body["choices"][0]["logprobs"]["content"] + for item in logprobs_content: + if "token" in item and "token_id" not in item: + item["token_id"] = get_tokenizer().convert_tokens_to_ids(item["token"]) + # ============================= HACK END =============================== + + record = SessionRecord( + timestamp=time.time(), + method=request.method, + path=path, + request=request_body, + response=response_body, + status_code=result["status_code"], + ) + manager.add_record(session_id, record) + + return router._build_proxy_response(result) diff --git a/miles/utils/arguments.py b/miles/utils/arguments.py index ce6e47161..f3dce7fbf 100644 --- a/miles/utils/arguments.py +++ b/miles/utils/arguments.py @@ -10,9 +10,10 @@ from miles.backends.sglang_utils.arguments import add_sglang_arguments from miles.backends.sglang_utils.arguments import validate_args as sglang_validate_args +from miles.utils.environ import enable_experimental_rollout_refactor from miles.utils.eval_config import EvalDatasetConfig, build_eval_dataset_configs, ensure_dataset_list - from miles.utils.logging_utils import configure_logger +from miles.utils.misc import load_function logger = logging.getLogger(__name__) @@ -155,11 +156,26 @@ def add_train_arguments(parser): default="raw", help="The method to convert megatron weights to hugging face weights for SGLang.", ) + parser.add_argument( + "--custom-model-provider-path", + type=str, + default=None, + help=( + "Path to a custom model provider function. " + "If set, we will use this function instead of the default model provider. " + "The function should have the signature " + "`def custom_model_provider(pre_process: bool, post_process: bool, vp_stage: int | None = None) -> GPTModel`. " + "Example: 'my_module.my_model_provider'." + ), + ) parser.add_argument( "--recompute-loss-function", action="store_true", help="Whether to disable recompute loss function to save memory during training.", ) + parser.add_argument( + "--log-probs-chunk-size", type=int, default=-1, help="Chunk size to compute log probs to save memory" + ) return parser @@ -177,11 +193,6 @@ def add_rollout_arguments(parser): "It doesn't necessary need to contain the most up-to-date parameters." ), ) - parser.add_argument( - "--use-hf-config-for-megatron", - action="store_true", - help="Whether to use HF config for Megatron core to define the model architecture.", - ) parser.add_argument( "--model-name", type=str, @@ -195,7 +206,11 @@ def add_rollout_arguments(parser): parser.add_argument( "--rollout-function-path", type=str, - default="miles.rollout.sglang_rollout.generate_rollout", + default=( + "miles.rollout.inference_rollout.inference_rollout_common.InferenceRolloutFn" + if enable_experimental_rollout_refactor() + else "miles.rollout.sglang_rollout.generate_rollout" + ), help=( "Path to the rollout generation function." "You should use this model to create your own custom rollout function, " @@ -240,7 +255,7 @@ def add_rollout_arguments(parser): parser.add_argument( "--rollout-max-response-len", type=int, - default=1024, + default=None, help=( "The maximum length of the response for the inference engine during rollout. " "It is basically `max_tokens` in sglang." @@ -329,6 +344,15 @@ def add_rollout_arguments(parser): "This is useful for long responses." ), ) + parser.add_argument( + "--mask-offpolicy-in-partial-rollout", + action="store_true", + default=False, + help=( + "Whether to mask previous generation in partial rollout. " + "If set, only on-policy generated tokens will be used in training" + ), + ) parser.add_argument( "--custom-generate-function-path", type=str, @@ -338,6 +362,26 @@ def add_rollout_arguments(parser): "This should be useful if you need to implement some special rollout logic, e.g. multi-turn, function calling." ), ) + parser.add_argument( + "--custom-rollout-log-function-path", + type=str, + default=None, + help=( + "The custom function for logging rollout data. The signature of the functions is: " + "def log_rollout_data(rollout_id, args, samples, rollout_extra_metrics, rollout_time) -> bool. " + "The return value indicates whether to skip the default logging. " + ), + ) + parser.add_argument( + "--custom-eval-rollout-log-function-path", + type=str, + default=None, + help=( + "The custom function for logging eval rollout data. " + "def log_eval_rollout_data(rollout_id, args, data, extra_metrics) -> bool. " + "The return value indicates whether to skip the default logging. " + ), + ) parser.add_argument( "--buffer-filter-path", @@ -417,8 +461,8 @@ def add_fault_tolerance_arguments(parser): parser.add_argument( "--rollout-health-check-first-wait", type=float, - default=300.0, - help="Time to wait for the compilation before the actual health check.", + default=0, + help="Initial grace period (in seconds) before starting health checks. This allows time for model compilation and initialization. Increase this value significantly when using deepgemm.", ) return parser @@ -491,7 +535,7 @@ def add_data_arguments(parser): parser.add_argument( "--tool-key", type=str, - default=None, + default="tools", help=( "When need to add tools during apply_chat_template, you should provide the key for the tools in the prompt dataset." ), @@ -613,6 +657,12 @@ def add_eval_arguments(parser): "When provided, this overrides --eval-prompt-data." ), ) + parser.add_argument( + "--skip-eval-before-train", + action="store_true", + default=False, + help="Whether to skip evaluation before training.", + ) # The following keys are used to override the rollout version during eval. parser.add_argument("--eval-input-key", type=str, default=None, help="JSON dataset key") @@ -650,6 +700,26 @@ def add_algo_arguments(parser): reset_arg(parser, "--load", type=str, default=None) reset_arg(parser, "--save", type=str, default=None) reset_arg(parser, "--save-interval", type=int, default=None) + reset_arg(parser, "--async-save", action="store_true") + reset_arg( + parser, + "--no-save-optim", + action="store_true", + default=False, + help=( + "If set, do not save the optimizer state when saving checkpoints. " + "This reduces checkpoint size but disables training resumption from the saved checkpoint." + ), + ) + parser.add_argument( + "--save-hf", + type=str, + default=None, + help=( + "Path to save the model in HuggingFace format when using Megatron backend. " + "The model will be saved to `save_hf.format(rollout_id)`. " + ), + ) reset_arg(parser, "--seed", type=int, default=1234) reset_arg(parser, "--clip-grad", type=float, default=1.0) reset_arg(parser, "--calculate-per-token-loss", action="store_true") @@ -782,6 +852,15 @@ def add_algo_arguments(parser): default=False, help="Whether to calculate the mismatch metrics.", ) + parser.add_argument( + "--reset-optimizer-states", + action="store_true", + default=False, + help=( + "Whether to reset optimizer states after each rollout. " + "If enabled, the optimizer's history will be cleared at the end of each rollout, which can sometimes help with training stability or fulfill specific experiment requirements." + ), + ) parser.add_argument( "--use-rollout-logprobs", action="store_true", @@ -816,6 +895,12 @@ def add_algo_arguments(parser): default=None, help="Path to the custom TIS/RS function (e.g., examples/train_infer_mismatch_helper/mis.py:compute_mis_weights_with_cp).", ) + parser.add_argument( + "--custom-pg-loss-reducer-function-path", + type=str, + default=None, + help="Path to a custom reducer function for pg_loss only. When set, pg_loss will use this custom reducer while other metrics (pg_clipfrac, ppo_kl, entropy_loss, etc.) still use the default sum_of_sample_mean. (e.g., examples/Dr.GRPO/custom_reducer.py:get_pg_loss_reducer).", + ) parser.add_argument( "--use-routing-replay", @@ -868,6 +953,12 @@ def add_router_arguments(parser): default=None, help="Max connections for MilesRouter HTTP client.", ) + parser.add_argument( + "--miles-router-health-check-failure-threshold", + type=int, + default=3, + help="Number of consecutive failures before marking a worker as unhealthy.", + ) RouterArgs.add_cli_args(parser, use_router_prefix=True, exclude_host_port=True) return parser @@ -934,6 +1025,12 @@ def add_wandb_arguments(parser): "Specify the key in the reward dict using this argument.", ), ) + parser.add_argument( + "--log-correct-samples", + action="store_true", + default=False, + help="Whether to turn on passrate logging, which will log the pass@n of the responses in the rollout.", + ) parser.add_argument("--wandb-run-id", type=str, default=None) return parser @@ -1092,6 +1189,16 @@ def add_reward_model_arguments(parser): "Path to the custom function that will post process reward, by default it will be the normalization for grpo. " ), ) + parser.add_argument( + "--custom-convert-samples-to-train-data-path", + type=str, + default=None, + help=( + "Path to a custom function that converts samples to training data. " + "If set, this function will replace the default _convert_samples_to_train_data. " + "The function should have the signature `def convert_samples_to_train_data(args, samples) -> dict`." + ), + ) return parser def add_rollout_buffer_arguments(parser): @@ -1144,12 +1251,27 @@ def add_rollout_buffer_arguments(parser): "Note: This attribute does not determine whether the sample participates in advantage normalization." ), ) + parser.add_argument( + "--rollout-all-samples-process-path", + type=str, + default=None, + help=( + "Path to the rollout all samples process function that " + "can process all samples including filtered ones." + ), + ) parser.add_argument( "--disable-rollout-trim-samples", action="store_true", default=False, help="disable trim samples in rollout buffer when converting samples to train data", ) + parser.add_argument( + "--use-dynamic-global-batch-size", + action="store_true", + default=False, + help="enable dynamic global batch size, disable trim samples in rollout buffer when converting samples to train data", + ) return parser def add_custom_megatron_plugins_arguments(parser): @@ -1228,6 +1350,20 @@ def add_ci_arguments(parser): ) return parser + def add_user_provided_function_arguments(parser): + args_partial, _ = parser.parse_known_args() + for path in [ + args_partial.rollout_function_path, + args_partial.custom_generate_function_path, + ]: + try: + fn = load_function(path) + except (ModuleNotFoundError, ValueError): + continue + if fn is not None and callable(getattr(fn, "add_arguments", None)): + fn.add_arguments(parser) + return parser + def add_sglang_tp_size(): temp_parser = argparse.ArgumentParser(add_help=False) temp_parser.add_argument("--rollout-num-gpus-per-engine", type=int, default=1) @@ -1257,21 +1393,19 @@ def add_sglang_tp_size(): parser = add_mtp_training_arguments(parser) parser = add_prefill_decode_disaggregation_arguments(parser) parser = add_ci_arguments(parser) - parser.set_defaults(sglang_tensor_parallel_size=add_sglang_tp_size()) - - # For megatron parser = add_custom_megatron_plugins_arguments(parser) - try: - parser.add_argument( - "--custom-config-path", - type=str, - default=None, - help="Path to the YAML config for custom function arguments.", - ) - parser.add_argument("--padded-vocab-size", type=int, default=None) - except argparse.ArgumentError: - pass + if enable_experimental_rollout_refactor(): + parser = add_user_provided_function_arguments(parser) + reset_arg( + parser, + "--custom-config-path", + type=str, + default=None, + help="Path to the YAML config for custom function arguments.", + ) + reset_arg(parser, "--padded-vocab-size", type=int, default=None) + parser.set_defaults(sglang_tensor_parallel_size=add_sglang_tp_size()) return parser return add_miles_arguments @@ -1285,19 +1419,13 @@ def parse_args(add_custom_arguments=None): backend = parse_args_train_backend() if backend == "megatron": - from miles.backends.megatron_utils import parse_args as megatron_parse_args - from miles.backends.megatron_utils import set_default_megatron_args - from miles.backends.megatron_utils import validate_args as megatron_validate_args + from miles.backends.megatron_utils.arguments import parse_args as megatron_parse_args + from miles.backends.megatron_utils.arguments import set_default_megatron_args + from miles.backends.megatron_utils.arguments import validate_args as megatron_validate_args args = megatron_parse_args(extra_args_provider=add_miles_arguments) if args.hf_checkpoint: hf_config = AutoConfig.from_pretrained(args.hf_checkpoint, trust_remote_code=True) - if args.use_hf_config_for_megatron: - from miles.backends.megatron_utils.config_mapping import get_mapper - - megatron_config_from_hf = get_mapper(hf_config.model_type)(hf_config) - _validate_and_update_megatron_args_from_hf(args, megatron_config_from_hf.transformer_config) - _validate_and_update_megatron_args_from_hf(args, megatron_config_from_hf.gpt_model_args) hf_validate_args(args, hf_config) args.rank = 0 @@ -1310,6 +1438,8 @@ def parse_args(add_custom_arguments=None): args.rank = 0 # Primary process rank for wandb initialization args.world_size = args.actor_num_nodes * args.actor_num_gpus_per_node + assert args.context_parallel_size == 1, "Context parallelism is not supported for FSDP backend." + miles_validate_args(args) if backend == "megatron": @@ -1396,18 +1526,23 @@ def miles_validate_args(args): ) # TODO: During loading, we need to set the start_rollout_id here. - if ( - args.load is None - or not os.path.exists(args.load) - or not os.path.exists(os.path.join(args.load, "latest_checkpointed_iteration.txt")) - ): - args.no_load_optim = True - args.no_load_rng = True - args.finetune = True - args.load = args.ref_load - if args.ref_ckpt_step is not None: - args.ckpt_step = args.ref_ckpt_step + if args.megatron_to_hf_mode == "bridge": + if args.load is None: + args.load = args.ref_load or args.hf_checkpoint args.start_rollout_id = 0 + else: + if ( + args.load is None + or not os.path.exists(args.load) + or not os.path.exists(os.path.join(args.load, "latest_checkpointed_iteration.txt")) + ): + args.no_load_optim = True + args.no_load_rng = True + args.finetune = True + args.load = args.ref_load + if args.ref_ckpt_step is not None: + args.ckpt_step = args.ref_ckpt_step + args.start_rollout_id = 0 if args.eval_interval is not None: assert args.eval_datasets, "Evaluation datasets must be configured when eval_interval is set." @@ -1522,11 +1657,6 @@ def miles_validate_args(args): ) args.global_batch_size = global_batch_size - assert args.rollout_batch_size * args.n_samples_per_prompt % args.global_batch_size == 0, ( - f"rollout_batch_size {args.rollout_batch_size} * n_samples_per_prompt {args.n_samples_per_prompt} " - f"is not a multiple of global_batch_size {args.global_batch_size}" - ) - if args.n_samples_per_prompt == 1: args.grpo_std_normalization = False logger.info("n_samples_per_prompt is set to 1, grpo_std_normalization will be set to False.") @@ -1567,30 +1697,31 @@ def miles_validate_args(args): logger.info(f"Warning: Argument {k} is already set to {getattr(args, k)}, will override with {v}.") setattr(args, k, v) - if args.rollout_max_context_len is None: - logger.info( - f"args.rollout_max_context_len is not set. Use args.rollout_max_response_len {args.rollout_max_response_len} as default value." - ) - args.rollout_max_context_len = args.rollout_max_response_len - if args.eval_max_context_len is None: logger.info( f"args.eval_max_context_len is not set. Use args.rollout_max_context_len {args.rollout_max_context_len} as default value." ) args.eval_max_context_len = args.rollout_max_context_len - if args.rollout_max_prompt_len is None: - logger.info( - f"args.rollout_max_prompt_len is not set. Use args.rollout_max_context_len - 1 ({args.rollout_max_context_len} - 1) as default value so that there is at least one generated token to compute loss." - ) - args.rollout_max_prompt_len = args.rollout_max_context_len - 1 + if args.rollout_max_context_len is not None: + if args.rollout_max_prompt_len is None: + args.rollout_max_prompt_len = args.rollout_max_context_len - 1 + logger.info( + f"args.rollout_max_prompt_len is not set. Use args.rollout_max_context_len - 1 ({args.rollout_max_context_len} - 1) as default value so that there is at least one generated token to compute loss." + ) + assert ( + args.rollout_max_prompt_len <= args.rollout_max_context_len - 1 + ), f"args.rollout_max_prompt_len ({args.rollout_max_prompt_len}) must be smaller than args.rollout_max_context_len ({args.rollout_max_context_len}) so that there is at least one generated token to compute loss." - assert ( - args.rollout_max_prompt_len <= args.rollout_max_context_len - 1 - ), f"args.rollout_max_prompt_len ({args.rollout_max_prompt_len}) must be smaller than args.rollout_max_context_len ({args.rollout_max_context_len}) so that there is at least one generated token to compute loss." + assert not ( + args.prefill_num_servers is not None and args.rollout_external + ), "prefill_num_servers cannot be set when rollout_external is set." - if args.prefill_num_servers is not None: - assert not args.use_fault_tolerance, "fault tolerance is not supported when prefill_num_servers is set." + if args.qkv_format == "bshd": + assert args.train_backend == "megatron", "bshd format is only supported for megatron backend." + assert ( + args.use_dynamic_batch_size is False + ), "Dynamic batch size is not supported for bshd format. Please specify --micro-batch-size instead." assert args.qkv_format in [ "thd", @@ -1609,6 +1740,10 @@ def equal(x, y): errors = [] + # multimodal models have different config structure + if hasattr(hf_config, "text_config"): + hf_config = hf_config.text_config + for hf_config_name, megatron_config_name, compare_fn in [ ("hidden_size", "hidden_size", equal), ("num_attention_heads", "num_attention_heads", equal), @@ -1627,12 +1762,3 @@ def equal(x, y): if len(errors) > 0: raise AssertionError("hf_validate_args failed: " + "; ".join(errors)) - - -def _validate_and_update_megatron_args_from_hf(args, args_from_hf_config: dict[str, Any]): - for key, value in args_from_hf_config.items(): - if hasattr(args, key) and getattr(args, key) != value: - raise ValueError( - f"Argument {key} is not consistent. {key} in args is {getattr(args, key)}, but from HF config is {value}." - ) - setattr(args, key, value) diff --git a/miles/utils/data.py b/miles/utils/data.py index c36902c81..eb512e514 100644 --- a/miles/utils/data.py +++ b/miles/utils/data.py @@ -1,3 +1,4 @@ +import itertools import json import logging import os @@ -5,9 +6,13 @@ import re import numpy as np -import pandas as pd import ray +try: + import pyarrow.parquet as pq +except ImportError: + pq = None + from miles.utils.types import MultimodalTypes, Sample from .timer import Timer @@ -17,26 +22,50 @@ logger = logging.getLogger(__name__) -# TODO: don't read the whole file into memory. def read_file(path): path, row_slice = _parse_generalized_path(path) + reader = None if not os.path.exists(path): raise FileNotFoundError(f"Prompt dataset path '{path}' does not exist.") if path.endswith(".jsonl"): - df = pd.read_json(path, lines=True, dtype={"label": str}) + + def jsonl_reader(p): + with open(p, encoding="utf-8") as f: + for line_num, line in enumerate(f): + line = line.strip() + if not line: + continue + try: + yield json.loads(line) + except json.JSONDecodeError as e: + print(f"JSON decode error at line {line_num}: {e}") + continue + + reader = jsonl_reader(path) + elif path.endswith(".parquet"): - df = pd.read_parquet(path, dtype_backend="pyarrow") + if pq is None: + raise ImportError("pyarrow is required for parquet support") + + def parquet_reader(p): + pf = pq.ParquetFile(p) + + for batch in pf.iter_batches(): + yield from batch.to_pylist() + + reader = parquet_reader(path) + else: raise ValueError(f"Unsupported file format: {path}. Supported formats are .jsonl and .parquet.") if row_slice is not None: - logger.info(f"read_file path={path} slice {len(df)=} rows into {row_slice=}") - df = df.iloc[row_slice] - for _, row in df.iterrows(): - yield row.to_dict() + logger.info("read_file path=%s applying slice row_slice=%s", path, row_slice) + reader = itertools.islice(reader, row_slice.start, row_slice.stop, row_slice.step) + + yield from reader def _parse_generalized_path(s: str): @@ -49,21 +78,49 @@ def _parse_generalized_path(s: str): return s, None -def _should_skip_prompt(prompt, tokenizer, processor, max_length, apply_chat_template_kwargs): +def filter_long_prompt(origin_samples: list[Sample], tokenizer, processor, max_length: int | None) -> list[Sample]: if max_length is None: return False - from miles.utils.processing_utils import prepare_model_inputs + if not isinstance(origin_samples[0].prompt, str): + logger.warning( + "Skipping max_length check for list prompt. Set apply_chat_template=True to enable length filtering." + ) + return False + + if processor: + filtered_samples = [] + for sample in origin_samples: + from miles.utils.processing_utils import process_vision_info + + multimodal_inputs = process_vision_info(sample.prompt, processor) + processor_output = processor(text=sample.prompt, **multimodal_inputs) + input_ids = processor_output["input_ids"][0] + if len(input_ids) <= max_length: + filtered_samples.append(sample) + else: + prompts = [sample.prompt for sample in origin_samples] + input_ids_list = tokenizer(prompts, add_special_tokens=False)["input_ids"] + filtered_samples = [ + sample + for sample, input_ids in zip(origin_samples, input_ids_list, strict=True) + if len(input_ids) <= max_length + ] + + logger.info(f"Filtered {len(origin_samples) - len(filtered_samples)} samples longer than max_length={max_length}.") - input_ids, _ = prepare_model_inputs(prompt, tokenizer, processor, None, apply_chat_template_kwargs) - return len(input_ids) > max_length + return filtered_samples -def _build_messages(data: dict, prompt_key: str, multimodal_keys: dict = None): - messages = data.get(prompt_key) +def _build_messages(data: dict, prompt_key: str, as_conversation: bool, multimodal_keys: dict = None): + prompt = data.get(prompt_key) - if isinstance(messages, str): - messages = [{"role": "user", "content": messages}] + if isinstance(prompt, str): + # If prompt is a string and we don't apply chat template, return the prompt as is. + if not as_conversation: + return prompt + else: + prompt = [{"role": "user", "content": prompt}] if multimodal_keys: # Build mapping: placeholder -> (MultimodalType, content_list) @@ -75,7 +132,7 @@ def _build_messages(data: dict, prompt_key: str, multimodal_keys: dict = None): pattern = "(" + "|".join(re.escape(p) for p in multimodals.keys()) + ")" - for message in messages: + for message in prompt: if isinstance(message["content"], str): content_list = [] for segment in re.split(pattern, message["content"]): @@ -105,7 +162,7 @@ def _build_messages(data: dict, prompt_key: str, multimodal_keys: dict = None): f"Unsupported content type: {type(message['content'])}, expected str or list of dicts" ) - return messages + return prompt class Dataset: @@ -125,11 +182,14 @@ def __init__( apply_chat_template=False, apply_chat_template_kwargs=None, ): - self.origin_samples = [] + origin_samples = [] for data in read_file(path): - prompt = _build_messages(data, prompt_key, multimodal_keys) + # Both chat templates and multimodal inputs require conversation format (list of message dicts) + as_conversation = apply_chat_template or (multimodal_keys is not None) + prompt = _build_messages(data, prompt_key, as_conversation, multimodal_keys) metadata = data.get(metadata_key) or {} + tools = None if tool_key is not None and tool_key in data: tools = data[tool_key] if isinstance(tools, str): @@ -139,18 +199,49 @@ def __init__( assert isinstance(tools, list), f"tools must be a list, got {type(tools)} instead" metadata["tools"] = tools - # TODO: this is slow. - if _should_skip_prompt(prompt, tokenizer, processor, max_length, apply_chat_template_kwargs): - continue + if apply_chat_template: + ### DSV32 + try: + prompt = tokenizer.apply_chat_template( + prompt, + tools, + tokenize=False, + add_generation_prompt=True, + **apply_chat_template_kwargs, + ) + except Exception as e: + from sglang.srt.entrypoints.openai.encoding_dsv32 import encode_messages + encode_config = dict(thinking_mode="thinking", drop_thinking=True, add_default_bos_token=True) + prompt = encode_messages(prompt, **encode_config) + ### DSV32 + output_prompt = prompt + else: + output_prompt = prompt + + if processor: + from miles.utils.processing_utils import process_vision_info + + assert isinstance( + prompt, list + ), f"prompt must be a list when processor is not None, got {type(prompt)} instead" + multimodal_inputs = process_vision_info(prompt, processor) + else: + multimodal_inputs = None - self.origin_samples.append( + origin_samples.append( Sample( - prompt=prompt, + prompt=output_prompt, label=data[label_key] if label_key is not None else None, metadata=metadata, + multimodal_inputs=multimodal_inputs, ) ) + if max_length is not None: + self.origin_samples = filter_long_prompt(origin_samples, tokenizer, processor, max_length) + else: + self.origin_samples = origin_samples + self.epoch_id = -1 self.seed = seed self.samples = self.origin_samples diff --git a/miles/utils/debug_utils/display_debug_rollout_data.py b/miles/utils/debug_utils/display_debug_rollout_data.py index 3036e16ea..5775877b5 100644 --- a/miles/utils/debug_utils/display_debug_rollout_data.py +++ b/miles/utils/debug_utils/display_debug_rollout_data.py @@ -6,7 +6,7 @@ import torch import typer -from miles.ray.rollout import compute_metrics_from_samples +from miles.ray.rollout import compute_perf_metrics_from_samples from miles.utils.types import Sample _WHITELIST_KEYS = [ @@ -47,7 +47,7 @@ def main( log_reward_category=None, ) sample_objects = [Sample.from_dict(s) for s in sample_dicts] - metrics = compute_metrics_from_samples(args, sample_objects) + metrics = compute_perf_metrics_from_samples(args, sample_objects) print("metrics", metrics) if show_samples: diff --git a/miles/utils/environ.py b/miles/utils/environ.py new file mode 100644 index 000000000..35d1f350e --- /dev/null +++ b/miles/utils/environ.py @@ -0,0 +1,14 @@ +import os + +_printed_experimental_rollout_refactor = False + + +def enable_experimental_rollout_refactor() -> bool: + result = bool(int(os.environ.get("MILES_EXPERIMENTAL_ROLLOUT_REFACTOR", "0"))) + + global _printed_experimental_rollout_refactor + if result and not _printed_experimental_rollout_refactor: + print("MILES_EXPERIMENTAL_ROLLOUT_REFACTOR=1 is enabled (experimental feature)") + _printed_experimental_rollout_refactor = True + + return result diff --git a/miles/utils/eval_config.py b/miles/utils/eval_config.py index 4a7c1e912..69b4464b4 100644 --- a/miles/utils/eval_config.py +++ b/miles/utils/eval_config.py @@ -111,6 +111,9 @@ class EvalDatasetConfig: top_p: float | None = None top_k: int | None = None max_response_len: int | None = None + stop: list[str] | None = None + stop_token_ids: list[int] | None = None + min_new_tokens: int | None = None metadata_overrides: dict[str, Any] = field(default_factory=dict) diff --git a/miles/utils/external_utils/command_utils.py b/miles/utils/external_utils/command_utils.py index ac7f3ce46..bdfc864b4 100644 --- a/miles/utils/external_utils/command_utils.py +++ b/miles/utils/external_utils/command_utils.py @@ -26,6 +26,7 @@ def convert_checkpoint( extra_args: str = "", dir_dst: str = "/root", hf_checkpoint: str | None = None, + megatron_path: str = "/host_home/primary_synced/Megatron-LM", ): hf_checkpoint = hf_checkpoint or f"/root/models/{model_name}" @@ -51,14 +52,15 @@ def convert_checkpoint( exec_command( f"source {repo_base_dir}/scripts/models/{megatron_model_type}.sh && " - f"PYTHONPATH=/root/Megatron-LM " + # Use installed Megatron instead of hardcoded path + f"PYTHONPATH={megatron_path} " f"torchrun " f"--nproc-per-node {num_gpus_per_node} " f"{multinode_args}" f"tools/convert_hf_to_torch_dist.py " "${MODEL_ARGS[@]} " f"--hf-checkpoint {hf_checkpoint} " - f"--save {path_dst}" + f"--save {path_dst} " f"{extra_args}" ) @@ -67,9 +69,9 @@ def rsync_simple(path_src: str, path_dst: str): exec_command(f"mkdir -p {path_dst} && rsync -a --info=progress2 {path_src}/ {path_dst}") -def hf_download_dataset(full_name: str): +def hf_download_dataset(full_name: str, data_dir: str = "/root/datasets"): _, partial_name = full_name.split("/") - exec_command(f"hf download --repo-type dataset {full_name} --local-dir /root/datasets/{partial_name}") + exec_command(f"hf download --repo-type dataset {full_name} --local-dir {data_dir}/{partial_name}") def fp8_cast_bf16(path_src, path_dst): @@ -98,6 +100,7 @@ def execute_train( before_ray_job_submit=None, extra_env_vars=None, config: ExecuteTrainConfig | None = None, + megatron_path: str = "/host_home/primary_synced/Megatron-LM", ): if extra_env_vars is None: extra_env_vars = {} @@ -139,7 +142,8 @@ def execute_train( runtime_env_json = json.dumps( { "env_vars": { - "PYTHONPATH": "/root/Megatron-LM/", + # Use installed Megatron instead of hardcoded path + "PYTHONPATH": f"{megatron_path}", # If setting this in FSDP, the computation communication overlapping may have issues **( {} @@ -213,12 +217,13 @@ def get_default_wandb_args(test_file: str, run_name_prefix: str | None = None, r if (x := run_name_prefix) is not None: wandb_run_name = f"{x}_{wandb_run_name}" - # do not put wandb_api_key value here to avoid leaking to logs explicitly + # Use the actual key value from environment to avoid shell expansion issues + wandb_key = os.environ.get("WANDB_API_KEY") return ( "--use-wandb " f"--wandb-project miles-{test_name} " f"--wandb-group {wandb_run_name} " - f"--wandb-key ${{WANDB_API_KEY}} " + f"--wandb-key '{wandb_key}' " "--disable-wandb-random-suffix " ) diff --git a/miles/utils/flops_utils.py b/miles/utils/flops_utils.py index 71cdd4c65..75afccc05 100644 --- a/miles/utils/flops_utils.py +++ b/miles/utils/flops_utils.py @@ -6,20 +6,43 @@ def calculate_lm_head_flops(seqlen, hidden_size, vocab_size): return 2 * seqlen * hidden_size * vocab_size -def calculate_qkv_projection_flops(seqlen, hidden_size, num_attention_heads, num_query_groups): - head_dim = hidden_size // num_attention_heads - n_q_heads = num_attention_heads - n_kv_heads = num_query_groups - q_flops = 2 * seqlen * hidden_size * n_q_heads * head_dim - kv_flops = 2 * seqlen * hidden_size * n_kv_heads * head_dim * 2 +def calculate_qkv_projection_flops(args, seqlen, hidden_size, num_attention_heads, num_query_groups): + if args.q_lora_rank is None: + q_flops = 2 * seqlen * hidden_size * num_attention_heads * args.kv_channels + else: + q_flops = ( + 2 + * seqlen + * args.q_lora_rank + * (args.hidden_size + args.num_attention_heads * (args.qk_head_dim + args.qk_pos_emb_head_dim)) + ) + if args.kv_lora_rank is None: + kv_flops = 2 * 2 * seqlen * hidden_size * num_query_groups * args.kv_channels + else: + kv_flops = ( + 2 + * seqlen + * ( + args.kv_lora_rank + * (args.hidden_size + args.num_attention_heads * (args.qk_head_dim + args.v_head_dim)) + + args.hidden_size * args.qk_pos_emb_head_dim + ) + ) + return q_flops + kv_flops -def calculate_attention_flops(seqlen, num_attention_heads, head_dim): +def calculate_attention_flops(args, seqlen, num_attention_heads): # QK^T with causal - flops = 2 * num_attention_heads * seqlen * seqlen * head_dim // 2 + if args.qk_pos_emb_head_dim: + flops = 2 * num_attention_heads * seqlen * seqlen * (args.qk_head_dim + args.qk_pos_emb_head_dim) / 2 + else: + flops = 2 * num_attention_heads * seqlen * seqlen * args.kv_channels / 2 # A*V - flops += 2 * num_attention_heads * seqlen * seqlen * head_dim + if args.v_head_dim: + flops += num_attention_heads * seqlen * seqlen * args.v_head_dim + else: + flops += num_attention_heads * seqlen * seqlen * args.kv_channels return flops @@ -31,12 +54,10 @@ def calculate_mlp_flops(seqlen, hidden_size, ffn_hidden_size): return 2 * seqlen * hidden_size * ffn_hidden_size * 3 -def calculate_layer_flops(seqlen, hidden_size, num_attention_heads, num_query_groups, ffn_hidden_size, head_dim): - if head_dim is None: - head_dim = hidden_size // num_attention_heads +def calculate_layer_flops(args, seqlen, hidden_size, num_attention_heads, num_query_groups, ffn_hidden_size): return ( - calculate_qkv_projection_flops(seqlen, hidden_size, num_attention_heads, num_query_groups) - + calculate_attention_flops(seqlen, num_attention_heads, head_dim) + calculate_qkv_projection_flops(args, seqlen, hidden_size, num_attention_heads, num_query_groups) + + calculate_attention_flops(args, seqlen, num_attention_heads) + calculate_output_flops(seqlen, hidden_size) + calculate_mlp_flops(seqlen, hidden_size, ffn_hidden_size) ) @@ -50,7 +71,6 @@ def calculate_fwd_flops( num_attention_heads = args.num_attention_heads num_query_groups = args.num_query_groups vocab_size = args.vocab_size - kv_channels = args.kv_channels total_flops = 0 @@ -79,12 +99,12 @@ def calculate_fwd_flops( if num_dense_layers > 0: total_flops += ( calculate_layer_flops( + args, seqlen, hidden_size, num_attention_heads, num_query_groups, dense_ffn, - kv_channels, ) * num_dense_layers ) @@ -92,12 +112,12 @@ def calculate_fwd_flops( if num_moe_layers > 0: total_flops += ( calculate_layer_flops( + args, seqlen, hidden_size, num_attention_heads, num_query_groups, moe_ffn, - kv_channels, ) * num_moe_layers ) diff --git a/miles/utils/health_monitor.py b/miles/utils/health_monitor.py index 5757a1675..e95367e95 100644 --- a/miles/utils/health_monitor.py +++ b/miles/utils/health_monitor.py @@ -8,52 +8,130 @@ class RolloutHealthMonitor: + """Health monitor for rollout engines. + + The monitor runs continuously once started, but can be paused/resumed + based on whether the engines are offloaded (cannot health check when offloaded). + + Lifecycle: + - start(): Start the monitor thread (called once during initialization) + - pause(): Pause health checking (called when offloading engines) + - resume(): Resume health checking (called when onloading engines) + - stop(): Stop the monitor thread completely (called during dispose) + """ + def __init__(self, rollout_manager, args): # TODO may remove this dependency after refactoring self._rollout_manager = rollout_manager self._thread = None self._stop_event = None + self._pause_event = None # When set, health checking is paused self._check_interval = args.rollout_health_check_interval self._check_timeout = args.rollout_health_check_timeout self._check_first_wait = args.rollout_health_check_first_wait + self._need_first_wait = True # Need to wait after each resume + self._is_checking_enabled = False # Track if health checking should be active def start(self) -> bool: - if not self._rollout_manager.rollout_engines: + """Start the health monitor thread. Called once during initialization. + + Returns: + True if the monitor was started, False if there are no engines to monitor. + """ + if not self._rollout_manager.all_rollout_engines: return False - assert self._thread is None, "Health monitor thread is already running." + if self._thread is not None: + logger.warning("Health monitor thread is already running.") + return True + logger.info("Starting RolloutHealthMonitor...") self._stop_event = threading.Event() + self._pause_event = threading.Event() + self._pause_event.set() # Start in paused state until resume() is called self._thread = threading.Thread( target=self._health_monitor_loop, name="RolloutHealthMonitor", daemon=True, ) self._thread.start() + logger.info("RolloutHealthMonitor started (in paused state).") return True def stop(self) -> None: + """Stop the health monitor thread completely. Called during dispose.""" if not self._thread: return + logger.info("Stopping RolloutHealthMonitor...") assert self._stop_event is not None self._stop_event.set() + # Also clear pause to let the thread exit + if self._pause_event: + self._pause_event.clear() timeout = self._check_timeout + self._check_interval + 5 self._thread.join(timeout=timeout) if self._thread.is_alive(): logging.warning("Rollout health monitor thread did not terminate within %.1fs", timeout) + else: + logger.info("RolloutHealthMonitor stopped.") self._thread = None self._stop_event = None + self._pause_event = None + self._is_checking_enabled = False + + def pause(self) -> None: + """Pause health checking. Called when engines are offloaded.""" + if self._pause_event is None: + return + logger.info("Pausing health monitor...") + self._pause_event.set() + self._is_checking_enabled = False + + def resume(self) -> None: + """Resume health checking. Called when engines are onloaded.""" + if self._pause_event is None: + return + logger.info("Resuming health monitor...") + self._need_first_wait = True # Need to wait after each resume + self._pause_event.clear() + self._is_checking_enabled = True + + def is_checking_enabled(self) -> bool: + """Return whether health checking is currently enabled (not paused).""" + return self._is_checking_enabled def _health_monitor_loop(self) -> None: assert self._stop_event is not None - # TODO: need to be waiting for the large moe to be ready. this is hacky. - if self._stop_event.wait(self._check_first_wait): - return + assert self._pause_event is not None + while not self._stop_event.is_set(): - self._run_health_checks() + # Wait while paused + while self._pause_event.is_set() and not self._stop_event.is_set(): + self._stop_event.wait(timeout=0.5) + + if self._stop_event.is_set(): + break + + # Do first wait after each resume (for large MoE models to be ready) + if self._need_first_wait: + logger.info(f"Health monitor doing first wait after resume: {self._check_first_wait}s") + if self._stop_event.wait(self._check_first_wait): + logger.info("Health monitor stopped during first wait.") + break + if self._pause_event.is_set(): + # Got paused during first wait, skip this round and wait again next resume + logger.info("Health monitor paused during first wait, will wait again next resume.") + continue + self._need_first_wait = False + + # Run health checks + if not self._pause_event.is_set() and not self._stop_event.is_set(): + self._run_health_checks() + + # Wait for next check interval if self._stop_event.wait(self._check_interval): break @@ -61,29 +139,40 @@ def _run_health_checks(self) -> None: for rollout_engine_id, engine in enumerate(self._rollout_manager.rollout_engines): if self._stop_event is not None and self._stop_event.is_set(): break + if self._pause_event is not None and self._pause_event.is_set(): + break self._check_engine_health(rollout_engine_id, engine) def _check_engine_health(self, rollout_engine_id, engine) -> None: if engine is None: + logger.info(f"Skipping health check for engine {rollout_engine_id} (None)") return try: ray.get(engine.health_generate.remote(timeout=self._check_timeout)) except Exception as e: - logger.info( - f"Health check timed out for rollout engine {rollout_engine_id} (ray timeout). Killing actor. (original exception: {e})" + logger.error( + f"Health check failed for rollout engine {rollout_engine_id} (ray timeout or error). Killing actor. Exception: {e}" ) self._kill_engine(rollout_engine_id=rollout_engine_id) + else: + logger.debug(f"Health check passed for rollout engine {rollout_engine_id}") def _kill_engine(self, rollout_engine_id: int): + logger.info(f"Killing engine group {rollout_engine_id}...") for i in range( rollout_engine_id * self._rollout_manager.nodes_per_engine, (rollout_engine_id + 1) * self._rollout_manager.nodes_per_engine, ): engine = self._rollout_manager.all_rollout_engines[i] - try: - ray.get(engine.shutdown.remote()) - ray.kill(engine) - except Exception as e: - logger.info(f"Fail to kill engine and skip (e: {e})") + if engine: + logger.info(f"Shutting down and killing engine at index {i}") + try: + ray.get(engine.shutdown.remote()) + ray.kill(engine) + logger.info(f"Successfully killed engine at index {i}") + except Exception as e: + logger.warning(f"Fail to kill engine at index {i} (e: {e})") + else: + logger.info(f"Engine at index {i} is already None") self._rollout_manager.all_rollout_engines[i] = None diff --git a/miles/utils/http_utils.py b/miles/utils/http_utils.py index 04b7a677e..0abdbbf59 100644 --- a/miles/utils/http_utils.py +++ b/miles/utils/http_utils.py @@ -45,34 +45,64 @@ def get_host_info(): if env_overwrite_local_ip := os.getenv(MILES_HOST_IP_ENV, None): return hostname, env_overwrite_local_ip - # try DNS - try: - return hostname, socket.gethostbyname(hostname) - except socket.gaierror: - pass + def _is_loopback(ip): + return ip.startswith("127.") or ip == "::1" - # try IPv4 - try: - with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as udp_sock: - udp_sock.connect(("8.8.8.8", 80)) # Google DNS - return hostname, udp_sock.getsockname()[0] - except OSError: - pass + def _resolve_ip(family, test_target_ip): + """ + Attempt to get the local LAN IP for the specific family (IPv4/IPv6). + Strategy: UDP Probe (Preferred) -> Hostname Resolution (Fallback) -> None + """ - # try IPv6 - try: - with socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) as s6: - s6.connect(("2001:4860:4860::8888", 80)) - return hostname, s6.getsockname()[0] - except OSError: - pass + # Strategy 1: UDP Connect Probe (Most accurate, relies on routing table) + # Useful when the machine has a default gateway or internet access. + try: + with socket.socket(family, socket.SOCK_DGRAM) as s: + # The IP doesn't need to be reachable, but the routing table must exist. + s.connect((test_target_ip, 80)) + ip = s.getsockname()[0] + if not _is_loopback(ip): + return ip + except Exception: + pass # Route unreachable or network error, move to next strategy. + + # Strategy 2: Hostname Resolution (Fallback for offline clusters) + # Useful for offline environments where UDP connect fails but /etc/hosts is configured. + try: + # getaddrinfo allows specifying the family (AF_INET or AF_INET6) + # Result format: [(family, type, proto, canonname, sockaddr), ...] + infos = socket.getaddrinfo(hostname, None, family=family, type=socket.SOCK_STREAM) + + for info in infos: + ip = info[4][0] # The first element of sockaddr is the IP + # Must filter out loopback addresses to avoid "127.0.0.1" issues + if not _is_loopback(ip): + return ip + except Exception: + pass - # hostname -I - try: - local_ip = os.popen("hostname -I | awk '{print $1}'").read().strip() - return hostname, local_ip or "::1" - except Exception: - return hostname, "::1" + return None + + prefer_ipv6 = os.getenv("MILES_PREFER_IPV6", "0").lower() in ("1", "true", "yes", "on") + local_ip = None + final_fallback = "127.0.0.1" + + if prefer_ipv6: + # [Strict Mode] IPv6 Only + # 1. Try UDP V6 Probe + # 2. Try Hostname Resolution (V6) + # If failed, fallback to V6 loopback. Never mix with V4. + local_ip = _resolve_ip(socket.AF_INET6, "2001:4860:4860::8888") + final_fallback = "::1" + else: + # [Strict Mode] IPv4 Only (Default) + # 1. Try UDP V4 Probe + # 2. Try Hostname Resolution (V4) + # If failed, fallback to V4 loopback. Never mix with V6. + local_ip = _resolve_ip(socket.AF_INET, "8.8.8.8") + final_fallback = "127.0.0.1" + + return hostname, local_ip or final_fallback def _wrap_ipv6(host): @@ -132,11 +162,15 @@ def _next_actor(): return actor -async def _post(client, url, payload, max_retries=60): +async def _post(client, url, payload, max_retries=60, action="post"): retry_count = 0 while retry_count < max_retries: try: - response = await client.post(url, json=payload or {}) + if action in ("delete", "get"): + assert not payload + response = await getattr(client, action)(url) + else: + response = await getattr(client, action)(url, json=payload or {}) response.raise_for_status() try: output = response.json() @@ -210,8 +244,8 @@ def __init__(self, concurrency: int): timeout=httpx.Timeout(None), ) - async def do_post(self, url, payload, max_retries=60): - return await _post(self._client, url, payload, max_retries) + async def do_post(self, url, payload, max_retries=60, action="post"): + return await _post(self._client, url, payload, max_retries, action=action) # Create actors per node created = [] @@ -235,7 +269,8 @@ async def do_post(self, url, payload, max_retries=60): _post_actors = created -async def post(url, payload, max_retries=60): +# TODO may generalize the name since it now contains http DELETE/GET etc (with retries and remote-execution) +async def post(url, payload, max_retries=60, action="post"): # If distributed mode is enabled and actors exist, dispatch via Ray. if _distributed_post_enabled and _post_actors: try: @@ -244,15 +279,16 @@ async def post(url, payload, max_retries=60): actor = _next_actor() if actor is not None: # Use a thread to avoid blocking the event loop on ray.get - obj_ref = actor.do_post.remote(url, payload, max_retries) + obj_ref = actor.do_post.remote(url, payload, max_retries, action=action) return await asyncio.to_thread(ray.get, obj_ref) except Exception as e: logger.info(f"[http_utils] Distributed POST failed, falling back to local: {e} (url={url})") # fall through to local - return await _post(_http_client, url, payload, max_retries) + return await _post(_http_client, url, payload, max_retries, action=action) +# TODO unify w/ `post` to add retries and remote-execution async def get(url): response = await _http_client.get(url) response.raise_for_status() diff --git a/miles/utils/mask_utils.py b/miles/utils/mask_utils.py index 36fc75aac..0ddb3a141 100644 --- a/miles/utils/mask_utils.py +++ b/miles/utils/mask_utils.py @@ -2,6 +2,7 @@ def get_response_lengths(loss_masks: list[list[int]]) -> list[int]: + # return the lengths starting from the first occurrence of 1 to the end of each loss mask return [len(mask[mask.index(1) :]) if 1 in mask else 0 for mask in loss_masks] @@ -44,12 +45,17 @@ def get_system_message_length(self) -> tuple[int, int]: system_message_length = idx_1 - ((idx_2 - idx_1) - end_interval - len(raw_token_ids)) return system_message_length, gen_token_length - def gen_multi_turn_loss_mask_qwen(self, messages: list[dict]) -> tuple[list[int], list[int]]: + def gen_multi_turn_loss_mask_qwen( + self, messages: list[dict], tools: list[dict] = None + ) -> tuple[list[int], list[int]]: all_loss_masks = [] all_token_ids = [] for i, message in enumerate(messages): - message_ids = self.tokenizer.apply_chat_template([message], tokenize=True) + if i == 0: + message_ids = self.tokenizer.apply_chat_template([message], tokenize=True, tools=tools) + else: + message_ids = self.tokenizer.apply_chat_template([message], tokenize=True) if message["role"] != "system" and i > 0: message_ids = message_ids[self.system_message_length :] @@ -67,7 +73,9 @@ def gen_multi_turn_loss_mask_qwen(self, messages: list[dict]) -> tuple[list[int] return all_token_ids, all_loss_masks - def gen_multi_turn_loss_mask_qwen3(self, messages: list[dict]) -> tuple[list[int], list[int]]: + def gen_multi_turn_loss_mask_qwen3( + self, messages: list[dict], tools: list[dict] = None + ) -> tuple[list[int], list[int]]: all_loss_masks = [] all_token_ids = [] @@ -75,8 +83,14 @@ def gen_multi_turn_loss_mask_qwen3(self, messages: list[dict]) -> tuple[list[int prefix_token_ids = self.tokenizer.apply_chat_template([prefix_message], tokenize=True) for i, message in enumerate(messages): - prefixed_message_ids = self.tokenizer.apply_chat_template([prefix_message, message], tokenize=True) - message_ids = prefixed_message_ids[len(prefix_token_ids) :] + if i == 0: + tailed_message_ids = self.tokenizer.apply_chat_template( + [message, prefix_message], tokenize=True, tools=tools + ) + message_ids = tailed_message_ids[: -len(prefix_token_ids)] + else: + prefixed_message_ids = self.tokenizer.apply_chat_template([prefix_message, message], tokenize=True) + message_ids = prefixed_message_ids[len(prefix_token_ids) :] if message["role"] != "system" and i > 0: message_ids = message_ids[self.system_message_length :] @@ -94,8 +108,12 @@ def gen_multi_turn_loss_mask_qwen3(self, messages: list[dict]) -> tuple[list[int return all_token_ids, all_loss_masks - def gen_multi_turn_loss_mask_distill_qwen(self, messages: list[dict]) -> tuple[list[int], list[int]]: - prompt = self.tokenizer.apply_chat_template(messages[:1], tokenize=False, add_generation_prompt=True) + def gen_multi_turn_loss_mask_distill_qwen( + self, messages: list[dict], tools: list[dict] = None + ) -> tuple[list[int], list[int]]: + prompt = self.tokenizer.apply_chat_template( + messages[:1], tokenize=False, add_generation_prompt=True, tools=tools + ) response = messages[-1]["content"] prompt_tokens = self.tokenizer(prompt, add_special_tokens=False)["input_ids"] response_tokens = self.tokenizer(response, add_special_tokens=False)["input_ids"] @@ -108,19 +126,46 @@ def gen_multi_turn_loss_mask_distill_qwen(self, messages: list[dict]) -> tuple[l loss_mask = [0] * len(token_ids) return token_ids, loss_mask - def get_loss_mask(self, messages: list[dict]) -> list[int]: + def get_loss_mask(self, messages: list[dict], tools: list[dict] = None) -> tuple[list[int], list[int]]: if self.tokenizer_type == "qwen": if "<๏ฝœAssistant๏ฝœ>" in self.tokenizer.get_added_vocab(): - return self.gen_multi_turn_loss_mask_distill_qwen(messages) + return self.gen_multi_turn_loss_mask_distill_qwen(messages, tools) - return self.gen_multi_turn_loss_mask_qwen(messages) + return self.gen_multi_turn_loss_mask_qwen(messages, tools) elif self.tokenizer_type == "qwen3": - return self.gen_multi_turn_loss_mask_qwen3(messages) + return self.gen_multi_turn_loss_mask_qwen3(messages, tools) elif self.tokenizer_type == "distill_qwen": - return self.gen_multi_turn_loss_mask_distill_qwen(messages) + return self.gen_multi_turn_loss_mask_distill_qwen(messages, tools) else: raise ValueError(f"Unsupported tokenizer type: {self.tokenizer_type}") + def get_loss_mask_with_multimodal_alignment( + self, messages: list[dict], input_ids: list[int], tools: list[dict] = None + ) -> tuple[list[int], list[int]]: + text = [] + for msg in messages: + if isinstance(msg.get("content"), list): + text_parts = [] + for item in msg["content"]: + if isinstance(item, dict) and item.get("type") == "text": + text_parts.append(item.get("text", "")) + elif isinstance(item, str): + text_parts.append(item) + text.append({"role": msg["role"], "content": " ".join(text_parts)}) + else: + text.append(msg) + + _, loss_mask_text = self.get_loss_mask(text, tools=tools) + + diff = len(input_ids) - len(loss_mask_text) + assert diff >= 0, ( + f"input_ids (length={len(input_ids)}) is shorter than text loss_mask (length={len(loss_mask_text)}) " + f"Please check if processor and tokenizer tokenization are consistent." + ) + loss_mask = [0] * diff + loss_mask_text + + return input_ids, loss_mask + def get_text_from_loss_mask(self, token_ids: list[int], loss_masks: list[int]) -> list[str]: selected_texts = [] current_tokens = [] diff --git a/miles/utils/megatron_bridge_utils.py b/miles/utils/megatron_bridge_utils.py index d8bc8060b..9e5f065cd 100644 --- a/miles/utils/megatron_bridge_utils.py +++ b/miles/utils/megatron_bridge_utils.py @@ -10,10 +10,13 @@ def patch_megatron_model(model): unwrapped_model = unwrap_model(model)[0] model_config = unwrapped_model.config - assert not hasattr(model_config, "share_embeddings_and_output_weights") - model_config.share_embeddings_and_output_weights = unwrapped_model.share_embeddings_and_output_weights + attribute_was_added = False + if not hasattr(model_config, "share_embeddings_and_output_weights"): + model_config.share_embeddings_and_output_weights = unwrapped_model.share_embeddings_and_output_weights + attribute_was_added = True try: yield finally: - delattr(model_config, "share_embeddings_and_output_weights") + if attribute_was_added: + delattr(model_config, "share_embeddings_and_output_weights") diff --git a/miles/utils/metric_utils.py b/miles/utils/metric_utils.py index fe8d6df50..66292c79e 100644 --- a/miles/utils/metric_utils.py +++ b/miles/utils/metric_utils.py @@ -58,6 +58,8 @@ def compute_statistics(values: list[float]) -> dict[str, float]: return { "mean": np.mean(values).item(), "median": np.median(values).item(), + "max": np.max(values).item(), + "min": np.min(values).item(), } @@ -105,7 +107,7 @@ def compression_ratio( return ratio, savings_pct -def has_repetition(text: str = None): +def has_repetition(text: str): if len(text) > 10000 and compression_ratio(text[-10000:])[0] > 10: return True else: diff --git a/miles/utils/misc.py b/miles/utils/misc.py index 2fe825812..bae72ec0d 100644 --- a/miles/utils/misc.py +++ b/miles/utils/misc.py @@ -1,17 +1,55 @@ +import asyncio import importlib import subprocess +from contextlib import contextmanager import ray from miles.utils.http_utils import is_port_available +# Mainly used for test purpose where `load_function` needs to load many in-flight generated functions +class FunctionRegistry: + def __init__(self): + self._registry: dict[str, object] = {} + + @contextmanager + def temporary(self, name: str, fn: object): + self._register(name, fn) + try: + yield + finally: + self._unregister(name) + + def get(self, name: str) -> object | None: + return self._registry.get(name) + + def _register(self, name: str, fn: object) -> None: + assert name not in self._registry + self._registry[name] = fn + + def _unregister(self, name: str) -> None: + assert name in self._registry + self._registry.pop(name) + + +function_registry = FunctionRegistry() + + +# TODO may rename to `load_object` since it can be used to load things like tool_specs def load_function(path): """ - Load a function from a module. + Load a function from registry or module. :param path: The path to the function, e.g. "module.submodule.function". :return: The function object. """ + if path is None: + return None + + registered = function_registry.get(path) + if registered is not None: + return registered + module_path, _, attr = path.rpartition(".") module = importlib.import_module(module_path) return getattr(module, attr) @@ -30,6 +68,10 @@ def __call__(cls, *args, **kwargs): cls._instances[cls] = instance return cls._instances[cls] + @staticmethod + def clear_all_instances(): + SingletonMeta._instances.clear() + def exec_command(cmd: str, capture_output: bool = False) -> str | None: print(f"EXEC: {cmd}", flush=True) @@ -71,6 +113,7 @@ def should_run_periodic_action( rollout_id: int, interval: int | None, num_rollout_per_epoch: int | None = None, + num_rollout: int | None = None, ) -> bool: """ Return True when a periodic action (eval/save/checkpoint) should run. @@ -83,5 +126,13 @@ def should_run_periodic_action( if interval is None: return False + if num_rollout is not None and rollout_id == num_rollout - 1: + return True + step = rollout_id + 1 return (step % interval == 0) or (num_rollout_per_epoch is not None and step % num_rollout_per_epoch == 0) + + +async def as_completed_async(tasks): + for coro in asyncio.as_completed(tasks): + yield await coro diff --git a/miles/utils/ppo_utils.py b/miles/utils/ppo_utils.py index c301ef624..34904477a 100644 --- a/miles/utils/ppo_utils.py +++ b/miles/utils/ppo_utils.py @@ -6,6 +6,7 @@ import torch import torch.distributed as dist import torch.nn.functional as F +from miles.backends.training_utils.parallel import ParallelState @torch.compile(dynamic=True) @@ -149,6 +150,7 @@ def compute_policy_loss( def compute_log_probs(logits: torch.Tensor, tokens: torch.Tensor, process_group: dist.ProcessGroup | None): + # TODO: when megatron is not installed, fall back to naive implementation from megatron.core.fusions.fused_cross_entropy import fused_vocab_parallel_cross_entropy # convert to [seq_len, batch_size, vocab_size] as expected by fused_vocab_parallel_cross_entropy @@ -215,6 +217,7 @@ def get_reinforce_plus_plus_returns( total_lengths: list[int], kl_coef: float, gamma: float, + parallel_state: ParallelState, ) -> list[torch.Tensor]: """ Calculates discounted returns for REINFORCE++ (https://arxiv.org/pdf/2501.03262) @@ -243,9 +246,9 @@ def get_reinforce_plus_plus_returns( if cp_size > 1: # Step 1,2:Gather all chunks and token_offsets from all ranks and reconstruct the full response tensor by splitting and placing each part - from miles.backends.megatron_utils.cp_utils import all_gather_with_cp + from miles.backends.training_utils.cp_utils import all_gather_with_cp - full_kl_response = all_gather_with_cp(local_kl_chunk, total_len, response_len) + full_kl_response = all_gather_with_cp(local_kl_chunk, total_len, response_len, parallel_state) else: full_kl_response = local_kl_chunk @@ -267,7 +270,7 @@ def get_reinforce_plus_plus_returns( if cp_size > 1: from miles.backends.megatron_utils.cp_utils import slice_log_prob_with_cp - local_returns_chunk = slice_log_prob_with_cp(returns_for_seq, total_len, response_len) + local_returns_chunk = slice_log_prob_with_cp(returns_for_seq, total_len, response_len, parallel_state) else: local_returns_chunk = returns_for_seq @@ -313,6 +316,7 @@ def get_advantages_and_returns( rewards: torch.Tensor, gamma: float, lambd: float, + parallel_state: ParallelState, ) -> tuple[torch.Tensor, torch.Tensor]: """Function that computes advantages and returns from rewards and values. Calculated as in the original PPO paper: https://arxiv.org/abs/1707.06347 @@ -338,10 +342,10 @@ def get_advantages_and_returns( cp_size = mpu.get_context_parallel_world_size() if cp_size > 1: - from miles.backends.megatron_utils.cp_utils import all_gather_with_cp + from miles.backends.training_utils.cp_utils import all_gather_with_cp - full_rewards = all_gather_with_cp(rewards, total_len, response_len) - full_values = all_gather_with_cp(values, total_len, response_len) + full_rewards = all_gather_with_cp(rewards, total_len, response_len, parallel_state) + full_values = all_gather_with_cp(values, total_len, response_len, parallel_state) else: full_rewards = rewards full_values = values @@ -360,8 +364,8 @@ def get_advantages_and_returns( if cp_size > 1: from miles.backends.megatron_utils.cp_utils import slice_log_prob_with_cp - advantages = slice_log_prob_with_cp(full_advantages, total_len, response_len) - returns = slice_log_prob_with_cp(full_returns, total_len, response_len) + advantages = slice_log_prob_with_cp(full_advantages, total_len, response_len, parallel_state) + returns = slice_log_prob_with_cp(full_returns, total_len, response_len, parallel_state) else: advantages = full_advantages returns = full_returns @@ -376,6 +380,7 @@ def get_advantages_and_returns_batch( rewards_list, gamma, lambd, + parallel_state: ParallelState, chunked: bool = True, ): """ @@ -402,7 +407,7 @@ def get_advantages_and_returns_batch( dtype = values_list[0].dtype if cp_size > 1: - from miles.backends.megatron_utils.cp_utils import all_gather_with_cp + from miles.backends.training_utils.cp_utils import all_gather_with_cp full_values_list = [] full_rewards_list = [] @@ -410,8 +415,8 @@ def get_advantages_and_returns_batch( for total_len, resp_len, v, r in zip( total_lengths, response_lengths, values_list, rewards_list, strict=False ): - full_v = all_gather_with_cp(v, total_len, resp_len) - full_r = all_gather_with_cp(r, total_len, resp_len) + full_v = all_gather_with_cp(v, total_len, resp_len, parallel_state) + full_r = all_gather_with_cp(r, total_len, resp_len, parallel_state) full_values_list.append(full_v) full_rewards_list.append(full_r) @@ -450,7 +455,7 @@ def get_advantages_and_returns_batch( returns_list = [] if cp_size > 1: - from miles.backends.megatron_utils.cp_utils import slice_log_prob_with_cp + from miles.backends.training_utils.cp_utils import slice_log_prob_with_cp for total_len, resp_len, adv_row, ret_row in zip( total_lengths, @@ -462,8 +467,8 @@ def get_advantages_and_returns_batch( adv_full = adv_row # shape = [resp_len_i padded to max_len] ret_full = ret_row - adv_sliced = slice_log_prob_with_cp(adv_full[:resp_len], total_len, resp_len) - ret_sliced = slice_log_prob_with_cp(ret_full[:resp_len], total_len, resp_len) + adv_sliced = slice_log_prob_with_cp(adv_full[:resp_len], total_len, resp_len, parallel_state) + ret_sliced = slice_log_prob_with_cp(ret_full[:resp_len], total_len, resp_len, parallel_state) advantages_list.append(adv_sliced) returns_list.append(ret_sliced) @@ -644,21 +649,35 @@ def chunked_gae( return advantages, returns -def calculate_log_probs_and_entropy(logits, tokens, tp_group, with_entropy: bool = False): +def calculate_log_probs_and_entropy(logits, tokens, tp_group, with_entropy: bool = False, chunk_size: int = -1): logits = logits.contiguous() # TODO: not sure why we need to clone the logits here. # Without the clone, the backward will trigger inplace edit error. # It seems that the function with tp will modify the logits inplace. + entropy = None if logits.size(0) != 0: - log_prob = compute_log_probs(logits.clone(), tokens, tp_group) + if chunk_size > 0: + num_chunks = (logits.size(0) - 1) // chunk_size + 1 + tokens_chunks = tokens.chunk(num_chunks, dim=0) + logits_chunks = logits.chunk(num_chunks, dim=0) + log_probs = [] + for tokens_chunk, logits_chunk in zip(tokens_chunks, logits_chunks, strict=True): + log_prob = compute_log_probs(logits_chunk.clone(), tokens_chunk, tp_group) + log_probs.append(log_prob) + log_prob = torch.cat(log_probs, dim=0) + if with_entropy: + entropys = [] + for _, logits_chunk in zip(tokens_chunks, logits_chunks, strict=True): + entropy = compute_entropy_from_logits(logits_chunk.clone(), tp_group) + entropys.append(entropy) + entropy = torch.cat(entropys, dim=0) + else: + log_prob = compute_log_probs(logits.clone(), tokens, tp_group) + if with_entropy: + entropy = compute_entropy_from_logits(logits.clone(), tp_group) else: log_prob = logits.new_zeros((0,)) - - if with_entropy: - if logits.size(0) != 0: - entropy = compute_entropy_from_logits(logits.clone(), tp_group) - else: + if with_entropy: entropy = logits.new_zeros((0,)) - else: - entropy = None + return log_prob, entropy diff --git a/miles/utils/processing_utils.py b/miles/utils/processing_utils.py index 60cd8f255..f36f93c1b 100644 --- a/miles/utils/processing_utils.py +++ b/miles/utils/processing_utils.py @@ -6,6 +6,11 @@ logger = logging.getLogger(__name__) +# Default image patch size for vision-language models +# Note: Qwen3-VL uses 16, Qwen2.5-VL uses 14 +# Reference: https://github.com/QwenLM/Qwen3-VL/blob/main/qwen-vl-utils/README.md +DEFAULT_PATCH_SIZE = 14 + def load_tokenizer(name_or_path: str, **kwargs): return AutoTokenizer.from_pretrained(name_or_path, **kwargs) @@ -25,50 +30,22 @@ def load_processor(name_or_path: str, **kwargs): return proc -def prepare_model_inputs(prompt, tokenizer, processor=None, metadata=None, apply_chat_template_kwargs=None): - """Prepare all inputs for model inference. - - Returns: - tuple: (input_ids, extra_info) - - input_ids: Token IDs for the prompt - - extra_info: Dict with 'images', 'videos', 'multimodal_inputs' (or empty dict) - """ - tools = metadata.get("tools") if metadata else None - text_prompt = tokenizer.apply_chat_template( - prompt, - tools=tools, - tokenize=False, - add_generation_prompt=True, - **(apply_chat_template_kwargs or {}), - ) +def process_vision_info(prompt, processor): + # temporary solution, will write image utils for miles later + from qwen_vl_utils import process_vision_info - if not processor: - input_ids = tokenizer.encode(text_prompt, add_special_tokens=False) - return input_ids, {} + if hasattr(processor.image_processor, "patch_size"): + image_patch_size = processor.image_processor.patch_size else: - # temporary solution, will write image utils for miles later - from qwen_vl_utils import process_vision_info - - images, videos = process_vision_info(prompt) - - # Get input IDs with full prompt (text + multimodal) - processor_output = processor(text=text_prompt, images=images, videos=videos) - input_ids = processor_output["input_ids"][0] - - # Extract multimodal tokens (exclude text-related tokens) - multimodal_inputs = {k: v for k, v in processor_output.items() if k not in ["input_ids", "attention_mask"]} - - extra_info = { - "images": images, - "videos": videos, - "multimodal_inputs": multimodal_inputs, - } - - return input_ids, extra_info + logger.info(f"Using default patch size: {DEFAULT_PATCH_SIZE}") + image_patch_size = DEFAULT_PATCH_SIZE + images, videos = process_vision_info(prompt, image_patch_size=image_patch_size) + multimodal_inputs = {"images": images, "videos": videos} + return multimodal_inputs def encode_image_for_rollout_engine(image) -> str: - """Load an image from path, ensure RGB, encode as JPEG base64 string.""" + """Load an image from path, ensure RGB, encode as PNG base64 string.""" buffer = io.BytesIO() if image.mode != "RGB": image = image.convert("RGB") diff --git a/miles/utils/seqlen_balancing.py b/miles/utils/seqlen_balancing.py index 5bee97c6e..a5dd71f94 100644 --- a/miles/utils/seqlen_balancing.py +++ b/miles/utils/seqlen_balancing.py @@ -165,11 +165,11 @@ def _check_and_sort_partitions(partitions): assert len(partitions) == k_partitions, f"{len(partitions)} != {k_partitions}" seen_idx = set() sorted_partitions = [None] * k_partitions - for _i, partition in enumerate(partitions): - assert len(partition) > 0, f"the {_i}-th partition is empty" + for i, partition in enumerate(partitions): + assert len(partition) > 0, f"the {i}-th partition is empty" for idx in partition: seen_idx.add(idx) - sorted_partitions[_i] = sorted(partition) + sorted_partitions[i] = sorted(partition) assert seen_idx == set(range(len(seqlen_list))) return sorted_partitions diff --git a/miles/utils/test_utils/__init__.py b/miles/utils/test_utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/miles/utils/test_utils/mock_sglang_server.py b/miles/utils/test_utils/mock_sglang_server.py new file mode 100644 index 000000000..2c0dddfe5 --- /dev/null +++ b/miles/utils/test_utils/mock_sglang_server.py @@ -0,0 +1,248 @@ +import asyncio +import re +import time +import uuid +from collections.abc import Callable +from contextlib import contextmanager +from dataclasses import asdict, dataclass + +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse +from pydantic import TypeAdapter +from sglang.srt.entrypoints.openai.protocol import Tool +from sglang.srt.function_call.function_call_parser import FunctionCallParser +from transformers import AutoTokenizer + +from miles.utils.http_utils import find_available_port +from miles.utils.test_utils.uvicorn_thread_server import UvicornThreadServer + + +@dataclass(frozen=True) +class ProcessResultMetaInfo: + weight_version: str | None = None + routed_experts: str | None = None + spec_accept_token_num: int | None = None + spec_draft_token_num: int | None = None + spec_verify_ct: int | None = None + + def to_dict(self) -> dict: + return {k: v for k, v in asdict(self).items() if v is not None} + + +@dataclass(frozen=True) +class ProcessResult: + text: str + finish_reason: str = "stop" + cached_tokens: int = 0 + meta_info: ProcessResultMetaInfo = ProcessResultMetaInfo() + + +ProcessFn = Callable[[str], ProcessResult] + + +class MockSGLangServer: + def __init__( + self, + model_name: str, + process_fn: ProcessFn, + host: str, + port: int, + latency: float = 0.0, + ): + self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + self.process_fn = process_fn + self.host = host + self.port = port or find_available_port(30000) + self.latency = latency + + self.app = FastAPI() + self._server: UvicornThreadServer | None = None + + self.request_log: list[dict] = [] + self._concurrency = Counter() + + self._setup_routes() + + @property + def max_concurrent(self) -> int: + return self._concurrency.max_value + + def reset_stats(self): + self.request_log.clear() + self._concurrency.reset() + + def start(self): + self._server = UvicornThreadServer(self.app, host=self.host, port=self.port) + self._server.start() + + def stop(self): + if self._server is not None: + self._server.stop() + + @property + def url(self) -> str: + return f"http://{self.host}:{self.port}" + + def _setup_routes(self): + @self.app.post("/generate") + async def generate(request: Request): + return await self._handle_generate_like_request(request, self._compute_generate_response) + + @self.app.post("/v1/chat/completions") + async def chat_completions(request: Request): + return await self._handle_generate_like_request(request, self._compute_chat_completions_response) + + @self.app.get("/health") + async def health(): + return JSONResponse(content={"status": "ok"}) + + @self.app.post("/abort_request") + async def abort_request(_request: Request): + return JSONResponse(content={"status": "ok"}) + + async def _handle_generate_like_request(self, request: Request, compute_fn: Callable[[dict], dict]): + payload = await request.json() + self.request_log.append(payload) + with self._concurrency.track(): + if self.latency > 0: + await asyncio.sleep(self.latency) + response = compute_fn(payload) + return JSONResponse(content=response) + + def _compute_generate_response(self, payload: dict) -> dict: + assert payload.get("return_logprob", True) is True, "MockSGLangServer requires return_logprob=True" + input_ids = payload.get("input_ids", []) + + prompt_str = self.tokenizer.decode(input_ids, skip_special_tokens=False) + process_result = self.process_fn(prompt_str) + output_ids = self.tokenizer.encode(process_result.text, add_special_tokens=False) + + prompt_tokens = len(input_ids) + completion_tokens = len(output_ids) + + finish_reason_dict = {"type": process_result.finish_reason} + if process_result.finish_reason == "length": + finish_reason_dict["length"] = completion_tokens + + output_token_logprobs = [(-1 / 128 * i, token_id) for i, token_id in enumerate(output_ids)] + + meta_info = { + "finish_reason": finish_reason_dict, + "prompt_tokens": prompt_tokens, + "cached_tokens": process_result.cached_tokens, + "completion_tokens": completion_tokens, + "output_token_logprobs": output_token_logprobs, + **process_result.meta_info.to_dict(), + } + + return {"text": process_result.text, "meta_info": meta_info} + + def _compute_chat_completions_response(self, payload: dict) -> dict: + messages = payload.get("messages", []) + tools = payload.get("tools") + + prompt_str = self.tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True, tools=tools + ) + + process_result = self.process_fn(prompt_str) + output_ids = self.tokenizer.encode(process_result.text, add_special_tokens=False) + + logprobs_content = [ + {"token": self.tokenizer.convert_ids_to_tokens(tid), "logprob": -1 / 128 * i} + for i, tid in enumerate(output_ids) + ] + + finish_reason = process_result.finish_reason + tool_calls = None + if tools and finish_reason == "stop": + parser = FunctionCallParser( + tools=TypeAdapter(list[Tool]).validate_python(tools), + tool_call_parser="qwen25", + ) + message_content, parsed_calls = parser.parse_non_stream(process_result.text) + if parsed_calls: + finish_reason = "tool_calls" + tool_calls = [ + { + "id": f"call{i:05d}", + "type": "function", + "function": {"name": call.name, "arguments": call.parameters or "{}"}, + } + for i, call in enumerate(parsed_calls) + ] + else: + message_content = process_result.text + + return { + "id": f"chatcmpl-{uuid.uuid4().hex[:8]}", + "object": "chat.completion", + "created": int(time.time()), + "model": "mock-model", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": message_content, + "tool_calls": tool_calls, + }, + "logprobs": {"content": logprobs_content}, + "finish_reason": finish_reason, + } + ], + } + + +class Counter: + def __init__(self): + self._current = 0 + self._max = 0 + + @property + def max_value(self) -> int: + return self._max + + def reset(self): + self._current = 0 + self._max = 0 + + @contextmanager + def track(self): + self._current += 1 + self._max = max(self._max, self._current) + try: + yield + finally: + self._current -= 1 + + +def default_process_fn(prompt: str) -> ProcessResult: + match = re.search(r"What is 1\+(\d+)\?", prompt) + if match: + num = int(match.group(1)) + ans = 1 + num + return ProcessResult(text=f"\\boxed{{{ans}}}", finish_reason="stop") + return ProcessResult(text="I don't understand.", finish_reason="stop") + + +@contextmanager +def with_mock_server( + model_name: str = "Qwen/Qwen3-0.6B", + process_fn: ProcessFn = default_process_fn, + host: str = "127.0.0.1", + port: int | None = None, + latency: float = 0.0, +): + server = MockSGLangServer( + model_name=model_name, + process_fn=process_fn, + host=host, + port=port, + latency=latency, + ) + try: + server.start() + yield server + finally: + server.stop() diff --git a/miles/utils/test_utils/mock_tools.py b/miles/utils/test_utils/mock_tools.py new file mode 100644 index 000000000..6b99e3673 --- /dev/null +++ b/miles/utils/test_utils/mock_tools.py @@ -0,0 +1,268 @@ +import json + +from transformers import AutoTokenizer + +from miles.utils.test_utils.mock_sglang_server import ProcessResult + +SAMPLE_TOOLS = [ + { + "type": "function", + "function": { + "name": "get_year", + "description": "Get current year", + "parameters": { + "type": "object", + "properties": {}, + "required": [], + }, + }, + }, + { + "type": "function", + "function": { + "name": "get_temperature", + "description": "Get temperature for a location", + "parameters": { + "type": "object", + "properties": {"location": {"type": "string"}}, + "required": ["location"], + }, + }, + }, +] + + +def _get_year(params: dict) -> str: + assert len(params) == 0 + return json.dumps({"year": 2026}) + + +def _get_temperature(params: dict) -> str: + temps = {"Mars": -60, "Earth": 15} + location = params.get("location") + assert location in temps, f"Unknown location: {location}" + return json.dumps({"temperature": temps[location]}) + + +TOOL_EXECUTORS = { + "get_year": _get_year, + "get_temperature": _get_temperature, +} + + +async def execute_tool_call(name: str, params: dict) -> str: + return TOOL_EXECUTORS[name](params) + + +_SYSTEM_PROMPT = ( + "<|im_start|>system\n" + "# Tools\n" + "\n" + "You may call one or more functions to assist with the user query.\n" + "\n" + "You are provided with function signatures within XML tags:\n" + "\n" + '{"type": "function", "function": {"name": "get_year", "description": "Get current year", "parameters": {"type": "object", "properties": {}, "required": []}}}\n' + '{"type": "function", "function": {"name": "get_temperature", "description": "Get temperature for a location", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}}}\n' + "\n" + "\n" + "For each function call, return a json object with function name and arguments within XML tags:\n" + "\n" + '{"name": , "arguments": }\n' + "<|im_end|>\n" +) + + +_TOKENIZER = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B", trust_remote_code=True) + + +class TwoTurnStub: + """Stub for 2-turn: get_year + get_temperature(Mars) -> final answer""" + + USER_QUESTION = "What is 42 + year + temperature?" + + FIRST_RESPONSE = ( + "Let me get the year and temperature first.\n" + "\n" + '{"name": "get_year", "arguments": {}}\n' + "\n" + "\n" + '{"name": "get_temperature", "arguments": {"location": "Mars"}}\n' + "<|im_end|>\n" + ) + + FIRST_TOOL_RESPONSE = ( + "<|im_start|>user\n" + "\n" + '{"year": 2026}\n' + "\n" + "\n" + '{"temperature": -60}\n' + "<|im_end|>\n" + "<|im_start|>assistant\n" + ) + + SECOND_RESPONSE = "The answer is: 42 + 2026 + -60 = 2008." + + FIRST_PROMPT = _SYSTEM_PROMPT + "<|im_start|>user\n" + USER_QUESTION + "<|im_end|>\n" + "<|im_start|>assistant\n" + SECOND_PROMPT = FIRST_PROMPT + FIRST_RESPONSE + FIRST_TOOL_RESPONSE + + PROMPT = [{"role": "user", "content": USER_QUESTION}] + + FIRST_PROMPT_TOKEN_IDS = _TOKENIZER(FIRST_PROMPT, add_special_tokens=False)["input_ids"] + SECOND_PROMPT_TOKEN_IDS = _TOKENIZER(SECOND_PROMPT, add_special_tokens=False)["input_ids"] + + FIRST_RESPONSE_CONTENT = "Let me get the year and temperature first." + FIRST_TOOL_CALLS_OPENAI_FORMAT = [ + {"id": "call00000", "function": {"arguments": "{}", "name": "get_year"}, "type": "function"}, + { + "id": "call00001", + "function": {"arguments": '{"location": "Mars"}', "name": "get_temperature"}, + "type": "function", + }, + ] + + OPENAI_MESSAGES_FIRST_TURN = [{"role": "user", "content": USER_QUESTION}] + + OPENAI_MESSAGES_SECOND_TURN_FROM_CLIENT = OPENAI_MESSAGES_FIRST_TURN + [ + { + "content": FIRST_RESPONSE_CONTENT, + "refusal": None, + "role": "assistant", + "annotations": None, + "audio": None, + "function_call": None, + "tool_calls": FIRST_TOOL_CALLS_OPENAI_FORMAT, + }, + {"role": "tool", "tool_call_id": "call00000", "content": '{"year": 2026}', "name": "get_year"}, + {"role": "tool", "tool_call_id": "call00001", "content": '{"temperature": -60}', "name": "get_temperature"}, + ] + + @staticmethod + def process_fn(prompt: str) -> ProcessResult: + prompt_response_pairs = { + TwoTurnStub.FIRST_PROMPT: TwoTurnStub.FIRST_RESPONSE, + TwoTurnStub.SECOND_PROMPT: TwoTurnStub.SECOND_RESPONSE, + } + + for expect_prompt, response in prompt_response_pairs.items(): + if prompt == expect_prompt: + return ProcessResult(text=response, finish_reason="stop") + + raise ValueError(f"Unexpected {prompt=}") + + +class ThreeTurnStub: + """Stub for 3-turn: get_year + get_temperature(Mars) -> get_temperature(Earth) -> final answer""" + + USER_QUESTION = "What is 42 + year + Mars temperature + Earth temperature?" + + FIRST_RESPONSE = ( + "Let me get the year and Mars temperature first.\n" + "\n" + '{"name": "get_year", "arguments": {}}\n' + "\n" + "\n" + '{"name": "get_temperature", "arguments": {"location": "Mars"}}\n' + "<|im_end|>\n" + ) + + SECOND_RESPONSE = ( + "Now let me get Earth temperature.\n" + "\n" + '{"name": "get_temperature", "arguments": {"location": "Earth"}}\n' + "<|im_end|>\n" + ) + + FIRST_TOOL_RESPONSE = ( + "<|im_start|>user\n" + "\n" + '{"year": 2026}\n' + "\n" + "\n" + '{"temperature": -60}\n' + "<|im_end|>\n" + "<|im_start|>assistant\n" + ) + + SECOND_TOOL_RESPONSE = ( + "<|im_start|>user\n" + "\n" + '{"temperature": 15}\n' + "<|im_end|>\n" + "<|im_start|>assistant\n" + ) + + THIRD_RESPONSE = "The answer is: 42 + 2026 + -60 + 15 = 2023." + + FIRST_PROMPT = _SYSTEM_PROMPT + "<|im_start|>user\n" + USER_QUESTION + "<|im_end|>\n" + "<|im_start|>assistant\n" + SECOND_PROMPT = FIRST_PROMPT + FIRST_RESPONSE + FIRST_TOOL_RESPONSE + THIRD_PROMPT = SECOND_PROMPT + SECOND_RESPONSE + SECOND_TOOL_RESPONSE + + PROMPT = [{"role": "user", "content": USER_QUESTION}] + + FIRST_PROMPT_TOKEN_IDS = _TOKENIZER(FIRST_PROMPT, add_special_tokens=False)["input_ids"] + SECOND_PROMPT_TOKEN_IDS = _TOKENIZER(SECOND_PROMPT, add_special_tokens=False)["input_ids"] + THIRD_PROMPT_TOKEN_IDS = _TOKENIZER(THIRD_PROMPT, add_special_tokens=False)["input_ids"] + + FIRST_RESPONSE_CONTENT = "Let me get the year and Mars temperature first." + FIRST_TOOL_CALLS_OPENAI_FORMAT = [ + {"id": "call00000", "function": {"arguments": "{}", "name": "get_year"}, "type": "function"}, + { + "id": "call00001", + "function": {"arguments": '{"location": "Mars"}', "name": "get_temperature"}, + "type": "function", + }, + ] + + SECOND_RESPONSE_CONTENT = "Now let me get Earth temperature." + SECOND_TOOL_CALLS_OPENAI_FORMAT = [ + { + "id": "call00000", + "function": {"arguments": '{"location": "Earth"}', "name": "get_temperature"}, + "type": "function", + }, + ] + + OPENAI_MESSAGES_FIRST_TURN = [{"role": "user", "content": USER_QUESTION}] + + OPENAI_MESSAGES_SECOND_TURN_FROM_CLIENT = OPENAI_MESSAGES_FIRST_TURN + [ + { + "content": FIRST_RESPONSE_CONTENT, + "refusal": None, + "role": "assistant", + "annotations": None, + "audio": None, + "function_call": None, + "tool_calls": FIRST_TOOL_CALLS_OPENAI_FORMAT, + }, + {"role": "tool", "tool_call_id": "call00000", "content": '{"year": 2026}', "name": "get_year"}, + {"role": "tool", "tool_call_id": "call00001", "content": '{"temperature": -60}', "name": "get_temperature"}, + ] + + OPENAI_MESSAGES_THIRD_TURN_FROM_CLIENT = OPENAI_MESSAGES_SECOND_TURN_FROM_CLIENT + [ + { + "content": SECOND_RESPONSE_CONTENT, + "refusal": None, + "role": "assistant", + "annotations": None, + "audio": None, + "function_call": None, + "tool_calls": SECOND_TOOL_CALLS_OPENAI_FORMAT, + }, + {"role": "tool", "tool_call_id": "call00000", "content": '{"temperature": 15}', "name": "get_temperature"}, + ] + + @staticmethod + def process_fn(prompt: str) -> ProcessResult: + prompt_response_pairs = { + ThreeTurnStub.FIRST_PROMPT: ThreeTurnStub.FIRST_RESPONSE, + ThreeTurnStub.SECOND_PROMPT: ThreeTurnStub.SECOND_RESPONSE, + ThreeTurnStub.THIRD_PROMPT: ThreeTurnStub.THIRD_RESPONSE, + } + + for expect_prompt, response in prompt_response_pairs.items(): + if prompt == expect_prompt: + return ProcessResult(text=response, finish_reason="stop") + + raise ValueError(f"Unexpected {prompt=}") diff --git a/miles/utils/test_utils/uvicorn_thread_server.py b/miles/utils/test_utils/uvicorn_thread_server.py new file mode 100644 index 000000000..904343c98 --- /dev/null +++ b/miles/utils/test_utils/uvicorn_thread_server.py @@ -0,0 +1,49 @@ +import asyncio +import socket +import threading +import time + +import uvicorn + + +class UvicornThreadServer: + def __init__(self, app, host: str, port: int): + self._app = app + self.host = host + self.port = port + self._server: uvicorn.Server | None = None + self._thread: threading.Thread | None = None + + @property + def url(self) -> str: + return f"http://{self.host}:{self.port}" + + def start(self) -> None: + config = uvicorn.Config(self._app, host=self.host, port=self.port, log_level="info") + self._server = uvicorn.Server(config) + + def run() -> None: + asyncio.run(self._server.serve()) + + self._thread = threading.Thread(target=run, daemon=True) + self._thread.start() + self._wait_for_port_open() + + def stop(self) -> None: + if self._server is not None: + self._server.should_exit = True + if self._thread is not None and self._thread.is_alive(): + self._thread.join(timeout=2.0) + + def _wait_for_port_open(self) -> None: + for _ in range(50): + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + result = sock.connect_ex((self.host, self.port)) + sock.close() + if result == 0: + return + except Exception: + pass + time.sleep(0.1) + raise RuntimeError(f"Failed to start server on {self.url}") diff --git a/miles/utils/types.py b/miles/utils/types.py index 9003050a2..5200d625e 100644 --- a/miles/utils/types.py +++ b/miles/utils/types.py @@ -14,7 +14,8 @@ class Sample: # prompt prompt: str | list[dict[str, str]] = "" tokens: list[int] = field(default_factory=list) - multimodal_inputs: dict[str, Any] = None + multimodal_inputs: dict[str, Any] = None # raw multimodal data, e.g. images, videos, etc. + multimodal_train_inputs: dict[str, Any] = None # processed multimodal data, e.g. pixel_values, etc. # response response: str = "" response_length: int = 0 @@ -31,6 +32,10 @@ class Status(Enum): COMPLETED = "completed" TRUNCATED = "truncated" ABORTED = "aborted" + # Indicates a recoverable or non-critical failure during generation (e.g., tool call failure, + # external API error, parsing error). Unlike ABORTED, FAILED samples may still contain partial + # valid output and can be retried or handled gracefully. + FAILED = "failed" status: Status = Status.PENDING @@ -38,31 +43,35 @@ class Status(Enum): # metadata used during training, e.g., what loss to use for this sample. train_metadata: dict | None = None + non_generation_time: float = 0.0 # time spent in non-generation steps + + @dataclass class SpecInfo: spec_accept_token_num: int = 0 spec_draft_token_num: int = 0 spec_verify_ct: int = 0 - spec_accept_rate: float = 0.0 - spec_accept_length: float = 0.0 - - def add(self, meta_info: dict, response_length: int): - self.spec_accept_token_num += meta_info["spec_accept_token_num"] - self.spec_draft_token_num += meta_info["spec_draft_token_num"] - self.spec_verify_ct += meta_info["spec_verify_ct"] - if self.spec_draft_token_num > 0: - # Notice: this does not iclude the bonus token generated by verify step. - self.spec_accept_rate = self.spec_accept_token_num / self.spec_draft_token_num - # self.spec_accept_rate = meta_info["spec_accept_rate"] # - if self.spec_verify_ct > 0: - self.spec_accept_length = response_length / self.spec_verify_ct + completion_token_num: int = 0 + + @property + def spec_accept_rate(self) -> float: + return self.spec_accept_token_num / self.spec_draft_token_num if self.spec_draft_token_num > 0 else 0.0 + + @property + def spec_accept_length(self) -> float: + return self.completion_token_num / self.spec_verify_ct if self.spec_verify_ct > 0 else 0.0 + + def add(self, meta_info: dict): + self.spec_accept_token_num += meta_info.get("spec_accept_token_num", 0) + self.spec_draft_token_num += meta_info.get("spec_draft_token_num", 0) + self.spec_verify_ct += meta_info.get("spec_verify_ct", 0) + self.completion_token_num += meta_info.get("completion_tokens", 0) def to_dict(self): return { "spec_accept_token_num": self.spec_accept_token_num, "spec_draft_token_num": self.spec_draft_token_num, "spec_verify_ct": self.spec_verify_ct, - "spec_accept_rate": self.spec_accept_rate, - "spec_accept_length": self.spec_accept_length, + "completion_token_num": self.completion_token_num, } @staticmethod @@ -71,23 +80,63 @@ def from_dict(data: dict): info.spec_accept_token_num = data.get("spec_accept_token_num", 0) info.spec_draft_token_num = data.get("spec_draft_token_num", 0) info.spec_verify_ct = data.get("spec_verify_ct", 0) - info.spec_accept_rate = data.get("spec_accept_rate", 0.0) - info.spec_accept_length = data.get("spec_accept_length", 0.0) + info.completion_token_num = data.get("completion_token_num", 0) return info spec_info: SpecInfo = field(default_factory=SpecInfo) + @dataclass + class PrefixCacheInfo: + cached_tokens: int = 0 + total_prompt_tokens: int = 0 + + @property + def prefix_cache_hit_rate(self) -> float: + return self.cached_tokens / self.total_prompt_tokens if self.total_prompt_tokens > 0 else 0.0 + + def add(self, meta_info: dict): + self.cached_tokens += meta_info.get("cached_tokens", 0) + # new_tokens = input_tokens - cached_tokens + self.total_prompt_tokens += meta_info.get("prompt_tokens", 0) + + def to_dict(self): + return { + "cached_tokens": self.cached_tokens, + "total_prompt_tokens": self.total_prompt_tokens, + } + + @staticmethod + def from_dict(data: dict): + info = Sample.PrefixCacheInfo() + info.cached_tokens = data.get("cached_tokens", 0) + info.total_prompt_tokens = data.get("total_prompt_tokens", 0) + return info + + prefix_cache_info: PrefixCacheInfo = field(default_factory=PrefixCacheInfo) + def to_dict(self): value = self.__dict__.copy() value["status"] = self.status.value value["spec_info"] = self.spec_info.to_dict() + value["prefix_cache_info"] = self.prefix_cache_info.to_dict() return value @staticmethod def from_dict(data: dict): + data = dict(data) data["status"] = Sample.Status(data["status"]) data["spec_info"] = Sample.SpecInfo.from_dict(data.get("spec_info", {})) - return Sample(**data) + data["prefix_cache_info"] = Sample.PrefixCacheInfo.from_dict(data.get("prefix_cache_info", {})) + + field_names = set(Sample.__dataclass_fields__.keys()) + init_data = {k: v for k, v in data.items() if k in field_names} + sample = Sample(**init_data) + + for key, value in data.items(): + if key not in field_names: + setattr(sample, key, value) + + return sample def get_reward_value(self, args) -> float: return self.reward if not args.reward_key else self.reward[args.reward_key] @@ -96,6 +145,47 @@ def get_reward_value(self, args) -> float: def effective_response_length(self): return sum(self.loss_mask) if self.loss_mask is not None else self.response_length + def validate(self): + assert self.response_length >= 0, f"response_length must be >= 0, got {self.response_length}" + assert ( + len(self.tokens) >= self.response_length + ), f"tokens length ({len(self.tokens)}) must be >= response_length ({self.response_length})" + if self.loss_mask is not None: + assert ( + len(self.loss_mask) == self.response_length + ), f"loss_mask length ({len(self.loss_mask)}) != response_length ({self.response_length})" + if self.rollout_log_probs is not None: + assert ( + len(self.rollout_log_probs) == self.response_length + ), f"rollout_log_probs length ({len(self.rollout_log_probs)}) != response_length ({self.response_length})" + if self.rollout_routed_experts is not None: + actual = len(self.rollout_routed_experts) + expect = len(self.tokens) - 1 + assert actual == expect, f"rollout_routed_experts length ({actual}) != len(tokens) - 1 ({expect})" + + def update_from_meta_info(self, args, meta_info: dict): + """ + Update the sample with new information from meta_info returned by the rollout engine. + And extract + """ + if args.sglang_speculative_algorithm: + # cannot directly use spec info from sglang because of partial rollout. + self.spec_info.add(meta_info=meta_info) + + # Collect prefix cache statistics + self.prefix_cache_info.add(meta_info=meta_info) + + if "weight_version" in meta_info: + self.weight_versions.append(meta_info["weight_version"]) + + match meta_info["finish_reason"]["type"]: + case "length": + self.status = Sample.Status.TRUNCATED + case "abort": + self.status = Sample.Status.ABORTED + case "stop": + self.status = Sample.Status.COMPLETED + @dataclass(frozen=True) class ParamInfo: diff --git a/miles/utils/wandb_utils.py b/miles/utils/wandb_utils.py index 4b2bfbfc5..e890a8771 100644 --- a/miles/utils/wandb_utils.py +++ b/miles/utils/wandb_utils.py @@ -157,16 +157,3 @@ def _init_wandb_common(): wandb.define_metric("eval/step") wandb.define_metric("eval/*", step_metric="eval/step") wandb.define_metric("perf/*", step_metric="rollout/step") - - -def get_wandb_offline_dir(args): - """Get the directory where offline W&B data is stored.""" - if _is_offline_mode(args): - if args and hasattr(args, "wandb_dir") and args.wandb_dir: - # Use custom directory if specified - return args.wandb_dir - else: - # Default offline directory is ~/wandb/offline-run- - # This will be created automatically by wandb - return os.path.expanduser("~/wandb") - return None diff --git a/miles_plugins/mbridge/__init__.py b/miles_plugins/mbridge/__init__.py index f97c7f46e..67b824aa9 100644 --- a/miles_plugins/mbridge/__init__.py +++ b/miles_plugins/mbridge/__init__.py @@ -1,6 +1,29 @@ +import os +import sys + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../..")) + +from .deepseekv32 import DeepseekV32Bridge from .glm4 import GLM4Bridge from .glm4moe import GLM4MoEBridge from .mimo import MimoBridge from .qwen3_next import Qwen3NextBridge -__all__ = ["GLM4Bridge", "GLM4MoEBridge", "Qwen3NextBridge", "MimoBridge"] +__all__ = ["DeepseekV32Bridge", "GLM4Bridge", "GLM4MoEBridge", "Qwen3NextBridge", "MimoBridge"] + +from mbridge import AutoBridge + +_original_from_config = AutoBridge.from_config + + +@classmethod +def _patched_from_config(cls, hf_config, **kwargs): + if hasattr(hf_config, "index_n_heads"): + from mbridge.core.bridge import _MODEL_REGISTRY + + return _MODEL_REGISTRY["deepseek_v32"](hf_config, **kwargs) + + return _original_from_config(hf_config, **kwargs) + + +AutoBridge.from_config = _patched_from_config diff --git a/miles_plugins/mbridge/deepseekv32.py b/miles_plugins/mbridge/deepseekv32.py new file mode 100644 index 000000000..19e417fa0 --- /dev/null +++ b/miles_plugins/mbridge/deepseekv32.py @@ -0,0 +1,57 @@ +from megatron.core.transformer.enums import AttnBackend + +from mbridge.core import register_model +from mbridge.models import DeepseekV3Bridge + + +@register_model("deepseek_v32") +class DeepseekV32Bridge(DeepseekV3Bridge): + + # Weights with parallel_mode="duplicated" that should NOT be gathered across TP + _DUPLICATED_WEIGHTS = { + "self_attention.core_attention.indexer.linear_wq_b.weight", + "self_attention.core_attention.indexer.linear_wk.weight", + "self_attention.core_attention.indexer.linear_weights_proj.weight", + } + + _ATTENTION_MAPPING = DeepseekV3Bridge._ATTENTION_MAPPING.copy() + + # Because the indexer needs the norm output, we cannot use the fused transformer engine impl and have to compute it separately. + if "self_attention.linear_q_up_proj.layer_norm_weight" in _ATTENTION_MAPPING: + del _ATTENTION_MAPPING["self_attention.linear_q_up_proj.layer_norm_weight"] + if "self_attention.linear_kv_up_proj.layer_norm_weight" in _ATTENTION_MAPPING: + del _ATTENTION_MAPPING["self_attention.linear_kv_up_proj.layer_norm_weight"] + + _ATTENTION_MAPPING.update( + { + "self_attention.q_layernorm.weight": ["model.layers.{layer_number}.self_attn.q_a_layernorm.weight"], + "self_attention.kv_layernorm.weight": ["model.layers.{layer_number}.self_attn.kv_a_layernorm.weight"], + "self_attention.core_attention.indexer.linear_wq_b.weight": [ + "model.layers.{layer_number}.self_attn.indexer.wq_b.weight" + ], + "self_attention.core_attention.indexer.linear_wk.weight": [ + "model.layers.{layer_number}.self_attn.indexer.wk.weight" + ], + "self_attention.core_attention.indexer.k_norm.weight": [ + "model.layers.{layer_number}.self_attn.indexer.k_norm.weight" + ], + "self_attention.core_attention.indexer.k_norm.bias": [ + "model.layers.{layer_number}.self_attn.indexer.k_norm.bias" + ], + "self_attention.core_attention.indexer.linear_weights_proj.weight": [ + "model.layers.{layer_number}.self_attn.indexer.weights_proj.weight" + ], + } + ) + + def _build_config(self): + config = super()._build_config() + + config.attention_backend = AttnBackend.auto + + config.experimental_attention_variant = "dsa" + config.dsa_indexer_n_heads = getattr(self.hf_config, "dsa_indexer_n_heads", 64) + config.dsa_indexer_head_dim = getattr(self.hf_config, "dsa_indexer_head_dim", 128) + config.dsa_indexer_topk = getattr(self.hf_config, "dsa_indexer_topk", 2048) + + return config diff --git a/miles_plugins/mbridge/qwen3_next.py b/miles_plugins/mbridge/qwen3_next.py index 377ba18f6..8a86dcc57 100644 --- a/miles_plugins/mbridge/qwen3_next.py +++ b/miles_plugins/mbridge/qwen3_next.py @@ -29,8 +29,8 @@ class Qwen3NextBridge(Qwen2MoEBridge): ] } | { - "self_attention.linear_qgkv.layer_norm_weight": ["model.layers.{layer_number}.input_layernorm.weight"], - "self_attention.linear_qgkv.weight": [ + "self_attention.linear_qkv.layer_norm_weight": ["model.layers.{layer_number}.input_layernorm.weight"], + "self_attention.linear_qkv.weight": [ "model.layers.{layer_number}.self_attn.q_proj.weight", "model.layers.{layer_number}.self_attn.k_proj.weight", "model.layers.{layer_number}.self_attn.v_proj.weight", @@ -41,7 +41,7 @@ class Qwen3NextBridge(Qwen2MoEBridge): def _weight_to_mcore_format( self, mcore_weights_name: str, hf_weights: list[torch.Tensor] ) -> tuple[list[str], list[torch.Tensor]]: - if "self_attention.linear_qgkv." in mcore_weights_name and "layer_norm" not in mcore_weights_name: + if "self_attention.linear_qkv." in mcore_weights_name and "layer_norm" not in mcore_weights_name: # merge qkv assert len(hf_weights) == 3 num_key_value_heads = self.hf_config.num_key_value_heads @@ -96,5 +96,6 @@ def _build_config(self): moe_router_pre_softmax=False, qk_layernorm=True, # Qwen3 Next specific - use_gated_attention=True, + attention_output_gate=True, + moe_shared_expert_gate=True, ) diff --git a/miles_plugins/models/glm4.py b/miles_plugins/models/glm4.py index d3e920efd..ba42ea1a6 100644 --- a/miles_plugins/models/glm4.py +++ b/miles_plugins/models/glm4.py @@ -3,11 +3,11 @@ def get_glm_spec(args, config, vp_stage): transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( - args.num_experts, - args.moe_grouped_gemm, - args.qk_layernorm, - args.multi_latent_attention, - args.moe_use_legacy_grouped_gemm, + num_experts=args.num_experts, + moe_grouped_gemm=args.moe_grouped_gemm, + qk_layernorm=args.qk_layernorm, + multi_latent_attention=args.multi_latent_attention, + moe_use_legacy_grouped_gemm=args.moe_use_legacy_grouped_gemm, post_self_attn_layernorm=args.post_self_attn_layernorm, post_mlp_layernorm=args.post_mlp_layernorm, ) diff --git a/miles_plugins/models/hf_attention.py b/miles_plugins/models/hf_attention.py index 77c15074b..c353ae7b2 100644 --- a/miles_plugins/models/hf_attention.py +++ b/miles_plugins/models/hf_attention.py @@ -22,7 +22,7 @@ def __init__( config, layer_number: int, cp_comm_type: str = "p2p", - model_comm_pgs=None, + pg_collection=None, ): super().__init__(config=config) self.args = args @@ -43,6 +43,7 @@ def forward( rotary_pos_emb: torch.Tensor | tuple[torch.Tensor, torch.Tensor] | None = None, rotary_pos_cos: torch.Tensor | None = None, rotary_pos_sin: torch.Tensor | None = None, + rotary_pos_cos_sin: torch.Tensor | None = None, attention_bias: torch.Tensor | None = None, packed_seq_params: PackedSeqParams | None = None, sequence_len_offset: int | None = None, diff --git a/miles_plugins/models/qwen3_next.py b/miles_plugins/models/qwen3_next.py index 71e94b922..0a42e4d57 100644 --- a/miles_plugins/models/qwen3_next.py +++ b/miles_plugins/models/qwen3_next.py @@ -169,14 +169,14 @@ def __init__( config, layer_number: int, cp_comm_type: str = "p2p", - model_comm_pgs=None, + pg_collection=None, ): super().__init__( args, config, layer_number, cp_comm_type, - model_comm_pgs, + pg_collection, ) if Qwen3NextAttention is None: raise ImportError("Please install transformers>=4.35.0 to use Qwen3NextAttention.") @@ -223,5 +223,4 @@ def get_qwen3_next_spec(args, config, vp_stage): params={"args": args}, ) transformer_layer_spec.layer_specs[layer_id] = layer_specs - transformer_layer_spec.layer_specs[layer_id].submodules.mlp.submodules.shared_experts.params = {"gate": True} return transformer_layer_spec diff --git a/requirements.txt b/requirements.txt index 2c20195fc..dacd51132 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,6 +6,7 @@ mcp[cli] memray # needed for debugging (but is lightweight), we can put it to dev mode when using pyproject.toml omegaconf pillow +pybase64 pylatexenc pyyaml qwen_vl_utils # for VLM diff --git a/scripts/models/deepseek-v32-5layer.sh b/scripts/models/deepseek-v32-5layer.sh new file mode 100644 index 000000000..2466640af --- /dev/null +++ b/scripts/models/deepseek-v32-5layer.sh @@ -0,0 +1 @@ +MODEL_ARGS_NUM_LAYERS=5 source "$(dirname -- "${BASH_SOURCE[0]}")/deepseek-v32.sh" diff --git a/scripts/models/deepseek-v32.sh b/scripts/models/deepseek-v32.sh new file mode 100644 index 000000000..a98f2f561 --- /dev/null +++ b/scripts/models/deepseek-v32.sh @@ -0,0 +1,69 @@ +NLAYERS="${MODEL_ARGS_NUM_LAYERS:-61}" +FIRST_K_DENSE_REPLACE=3 + +arr=() +for ((i=0; i convert_model.py << EOF +import torch +import os +from transformers import AutoModelForCausalLM, AutoTokenizer, Mxfp4Config + +model_id = "openai/gpt-oss-20b" +output_dir = "/root/models/gpt-oss-20b-bf16" + +if os.path.exists(output_dir): + print(f"Model already exists at {output_dir}, skipping conversion.") +else: + print(f"Converting model from {model_id} to {output_dir}...") + + quantization_config = Mxfp4Config(dequantize=True) + model_kwargs = dict( + attn_implementation="eager", + torch_dtype=torch.bfloat16, + quantization_config=quantization_config, + use_cache=False, + device_map="auto", + ) + + model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs) + + # Patch config + model.config.attn_implementation = "eager" + + model.save_pretrained(output_dir) + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer.save_pretrained(output_dir) + print("Conversion done.") +EOF + +python3 convert_model.py + + +# will prevent ray from buffering stdout/stderr +export PYTHONBUFFERED=16 +export CUDA_VISIBLE_DEVICES=4,5,6,7 + +CKPT_ARGS=( + --hf-checkpoint /root/models/gpt-oss-20b-bf16 +) + +ROLLOUT_ARGS=( + --prompt-data /root/dapo-math-17k/dapo-math-17k.jsonl + --input-key prompt + --label-key label + --apply-chat-template + --rollout-shuffle + --rm-type deepscaler + --num-rollout 1000 + --rollout-batch-size 4 + --n-samples-per-prompt 4 + --rollout-max-response-len 2048 + --rollout-temperature 0.8 + + --global-batch-size 16 +) + +GRPO_ARGS=( + --advantage-estimator grpo + # --use-kl-loss + --kl-loss-coef 0.00 + --kl-loss-type low_var_kl + --kl-coef 0.00 + --entropy-coef 0.00 + --eps-clip 0.2 + --eps-clip-high 0.28 +) + +OPTIMIZER_ARGS=( + --optimizer adam + --lr 1e-6 + --lr-decay-style constant + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.98 +) + +SGLANG_ARGS=( + # Set equal to the number of GPUs per node for colocated mode + --rollout-num-gpus-per-engine 4 + --sglang-tensor-parallel-size 1 + --sglang-dtype bfloat16 + --sglang-decode-log-interval 1000 +) + + +WANDB_ARGS=( + --use-wandb + --wandb-project "miles-fsdp-gpt" + --wandb-group "20b-bf16" + --wandb-key ${WANDB_API_KEY} +) + +# launch the master node of ray in container +ray start --head --node-ip-address 127.0.0.1 --num-gpus 4 --disable-usage-stats + +ray job submit --address="http://127.0.0.1:8265" \ + --runtime-env-json='{ + "env_vars": { + "no_proxy": "localhost,127.0.0.1,0.0.0.0,${MASTER_ADDR}" + } + }' \ + -- python3 train.py \ + --actor-num-nodes 1 \ + --actor-num-gpus-per-node 4 \ + --colocate \ + --train-backend fsdp \ + --bf16 \ + --attn-implementation eager \ + ${CKPT_ARGS[@]} \ + ${ROLLOUT_ARGS[@]} \ + ${OPTIMIZER_ARGS[@]} \ + ${GRPO_ARGS[@]} \ + ${SGLANG_ARGS[@]} \ + ${WANDB_ARGS[@]} \ diff --git a/scripts/run-qwen3-235B-A22B-sft.sh b/scripts/run-qwen3-235B-A22B-sft.sh index 54598d27d..50b46c048 100644 --- a/scripts/run-qwen3-235B-A22B-sft.sh +++ b/scripts/run-qwen3-235B-A22B-sft.sh @@ -49,6 +49,7 @@ SFT_ARGS=( --rollout-function-path miles.rollout.sft_rollout.generate_rollout --prompt-data ${BASE_FOLDER}/openhermes2_5.parquet --input-key messages + # --apply-chat-template --rollout-shuffle --num-epoch 3 --rollout-batch-size 128 diff --git a/scripts/run-qwen3-4B-amd.sh b/scripts/run-qwen3-4B-amd.sh index 321a9712d..998f06b7f 100755 --- a/scripts/run-qwen3-4B-amd.sh +++ b/scripts/run-qwen3-4B-amd.sh @@ -15,13 +15,13 @@ set -euxo pipefail ### AMD Support ### -MILES_DIR="${MILES_DIR:-/home/yushensu/projects/miles}" # Default path if not set in environment +MILES_DIR="${MILES_DIR:-/root}" # Default path if not set in environment export MILES_DIR -MODEL_DIR="${MODEL_DIR:-/home/yushensu/projects/model}" # Default path if not set in environment +MODEL_DIR="${MODEL_DIR:-/root}" # Default path if not set in environment export MODEL_DIR -DATA_DIR="${DATA_DIR:-/home/yushensu/projects/data}" # Default path if not set in environment +DATA_DIR="${DATA_DIR:-/root}" # Default path if not set in environment export DATA_DIR # For AMD GPU @@ -139,20 +139,22 @@ NUM_GPUS=$(echo ${HIP_VISIBLE_DEVICES} | tr ',' '\n' | wc -l) ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus ${NUM_GPUS} --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 -# "PYTHONPATH": "/workspace/Megatron-LM/", -MEGATRON_LM_PATH=$(pip list | grep megatron-core | awk '{print $NF}') +# Dynamically detect Megatron-LM installation path +MEGATRON_LM_PATH=$(python3 -c "import megatron; import os; print(os.path.dirname(os.path.dirname(megatron.__file__)))" 2>/dev/null || echo "/app/Megatron-LM") ray job submit --address="http://127.0.0.1:8265" \ - --runtime-env-json='{ - "env_vars": { - "PYTHONPATH": "/workspace/Megatron-LM/", - "CUDA_DEVICE_MAX_CONNECTIONS": "1" + --runtime-env-json="{ + \"env_vars\": { + \"PYTHONPATH\": \"${MEGATRON_LM_PATH}/\", + \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\" } - }' \ + }" \ -- python3 train.py \ --actor-num-nodes 1 \ --actor-num-gpus-per-node 8 \ --colocate \ + --no-offload-train \ + --no-offload-rollout \ ${MODEL_ARGS[@]} \ ${CKPT_ARGS[@]} \ ${ROLLOUT_ARGS[@]} \ diff --git a/scripts/run-qwen3-4B-base-sft.sh b/scripts/run-qwen3-4B-base-sft.sh index 64cd055bd..6086313e0 100644 --- a/scripts/run-qwen3-4B-base-sft.sh +++ b/scripts/run-qwen3-4B-base-sft.sh @@ -38,6 +38,7 @@ SFT_ARGS=( --rollout-function-path miles.rollout.sft_rollout.generate_rollout --prompt-data /root/openhermes2_5.parquet --input-key messages + # --apply-chat-template --rollout-shuffle --num-epoch 3 --rollout-batch-size 128 diff --git a/scripts/run-qwen3-4B-fsdp.sh b/scripts/run-qwen3-4B-fsdp.sh index 3c95442d5..9fa339178 100644 --- a/scripts/run-qwen3-4B-fsdp.sh +++ b/scripts/run-qwen3-4B-fsdp.sh @@ -75,12 +75,16 @@ OPTIMIZER_ARGS=( --adam-beta2 0.98 ) -WANDB_ARGS=( - --use-wandb - --wandb-project miles-dev-mcore-fsdp - --wandb-group qwen3-4B-fsdp-1130-ref - --wandb-key ${WANDB_API_KEY} -) +if [ -z "${WANDB_API_KEY}" ]; then + WANDB_ARGS=() +else + WANDB_ARGS=( + --use-wandb + --wandb-project miles-dev-mcore-fsdp + --wandb-group qwen3-4B-fsdp-1130-ref + --wandb-key "${WANDB_API_KEY}" + ) +fi SGLANG_ARGS=( --rollout-num-gpus-per-engine 1 @@ -128,15 +132,15 @@ RUNTIME_ENV_JSON="{ ray job submit --address="http://127.0.0.1:8265" \ --runtime-env-json="${RUNTIME_ENV_JSON}" \ -- python3 train.py \ - ${CKPT_ARGS[@]} \ - ${ROLLOUT_ARGS[@]} \ - ${OPTIMIZER_ARGS[@]} \ - ${GRPO_ARGS[@]} \ - ${WANDB_ARGS[@]} \ - ${SGLANG_ARGS[@]} \ - ${TRAIN_BACKEND_ARGS[@]} \ - ${PERF_ARGS[@]} \ - ${MISC_ARGS[@]} + "${CKPT_ARGS[@]}" \ + "${ROLLOUT_ARGS[@]}" \ + "${OPTIMIZER_ARGS[@]}" \ + "${GRPO_ARGS[@]}" \ + "${WANDB_ARGS[@]}" \ + "${SGLANG_ARGS[@]}" \ + "${TRAIN_BACKEND_ARGS[@]}" \ + "${PERF_ARGS[@]}" \ + "${MISC_ARGS[@]}" diff --git a/scripts/run-qwen3-next-80B-A3B-8gpus.sh b/scripts/run-qwen3-next-80B-A3B-8gpus.sh new file mode 100644 index 000000000..7e36e1944 --- /dev/null +++ b/scripts/run-qwen3-next-80B-A3B-8gpus.sh @@ -0,0 +1,192 @@ +#!/bin/bash + +# for rerun the task +pkill -9 sglang +sleep 3 +ray stop --force +pkill -9 ray +pkill -9 python +sleep 3 +pkill -9 ray +pkill -9 python + +set -ex + +# if base folder not set raise error +if [ -z "${BASE_FOLDER}" ]; then + echo "BASE_FOLDER is not set. Please set it to the base directory of your checkpoints." + exit 1 +fi + +if [ -z "${MASTER_ADDR}" ]; then + echo "MASTER_ADDR is not set. Please set it to the master node address." + exit 1 +fi + +# will prevent ray from buffering stdout/stderr +export PYTHONBUFFERED=16 + +NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l) +if [ "$NVLINK_COUNT" -gt 0 ]; then + HAS_NVLINK=1 +else + HAS_NVLINK=0 +fi +echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" +source "${SCRIPT_DIR}/models/qwen3-next-80B-A3B.sh" + +CKPT_ARGS=( + --hf-checkpoint ${BASE_FOLDER}/Qwen3-Next-80B-A3B-Thinking + --ref-load ${BASE_FOLDER}/Qwen3-Next-80B-A3B-Thinking_torch_dist + --load ${BASE_FOLDER}/Qwen3-Next-80B-A3B-Thinking_miles/ + --save ${BASE_FOLDER}/Qwen3-Next-80B-A3B-Thinking_miles/ + --save-interval 20 +) + +ROLLOUT_ARGS=( + --prompt-data ${BASE_FOLDER}/dapo-math-17k/dapo-math-17k.jsonl + --input-key prompt + --label-key label + --apply-chat-template + --rollout-shuffle + --rm-type deepscaler + --num-rollout 300 + --rollout-batch-size 16 + --n-samples-per-prompt 4 + --rollout-max-response-len 8192 + --rollout-temperature 0.8 + + --global-batch-size 64 + --balance-data +) + +EVAL_ARGS=( + --eval-interval 20 + --eval-prompt-data aime ${BASE_FOLDER}/aime-2024/aime-2024.jsonl + --n-samples-per-eval-prompt 2 + --eval-max-response-len 16384 + --eval-top-p 0.7 +) + +PERF_ARGS=( + --tensor-model-parallel-size 1 + --sequence-parallel + --pipeline-model-parallel-size 6 + --context-parallel-size 1 + --expert-model-parallel-size 1 + --expert-tensor-parallel-size 1 + + --recompute-granularity full + --recompute-method uniform + --recompute-num-layers 1 + + # --micro-batch-size 1 + --use-dynamic-batch-size + --max-tokens-per-gpu 2048 +) + +GRPO_ARGS=( + --advantage-estimator gspo + #--use-kl-loss + --kl-loss-coef 0.00 + --kl-loss-type low_var_kl + --kl-coef 0.00 + --entropy-coef 0.00 + --eps-clip 4e-4 +) + +OPTIMIZER_ARGS=( + --optimizer adam + --lr 1e-6 + --lr-decay-style constant + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.98 + + --optimizer-cpu-offload + --overlap-cpu-optimizer-d2h-h2d + --use-precision-aware-optimizer +) + +WANDB_ARGS=( +# --use-wandb +# --wandb-project miles-dev +# --wandb-group qwen3-next-80B-A3B-test +# --wandb-key ${WANDB_KEY} +) + +SGLANG_ARGS=( + --rollout-num-gpus-per-engine 2 + --rollout-num-gpus 2 + --sglang-mem-fraction-static 0.8 + --sglang-ep-size 1 + + --sglang-cuda-graph-bs 1 2 4 8 $(seq 16 8 128) + + # mtp +# --sglang-speculative-algorithm EAGLE +# --sglang-speculative-num-steps 2 +# --sglang-speculative-eagle-topk 1 +# --sglang-speculative-num-draft-tokens 3 +# --sglang-enable-draft-weights-cpu-backup +# +# --sglang-max-running-requests 512 +) + +MISC_ARGS=( + # default dropout in megatron is 0.1 + --attention-dropout 0.0 + --hidden-dropout 0.0 + # should be good for model performance + --accumulate-allreduce-grads-in-fp32 +# --grad-reduce-in-bf16 + --attention-softmax-in-fp32 + # need to comment this when using model with MLA + --attention-backend flash + + --moe-token-dispatcher-type alltoall +# --moe-enable-deepep +# --debug-rollout-only +) + +# launch the master node of ray in container +export no_proxy="127.0.0.1,${MASTER_ADDR}" +ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 +for WORKER_IP in $(awk '{print $1}' /root/mpi_rack_hostfile); do + if [[ "$WORKER_IP" == "$MLP_WORKER_0_HOST" ]]; then + continue + fi + echo "Starting Ray worker on ${WORKER_IP}" + ssh root@"${WORKER_IP}" \ + "pkill -9 sglang ; ray stop --force ; pkill -9 python ; ray start --address=${MASTER_ADDR}:6379 --num-gpus 8 --node-ip-address ${WORKER_IP} --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265" & +done +wait + +# Build the runtime environment JSON with proper variable substitution +RUNTIME_ENV_JSON="{ + \"env_vars\": { + \"PYTHONPATH\": \"/root/Megatron-LM/\", + \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", + \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\", + \"no_proxy\": \"${no_proxy}\", + \"MASTER_ADDR\": \"${MASTER_ADDR}\" + } +}" + +ray job submit --address="http://127.0.0.1:8265" \ + --runtime-env-json="${RUNTIME_ENV_JSON}" \ + -- python3 train.py \ + --actor-num-nodes 1 \ + --actor-num-gpus-per-node 6 \ + ${MODEL_ARGS[@]} \ + ${CKPT_ARGS[@]} \ + ${ROLLOUT_ARGS[@]} \ + ${OPTIMIZER_ARGS[@]} \ + ${GRPO_ARGS[@]} \ + ${WANDB_ARGS[@]} \ + ${PERF_ARGS[@]} \ + ${EVAL_ARGS[@]} \ + ${SGLANG_ARGS[@]} \ + ${MISC_ARGS[@]} diff --git a/scripts/run-qwen3-next-80B-A3B-fsdp.sh b/scripts/run-qwen3-next-80B-A3B-fsdp.sh new file mode 100644 index 000000000..786db35c0 --- /dev/null +++ b/scripts/run-qwen3-next-80B-A3B-fsdp.sh @@ -0,0 +1,181 @@ +#!/bin/bash + +# for rerun the task +pkill -9 sglang +sleep 3 +ray stop --force +pkill -9 ray +pkill -9 python +sleep 3 +pkill -9 ray +pkill -9 python + +set -ex + +# if base folder not set raise error +if [ -z "${BASE_FOLDER}" ]; then + echo "BASE_FOLDER is not set. Please set it to the base directory of your checkpoints." + exit 1 +fi + +if [ -z "${MASTER_ADDR}" ]; then + echo "MASTER_ADDR is not set. Please set it to the master node address." + exit 1 +fi + +# will prevent ray from buffering stdout/stderr +export PYTHONBUFFERED=16 + +NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l) +if [ "$NVLINK_COUNT" -gt 0 ]; then + HAS_NVLINK=1 +else + HAS_NVLINK=0 +fi +echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" + +CKPT_ARGS=( + --hf-checkpoint ${BASE_FOLDER}/Qwen3-Next-80B-A3B-Thinking +# --ref-load ${BASE_FOLDER}/Qwen3-Next-80B-A3B-Thinking + --load ${BASE_FOLDER}/Qwen3-Next-80B-A3B-Thinking_miles/ + --save ${BASE_FOLDER}/Qwen3-Next-80B-A3B-Thinking_miles/ + --save-interval 20 +) + + +ROLLOUT_ARGS=( + --prompt-data ${BASE_FOLDER}/dapo-math-17k/dapo-math-17k.jsonl + --input-key prompt + --label-key label + --apply-chat-template + --rollout-shuffle + --rm-type deepscaler + --num-rollout 300 + --rollout-batch-size 4 + --n-samples-per-prompt 3 + --rollout-max-response-len 8192 + --rollout-temperature 0.8 + + --global-batch-size 12 +# --balance-data +) + +EVAL_ARGS=( + --eval-interval 10 + --eval-prompt-data aime ${BASE_FOLDER}/aime-2024/aime-2024.jsonl + --n-samples-per-eval-prompt 1 + --eval-max-response-len 16384 + --eval-top-p 0.7 +) + +PERF_ARGS=( + --micro-batch-size 1 +# --use-dynamic-batch-size + --max-tokens-per-gpu 1 +) + +GRPO_ARGS=( + --advantage-estimator gspo + #--use-kl-loss + --kl-loss-coef 0.00 + --kl-loss-type low_var_kl + --kl-coef 0.00 + --entropy-coef 0.00 + --eps-clip 4e-4 +) + +OPTIMIZER_ARGS=( + --optimizer adam + --lr 1e-6 + --lr-decay-style constant + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.98 +) + +WANDB_ARGS=( +# --use-wandb +# --wandb-project miles-dev +# --wandb-group qwen3-next-80B-A3B-test +# --wandb-key ${WANDB_KEY} +) + +SGLANG_ARGS=( + --rollout-num-gpus-per-engine 2 + --rollout-num-gpus 2 + --sglang-mem-fraction-static 0.8 + --sglang-ep-size 1 + + --sglang-cuda-graph-bs 1 2 4 8 $(seq 16 8 128) + + # mtp +# --sglang-speculative-algorithm EAGLE +# --sglang-speculative-num-steps 2 +# --sglang-speculative-eagle-topk 1 +# --sglang-speculative-num-draft-tokens 3 +# --sglang-enable-draft-weights-cpu-backup +# +# --sglang-max-running-requests 512 +) + +TRAIN_BACKEND_ARGS=( + --train-backend fsdp +# --update-weight-buffer-size 536870912 + --gradient-checkpointing +# --fp16 +# --attn-implementation flash_attention_3 + --train-env-vars '{"PYTORCH_CUDA_ALLOC_CONF":"expandable_segments:True"}' +) + +MISC_ARGS=( + # default dropout in megatron is 0.1 +# --accumulate-allreduce-grads-in-fp32 +# --attention-softmax-in-fp32 + # need to comment this when using model with MLA + +# --moe-enable-deepep +# --debug-train-only +) + +# launch the master node of ray in container +export no_proxy="127.0.0.1,${MASTER_ADDR}" +ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 +for WORKER_IP in $(awk '{print $1}' /root/mpi_rack_hostfile); do + if [[ "$WORKER_IP" == "$MLP_WORKER_0_HOST" ]]; then + continue + fi + echo "Starting Ray worker on ${WORKER_IP}" + ssh root@"${WORKER_IP}" \ + "pkill -9 sglang ; ray stop --force ; pkill -9 python ; ray start --address=${MASTER_ADDR}:6379 --num-gpus 8 --node-ip-address ${WORKER_IP} --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265" & +done +wait + +# Build the runtime environment JSON with proper variable substitution +RUNTIME_ENV_JSON="{ + \"env_vars\": { + \"PYTHONPATH\": \"/root/Megatron-LM/:${SCRIPT_DIR}\", + \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", + \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\", + \"no_proxy\": \"${no_proxy}\", + \"MASTER_ADDR\": \"${MASTER_ADDR}\" + } +}" + +ray job submit --address="http://127.0.0.1:8265" \ + --runtime-env-json="${RUNTIME_ENV_JSON}" \ + -- python3 train.py \ + --actor-num-nodes 1 \ + --actor-num-gpus-per-node 6 \ + ${MODEL_ARGS[@]} \ + ${CKPT_ARGS[@]} \ + ${ROLLOUT_ARGS[@]} \ + ${OPTIMIZER_ARGS[@]} \ + ${TRAIN_BACKEND_ARGS[@]} \ + ${GRPO_ARGS[@]} \ + ${WANDB_ARGS[@]} \ + ${PERF_ARGS[@]} \ + ${EVAL_ARGS[@]} \ + ${SGLANG_ARGS[@]} \ + ${MISC_ARGS[@]} diff --git a/scripts/run_deepseek_v32.py b/scripts/run_deepseek_v32.py new file mode 100644 index 000000000..1065ebbd8 --- /dev/null +++ b/scripts/run_deepseek_v32.py @@ -0,0 +1,317 @@ +""" +This file is in preview, and will be further refined and optimized. +""" + +import re +from dataclasses import dataclass +from typing import Literal +import typer + +import miles.utils.external_utils.command_utils as U + +app = typer.Typer() + + +@dataclass +class ScriptArgs(U.ExecuteTrainConfig): + mode: Literal["normal", "debug_minimal"] = "debug_minimal" + run_id: str = U.create_run_id() + model_org: str = "deepseek-ai" + model_name: Literal["DeepSeek-V3.2", "DeepSeek-V3.2-5layer"] = "DeepSeek-V3.2" + megatron_model_type: Literal["deepseek-v32", "deepseek-v32-5layer"] = "deepseek-v32" + num_gpus_per_node: int = 4 + enable_eval: bool = True + extra_args: str = "" + task: Literal["dapo_aime", "gsm8k"] = "dapo_aime" + enable_deepep: bool = True + data_dir: str = "/root" + model_dir: str = "/root/models" + model_local_dir: str = "/root/models" + save_dir: str = "/root/models" + megatron_path: str = "/root/Megatron-LM" + + +@app.command() +@U.dataclass_cli +def prepare_single(args: ScriptArgs): + """This script only needs to be executed on one node.""" + match args.task: + case "dapo_aime": + U.hf_download_dataset("zhuzilin/dapo-math-17k", data_dir=args.data_dir) + U.hf_download_dataset("zhuzilin/aime-2024", data_dir=args.data_dir) + case "gsm8k": + U.hf_download_dataset("zhuzilin/gsm8k", data_dir=args.data_dir) + + U.fp8_cast_bf16( + path_src=f"{args.model_dir}/{args.model_name}", + path_dst=f"{args.model_dir}/{args.model_name}-bf16/", + ) + + +@app.command() +@U.dataclass_cli +def prepare_spmd(args: ScriptArgs): + # TODO unify 5layer w/ 20layer, also maybe unify the whole script + extra_args = "--tensor-model-parallel-size 1 " "--expert-tensor-parallel-size 1 " + if args.num_nodes == 1 and args.model_name == "DeepSeek-V3.2-5layer": + extra_args += "--pipeline-model-parallel-size 1 " "--expert-model-parallel-size 1 " + else: + extra_args += ( + "--pipeline-model-parallel-size 8 " + "--expert-model-parallel-size 4 " + "--decoder-first-pipeline-num-layers 7 " + "--decoder-last-pipeline-num-layers 6 " + ) + + U.convert_checkpoint( + model_name=args.model_name, + hf_checkpoint=f"{args.model_dir}/{args.model_name}-bf16", + megatron_model_type=args.megatron_model_type, + num_gpus_per_node=args.num_gpus_per_node, + multinode=True if args.num_nodes > 1 else False, + extra_args=extra_args, + dir_dst=f"{args.model_dir}", + megatron_path=args.megatron_path, + ) + + +@app.command() +@U.dataclass_cli +def prepare_cp(args: ScriptArgs): + _prepare_cp(args) + + +def _prepare_cp(args: ScriptArgs): + U.rsync_simple( + path_src=f"{args.model_dir}/{args.model_name}_torch_dist", + path_dst=f"{args.model_local_dir}/{args.model_name}_torch_dist", + ) + U.rsync_simple( + path_src=f"{args.model_dir}/{args.model_name}", + path_dst=f"{args.model_local_dir}/{args.model_name}", + ) + + +@app.command() +@U.dataclass_cli +def train(args: ScriptArgs): + print("running on {args.num_nodes} nodes") + # ensure files are there is it was not synced before + # _prepare_cp(args) + + load_save_path = f"{args.save_dir}/{args.run_id}/checkpoints" + ckpt_args = ( + f"--hf-checkpoint {args.model_local_dir}/{args.model_name} " + f"--ref-load {args.model_local_dir}/{args.model_name}_torch_dist " + f"--load {load_save_path} " + f"--save {load_save_path} " + "--save-interval 20 " + "--save-retain-interval 20 " + ) + + rollout_args = ( + "--label-key label " + "--apply-chat-template " + "--rollout-shuffle " + "--rm-type math " + "--num-rollout 3000 " + "--rollout-batch-size 1 " + "--n-samples-per-prompt 1 " + "--rollout-temperature 0.8 " + # ------------ + "--num-steps-per-rollout 1 " + "--balance-data " + ) + + if args.mode != "debug_minimal": + rollout_args += ( + "--over-sampling-batch-size 256 " + "--dynamic-sampling-filter-path miles.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std " + ) + + # sometimes disable eval to speed up debugging + eval_args = "" + if (args.mode != "debug_minimal") and args.enable_eval: + eval_args += "--eval-interval 20 " "--eval-top-p 0.7 " + + match args.task: + case "dapo_aime": + rollout_args += ( + f"--prompt-data {args.data_dir}/dapo-math-17k/dapo-math-17k.jsonl " + "--input-key prompt " + f"--rollout-max-response-len {100 if args.mode == 'debug_minimal' else 8192} " + ) + eval_args += ( + f"--eval-prompt-data aime {args.data_dir}/aime-2024/aime-2024.jsonl " + "--n-samples-per-eval-prompt 8 " + "--eval-max-response-len 8192 " + ) + case "gsm8k": + rollout_args += ( + f"--prompt-data {args.data_dir}/gsm8k/train.parquet " + "--input-key messages " + # Deliberately make it very short for this easy task + "--rollout-max-response-len 256 " + ) + eval_args += ( + f"--eval-prompt-data gsm8k {args.data_dir}/gsm8k/test.parquet " + "--n-samples-per-eval-prompt 1 " + "--eval-max-response-len 256 " + ) + + if args.num_nodes <= 2: + perf_args = ( + "--tensor-model-parallel-size 2 " + "--sequence-parallel " + "--pipeline-model-parallel-size 1 " + "--context-parallel-size 2 " + "--expert-model-parallel-size 4 " + "--expert-tensor-parallel-size 1 " + ) + elif args.num_nodes <= 4: + # TODO remove this temp cfg + perf_args = ( + "--tensor-model-parallel-size 4 " + "--sequence-parallel " + "--pipeline-model-parallel-size 1 " + "--context-parallel-size 4 " + "--expert-model-parallel-size 4 " + "--expert-tensor-parallel-size 1 " + ) + else: + # TODO choose a good config (currently randomly change to suit 64gpu) + perf_args = ( + "--tensor-model-parallel-size 8 " + "--sequence-parallel " + f"--pipeline-model-parallel-size {1 if args.model_name == 'DeepSeek-V3.2-5layer' else 4} " + "--context-parallel-size 2 " + "--expert-model-parallel-size 16 " + "--expert-tensor-parallel-size 1 " + ) + if re.search(r"(\d+)layer", args.model_name) is None: + perf_args += "--decoder-last-pipeline-num-layers 13 " + perf_args += ( + # ------------ + "--recompute-granularity full " + "--recompute-method uniform " + "--recompute-num-layers 1 " + # ------------ + # "--use-dynamic-batch-size " + "--micro-batch-size 1 " + # TODO temp use tiny value + "--max-tokens-per-gpu 2048 " + # "--max-tokens-per-gpu 16384 " + ) + + grpo_args = ( + "--advantage-estimator grpo " + # TODO run-deepseek-r1.sh enables use-kl-loss but w/ coef 0. can we just disable it like this? + # "--use-kl-loss " + "--kl-loss-coef 0.00 " + "--kl-loss-type low_var_kl " + "--entropy-coef 0.00 " + "--eps-clip 0.2 " + "--eps-clip-high 0.28 " + "--use-miles-router " + "--use-rollout-routing-replay " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + # ------------ + # "--optimizer-cpu-offload " + # "--overlap-cpu-optimizer-d2h-h2d " + # "--use-precision-aware-optimizer " + ) + + sglang_decode_max_bs = 256 + sglang_world_size = 4 if args.num_nodes <= 4 else 64 + sglang_attn_dp_size = 1 if args.num_nodes <= 4 else 8 + sglang_attn_tp_size = sglang_world_size // sglang_attn_dp_size + sglang_args = ( + f"--rollout-num-gpus-per-engine {sglang_world_size} " + "--sglang-mem-fraction-static 0.7 " + f"--sglang-tp-size {sglang_world_size} " + f"--sglang-ep-size {sglang_world_size} " + # dp attention + "--sglang-enable-dp-attention " + f"--sglang-dp-size {sglang_attn_dp_size} " + "--sglang-moe-dense-tp-size 1 " + "--sglang-enable-dp-lm-head " + # make every dp rank has 128 concurrency + "--sglang-server-concurrency 1024 " + f"--sglang-max-running-requests {sglang_world_size * sglang_decode_max_bs // sglang_attn_tp_size} " + f"--sglang-chunked-prefill-size {sglang_world_size * sglang_decode_max_bs} " + f"--sglang-cuda-graph-max-bs {sglang_decode_max_bs} " + "--sglang-disable-cuda-graph " + # For quick experiments + # """--sglang-json-model-override-args '{"num_hidden_layers": 5}' """ + ) + sglang_extra_env_vars = {} + + if args.enable_deepep: + sglang_args += ( + "--sglang-moe-a2a-backend deepep " + "--sglang-moe-runner-backend deep_gemm " + "--sglang-deepep-mode low_latency " + ) + sglang_extra_env_vars["SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK"] = f"{sglang_decode_max_bs}" + + misc_args = ( + # default dropout in megatron is 0.1 + "--attention-dropout 0.0 " + "--hidden-dropout 0.0 " + # should be good for model performance + "--accumulate-allreduce-grads-in-fp32 " + "--attention-softmax-in-fp32 " + # need to comment this when using model with MLA + # "--attention-backend flash " + f"--update-weight-buffer-size {4 * 1024 ** 3} " + # TODO maybe enable it + # use deepep for megatron + # "--moe-enable-deepep " + # "--moe-token-dispatcher-type flex " + # ------------ + f"--actor-num-nodes {args.num_nodes} " + f"--actor-num-gpus-per-node {args.num_gpus_per_node} " + f"--num-gpus-per-node {args.num_gpus_per_node} " + "--colocate " + "--use-fault-tolerance " + f"--dump-details /root/shared_data/{args.run_id}/dump_details " + "--disable-weights-backuper " + "--model-name deepseekv32 " # for mbridge load + "--train-memory-margin-bytes 1073741824 " + # "--check-weight-update-equal " + "--qkv-format bshd " + ) + + train_args = ( + f"{ckpt_args} " + f"{rollout_args} " + f"{optimizer_args} " + f"{grpo_args} " + f"{U.get_default_wandb_args(__file__, run_id=args.run_id)} " + f"{perf_args} " + f"{eval_args} " + f"{sglang_args} " + f"{misc_args} " + f"{args.extra_args} " + ) + + U.execute_train( + train_args=train_args, + config=args, + # TODO may get it from `config` + num_gpus_per_node=args.num_gpus_per_node, + megatron_model_type=args.megatron_model_type, + extra_env_vars={**sglang_extra_env_vars}, + megatron_path=args.megatron_path, + ) + +if __name__ == "__main__": + app() \ No newline at end of file diff --git a/scripts/run_qwen3_4b.py b/scripts/run_qwen3_4b.py index d1aa63301..393b8a12e 100644 --- a/scripts/run_qwen3_4b.py +++ b/scripts/run_qwen3_4b.py @@ -40,7 +40,7 @@ def __post_init__(self): def prepare(args: ScriptArgs): U.exec_command("mkdir -p /root/models /root/datasets") - U.exec_command(f"huggingface-cli download Qwen/{args.model_name} --local-dir /root/models/{args.model_name}") + U.exec_command(f"hf download Qwen/{args.model_name} --local-dir /root/models/{args.model_name}") U.hf_download_dataset("zhuzilin/dapo-math-17k") U.hf_download_dataset("zhuzilin/aime-2024") @@ -49,9 +49,7 @@ def prepare(args: ScriptArgs): U.hf_download_dataset("zyzshishui0627/IFBench") if args.rollout_fp8: - U.exec_command( - f"huggingface-cli download Qwen/{args.model_name}-FP8 --local-dir /root/models/{args.model_name}-FP8" - ) + U.exec_command(f"hf download Qwen/{args.model_name}-FP8 --local-dir /root/models/{args.model_name}-FP8") if (args.train_backend == "megatron") and not args.enable_megatron_bridge: U.convert_checkpoint( @@ -64,23 +62,21 @@ def prepare(args: ScriptArgs): def execute(args: ScriptArgs): load_save_path = f"/root/shared_data/{args.run_id}/checkpoints" + + ref_load_path = f"/root/models/{args.model_name}" + if args.train_backend == "megatron" and not args.enable_megatron_bridge: + ref_load_path = f"/root/models/{args.model_name}_torch_dist" + ckpt_args = ( f"--hf-checkpoint /root/models/{args.model_name}{'-FP8' if args.rollout_fp8 else ''} " f"--load {load_save_path} " + f"--ref-load {ref_load_path} " f"--save {load_save_path} " f"--save-interval {2 if args.mode == 'debug_minimal' else 20} " - f"--save-retain-interval {2 if args.mode == 'debug_minimal' else 20} " ) + if args.train_backend == "megatron": - ref_load_path = ( - f"/root/models/{args.model_name}/" - if args.enable_megatron_bridge - else f"/root/models/{args.model_name}_torch_dist" - ) - ckpt_args += ( - # FSDP does not support this - f"--ref-load {ref_load_path} " - ) + ckpt_args += f"--save-retain-interval {2 if args.mode == 'debug_minimal' else 20} " rollout_args = ( "--prompt-data /root/datasets/dapo-math-17k/dapo-math-17k.jsonl " diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/ci/gpu_lock_exec.py b/tests/ci/gpu_lock_exec.py index 9507e2e85..20379f76a 100644 --- a/tests/ci/gpu_lock_exec.py +++ b/tests/ci/gpu_lock_exec.py @@ -19,11 +19,14 @@ def main(): _execute_print_only(args) return - fd_locks = _try_acquire(args) + if args.count == 0 and not args.devices: + print("[gpu_lock_exec] Do not acquire GPU since count=0", flush=True) + else: + fd_locks = _try_acquire(args) - dev_list = ",".join(str(x.gpu_id) for x in fd_locks) - os.environ[args.target_env_name] = dev_list - print(f"[gpu_lock_exec] Acquired GPUs: {dev_list}", flush=True) + dev_list = ",".join(str(x.gpu_id) for x in fd_locks) + os.environ[args.target_env_name] = dev_list + print(f"[gpu_lock_exec] Acquired GPUs: {dev_list}", flush=True) _os_execvp(args) diff --git a/tests/e2e/.gitkeep b/tests/e2e/.gitkeep new file mode 100644 index 000000000..615f2b076 --- /dev/null +++ b/tests/e2e/.gitkeep @@ -0,0 +1 @@ +# TODO: may move e2e tests to this folder \ No newline at end of file diff --git a/tests/fast/__init__.py b/tests/fast/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/fast/conftest.py b/tests/fast/conftest.py new file mode 100644 index 000000000..4cb30e91f --- /dev/null +++ b/tests/fast/conftest.py @@ -0,0 +1,15 @@ +import os + +import pytest + +from tests.fast.fixtures.generation_fixtures import generation_env +from tests.fast.fixtures.rollout_fixtures import rollout_env + +_ = rollout_env, generation_env + + +@pytest.fixture(autouse=True) +def enable_experimental_rollout_refactor(): + os.environ["MILES_EXPERIMENTAL_ROLLOUT_REFACTOR"] = "1" + yield + os.environ.pop("MILES_EXPERIMENTAL_ROLLOUT_REFACTOR", None) diff --git a/tests/fast/fixtures/__init__.py b/tests/fast/fixtures/__init__.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/tests/fast/fixtures/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/fast/fixtures/generation_fixtures.py b/tests/fast/fixtures/generation_fixtures.py new file mode 100644 index 000000000..816371ee3 --- /dev/null +++ b/tests/fast/fixtures/generation_fixtures.py @@ -0,0 +1,274 @@ +""" +Fixtures to test custom-generate-function +""" + +from argparse import Namespace +from contextlib import contextmanager +from dataclasses import dataclass +from types import SimpleNamespace +from typing import Any +from unittest.mock import patch + +import pytest +import requests + +from miles.rollout.base_types import GenerateFnInput +from miles.rollout.inference_rollout.compatibility import load_generate_function +from miles.rollout.inference_rollout.inference_rollout_common import GenerateState +from miles.router.router import MilesRouter +from miles.utils.async_utils import run +from miles.utils.http_utils import find_available_port, init_http_client +from miles.utils.misc import SingletonMeta +from miles.utils.test_utils.mock_sglang_server import ProcessResult, ProcessResultMetaInfo, with_mock_server +from miles.utils.test_utils.uvicorn_thread_server import UvicornThreadServer +from miles.utils.types import Sample + +MODEL_NAME = "Qwen/Qwen3-0.6B" +RESPONSE_TEXT = "\\boxed{8}" +DEFAULT_SAMPLING_PARAMS = {"max_new_tokens": 64, "temperature": 0.7} + +VARIANT_TO_GENERATE_FN_PATH = { + "old_sglang_rollout": "miles.rollout.sglang_rollout.generate", + "single_turn": "miles.rollout.generate_hub.single_turn.generate", + "multi_turn_single_sample": "miles.rollout.generate_hub.multi_turn.generate", + "multi_turn_multi_samples": "miles.rollout.generate_hub.multi_turn.generate", + "agentic_tool_call_single_sample": "miles.rollout.generate_hub.agentic_tool_call.generate", + "agentic_tool_call_multi_samples": "miles.rollout.generate_hub.agentic_tool_call.generate", +} + + +def extra_argv_for_variant( + variant: str, + *, + custom_generate_function_path: str | None = None, + generate_max_turns: int = 16, + generate_tool_specs_path: str = "miles.utils.test_utils.mock_tools.SAMPLE_TOOLS", + generate_tool_call_parser: str = "qwen25", + generate_execute_tool_function_path: str = "miles.utils.test_utils.mock_tools.execute_tool_call", +) -> list[str]: + argv = [ + "--custom-generate-function-path", + custom_generate_function_path or VARIANT_TO_GENERATE_FN_PATH[variant], + ] + + if variant in ( + "multi_turn_single_sample", + "multi_turn_multi_samples", + "agentic_tool_call_single_sample", + "agentic_tool_call_multi_samples", + ): + argv += [ + "--generate-max-turns", + str(generate_max_turns), + "--generate-tool-specs-path", + generate_tool_specs_path, + "--generate-execute-tool-function-path", + generate_execute_tool_function_path, + ] + if variant in ("multi_turn_single_sample", "multi_turn_multi_samples"): + argv += ["--generate-tool-call-parser", generate_tool_call_parser] + if variant in ("multi_turn_multi_samples", "agentic_tool_call_multi_samples"): + argv.append("--generate-multi-samples") + + return argv + + +def listify(x): + return x if isinstance(x, list) else [x] + + +def make_sample( + *, + prompt: str | list[dict] = "What is 1+7?", + tokens: list[int] | None = None, + response: str = "", + response_length: int = 0, + status: Sample.Status = Sample.Status.PENDING, + multimodal_inputs: dict | None = None, +) -> Sample: + return Sample( + prompt=prompt, + tokens=tokens or [], + response=response, + response_length=response_length, + status=status, + multimodal_inputs=multimodal_inputs, + ) + + +@dataclass +class GenerateEnv: + args: Namespace + mock_server: Any + + +@dataclass +class GenerateResult: + sample: Sample | list[Sample] + requests: list[dict] + + +def run_generate( + env: GenerateEnv, + sample: Sample, + sampling_params: dict[str, Any] | None = None, + *, + variant: str = "single_turn", +) -> GenerateResult: + env.mock_server.request_log.clear() + result_sample = run( + _call_generate( + env.args, + sample, + sampling_params or DEFAULT_SAMPLING_PARAMS, + variant=variant, + ) + ) + return GenerateResult(sample=result_sample, requests=list(env.mock_server.request_log)) + + +async def _call_generate( + args: Namespace, + sample: Sample, + sampling_params: dict[str, Any], + *, + variant: str = "single_turn", +) -> Sample: + generate_fn = load_generate_function(VARIANT_TO_GENERATE_FN_PATH[variant]) + state = GenerateState(args) + input = GenerateFnInput(state=state, sample=sample, sampling_params=sampling_params.copy(), evaluation=False) + output = await generate_fn(input) + return output.samples + + +def make_args( + *, + variant: str, + router_port: int, + use_rollout_routing_replay: bool = False, + sglang_speculative_algorithm: str | None = None, + model_name: str = MODEL_NAME, + extra_argv: list[str] | None = None, + custom_generate_function_path: str | None = None, + generate_max_turns: int = 16, + generate_tool_specs_path: str = "miles.utils.test_utils.mock_tools.SAMPLE_TOOLS", + generate_tool_call_parser: str = "qwen25", + generate_execute_tool_function_path: str = "miles.utils.test_utils.mock_tools.execute_tool_call", + rollout_max_context_len: int | None = None, +) -> Namespace: + argv = [ + "pytest", + "--train-backend", + "fsdp", + "--rollout-batch-size", + "1", + "--num-rollout", + "1", + "--rollout-num-gpus", + "1", + "--rollout-num-gpus-per-engine", + "1", + "--hf-checkpoint", + model_name, + "--prompt-data", + "/dev/null", + "--rm-type", + "math", + "--sglang-router-ip", + "127.0.0.1", + "--sglang-router-port", + str(router_port), + "--rollout-max-response-len", + "16", + ] + if use_rollout_routing_replay: + argv.append("--use-rollout-routing-replay") + if sglang_speculative_algorithm: + argv.extend(["--sglang-speculative-algorithm", sglang_speculative_algorithm]) + if rollout_max_context_len is not None: + argv.extend(["--rollout-max-context-len", str(rollout_max_context_len)]) + + argv.extend( + extra_argv_for_variant( + variant, + custom_generate_function_path=custom_generate_function_path, + generate_max_turns=generate_max_turns, + generate_tool_specs_path=generate_tool_specs_path, + generate_tool_call_parser=generate_tool_call_parser, + generate_execute_tool_function_path=generate_execute_tool_function_path, + ) + ) + + if extra_argv: + argv.extend(extra_argv) + + from miles.utils.arguments import parse_args + + with patch("sys.argv", argv): + args = parse_args() + + init_http_client(args) + return args + + +@contextmanager +def with_miles_router(backend_url: str, model_name: str): + router_args = SimpleNamespace( + miles_router_max_connections=10, + miles_router_timeout=30, + miles_router_middleware_paths=[], + rollout_health_check_interval=60, + miles_router_health_check_failure_threshold=3, + hf_checkpoint=model_name, + ) + router = MilesRouter(router_args) + + port = find_available_port(31000) + server = UvicornThreadServer(router.app, host="127.0.0.1", port=port) + server.start() + + url = f"http://127.0.0.1:{port}" + requests.post(f"{url}/add_worker", json={"url": backend_url}) + + try: + yield port + finally: + server.stop() + + +@pytest.fixture +def generation_env(request, variant): + SingletonMeta.clear_all_instances() + params = getattr(request, "param", {}) + args_kwargs = params.get("args_kwargs", {}) + model_name = args_kwargs.get("model_name", MODEL_NAME) + custom_generate_function_path = VARIANT_TO_GENERATE_FN_PATH[variant] + + def process_fn(_): + x = params.get("process_fn_kwargs", {}) + return ProcessResult( + text=x.get("response_text", RESPONSE_TEXT), + finish_reason=x.get("finish_reason", "stop"), + cached_tokens=x.get("cached_tokens", 0), + meta_info=ProcessResultMetaInfo( + weight_version=x.get("weight_version"), + routed_experts=x.get("routed_experts"), + spec_accept_token_num=x.get("spec_accept_token_num"), + spec_draft_token_num=x.get("spec_draft_token_num"), + spec_verify_ct=x.get("spec_verify_ct"), + ), + ) + + with with_mock_server(model_name=model_name, process_fn=process_fn) as mock_server: + with with_miles_router(mock_server.url, model_name) as router_port: + other_args_kwargs = {k: v for k, v in args_kwargs.items() if k != "model_name"} + args = make_args( + variant=variant, + router_port=router_port, + model_name=model_name, + custom_generate_function_path=custom_generate_function_path, + **other_args_kwargs, + ) + yield GenerateEnv(args=args, mock_server=mock_server) + + SingletonMeta.clear_all_instances() diff --git a/tests/fast/fixtures/rollout_fixtures.py b/tests/fast/fixtures/rollout_fixtures.py new file mode 100644 index 000000000..44d8a50d7 --- /dev/null +++ b/tests/fast/fixtures/rollout_fixtures.py @@ -0,0 +1,127 @@ +""" +Fixtures to test rollout-function +""" + +import json +from argparse import Namespace +from collections.abc import Iterator +from contextlib import contextmanager +from dataclasses import dataclass +from pathlib import Path +from unittest.mock import patch + +import pytest +import requests + +from miles.rollout.data_source import DataSource, RolloutDataSourceWithBuffer +from miles.router.router import MilesRouter +from miles.utils.arguments import parse_args +from miles.utils.http_utils import find_available_port, init_http_client +from miles.utils.misc import SingletonMeta +from miles.utils.test_utils.mock_sglang_server import MockSGLangServer, with_mock_server +from miles.utils.test_utils.uvicorn_thread_server import UvicornThreadServer + + +@dataclass(frozen=True) +class RolloutEnvConfig: + extra_argv: list[str] | None = None + data_rows: list[dict] | None = None + latency: float = 0.0 + + +@dataclass(frozen=True) +class RolloutEnv: + args: Namespace + data_source: DataSource + mock_server: MockSGLangServer + + +def _build_args(*, data_path: str, router_port: int, extra_argv: list[str] | None = None) -> Namespace: + argv = [ + "pytest", + "--train-backend", + "fsdp", + "--rollout-batch-size", + "1", + "--n-samples-per-prompt", + "1", + "--num-rollout", + "1", + "--rollout-num-gpus", + "1", + "--rollout-num-gpus-per-engine", + "1", + "--hf-checkpoint", + "Qwen/Qwen3-0.6B", + "--prompt-data", + data_path, + "--input-key", + "input", + "--label-key", + "label", + "--rm-type", + "math", + "--eval-prompt-data", + "toy", + data_path, + "--use-miles-router", + "--sglang-router-ip", + "127.0.0.1", + "--sglang-router-port", + str(router_port), + "--rollout-max-response-len", + "16", + ] + (extra_argv or []) + with patch("sys.argv", argv): + args = parse_args() + args.miles_router_middleware_paths = [] + init_http_client(args) + return args + + +@contextmanager +def _with_miles_router(args: Namespace) -> Iterator[UvicornThreadServer]: + router = MilesRouter(args, verbose=False) + server = UvicornThreadServer(router.app, host=args.sglang_router_ip, port=args.sglang_router_port) + try: + server.start() + yield server + finally: + server.stop() + + +def _write_jsonl(path: str, rows: list[dict]) -> None: + Path(path).write_text("".join(json.dumps(row, ensure_ascii=False) + "\n" for row in rows), encoding="utf-8") + + +DEFAULT_DATA_ROWS = [{"input": "What is 1+7?", "label": "8"}] + + +@pytest.fixture +def rollout_env(tmp_path, request) -> RolloutEnv: + config = request.param + assert isinstance(config, RolloutEnvConfig) + + data_rows = config.data_rows or DEFAULT_DATA_ROWS + + data_path = str(tmp_path / "data.jsonl") + _write_jsonl(data_path, data_rows) + + router_port = find_available_port(20000) + args = _build_args(data_path=data_path, router_port=router_port, extra_argv=config.extra_argv) + + SingletonMeta.clear_all_instances() + + with with_mock_server(model_name=args.hf_checkpoint, latency=config.latency) as mock_server: + with _with_miles_router(args) as router_server: + r = requests.post( + f"{router_server.url}/add_worker", + params={"url": mock_server.url}, + timeout=5.0, + ) + r.raise_for_status() + + data_source = RolloutDataSourceWithBuffer(args) + yield RolloutEnv(args=args, data_source=data_source, mock_server=mock_server) + + SingletonMeta.clear_all_instances() diff --git a/tests/fast/rollout/__init__.py b/tests/fast/rollout/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/fast/rollout/generate_hub/__init__.py b/tests/fast/rollout/generate_hub/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/fast/rollout/generate_hub/test_multi_turn.py b/tests/fast/rollout/generate_hub/test_multi_turn.py new file mode 100644 index 000000000..5d974aaad --- /dev/null +++ b/tests/fast/rollout/generate_hub/test_multi_turn.py @@ -0,0 +1,572 @@ +from copy import deepcopy +from dataclasses import dataclass, replace +from itertools import groupby + +import numpy as np +import pybase64 +import pytest +from tests.fast.fixtures.generation_fixtures import GenerateEnv, generation_env, listify, make_sample, run_generate +from transformers import AutoTokenizer + +from miles.utils.test_utils.mock_sglang_server import ProcessResult, ProcessResultMetaInfo +from miles.utils.test_utils.mock_tools import SAMPLE_TOOLS, ThreeTurnStub, TwoTurnStub +from miles.utils.types import Sample + +_ = generation_env, SAMPLE_TOOLS, TwoTurnStub, ThreeTurnStub + + +def is_agentic_variant(variant: str) -> bool: + return variant in ("agentic_tool_call_single_sample", "agentic_tool_call_multi_samples") + + +# ------------------------------------ fixtures and consts ---------------------------------------- + + +MODEL_NAME = "Qwen/Qwen3-0.6B" +DEFAULT_SAMPLING_PARAMS = {"max_new_tokens": 64, "temperature": 0.7} +TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) + + +@pytest.fixture( + params=[ + "multi_turn_single_sample", + "multi_turn_multi_samples", + "agentic_tool_call_single_sample", + "agentic_tool_call_multi_samples", + ] +) +def variant(request): + return request.param + + +@dataclass(frozen=True) +class SampleParsedChunk: + tokens_decoded_str: str + loss_mask_value: int + rollout_log_probs: list[float] + + +@dataclass +class ExpectedSampleInfo: + chunks: list[SampleParsedChunk] + partial_sample: Sample + + +def token_len(text: str) -> int: + return len(TOKENIZER(text, add_special_tokens=False)["input_ids"]) + + +def expected_chunk(text: str, loss_mask: int) -> SampleParsedChunk: + n = token_len(text) + log_probs = [-1 / 128 * i for i in range(n)] if loss_mask else [0.0] * n + return SampleParsedChunk(text, loss_mask, log_probs) + + +def parse_sample_into_chunks(sample: Sample, tokenizer) -> list[SampleParsedChunk]: + prompt_len = len(sample.tokens) - sample.response_length + response_tokens = sample.tokens[prompt_len:] + loss_mask = sample.loss_mask or [] + log_probs = sample.rollout_log_probs or [] + + chunks = [] + idx = 0 + for mask_val, group in groupby(loss_mask): + group_len = len(list(group)) + sli = slice(idx, idx + group_len) + chunks.append( + SampleParsedChunk( + tokens_decoded_str=tokenizer.decode(response_tokens[sli]), + loss_mask_value=mask_val, + rollout_log_probs=log_probs[sli], + ) + ) + idx += group_len + return chunks + + +def expected_partial_sample( + *, + prompt: list[dict], + response: str, + response_length: int, + status: Sample.Status = Sample.Status.COMPLETED, +) -> Sample: + return Sample( + prompt=prompt, + response=response, + response_length=response_length, + status=status, + tokens=[], + loss_mask=[], + rollout_log_probs=[], + weight_versions=[], + spec_info=Sample.SpecInfo(), + prefix_cache_info=Sample.PrefixCacheInfo(), + ) + + +def verify_samples(actual: Sample | list[Sample], expected: list[ExpectedSampleInfo]): + actual = listify(actual) + assert len(actual) == len(expected) + + for actual_item, expected_item in zip(actual, expected, strict=True): + actual_chunks = parse_sample_into_chunks(actual_item, TOKENIZER) + assert actual_chunks == expected_item.chunks + + actual_partial = replace( + deepcopy(actual_item), + tokens=[], + loss_mask=[], + rollout_log_probs=[], + prefix_cache_info=Sample.PrefixCacheInfo(), + ) + assert actual_partial == expected_item.partial_sample + + +def _run_generate(variant: str, env: GenerateEnv, sample: Sample, sampling_params: dict | None = None): + return run_generate(env, sample, sampling_params, variant=variant) + + +def expected_request(input_ids: list[int], sampling_params: dict | None = None) -> dict: + return { + "input_ids": input_ids, + "sampling_params": sampling_params or DEFAULT_SAMPLING_PARAMS, + "return_logprob": True, + "return_routed_experts": False, + } + + +def expected_openai_request(messages: list[dict]) -> dict: + return {"messages": messages, "model": "default", "tools": SAMPLE_TOOLS} + + +SINGLE_TURN_PROMPT = [{"role": "user", "content": "What is 1+1?"}] +SINGLE_TURN_RESPONSE = "The answer is 2." +_SINGLE_TURN_PROMPT_TEXT = TOKENIZER.apply_chat_template( + SINGLE_TURN_PROMPT, tokenize=False, add_generation_prompt=True, tools=SAMPLE_TOOLS +) +SINGLE_TURN_PROMPT_TOKEN_IDS = TOKENIZER(_SINGLE_TURN_PROMPT_TEXT, add_special_tokens=False)["input_ids"] +SINGLE_TURN_PROMPT_TOKEN_LEN = len(SINGLE_TURN_PROMPT_TOKEN_IDS) + + +# ------------------------------------ tests ---------------------------------------- + + +class TestBasicMultiTurn: + def test_single_turn_no_tool_call(self, variant, generation_env): + generation_env.mock_server.process_fn = lambda _: ProcessResult( + text=SINGLE_TURN_RESPONSE, finish_reason="stop" + ) + + result = _run_generate(variant, generation_env, make_sample(prompt=SINGLE_TURN_PROMPT)) + + if is_agentic_variant(variant): + assert result.requests == [expected_openai_request(SINGLE_TURN_PROMPT)] + else: + assert result.requests == [expected_request(SINGLE_TURN_PROMPT_TOKEN_IDS)] + verify_samples( + result.sample, + [ + ExpectedSampleInfo( + chunks=[ + SampleParsedChunk( + tokens_decoded_str=SINGLE_TURN_RESPONSE, + loss_mask_value=1, + rollout_log_probs=[-1 / 128 * i for i in range(6)], + ), + ], + partial_sample=expected_partial_sample( + prompt=SINGLE_TURN_PROMPT, response=SINGLE_TURN_RESPONSE, response_length=6 + ), + ), + ], + ) + + def test_two_turns_with_tool_call(self, variant, generation_env): + generation_env.mock_server.process_fn = TwoTurnStub.process_fn + + S = TwoTurnStub + result = _run_generate(variant, generation_env, make_sample(prompt=S.PROMPT)) + + if is_agentic_variant(variant): + assert result.requests == [ + expected_openai_request(S.OPENAI_MESSAGES_FIRST_TURN), + expected_openai_request(S.OPENAI_MESSAGES_SECOND_TURN_FROM_CLIENT), + ] + else: + assert result.requests == [ + expected_request(S.FIRST_PROMPT_TOKEN_IDS), + expected_request(S.SECOND_PROMPT_TOKEN_IDS), + ] + if variant in ("multi_turn_single_sample", "agentic_tool_call_single_sample"): + full_response = S.FIRST_RESPONSE + S.FIRST_TOOL_RESPONSE + S.SECOND_RESPONSE + expected = [ + ExpectedSampleInfo( + chunks=[ + expected_chunk(S.FIRST_RESPONSE, 1), + expected_chunk(S.FIRST_TOOL_RESPONSE, 0), + expected_chunk(S.SECOND_RESPONSE, 1), + ], + partial_sample=expected_partial_sample( + prompt=S.PROMPT, + response=full_response, + response_length=token_len(full_response), + ), + ), + ] + else: + expected = [ + ExpectedSampleInfo( + chunks=[expected_chunk(S.FIRST_RESPONSE, 1)], + partial_sample=expected_partial_sample( + prompt=S.PROMPT, + response=S.FIRST_RESPONSE, + response_length=token_len(S.FIRST_RESPONSE), + ), + ), + ExpectedSampleInfo( + chunks=[expected_chunk(S.SECOND_RESPONSE, 1)], + partial_sample=expected_partial_sample( + prompt=S.PROMPT, + response=S.SECOND_RESPONSE, + response_length=token_len(S.SECOND_RESPONSE), + ), + ), + ] + verify_samples(result.sample, expected) + + +class TestExitConditions: + def test_partial_rollout_not_supported(self, variant, generation_env): + if is_agentic_variant(variant): + pytest.skip("agentic_tool_call does not check partial_rollout flag") + generation_env.args.partial_rollout = True + + with pytest.raises(AssertionError, match="Partial rollout is not supported"): + _run_generate(variant, generation_env, make_sample(prompt=SINGLE_TURN_PROMPT)) + + def test_abort_preserves_content(self, variant, generation_env): + if is_agentic_variant(variant): + pytest.skip("agentic_tool_call does not handle abort finish_reason") + generation_env.mock_server.process_fn = lambda _: ProcessResult( + text=SINGLE_TURN_RESPONSE, finish_reason="abort" + ) + + result = _run_generate(variant, generation_env, make_sample(prompt=SINGLE_TURN_PROMPT)) + + assert result.requests == [expected_request(SINGLE_TURN_PROMPT_TOKEN_IDS)] + verify_samples( + result.sample, + [ + ExpectedSampleInfo( + chunks=[ + SampleParsedChunk( + tokens_decoded_str=SINGLE_TURN_RESPONSE, + loss_mask_value=1, + rollout_log_probs=[-1 / 128 * i for i in range(6)], + ), + ], + partial_sample=expected_partial_sample( + prompt=SINGLE_TURN_PROMPT, + response=SINGLE_TURN_RESPONSE, + response_length=6, + status=Sample.Status.ABORTED, + ), + ), + ], + ) + + def test_finish_reason_length_exits_and_preserves_content(self, variant, generation_env): + S = TwoTurnStub + generation_env.mock_server.process_fn = lambda _: ProcessResult(text=S.FIRST_RESPONSE, finish_reason="length") + + result = _run_generate(variant, generation_env, make_sample(prompt=S.PROMPT)) + + if is_agentic_variant(variant): + assert result.requests == [expected_openai_request(S.OPENAI_MESSAGES_FIRST_TURN)] + else: + assert result.requests == [expected_request(S.FIRST_PROMPT_TOKEN_IDS)] + verify_samples( + result.sample, + [ + ExpectedSampleInfo( + chunks=[expected_chunk(S.FIRST_RESPONSE, 1)], + partial_sample=expected_partial_sample( + prompt=S.PROMPT, + response=S.FIRST_RESPONSE, + response_length=token_len(S.FIRST_RESPONSE), + status=Sample.Status.TRUNCATED, + ), + ), + ], + ) + + @pytest.mark.parametrize("generation_env", [{"args_kwargs": {"generate_max_turns": 1}}], indirect=True) + def test_max_turns_reached(self, variant, generation_env): + S = TwoTurnStub + generation_env.mock_server.process_fn = lambda _: ProcessResult(text=S.FIRST_RESPONSE, finish_reason="stop") + + result = _run_generate(variant, generation_env, make_sample(prompt=S.PROMPT)) + + if is_agentic_variant(variant): + assert result.requests == [expected_openai_request(S.OPENAI_MESSAGES_FIRST_TURN)] + else: + assert result.requests == [expected_request(S.FIRST_PROMPT_TOKEN_IDS)] + if variant == "multi_turn_single_sample": + expected = [ + ExpectedSampleInfo( + chunks=[ + expected_chunk(S.FIRST_RESPONSE, 1), + expected_chunk(S.FIRST_TOOL_RESPONSE, 0), + ], + partial_sample=expected_partial_sample( + prompt=S.PROMPT, + response=S.FIRST_RESPONSE + S.FIRST_TOOL_RESPONSE, + response_length=token_len(S.FIRST_RESPONSE + S.FIRST_TOOL_RESPONSE), + ), + ), + ] + else: + expected = [ + ExpectedSampleInfo( + chunks=[expected_chunk(S.FIRST_RESPONSE, 1)], + partial_sample=expected_partial_sample( + prompt=S.PROMPT, + response=S.FIRST_RESPONSE, + response_length=token_len(S.FIRST_RESPONSE), + ), + ), + ] + verify_samples(result.sample, expected) + + +class TestRespectMaxContextLen: + @pytest.mark.parametrize( + "generation_env", [{"args_kwargs": {"rollout_max_context_len": SINGLE_TURN_PROMPT_TOKEN_LEN}}], indirect=True + ) + def test_prompt_exceeds_max_context_len_returns_truncated(self, variant, generation_env): + if is_agentic_variant(variant): + pytest.skip("TODO: implement") + result = _run_generate(variant, generation_env, make_sample(prompt=SINGLE_TURN_PROMPT)) + assert result.requests == [] + if variant == "multi_turn_single_sample": + expected = [ + ExpectedSampleInfo( + chunks=[], + partial_sample=expected_partial_sample( + prompt=SINGLE_TURN_PROMPT, response="", response_length=0, status=Sample.Status.TRUNCATED + ), + ) + ] + else: + expected = [] + verify_samples(result.sample, expected) + + @pytest.mark.parametrize( + "generation_env", + [ + { + "args_kwargs": { + "rollout_max_context_len": len(TwoTurnStub.FIRST_PROMPT_TOKEN_IDS) + + token_len(TwoTurnStub.FIRST_RESPONSE) + + token_len(TwoTurnStub.FIRST_TOOL_RESPONSE) + } + } + ], + indirect=True, + ) + def test_second_turn_exceeds_max_context_len_returns_truncated(self, variant, generation_env): + if is_agentic_variant(variant): + pytest.skip("TODO: implement") + S = TwoTurnStub + generation_env.mock_server.process_fn = S.process_fn + + result = _run_generate(variant, generation_env, make_sample(prompt=S.PROMPT)) + + assert result.requests == [expected_request(S.FIRST_PROMPT_TOKEN_IDS)] + if variant == "multi_turn_single_sample": + partial_response = S.FIRST_RESPONSE + S.FIRST_TOOL_RESPONSE + expected = [ + ExpectedSampleInfo( + chunks=[ + expected_chunk(S.FIRST_RESPONSE, 1), + expected_chunk(S.FIRST_TOOL_RESPONSE, 0), + ], + partial_sample=expected_partial_sample( + prompt=S.PROMPT, + response=partial_response, + response_length=token_len(partial_response), + status=Sample.Status.TRUNCATED, + ), + ), + ] + else: + expected = [ + ExpectedSampleInfo( + chunks=[expected_chunk(S.FIRST_RESPONSE, 1)], + partial_sample=expected_partial_sample( + prompt=S.PROMPT, + response=S.FIRST_RESPONSE, + response_length=token_len(S.FIRST_RESPONSE), + status=Sample.Status.TRUNCATED, + ), + ), + ] + verify_samples(result.sample, expected) + + @pytest.mark.parametrize( + "generation_env,expected_max_new_tokens", + [ + ( + {"args_kwargs": {"rollout_max_context_len": len(TwoTurnStub.SECOND_PROMPT_TOKEN_IDS) + 10}}, + 10, + ), + ( + {"args_kwargs": {"rollout_max_context_len": len(TwoTurnStub.SECOND_PROMPT_TOKEN_IDS) + 100}}, + 64, + ), + ], + indirect=["generation_env"], + ) + def test_second_turn_adjusts_max_new_tokens(self, variant, generation_env, expected_max_new_tokens): + if is_agentic_variant(variant): + pytest.skip("TODO: implement") + S = TwoTurnStub + generation_env.mock_server.process_fn = S.process_fn + + result = _run_generate(variant, generation_env, make_sample(prompt=S.PROMPT)) + + assert len(result.requests) >= 2 + assert result.requests[1]["sampling_params"]["max_new_tokens"] == expected_max_new_tokens + assert result.requests[1]["sampling_params"]["temperature"] == DEFAULT_SAMPLING_PARAMS["temperature"] + + +class TestThreeTurn: + """Need to test 3-turn case besides 2-turn, because e.g. merge_samples may behave differently.""" + + def test_three_turns_with_sequential_tool_calls(self, variant, generation_env): + generation_env.mock_server.process_fn = ThreeTurnStub.process_fn + + S = ThreeTurnStub + result = _run_generate(variant, generation_env, make_sample(prompt=S.PROMPT)) + + if is_agentic_variant(variant): + assert result.requests == [ + expected_openai_request(S.OPENAI_MESSAGES_FIRST_TURN), + expected_openai_request(S.OPENAI_MESSAGES_SECOND_TURN_FROM_CLIENT), + expected_openai_request(S.OPENAI_MESSAGES_THIRD_TURN_FROM_CLIENT), + ] + else: + assert result.requests == [ + expected_request(S.FIRST_PROMPT_TOKEN_IDS), + expected_request(S.SECOND_PROMPT_TOKEN_IDS), + expected_request(S.THIRD_PROMPT_TOKEN_IDS), + ] + if variant in ("multi_turn_single_sample", "agentic_tool_call_single_sample"): + full_response = ( + S.FIRST_RESPONSE + + S.FIRST_TOOL_RESPONSE + + S.SECOND_RESPONSE + + S.SECOND_TOOL_RESPONSE + + S.THIRD_RESPONSE + ) + expected = [ + ExpectedSampleInfo( + chunks=[ + expected_chunk(S.FIRST_RESPONSE, 1), + expected_chunk(S.FIRST_TOOL_RESPONSE, 0), + expected_chunk(S.SECOND_RESPONSE, 1), + expected_chunk(S.SECOND_TOOL_RESPONSE, 0), + expected_chunk(S.THIRD_RESPONSE, 1), + ], + partial_sample=expected_partial_sample( + prompt=S.PROMPT, + response=full_response, + response_length=token_len(full_response), + ), + ), + ] + else: + expected = [ + ExpectedSampleInfo( + chunks=[expected_chunk(S.FIRST_RESPONSE, 1)], + partial_sample=expected_partial_sample( + prompt=S.PROMPT, + response=S.FIRST_RESPONSE, + response_length=token_len(S.FIRST_RESPONSE), + ), + ), + ExpectedSampleInfo( + chunks=[expected_chunk(S.SECOND_RESPONSE, 1)], + partial_sample=expected_partial_sample( + prompt=S.PROMPT, + response=S.SECOND_RESPONSE, + response_length=token_len(S.SECOND_RESPONSE), + ), + ), + ExpectedSampleInfo( + chunks=[expected_chunk(S.THIRD_RESPONSE, 1)], + partial_sample=expected_partial_sample( + prompt=S.PROMPT, + response=S.THIRD_RESPONSE, + response_length=token_len(S.THIRD_RESPONSE), + ), + ), + ] + verify_samples(result.sample, expected) + + +class TestRoutedExpertsMultiTurn: + @pytest.mark.parametrize( + "generation_env", + [ + { + "args_kwargs": { + "use_rollout_routing_replay": True, + } + } + ], + indirect=True, + ) + def test_two_turns_routed_experts(self, variant, generation_env): + if is_agentic_variant(variant): + pytest.skip("TODO: implement") + + S = TwoTurnStub + num_layers, moe_router_topk = 2, 4 + generation_env.args.num_layers = num_layers + generation_env.args.moe_router_topk = moe_router_topk + + def make_routed_experts(prompt_token_ids, response_text): + total_tokens = len(prompt_token_ids) + token_len(response_text) + routed_experts_len = total_tokens - 1 + return np.arange(routed_experts_len * num_layers * moe_router_topk, dtype=np.int32).reshape( + routed_experts_len, num_layers, moe_router_topk + ) + + first_routed_experts = make_routed_experts(S.FIRST_PROMPT_TOKEN_IDS, S.FIRST_RESPONSE) + second_routed_experts = make_routed_experts(S.SECOND_PROMPT_TOKEN_IDS, S.SECOND_RESPONSE) + + def process_fn(prompt: str) -> ProcessResult: + if prompt == S.FIRST_PROMPT: + text, routed_experts = S.FIRST_RESPONSE, first_routed_experts + elif prompt == S.SECOND_PROMPT: + text, routed_experts = S.SECOND_RESPONSE, second_routed_experts + else: + raise ValueError(f"Unexpected prompt: {prompt}") + return ProcessResult( + text=text, + finish_reason="stop", + meta_info=ProcessResultMetaInfo( + routed_experts=pybase64.b64encode(routed_experts.tobytes()).decode("ascii") + ), + ) + + generation_env.mock_server.process_fn = process_fn + result = _run_generate(variant, generation_env, make_sample(prompt=S.PROMPT), DEFAULT_SAMPLING_PARAMS) + + sample = result.sample[-1] if isinstance(result.sample, list) else result.sample + assert sample.rollout_routed_experts is not None + assert sample.rollout_routed_experts.shape == second_routed_experts.shape + np.testing.assert_array_equal(sample.rollout_routed_experts, second_routed_experts) + assert len(sample.tokens) - 1 == second_routed_experts.shape[0] diff --git a/tests/fast/rollout/generate_hub/test_single_turn.py b/tests/fast/rollout/generate_hub/test_single_turn.py new file mode 100644 index 000000000..a58e6fb3c --- /dev/null +++ b/tests/fast/rollout/generate_hub/test_single_turn.py @@ -0,0 +1,424 @@ +import numpy as np +import pybase64 +import pytest +import torch +from PIL import Image +from tests.fast.fixtures.generation_fixtures import GenerateEnv, generation_env, listify, make_sample, run_generate +from transformers import AutoProcessor + +from miles.utils.processing_utils import encode_image_for_rollout_engine +from miles.utils.test_utils.mock_sglang_server import ProcessResult, ProcessResultMetaInfo +from miles.utils.types import Sample + +_ = generation_env + +# ------------------------------------ fixtures and consts ---------------------------------------- + + +MODEL_NAME = "Qwen/Qwen3-0.6B" +PROMPT = "What is 1+7?" +PROMPT_TOKENS = [3838, 374, 220, 16, 10, 22, 30] +PROMPT_TOKEN_LEN = len(PROMPT_TOKENS) +RESPONSE_TOKENS = [59, 79075, 90, 23, 92] +RESPONSE_TEXT = "\\boxed{8}" +RESPONSE_LOG_PROBS = [-0.0, -0.0078125, -0.015625, -0.0234375, -0.03125] +SAMPLING_PARAMS = {"max_new_tokens": 16, "temperature": 0.7} +DEFAULT_MAX_NEW_TOKENS = SAMPLING_PARAMS["max_new_tokens"] + + +@pytest.fixture(params=["old_sglang_rollout", "single_turn", "multi_turn_single_sample", "multi_turn_multi_samples"]) +def variant(request): + return request.param + + +def expected_request( + variant: str, + *, + input_ids: list[int] | None = None, + sampling_params: dict | None = None, + return_routed_experts: bool = False, + image_data: list[str] | None = None, +) -> dict: + result = { + "input_ids": input_ids or PROMPT_TOKENS, + "sampling_params": sampling_params or SAMPLING_PARAMS, + "return_logprob": True, + } + if variant in ("single_turn", "multi_turn_single_sample", "multi_turn_multi_samples") or return_routed_experts: + result["return_routed_experts"] = return_routed_experts + if image_data is not None: + result["image_data"] = image_data + return result + + +class _Unset: + pass + + +_UNSET = _Unset() + + +def expected_sample( + variant: str, + *, + prompt: str = PROMPT, + response: str = RESPONSE_TEXT, + response_length: int = 5, + tokens: list[int] | None | _Unset = _UNSET, + rollout_log_probs: list[float] | None | _Unset = _UNSET, + status: Sample.Status = Sample.Status.COMPLETED, + cached_tokens: int = 0, + prompt_tokens: int = 7, + weight_versions: list[str] | None = None, + rollout_routed_experts: np.ndarray | None = None, + spec_info: Sample.SpecInfo | None = None, + multimodal_inputs: dict | None = None, + multimodal_train_inputs: dict | None = None, + loss_mask: list[int] | None | _Unset = _UNSET, +) -> Sample: + actual_response_length = response_length if response_length is not None else len(RESPONSE_TOKENS) + if isinstance(loss_mask, _Unset): + loss_mask = ( + [1] * actual_response_length + if variant in ("multi_turn_single_sample", "multi_turn_multi_samples") + else None + ) + + return Sample( + group_index=None, + index=None, + prompt=prompt, + tokens=PROMPT_TOKENS + RESPONSE_TOKENS if isinstance(tokens, _Unset) else tokens, + multimodal_inputs=multimodal_inputs, + multimodal_train_inputs=multimodal_train_inputs, + response=response, + response_length=response_length, + label=None, + reward=None, + loss_mask=loss_mask, + weight_versions=weight_versions or [], + rollout_log_probs=RESPONSE_LOG_PROBS if isinstance(rollout_log_probs, _Unset) else rollout_log_probs, + rollout_routed_experts=rollout_routed_experts, + remove_sample=False, + status=status, + metadata={}, + train_metadata=None, + non_generation_time=0.0, + spec_info=spec_info or Sample.SpecInfo(), + prefix_cache_info=Sample.PrefixCacheInfo(cached_tokens=cached_tokens, total_prompt_tokens=prompt_tokens), + ) + + +def _make_sample(tokens=None, response="", response_length=0, status=Sample.Status.PENDING, multimodal_inputs=None): + return make_sample( + prompt=PROMPT, + tokens=tokens, + response=response, + response_length=response_length, + status=status, + multimodal_inputs=multimodal_inputs, + ) + + +def _run_generate(variant: str, env: GenerateEnv, sample: Sample | None = None, sampling_params: dict | None = None): + return run_generate(env, sample or _make_sample(), sampling_params or SAMPLING_PARAMS, variant=variant) + + +# ------------------------------------ tests ---------------------------------------- + + +class TestBasicGeneration: + def test_basic_generation(self, variant, generation_env): + result = _run_generate(variant, generation_env) + assert result.requests == [expected_request(variant)] + assert listify(result.sample) == [expected_sample(variant)] + + +class TestResumedSingleTurn: + def test_two_consecutive_calls_on_same_sample(self, variant, generation_env): + if variant in ("multi_turn_single_sample", "multi_turn_multi_samples"): + pytest.skip("not tested yet") + partial_text = "\\boxed" + partial_tokens = [59, 79075] + partial_log_probs = [-0.0, -0.0078125] + + remaining_text = "{8}" + remaining_tokens = [90, 23, 92] + remaining_log_probs = [-0.0, -0.0078125, -0.015625] + + generation_env.mock_server.process_fn = lambda _: ProcessResult(text=partial_text, finish_reason="abort") + sample = _make_sample() + result1 = _run_generate(variant, generation_env, sample) + assert result1.requests == [expected_request(variant)] + assert result1.sample == expected_sample( + variant, + response=partial_text, + response_length=2, + tokens=PROMPT_TOKENS + partial_tokens, + rollout_log_probs=partial_log_probs, + status=Sample.Status.ABORTED, + ) + + generation_env.mock_server.process_fn = lambda _: ProcessResult(text=remaining_text, finish_reason="stop") + result2 = _run_generate(variant, generation_env, result1.sample) + tokens_after_turn1 = PROMPT_TOKENS + partial_tokens + assert result2.requests == [ + expected_request( + variant, + input_ids=tokens_after_turn1, + sampling_params={"max_new_tokens": 14, "temperature": 0.7}, + ) + ] + assert result2.sample == expected_sample( + variant, + response=partial_text + remaining_text, + response_length=2 + 3, + tokens=tokens_after_turn1 + remaining_tokens, + rollout_log_probs=partial_log_probs + remaining_log_probs, + prompt_tokens=len(PROMPT_TOKENS) + len(tokens_after_turn1), + status=Sample.Status.COMPLETED, + ) + + +class TestFinishReason: + @pytest.mark.parametrize( + "generation_env,expected_status", + [ + ({"process_fn_kwargs": {"finish_reason": "stop"}}, Sample.Status.COMPLETED), + ({"process_fn_kwargs": {"finish_reason": "length"}}, Sample.Status.TRUNCATED), + ({"process_fn_kwargs": {"finish_reason": "abort"}}, Sample.Status.ABORTED), + ], + indirect=["generation_env"], + ) + def test_finish_reason_sets_status(self, variant, generation_env, expected_status): + result = _run_generate(variant, generation_env) + assert result.requests == [expected_request(variant)] + assert listify(result.sample) == [expected_sample(variant, status=expected_status)] + + +class TestRoutedExperts: + @pytest.mark.parametrize( + "generation_env", + [ + { + "args_kwargs": {"use_rollout_routing_replay": True}, + "process_fn_kwargs": {"routed_experts": "placeholder"}, + } + ], + indirect=True, + ) + def test_routed_experts_enabled_and_parsed(self, variant, generation_env): + num_layers, moe_router_topk = 2, 4 + num_tokens = len(PROMPT_TOKENS) + len(RESPONSE_TOKENS) + routed_experts_array = np.arange((num_tokens - 1) * num_layers * moe_router_topk, dtype=np.int32).reshape( + num_tokens - 1, num_layers, moe_router_topk + ) + + generation_env.args.num_layers = num_layers + generation_env.args.moe_router_topk = moe_router_topk + routed_experts_str = pybase64.b64encode(routed_experts_array.tobytes()).decode("ascii") + generation_env.mock_server.process_fn = lambda _: ProcessResult( + text=RESPONSE_TEXT, + finish_reason="stop", + meta_info=ProcessResultMetaInfo(routed_experts=routed_experts_str), + ) + + result = _run_generate(variant, generation_env) + assert result.requests == [expected_request(variant, return_routed_experts=True)] + sample = result.sample[0] if isinstance(result.sample, list) else result.sample + assert sample.rollout_routed_experts is not None + assert sample.rollout_routed_experts.shape == (num_tokens - 1, num_layers, moe_router_topk) + np.testing.assert_array_equal(sample.rollout_routed_experts, routed_experts_array) + + +class TestMetaInfo: + @pytest.mark.parametrize( + "generation_env", [{"process_fn_kwargs": {"cached_tokens": 3, "weight_version": "v1.0"}}], indirect=True + ) + def test_meta_info_fields_updated(self, variant, generation_env): + result = _run_generate(variant, generation_env) + assert result.requests == [expected_request(variant)] + assert listify(result.sample) == [expected_sample(variant, cached_tokens=3, weight_versions=["v1.0"])] + + @pytest.mark.parametrize( + "generation_env", + [ + { + "args_kwargs": {"sglang_speculative_algorithm": "EAGLE"}, + "process_fn_kwargs": {"spec_accept_token_num": 10, "spec_draft_token_num": 15, "spec_verify_ct": 3}, + } + ], + indirect=True, + ) + def test_spec_info_updated(self, variant, generation_env): + result = _run_generate(variant, generation_env) + assert result.requests == [expected_request(variant)] + assert listify(result.sample) == [ + expected_sample( + variant, + spec_info=Sample.SpecInfo( + spec_accept_token_num=10, spec_draft_token_num=15, spec_verify_ct=3, completion_token_num=5 + ), + ) + ] + + +class TestInputStatusValidation: + @pytest.mark.parametrize("status", [Sample.Status.PENDING, Sample.Status.ABORTED]) + def test_allowed_statuses(self, variant, generation_env, status): + result = _run_generate(variant, generation_env, _make_sample(status=status)) + assert result.requests == [expected_request(variant)] + assert listify(result.sample) == [expected_sample(variant)] + + @pytest.mark.parametrize("status", [Sample.Status.COMPLETED, Sample.Status.TRUNCATED]) + def test_rejected_statuses(self, variant, generation_env, status): + if variant in ("multi_turn_single_sample", "multi_turn_multi_samples"): + pytest.skip("not tested yet") + with pytest.raises(AssertionError): + _run_generate(variant, generation_env, _make_sample(status=status)) + + +class TestPayloadStructure: + def test_sampling_params_passed_through(self, variant, generation_env): + result = _run_generate( + variant, generation_env, sampling_params={"max_new_tokens": 16, "temperature": 0.5, "top_p": 0.9} + ) + assert result.requests == [ + expected_request(variant, sampling_params={"max_new_tokens": 16, "temperature": 0.5, "top_p": 0.9}) + ] + assert listify(result.sample) == [expected_sample(variant)] + + +class TestBoundaryConditions: + def test_max_new_tokens_zero_returns_truncated(self, variant, generation_env): + if variant in ("multi_turn_single_sample", "multi_turn_multi_samples"): + pytest.skip("not tested yet") + existing_tokens = [1, 2, 3, 4, 5, 6, 7] + list(range(100, 110)) + sample = _make_sample(tokens=existing_tokens, response="x" * 10, response_length=10) + + result = _run_generate(variant, generation_env, sample, {"max_new_tokens": 10, "temperature": 0.7}) + assert result.requests == [] + assert result.sample == expected_sample( + variant, + response="x" * 10, + response_length=10, + tokens=existing_tokens, + rollout_log_probs=None, + status=Sample.Status.TRUNCATED, + prompt_tokens=0, + ) + + @pytest.mark.parametrize("generation_env", [{"args_kwargs": {"rollout_max_context_len": 5}}], indirect=True) + def test_prompt_exceeds_max_context_len_returns_truncated(self, variant, generation_env): + if variant == "old_sglang_rollout": + pytest.skip("old_sglang_rollout does not support rollout_max_context_len") + if variant == "multi_turn_multi_samples": + pytest.skip("multi_turn_multi_samples returns empty list when first turn fails") + result = _run_generate(variant, generation_env) + assert result.requests == [] + tokens = PROMPT_TOKENS if variant in ("multi_turn_single_sample", "multi_turn_multi_samples") else [] + assert listify(result.sample) == [ + expected_sample( + variant, + response="", + response_length=0, + tokens=tokens, + rollout_log_probs=None, + status=Sample.Status.TRUNCATED, + prompt_tokens=0, + loss_mask=None if variant == "multi_turn_single_sample" else _UNSET, + ) + ] + + @pytest.mark.parametrize( + "generation_env,expected_max_new_tokens", + [ + ({"args_kwargs": {"rollout_max_context_len": 10}}, 10 - PROMPT_TOKEN_LEN), + ({"args_kwargs": {"rollout_max_context_len": 8}}, 8 - PROMPT_TOKEN_LEN), + ({"args_kwargs": {"rollout_max_context_len": 100}}, DEFAULT_MAX_NEW_TOKENS), + ], + indirect=["generation_env"], + ) + def test_moderate_length_input_adjusts_max_new_tokens(self, variant, generation_env, expected_max_new_tokens): + if variant == "old_sglang_rollout": + pytest.skip("old_sglang_rollout does not support rollout_max_context_len") + result = _run_generate(variant, generation_env) + assert len(result.requests) == 1 + assert result.requests[0]["sampling_params"]["max_new_tokens"] == expected_max_new_tokens + assert result.requests[0]["sampling_params"]["temperature"] == SAMPLING_PARAMS["temperature"] + assert listify(result.sample) == [expected_sample(variant)] + + @pytest.mark.parametrize( + "generation_env", + [{"args_kwargs": {"rollout_max_context_len": PROMPT_TOKEN_LEN}}], + indirect=True, + ) + def test_adjusted_max_new_tokens_zero_returns_truncated(self, variant, generation_env): + if variant == "old_sglang_rollout": + pytest.skip("old_sglang_rollout does not support rollout_max_context_len") + if variant == "multi_turn_multi_samples": + pytest.skip("multi_turn_multi_samples returns empty list when first turn fails") + result = _run_generate(variant, generation_env) + assert result.requests == [] + tokens = PROMPT_TOKENS if variant == "multi_turn_single_sample" else [] + assert listify(result.sample) == [ + expected_sample( + variant, + response="", + response_length=0, + tokens=tokens, + rollout_log_probs=None, + status=Sample.Status.TRUNCATED, + prompt_tokens=0, + loss_mask=None if variant == "multi_turn_single_sample" else _UNSET, + ) + ] + + +class TestEmptyResponse: + @pytest.mark.parametrize("generation_env", [{"process_fn_kwargs": {"response_text": ""}}], indirect=True) + def test_empty_response(self, variant, generation_env): + result = _run_generate(variant, generation_env) + assert result.requests == [expected_request(variant)] + assert listify(result.sample) == [ + expected_sample(variant, response="", response_length=0, tokens=PROMPT_TOKENS, rollout_log_probs=[]) + ] + + +VLM_MODEL_NAME = "Qwen/Qwen2-VL-2B-Instruct" + + +class TestMultimodal: + @pytest.mark.parametrize("generation_env", [{"args_kwargs": {"model_name": VLM_MODEL_NAME}}], indirect=True) + def test_multimodal_inputs_processed(self, variant, generation_env): + if variant in ("multi_turn_single_sample", "multi_turn_multi_samples"): + pytest.skip("not tested yet") + test_image = Image.new("RGB", (64, 64), color="red") + multimodal_inputs = {"images": [test_image]} + processor = AutoProcessor.from_pretrained(VLM_MODEL_NAME, trust_remote_code=True) + expected_mti = { + k: v + for k, v in processor(text=PROMPT, **multimodal_inputs).items() + if k not in ["input_ids", "attention_mask"] + } + + result = _run_generate(variant, generation_env, _make_sample(multimodal_inputs=multimodal_inputs)) + + assert result.requests == [ + expected_request( + variant, + input_ids=PROMPT_TOKENS, + image_data=[encode_image_for_rollout_engine(test_image)], + ) + ] + actual_mti = result.sample.multimodal_train_inputs + assert actual_mti is not None + assert set(actual_mti.keys()) == set(expected_mti.keys()) + assert torch.all(actual_mti["pixel_values"] == expected_mti["pixel_values"]) + assert torch.all(actual_mti["image_grid_thw"] == expected_mti["image_grid_thw"]) + assert result.sample == expected_sample( + variant, + tokens=PROMPT_TOKENS + RESPONSE_TOKENS, + multimodal_inputs=multimodal_inputs, + multimodal_train_inputs=actual_mti, + ) diff --git a/tests/fast/rollout/generate_hub/test_tool_call_utils.py b/tests/fast/rollout/generate_hub/test_tool_call_utils.py new file mode 100644 index 000000000..0f2305e75 --- /dev/null +++ b/tests/fast/rollout/generate_hub/test_tool_call_utils.py @@ -0,0 +1,99 @@ +import pytest + +from miles.rollout.generate_utils.tool_call_utils import _DUMMY_USER, _build_dummy_assistant, tokenize_tool_responses + +TOOL_CALL_TEST_MODELS = [ + "Qwen/Qwen2.5-0.5B-Instruct", + "Qwen/Qwen3-0.6B", + "Qwen/Qwen3-4B-Instruct-2507", + "Qwen/Qwen3-Coder-30B-A3B-Instruct", + # "meta-llama/Llama-3.2-1B-Instruct", # Skipped: gated repo, requires HF_TOKEN in CI + "mistralai/Mistral-7B-Instruct-v0.3", + "deepseek-ai/DeepSeek-V3", + "stepfun-ai/step3", + "MiniMaxAI/MiniMax-M2", + "internlm/internlm3-8b-instruct", + "THUDM/glm-4-9b-chat", + "moonshotai/Kimi-K2-Instruct", + "XiaomiMiMo/MiMo-7B-RL", +] + +SINGLE_TOOL_CALL_ONLY_MODELS = [ + # "meta-llama/Llama-3.2-1B-Instruct", # Skipped: gated repo +] + +# Models where tokenize->decode produces extra whitespace vs direct string diff +TOKENIZE_DECODE_WHITESPACE_DIFF_MODELS = [ + "THUDM/glm-4-9b-chat", +] + +SAMPLE_TOOL_RESPONSES = [ + { + "role": "tool", + "tool_call_id": "call00000", + "content": '{"year": 2026}', + "name": "get_year", + }, + { + "role": "tool", + "tool_call_id": "call00001", + "content": '{"temperature": 25}', + "name": "get_temperature", + }, +] + + +class TestTokenizeToolResponses: + @pytest.mark.parametrize("model_name", ["Qwen/Qwen3-0.6B"]) + def test_snapshot(self, model_name): + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + token_ids = tokenize_tool_responses(SAMPLE_TOOL_RESPONSES, tokenizer) + decoded = tokenizer.decode(token_ids) + + assert decoded == ( + "<|im_start|>user\n" + "\n" + '{"year": 2026}\n' + "\n" + "\n" + '{"temperature": 25}\n' + "<|im_end|>\n" + "<|im_start|>assistant\n" + ) + + @pytest.mark.parametrize("num_tools", [1, 2]) + @pytest.mark.parametrize("model_name", TOOL_CALL_TEST_MODELS) + def test_tokenize_tool_responses(self, model_name, num_tools): + if num_tools > 1 and model_name in SINGLE_TOOL_CALL_ONLY_MODELS: + pytest.skip(f"{model_name} only supports single tool call") + + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + + tool_responses = SAMPLE_TOOL_RESPONSES[:num_tools] + assert len(tool_responses) == num_tools + + actual_token_ids = tokenize_tool_responses(tool_responses, tokenizer) + actual_str = tokenizer.decode(actual_token_ids) + + dummy_assistant = _build_dummy_assistant(tool_responses) + base_messages = [_DUMMY_USER, dummy_assistant] + expected_str = self._compute_chat_template_diff(base_messages, tool_responses, tokenizer) + + if model_name in TOKENIZE_DECODE_WHITESPACE_DIFF_MODELS: + # Some models produce whitespace differences between tokenize->decode and direct string diff + actual_str = actual_str.replace(" ", "") + expected_str = expected_str.replace(" ", "") + + assert actual_str == expected_str, f"{model_name=}" + + @staticmethod + def _compute_chat_template_diff(base_messages, extra_messages, tokenizer) -> str: + text_with = tokenizer.apply_chat_template( + base_messages + extra_messages, tokenize=False, add_generation_prompt=True + ) + text_without = tokenizer.apply_chat_template(base_messages, tokenize=False, add_generation_prompt=False) + return text_with[len(text_without) :] diff --git a/tests/fast/rollout/generate_utils/__init__.py b/tests/fast/rollout/generate_utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/fast/rollout/generate_utils/test_sample_utils.py b/tests/fast/rollout/generate_utils/test_sample_utils.py new file mode 100644 index 000000000..c53fbbb56 --- /dev/null +++ b/tests/fast/rollout/generate_utils/test_sample_utils.py @@ -0,0 +1,156 @@ +from unittest.mock import MagicMock + +import pytest + +from miles.rollout.generate_utils.sample_utils import _merge_sample_pair +from miles.utils.types import Sample + + +@pytest.fixture +def mock_tokenizer(): + tokenizer = MagicMock() + tokenizer.decode = lambda tokens: f"" + return tokenizer + + +def make_sample( + prompt="test_prompt", + tokens=None, + response="", + response_length=0, + loss_mask=None, + rollout_log_probs=None, + status=Sample.Status.COMPLETED, + label="test_label", + reward=1.0, + index=0, + group_index=0, +): + return Sample( + prompt=prompt, + tokens=tokens or [], + response=response, + response_length=response_length, + loss_mask=loss_mask, + rollout_log_probs=rollout_log_probs, + status=status, + label=label, + reward=reward, + index=index, + group_index=group_index, + ) + + +class TestMergeSamples: + def test_basic_merge(self, mock_tokenizer): + a = make_sample( + tokens=[1, 2, 3, 10, 11, 12], + response="response1", + response_length=3, + loss_mask=[1, 1, 1], + rollout_log_probs=[-0.1, -0.2, -0.3], + ) + b = make_sample( + tokens=[1, 2, 3, 10, 11, 12, 20, 21, 30, 31, 32], + response="response2", + response_length=3, + loss_mask=[1, 1, 1], + rollout_log_probs=[-0.4, -0.5, -0.6], + status=Sample.Status.TRUNCATED, + ) + + merged = _merge_sample_pair(a, b, mock_tokenizer) + + assert merged.tokens == b.tokens + assert merged.response_length == 3 + 2 + 3 + assert merged.loss_mask == [1, 1, 1, 0, 0, 1, 1, 1] + assert merged.rollout_log_probs == [-0.1, -0.2, -0.3, 0.0, 0.0, -0.4, -0.5, -0.6] + assert merged.prompt == a.prompt + assert merged.status == b.status + assert merged.label == a.label + assert merged.index == a.index + assert merged.group_index == a.group_index + assert "response1" in merged.response + assert "response2" in merged.response + assert "" in merged.response + + def test_loss_mask_none_defaults_to_all_ones(self, mock_tokenizer): + a = make_sample( + tokens=[1, 2, 10], + response_length=1, + loss_mask=None, + rollout_log_probs=None, + ) + b = make_sample( + tokens=[1, 2, 10, 20, 30], + response_length=1, + loss_mask=None, + rollout_log_probs=None, + ) + + merged = _merge_sample_pair(a, b, mock_tokenizer) + + assert merged.loss_mask == [1, 0, 1] + assert merged.rollout_log_probs == [0.0, 0.0, 0.0] + + def test_tokens_prefix_mismatch_raises(self, mock_tokenizer): + a = make_sample( + tokens=[1, 2, 3], + response_length=1, + loss_mask=[1], + ) + b = make_sample( + tokens=[1, 2, 99, 20, 30], + response_length=1, + loss_mask=[1], + ) + + with pytest.raises(AssertionError, match="b.tokens must start with a.tokens"): + _merge_sample_pair(a, b, mock_tokenizer) + + def test_field_mismatch_raises(self, mock_tokenizer): + a = make_sample( + tokens=[1, 2, 10], + response_length=1, + loss_mask=[1], + index=0, + ) + b = make_sample( + tokens=[1, 2, 10, 20, 30], + response_length=1, + loss_mask=[1], + index=1, + ) + + with pytest.raises(AssertionError, match="index mismatch"): + _merge_sample_pair(a, b, mock_tokenizer) + + def test_obs_len_invalid_raises(self, mock_tokenizer): + a = make_sample( + tokens=[1, 2, 10], + response_length=1, + loss_mask=[1], + ) + b = make_sample( + tokens=[1, 2, 10, 30], + response_length=1, + loss_mask=[1], + ) + + with pytest.raises(AssertionError, match="obs_len must be > 0"): + _merge_sample_pair(a, b, mock_tokenizer) + + def test_sample_validate_fails_raises(self, mock_tokenizer): + a = make_sample( + tokens=[1, 2, 10, 11], + response_length=2, + loss_mask=[1], + ) + b = make_sample( + tokens=[1, 2, 10, 11, 20, 30], + response_length=1, + loss_mask=[1], + ) + + with pytest.raises(AssertionError, match="loss_mask length"): + _merge_sample_pair(a, b, mock_tokenizer) diff --git a/tests/fast/rollout/inference_rollout/__init__.py b/tests/fast/rollout/inference_rollout/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/fast/rollout/inference_rollout/conftest.py b/tests/fast/rollout/inference_rollout/conftest.py new file mode 100644 index 000000000..ca47edeeb --- /dev/null +++ b/tests/fast/rollout/inference_rollout/conftest.py @@ -0,0 +1,45 @@ +from unittest.mock import patch + +import pytest + +from miles.utils.arguments import parse_args + + +def _build_mock_args(extra_argv: list[str] | None = None): + argv = [ + "pytest", + "--train-backend", + "fsdp", + "--rollout-batch-size", + "2", + "--n-samples-per-prompt", + "1", + "--num-rollout", + "1", + "--rollout-num-gpus", + "4", + "--rollout-num-gpus-per-engine", + "2", + "--hf-checkpoint", + "Qwen/Qwen3-0.6B", + "--prompt-data", + "/dev/null", + "--input-key", + "input", + "--label-key", + "label", + "--rm-type", + "math", + "--use-miles-router", + "--sglang-router-ip", + "127.0.0.1", + "--sglang-router-port", + "30000", + ] + (extra_argv or []) + with patch("sys.argv", argv): + return parse_args() + + +@pytest.fixture +def mock_args(): + return _build_mock_args() diff --git a/tests/fast/rollout/inference_rollout/integration/__init__.py b/tests/fast/rollout/inference_rollout/integration/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/fast/rollout/inference_rollout/integration/test_basic.py b/tests/fast/rollout/inference_rollout/integration/test_basic.py new file mode 100644 index 000000000..5b791829d --- /dev/null +++ b/tests/fast/rollout/inference_rollout/integration/test_basic.py @@ -0,0 +1,69 @@ +import pytest +from tests.fast.fixtures.generation_fixtures import extra_argv_for_variant +from tests.fast.fixtures.rollout_fixtures import RolloutEnvConfig +from tests.fast.rollout.inference_rollout.integration.utils import ( + MODULAR_ROLLOUT_BASE_ARGV, + expected_sample, + load_and_call_train, +) + +from miles.rollout.base_types import RolloutFnConstructorInput, RolloutFnEvalInput +from miles.rollout.inference_rollout.compatibility import call_rollout_function, load_rollout_function + +_VARIANTS = [ + pytest.param( + RolloutEnvConfig( + extra_argv=[ + "--rollout-function-path", + "miles.rollout.sglang_rollout.generate_rollout", + "--eval-function-path", + "miles.rollout.sglang_rollout.generate_rollout", + "--custom-generate-function-path", + "miles.rollout.sglang_rollout.generate", + ] + ), + id="old_rollout_old_generate", + ), + pytest.param( + RolloutEnvConfig( + extra_argv=[ + "--rollout-function-path", + "miles.rollout.inference_rollout.inference_rollout_common.InferenceRolloutFn", + "--custom-generate-function-path", + "miles.rollout.sglang_rollout.generate", + ] + ), + id="new_rollout_old_generate", + ), + pytest.param( + RolloutEnvConfig(extra_argv=MODULAR_ROLLOUT_BASE_ARGV + extra_argv_for_variant("single_turn")), + id="new_rollout_new_generate", + ), +] + + +@pytest.mark.parametrize("rollout_env", _VARIANTS, indirect=True) +def test_train(rollout_env): + env = rollout_env + out = load_and_call_train(env.args, env.data_source) + + assert len(out.samples) == env.args.rollout_batch_size + group = out.samples[0] + assert len(group) == env.args.n_samples_per_prompt + assert group[0] == expected_sample(group_index=0) + + +@pytest.mark.parametrize("rollout_env", _VARIANTS, indirect=True) +def test_eval(rollout_env): + env = rollout_env + fn = load_rollout_function( + RolloutFnConstructorInput(args=env.args, data_source=env.data_source), env.args.eval_function_path + ) + out = call_rollout_function(fn, RolloutFnEvalInput(rollout_id=0)) + + assert "toy" in out.data + rewards = out.data["toy"]["rewards"] + samples = out.data["toy"]["samples"] + assert len(rewards) == len(samples) == env.args.n_samples_per_eval_prompt + assert rewards[0] == 1 + assert samples[0] == expected_sample(group_index=None) diff --git a/tests/fast/rollout/inference_rollout/integration/test_deterministic.py b/tests/fast/rollout/inference_rollout/integration/test_deterministic.py new file mode 100644 index 000000000..69a235911 --- /dev/null +++ b/tests/fast/rollout/inference_rollout/integration/test_deterministic.py @@ -0,0 +1,37 @@ +import pytest + +from tests.fast.rollout.inference_rollout.integration.utils import integration_env_config, load_and_call_train + + +@pytest.mark.parametrize( + "rollout_env,expected_seeds", + [ + pytest.param( + integration_env_config( + [ + "--sglang-enable-deterministic-inference", + "--rollout-seed", + "42", + "--n-samples-per-prompt", + "3", + "--rollout-batch-size", + "1", + ] + ), + {42, 43, 44}, + id="enabled", + ), + pytest.param( + integration_env_config(["--n-samples-per-prompt", "2", "--rollout-batch-size", "1"]), + {None}, + id="disabled", + ), + ], + indirect=["rollout_env"], +) +def test_sampling_seeds(rollout_env, expected_seeds): + env = rollout_env + load_and_call_train(env.args, env.data_source) + + seeds = {req.get("sampling_params", {}).get("sampling_seed") for req in env.mock_server.request_log} + assert seeds == expected_seeds diff --git a/tests/fast/rollout/inference_rollout/integration/test_dynamic_filter.py b/tests/fast/rollout/inference_rollout/integration/test_dynamic_filter.py new file mode 100644 index 000000000..0ca5743ac --- /dev/null +++ b/tests/fast/rollout/inference_rollout/integration/test_dynamic_filter.py @@ -0,0 +1,46 @@ +from contextlib import nullcontext + +import pytest +from tests.fast.rollout.inference_rollout.integration.utils import ( + MIXED_DATA_ROWS, + filter_by_reward, + integration_env_config, + load_and_call_train, +) + +from miles.utils.misc import function_registry + + +@pytest.mark.parametrize( + "rollout_env,use_filter,expect_all_correct", + [ + pytest.param( + integration_env_config(["--rollout-batch-size", "4"], data_rows=MIXED_DATA_ROWS), + False, + False, + id="no_filter", + ), + pytest.param( + integration_env_config( + ["--rollout-batch-size", "3", "--dynamic-sampling-filter-path", "test:filter_by_reward"], + data_rows=MIXED_DATA_ROWS, + ), + True, + True, + id="with_filter", + ), + ], + indirect=["rollout_env"], +) +def test_filter_effect(rollout_env, use_filter, expect_all_correct): + env = rollout_env + ctx = function_registry.temporary("test:filter_by_reward", filter_by_reward) if use_filter else nullcontext() + + with ctx: + out = load_and_call_train(env.args, env.data_source) + + rewards = {group[0].reward for group in out.samples} + if expect_all_correct: + assert rewards == {1}, "Filter should keep only correct samples" + else: + assert 0 in rewards, "Without filter, incorrect samples should be present" diff --git a/tests/fast/rollout/inference_rollout/integration/test_group_rm.py b/tests/fast/rollout/inference_rollout/integration/test_group_rm.py new file mode 100644 index 000000000..afd870c30 --- /dev/null +++ b/tests/fast/rollout/inference_rollout/integration/test_group_rm.py @@ -0,0 +1,22 @@ +import pytest + +from tests.fast.rollout.inference_rollout.integration.utils import integration_env_config, load_and_call_train + + +@pytest.mark.parametrize( + "rollout_env", + [ + pytest.param( + integration_env_config(["--group-rm", "--n-samples-per-prompt", "2", "--rollout-batch-size", "1"]), + id="group_rm_enabled", + ), + ], + indirect=True, +) +def test_group_rm_rewards_set(rollout_env): + env = rollout_env + out = load_and_call_train(env.args, env.data_source) + + assert len(out.samples) == env.args.rollout_batch_size + rewards = [sample.reward for group in out.samples for sample in group] + assert all(r in (0, 1) for r in rewards) diff --git a/tests/fast/rollout/inference_rollout/integration/test_multi_sample.py b/tests/fast/rollout/inference_rollout/integration/test_multi_sample.py new file mode 100644 index 000000000..2b12d3d88 --- /dev/null +++ b/tests/fast/rollout/inference_rollout/integration/test_multi_sample.py @@ -0,0 +1,65 @@ +import pytest +from tests.fast.fixtures.rollout_fixtures import DEFAULT_DATA_ROWS, RolloutEnvConfig +from tests.fast.rollout.inference_rollout.integration.utils import MODULAR_ROLLOUT_BASE_ARGV, load_and_call_train + +from miles.rollout.base_types import GenerateFnInput, GenerateFnOutput +from miles.utils.misc import function_registry +from miles.utils.types import Sample + + +async def _multi_sample_generate(input: GenerateFnInput) -> GenerateFnOutput: + sample = input.sample + s1 = Sample( + prompt=sample.prompt, + response="\\boxed{8}", + response_length=5, + tokens=sample.tokens + [59, 79075, 90, 23, 92], + label=sample.label, + reward=None, + status=Sample.Status.COMPLETED, + ) + s2 = Sample( + prompt=sample.prompt, + response="\\boxed{8}", + response_length=5, + tokens=sample.tokens + [59, 79075, 90, 23, 92], + label=sample.label, + reward=0.5, + status=Sample.Status.COMPLETED, + ) + return GenerateFnOutput(samples=[s1, s2]) + + +@pytest.mark.parametrize( + "rollout_env", + [ + pytest.param( + RolloutEnvConfig( + extra_argv=MODULAR_ROLLOUT_BASE_ARGV + + [ + "--custom-generate-function-path", + "test:multi_sample_generate", + "--rollout-batch-size", + "1", + "--n-samples-per-prompt", + "1", + ], + data_rows=DEFAULT_DATA_ROWS, + ), + id="multi_sample_output", + ), + ], + indirect=True, +) +def test_multi_sample_output_preserves_existing_reward(rollout_env): + env = rollout_env + with function_registry.temporary("test:multi_sample_generate", _multi_sample_generate): + out = load_and_call_train(env.args, env.data_source) + + assert len(out.samples) == env.args.rollout_batch_size + group = out.samples[0] + assert isinstance(group[0], list) + samples = group[0] + assert len(samples) == 2 + assert samples[0].reward == 1 + assert samples[1].reward == 0.5 diff --git a/tests/fast/rollout/inference_rollout/integration/test_multi_turn.py b/tests/fast/rollout/inference_rollout/integration/test_multi_turn.py new file mode 100644 index 000000000..c41d71399 --- /dev/null +++ b/tests/fast/rollout/inference_rollout/integration/test_multi_turn.py @@ -0,0 +1,114 @@ +from typing import Any + +import pytest +from tests.fast.fixtures.generation_fixtures import extra_argv_for_variant +from tests.fast.fixtures.rollout_fixtures import RolloutEnvConfig +from tests.fast.rollout.inference_rollout.integration.utils import MODULAR_ROLLOUT_BASE_ARGV, load_and_call_rollout + +from miles.utils.test_utils.mock_tools import TwoTurnStub +from miles.utils.types import Sample + + +TWO_TURN_DATA_ROWS = [{"input": [{"role": "user", "content": TwoTurnStub.USER_QUESTION}], "label": "2008"}] + +_VARIANT_NAMES = [ + "multi_turn_single_sample", + "multi_turn_multi_samples", + "agentic_tool_call_single_sample", + "agentic_tool_call_multi_samples", +] + +_BASE_EXTRA_ARGV = [ + "--rollout-batch-size", + "2", + "--n-samples-per-prompt", + "2", + "--n-samples-per-eval-prompt", + "2", + "--custom-rm-path", + "tests.fast.rollout.inference_rollout.integration.test_multi_turn._simple_reward_function", +] + + +def _config_for_variant(variant: str) -> RolloutEnvConfig: + return RolloutEnvConfig( + extra_argv=MODULAR_ROLLOUT_BASE_ARGV + extra_argv_for_variant(variant) + _BASE_EXTRA_ARGV, + data_rows=TWO_TURN_DATA_ROWS, + ) + + +@pytest.mark.parametrize( + "variant,rollout_env", + [pytest.param(variant, _config_for_variant(variant), id=variant) for variant in _VARIANT_NAMES], + indirect=["rollout_env"], +) +@pytest.mark.parametrize("test_type", ["train", "eval"]) +def test_rollout(rollout_env, variant, test_type): + env = rollout_env + env.mock_server.process_fn = TwoTurnStub.process_fn + + out = load_and_call_rollout(env.args, env.data_source, mode=test_type) + + if test_type == "train": + assert len(out.samples) == env.args.rollout_batch_size + group = out.samples[0] + _verify_samples(variant, group) + else: + assert "toy" in out.data + samples = out.data["toy"]["samples"] + _verify_samples(variant, samples) + + +def _verify_samples(variant: str, samples: list[Any]): + is_multi_samples = variant in ("multi_turn_multi_samples", "agentic_tool_call_multi_samples") + + if is_multi_samples: + if len(samples) > 0 and isinstance(samples[0], list): + # Train mode: list[list[Sample]], grouped by prompt + assert len(samples) == 2, f"n_samples_per_prompt=2, so group should have 2 samples, got {len(samples)}" + for group_sample in samples: + assert isinstance(group_sample, list), "multi_samples variant should return list[Sample] per generate" + _verify_group_samples(group_sample) + else: + # Eval mode: list[Sample], flattened + # n_samples_per_eval_prompt=2, and each generate returns 2 turns, so 2*2=4 samples + assert ( + len(samples) == 4 + ), f"n_samples_per_eval_prompt=2, each generate returns 2 turns, so should have 4 samples, got {len(samples)}" + # Group samples by prompt (every 2 samples form a group) + group_samples_list = [samples[i : i + 2] for i in range(0, len(samples), 2)] + for group_samples in group_samples_list: + _verify_group_samples(group_samples) + else: + assert len(samples) == 2, f"n_samples_per_prompt=2, so group should have 2 samples, got {len(samples)}" + for sample in samples: + assert isinstance(sample, Sample), "single_sample variant should return Sample, not list" + _verify_sample(sample) + + +def _verify_group_samples(group_samples: list[Sample], expected_count: int = 2): + assert len(group_samples) == expected_count, f"Group should have {expected_count} samples (one per turn)" + for i, sample in enumerate(group_samples): + _verify_sample(sample, expect_answer=(i == len(group_samples) - 1)) + + +def _verify_sample(sample: Sample, expected_reward: float = 1.0, expect_answer: bool = True): + assert sample.status == Sample.Status.COMPLETED + assert sample.reward == expected_reward, f"Sample should have reward={expected_reward}" + if expect_answer: + assert "2008" in sample.response, "Response should contain final answer '2008'" + + +async def _simple_reward_function(args, samples: Sample | list[Sample]) -> float | list[float]: + if isinstance(samples, list): + # For multi_samples variants, use the last sample's reward + if getattr(args, "generate_multi_samples", False): + return [_check_reward(samples[-1])] * len(samples) + else: + return [_check_reward(sample) for sample in samples] + else: + return _check_reward(samples) + + +def _check_reward(sample: Sample) -> float: + return float(sample.response and (str(sample.label) in sample.response)) diff --git a/tests/fast/rollout/inference_rollout/integration/test_over_sampling.py b/tests/fast/rollout/inference_rollout/integration/test_over_sampling.py new file mode 100644 index 000000000..0812962cc --- /dev/null +++ b/tests/fast/rollout/inference_rollout/integration/test_over_sampling.py @@ -0,0 +1,48 @@ +import pytest +from tests.fast.rollout.inference_rollout.integration.utils import ( + filter_by_reward, + integration_env_config, + load_and_call_train, +) + +from miles.utils.misc import function_registry + +_DATA_ROWS = [ + {"input": "What is 1+7?", "label": "8"}, + {"input": "What is 1+8?", "label": "wrong"}, + {"input": "What is 1+9?", "label": "wrong"}, + {"input": "What is 1+6?", "label": "wrong"}, +] + +_BASE_ARGV = [ + "--over-sampling-batch-size", + "4", + "--dynamic-sampling-filter-path", + "test:filter_by_reward", +] + + +def _over_sampling_config(rollout_batch_size: int): + return integration_env_config(["--rollout-batch-size", str(rollout_batch_size)] + _BASE_ARGV, data_rows=_DATA_ROWS) + + +@pytest.mark.parametrize( + "rollout_env,expected_rounds", + [ + pytest.param(_over_sampling_config(1), 1, id="one_round"), + pytest.param(_over_sampling_config(2), 2, id="two_rounds"), + ], + indirect=["rollout_env"], +) +def test_over_sampling_rounds(rollout_env, expected_rounds): + env = rollout_env + + with function_registry.temporary("test:filter_by_reward", filter_by_reward): + out = load_and_call_train(env.args, env.data_source) + + assert len(out.samples) == env.args.rollout_batch_size + assert all(group[0].reward == 1 for group in out.samples) + + requests_count = len(env.mock_server.request_log) + expected_requests = expected_rounds * env.args.over_sampling_batch_size + assert requests_count == expected_requests, f"Expected {expected_rounds} round(s) = {expected_requests} requests" diff --git a/tests/fast/rollout/inference_rollout/integration/test_sample_filter.py b/tests/fast/rollout/inference_rollout/integration/test_sample_filter.py new file mode 100644 index 000000000..36e78c16c --- /dev/null +++ b/tests/fast/rollout/inference_rollout/integration/test_sample_filter.py @@ -0,0 +1,67 @@ +from unittest.mock import Mock + +import pytest +from tests.fast.rollout.inference_rollout.integration.utils import ( + filter_by_reward, + integration_env_config, + load_and_call_train, +) + +from miles.utils.misc import function_registry + +# Data with only 2 reward=1 samples out of 4. +# This ensures all 4 samples must be generated to collect 2 valid ones. +_FILTER_TEST_DATA_ROWS = [ + {"input": "What is 1+7?", "label": "8"}, # reward=1 + {"input": "What is 1+8?", "label": "wrong"}, # reward=0 + {"input": "What is 1+9?", "label": "wrong"}, # reward=0 + {"input": "What is 1+6?", "label": "7"}, # reward=1 +] + + +@pytest.mark.parametrize( + "rollout_env", + [ + pytest.param( + integration_env_config( + [ + "--rollout-batch-size", + "2", + "--over-sampling-batch-size", + "4", + "--dynamic-sampling-filter-path", + "test:filter_by_reward", + "--rollout-sample-filter-path", + "test:sample_filter", + "--rollout-all-samples-process-path", + "test:all_samples_process", + ], + data_rows=_FILTER_TEST_DATA_ROWS, + ), + id="sample_filter_vs_all_samples", + ), + ], + indirect=True, +) +def test_sample_filter_and_all_samples_process(rollout_env): + env = rollout_env + sample_filter_mock = Mock() + all_samples_process_mock = Mock() + + with ( + function_registry.temporary("test:filter_by_reward", filter_by_reward), + function_registry.temporary("test:sample_filter", sample_filter_mock), + function_registry.temporary("test:all_samples_process", all_samples_process_mock), + ): + load_and_call_train(env.args, env.data_source) + + sample_filter_mock.assert_called_once() + _, filtered_data = sample_filter_mock.call_args[0] + rewards = [g[0][0].reward if isinstance(g[0], list) else g[0].reward for g in filtered_data] + assert all(r == 1 for r in rewards) + + all_samples_process_mock.assert_called_once() + _, all_samples, data_source = all_samples_process_mock.call_args[0] + assert data_source is not None + + assert len(all_samples) > len(filtered_data), "all_samples_process should see more samples than sample_filter" diff --git a/tests/fast/rollout/inference_rollout/integration/test_semaphore.py b/tests/fast/rollout/inference_rollout/integration/test_semaphore.py new file mode 100644 index 000000000..889a9ff8a --- /dev/null +++ b/tests/fast/rollout/inference_rollout/integration/test_semaphore.py @@ -0,0 +1,33 @@ +import pytest + +from tests.fast.rollout.inference_rollout.integration.utils import integration_env_config, load_and_call_train + +_DATA_ROWS = [{"input": f"What is 1+{i}?", "label": str(1 + i)} for i in range(10)] +_BASE_ARGV = ["--rollout-batch-size", "4", "--n-samples-per-prompt", "2"] + + +@pytest.mark.parametrize( + "rollout_env,expected_range", + [ + pytest.param( + integration_env_config( + ["--sglang-server-concurrency", "1"] + _BASE_ARGV, data_rows=_DATA_ROWS, latency=0.05 + ), + (1, 1), + id="limit_1", + ), + pytest.param( + integration_env_config( + ["--sglang-server-concurrency", "999"] + _BASE_ARGV, data_rows=_DATA_ROWS, latency=0.05 + ), + (2, 999), + id="no_limit", + ), + ], + indirect=["rollout_env"], +) +def test_max_concurrent(rollout_env, expected_range): + env = rollout_env + load_and_call_train(env.args, env.data_source) + min_expected, max_expected = expected_range + assert min_expected <= env.mock_server.max_concurrent <= max_expected diff --git a/tests/fast/rollout/inference_rollout/integration/utils.py b/tests/fast/rollout/inference_rollout/integration/utils.py new file mode 100644 index 000000000..ad413cf94 --- /dev/null +++ b/tests/fast/rollout/inference_rollout/integration/utils.py @@ -0,0 +1,89 @@ +from tests.fast.fixtures.generation_fixtures import extra_argv_for_variant +from tests.fast.fixtures.rollout_fixtures import RolloutEnvConfig + +from miles.rollout.base_types import ( + RolloutFnConstructorInput, + RolloutFnEvalInput, + RolloutFnOutput, + RolloutFnTrainInput, +) +from miles.rollout.filter_hub.base_types import DynamicFilterOutput +from miles.rollout.inference_rollout.compatibility import call_rollout_function, load_rollout_function +from miles.utils.types import Sample + + +def expected_sample(*, group_index: int | None) -> Sample: + return Sample( + group_index=group_index, + index=0, + prompt="What is 1+7?", + tokens=[3838, 374, 220, 16, 10, 22, 30, 59, 79075, 90, 23, 92], + multimodal_inputs=None, + multimodal_train_inputs=None, + response="\\boxed{8}", + response_length=5, + label="8", + reward=1, + loss_mask=None, + weight_versions=[], + rollout_log_probs=[-0.0, -0.0078125, -0.015625, -0.0234375, -0.03125], + rollout_routed_experts=None, + remove_sample=False, + status=Sample.Status.COMPLETED, + metadata={}, + train_metadata=None, + non_generation_time=0.0, + spec_info=Sample.SpecInfo( + spec_accept_token_num=0, spec_draft_token_num=0, spec_verify_ct=0, completion_token_num=0 + ), + prefix_cache_info=Sample.PrefixCacheInfo(cached_tokens=0, total_prompt_tokens=7), + ) + + +MODULAR_ROLLOUT_BASE_ARGV = [ + "--rollout-function-path", + "miles.rollout.inference_rollout.inference_rollout_common.InferenceRolloutFn", +] + +MIXED_DATA_ROWS = [ + {"input": "What is 1+7?", "label": "8"}, + {"input": "What is 1+8?", "label": "9"}, + {"input": "What is 1+9?", "label": "wrong"}, + {"input": "What is 1+6?", "label": "7"}, +] + + +def integration_env_config( + extra_argv: list[str], + data_rows: list[dict] | None = None, + latency: float = 0.0, + variant: str = "single_turn", +): + return RolloutEnvConfig( + extra_argv=MODULAR_ROLLOUT_BASE_ARGV + extra_argv_for_variant(variant) + extra_argv, + data_rows=data_rows, + latency=latency, + ) + + +def load_and_call_rollout(args, data_source, mode: str = "train") -> RolloutFnOutput: + function_path = args.rollout_function_path if mode == "train" else args.eval_function_path + fn = load_rollout_function( + RolloutFnConstructorInput(args=args, data_source=data_source), + function_path, + ) + if mode == "train": + return call_rollout_function(fn, RolloutFnTrainInput(rollout_id=0)) + else: + return call_rollout_function(fn, RolloutFnEvalInput(rollout_id=0)) + + +def load_and_call_train(args, data_source): + return load_and_call_rollout(args, data_source, mode="train") + + +def filter_by_reward(args, samples, **kwargs): + reward = samples[0].reward if not isinstance(samples[0], list) else samples[0][0].reward + if reward == 1: + return DynamicFilterOutput(keep=True) + return DynamicFilterOutput(keep=False, reason="reward_zero") diff --git a/tests/fast/rollout/inference_rollout/test_compatibility.py b/tests/fast/rollout/inference_rollout/test_compatibility.py new file mode 100644 index 000000000..ddfecd067 --- /dev/null +++ b/tests/fast/rollout/inference_rollout/test_compatibility.py @@ -0,0 +1,196 @@ +import asyncio +from unittest.mock import MagicMock + +import pytest + +from miles.rollout.base_types import ( + GenerateFnInput, + GenerateFnOutput, + RolloutFnConstructorInput, + RolloutFnEvalInput, + RolloutFnEvalOutput, + RolloutFnTrainInput, + RolloutFnTrainOutput, +) +from miles.rollout.inference_rollout.compatibility import ( + LegacyGenerateFnAdapter, + LegacyRolloutFnAdapter, + call_rollout_function, + load_generate_function, + load_rollout_function, +) +from miles.utils.async_utils import run +from miles.utils.misc import function_registry + + +@pytest.fixture +def constructor_input(): + return RolloutFnConstructorInput(args="dummy_args", data_source="dummy_data_source") + + +@pytest.fixture +def make_generate_fn_input(): + def _make(evaluation: bool = False): + state = MagicMock() + state.args = MagicMock() + + return GenerateFnInput( + state=state, + sample={"text": "test prompt"}, + sampling_params={"temperature": 0.7}, + evaluation=evaluation, + ) + + return _make + + +class TestSupportedRolloutFormats: + """ + Documentation test to show various supported rollout function formats + """ + + @pytest.mark.parametrize("evaluation", [False, True]) + def test_format_1_legacy_function_raw_output(self, constructor_input, evaluation): + def legacy_rollout_fn(args, rollout_id, data_source, evaluation=False): + if evaluation: + return {"metric": {"accuracy": 0.9}} + return [[{"text": "sample"}]] + + with function_registry.temporary("test:legacy_rollout", legacy_rollout_fn): + fn = load_rollout_function(constructor_input, "test:legacy_rollout") + + input_cls = RolloutFnEvalInput if evaluation else RolloutFnTrainInput + result = call_rollout_function(fn, input_cls(rollout_id=1)) + + assert isinstance(fn, LegacyRolloutFnAdapter) + if evaluation: + assert isinstance(result, RolloutFnEvalOutput) + assert result.data == {"metric": {"accuracy": 0.9}} + else: + assert isinstance(result, RolloutFnTrainOutput) + assert result.samples == [[{"text": "sample"}]] + + @pytest.mark.parametrize("evaluation", [False, True]) + def test_format_2_legacy_function_typed_output(self, constructor_input, evaluation): + def legacy_rollout_fn(args, rollout_id, data_source, evaluation=False): + if evaluation: + return RolloutFnEvalOutput(data={"ds": {"acc": 0.95}}) + return RolloutFnTrainOutput(samples=[[{"text": "typed"}]]) + + with function_registry.temporary("test:legacy_typed", legacy_rollout_fn): + fn = load_rollout_function(constructor_input, "test:legacy_typed") + + input_cls = RolloutFnEvalInput if evaluation else RolloutFnTrainInput + result = call_rollout_function(fn, input_cls(rollout_id=1)) + + if evaluation: + assert isinstance(result, RolloutFnEvalOutput) + assert result.data == {"ds": {"acc": 0.95}} + else: + assert isinstance(result, RolloutFnTrainOutput) + assert result.samples == [[{"text": "typed"}]] + + @pytest.mark.parametrize("evaluation", [False, True]) + def test_format_3_sync_class(self, constructor_input, evaluation): + class SyncRolloutFn: + def __init__(self, input: RolloutFnConstructorInput): + pass + + def __call__(self, input): + if input.evaluation: + return RolloutFnEvalOutput(data={"test": {"score": 1}}) + return RolloutFnTrainOutput(samples=[[{"text": "sync"}]]) + + with function_registry.temporary("test:sync_class", SyncRolloutFn): + fn = load_rollout_function(constructor_input, "test:sync_class") + + input_cls = RolloutFnEvalInput if evaluation else RolloutFnTrainInput + result = call_rollout_function(fn, input_cls(rollout_id=1)) + + assert isinstance(fn, SyncRolloutFn) + expected_type = RolloutFnEvalOutput if evaluation else RolloutFnTrainOutput + assert isinstance(result, expected_type) + + @pytest.mark.parametrize("evaluation", [False, True]) + def test_format_4_async_class(self, constructor_input, evaluation): + class AsyncRolloutFn: + def __init__(self, input: RolloutFnConstructorInput): + pass + + async def __call__(self, input): + await asyncio.sleep(0.001) + if input.evaluation: + return RolloutFnEvalOutput(data={"benchmark": {"accuracy": 0.98}}) + return RolloutFnTrainOutput(samples=[[{"text": "async"}]]) + + with function_registry.temporary("test:async_class", AsyncRolloutFn): + fn = load_rollout_function(constructor_input, "test:async_class") + + input_cls = RolloutFnEvalInput if evaluation else RolloutFnTrainInput + result = call_rollout_function(fn, input_cls(rollout_id=1)) + + assert isinstance(fn, AsyncRolloutFn) + expected_type = RolloutFnEvalOutput if evaluation else RolloutFnTrainOutput + assert isinstance(result, expected_type) + + +class TestSupportedGenerateFormats: + """ + Documentation test similar to TestSupportedRolloutFormats + """ + + @pytest.mark.parametrize("evaluation", [False, True]) + def test_format_1_legacy_function_with_evaluation_param(self, make_generate_fn_input, evaluation): + async def legacy_generate_fn(args, sample, sampling_params, evaluation=False): + return "my_sample" + + with function_registry.temporary("test:legacy_gen_eval", legacy_generate_fn): + fn = load_generate_function("test:legacy_gen_eval") + + result = run(fn(make_generate_fn_input(evaluation))) + + assert isinstance(fn, LegacyGenerateFnAdapter) + assert isinstance(result, GenerateFnOutput) + assert result.samples == "my_sample" + + @pytest.mark.parametrize("evaluation", [False, True]) + def test_format_2_legacy_function_without_evaluation_param(self, make_generate_fn_input, evaluation): + async def legacy_generate_fn(args, sample, sampling_params): + return "my_sample" + + with function_registry.temporary("test:legacy_gen", legacy_generate_fn): + fn = load_generate_function("test:legacy_gen") + + result = run(fn(make_generate_fn_input(evaluation))) + + assert isinstance(fn, LegacyGenerateFnAdapter) + assert isinstance(result, GenerateFnOutput) + assert result.samples == "my_sample" + + @pytest.mark.parametrize("evaluation", [False, True]) + def test_format_3_new_async_function_api(self, make_generate_fn_input, evaluation): + async def generate(input: GenerateFnInput) -> GenerateFnOutput: + return GenerateFnOutput(samples="my_sample") + + with function_registry.temporary("test:new_async", generate): + fn = load_generate_function("test:new_async") + + result = run(fn(make_generate_fn_input(evaluation))) + + assert isinstance(result, GenerateFnOutput) + assert result.samples == "my_sample" + + @pytest.mark.parametrize("evaluation", [False, True]) + def test_format_4_new_class_api(self, make_generate_fn_input, evaluation): + class MyGenerateFn: + async def __call__(self, input: GenerateFnInput) -> GenerateFnOutput: + return GenerateFnOutput(samples="my_sample") + + with function_registry.temporary("test:new_class", MyGenerateFn): + fn = load_generate_function("test:new_class") + + result = run(fn(make_generate_fn_input(evaluation))) + + assert isinstance(fn, MyGenerateFn) + assert isinstance(result, GenerateFnOutput) + assert result.samples == "my_sample" diff --git a/tests/fast/rollout/rm_hub/__init__.py b/tests/fast/rollout/rm_hub/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/fast/rollout/rm_hub/test_deepscaler.py b/tests/fast/rollout/rm_hub/test_deepscaler.py new file mode 100644 index 000000000..bd4c606a6 --- /dev/null +++ b/tests/fast/rollout/rm_hub/test_deepscaler.py @@ -0,0 +1,26 @@ +import pytest + +from miles.rollout.rm_hub.deepscaler import get_deepscaler_rule_based_reward + + +class TestGetDeepscalerRuleBasedReward: + @pytest.mark.parametrize( + "response,label,expected", + [ + (r"Let me analyze...The answer is \boxed{42}", "42", 1), + (r"Thinking...The answer is \boxed{wrong}", "42", 0), + (r"###Response\boxed{42}", "42", 1), + (r"###Response\boxed{wrong}", "42", 0), + (r"The answer is \boxed{42}", "42", 0), + (r"The answer is 42", "42", 0), + (r"\boxed{42}", "", 0), + (r"\boxed{42}", r"\boxed{42}", 1), + (r"\boxed{123}", 123, 1), + (r"\boxed{3.14}", 3.14, 1), + (r"\boxed{1/2}", "0.5", 1), + (r"\boxed{\frac{1}{2}}", "0.5", 1), + (r"First thoughtSecond thought\boxed{42}", "42", 1), + ], + ) + def test_get_deepscaler_rule_based_reward(self, response, label, expected): + assert get_deepscaler_rule_based_reward(response, label) == expected diff --git a/tests/fast/rollout/rm_hub/test_f1.py b/tests/fast/rollout/rm_hub/test_f1.py new file mode 100644 index 000000000..c9ecf9614 --- /dev/null +++ b/tests/fast/rollout/rm_hub/test_f1.py @@ -0,0 +1,44 @@ +import pytest + +from miles.rollout.rm_hub.f1 import f1_score, normalize_answer + + +class TestNormalizeAnswer: + @pytest.mark.parametrize( + "input_str,expected", + [ + ("Hello World", "hello world"), + ("The quick brown fox", "quick brown fox"), + ("A cat and a dog", "cat and dog"), + ("Hello, world!", "hello world"), + (" multiple spaces ", "multiple spaces"), + ("An apple", "apple"), + ("UPPERCASE", "uppercase"), + ], + ) + def test_normalize_answer(self, input_str, expected): + assert normalize_answer(input_str) == expected + + +class TestF1Score: + @pytest.mark.parametrize( + "prediction,ground_truth,expected_f1,expected_prec,expected_recall", + [ + ("hello world", "hello world", 1.0, 1.0, 1.0), + ("hello world foo", "hello world bar", 2 / 3, 2 / 3, 2 / 3), + ("abc", "xyz", 0, 0, 0), + (None, "anything", 0, 0, 0), + ("yes", "no", 0, 0, 0), + ("no", "yes", 0, 0, 0), + ("yes", "yes", 1.0, 1.0, 1.0), + ("noanswer", "yes", 0, 0, 0), + ("the answer is correct", "answer is correct", 1.0, 1.0, 1.0), + ("hello, world!", "hello world", 1.0, 1.0, 1.0), + ("hello", "hello world", pytest.approx(2 / 3), 1.0, 0.5), + ], + ) + def test_f1_score(self, prediction, ground_truth, expected_f1, expected_prec, expected_recall): + f1, prec, recall = f1_score(prediction, ground_truth) + assert f1 == expected_f1 + assert prec == expected_prec + assert recall == expected_recall diff --git a/tests/fast/rollout/rm_hub/test_gpqa.py b/tests/fast/rollout/rm_hub/test_gpqa.py new file mode 100644 index 000000000..45cefd201 --- /dev/null +++ b/tests/fast/rollout/rm_hub/test_gpqa.py @@ -0,0 +1,86 @@ +import pytest + +from miles.rollout.rm_hub.gpqa import ( + _extract_letter_from_response, + _normalize_text, + _strip_chain_of_thought, + compute_gpqa_reward, +) + + +class TestStripChainOfThought: + @pytest.mark.parametrize( + "text,expected", + [ + ("Let me think...The answer is A", "The answer is A"), + ("The answer is A", "The answer is A"), + ("", ""), + (None, ""), + ], + ) + def test_strip_chain_of_thought(self, text, expected): + assert _strip_chain_of_thought(text) == expected + + +class TestNormalizeText: + @pytest.mark.parametrize( + "input_str,expected", + [ + ("Hello World", "hello world"), + ("Test-123", "test 123"), + ("A, B, C", "a b c"), + ("", ""), + ], + ) + def test_normalize_text(self, input_str, expected): + assert _normalize_text(input_str) == expected + + +class TestExtractLetterFromResponse: + @pytest.mark.parametrize( + "response,expected", + [ + ("The answer is A", "A"), + ("answer: B", "B"), + ("I think C is correct", "C"), + ("final answer: D", "D"), + ("Option A is the best choice", "A"), + ("The answer is B", "B"), + ("After analysis, my choice is C", "C"), + ("A B C D", "D"), + ("No valid letter here", None), + ("", None), + (None, None), + ("The answer is Z", None), + ], + ) + def test_extract_letter(self, response, expected): + assert _extract_letter_from_response(response, "ABCD") == expected + + +class TestComputeGpqaReward: + @pytest.mark.parametrize( + "response,label,metadata,expected", + [ + ("Answer: A", "A", None, 1.0), + ("Answer: A", "B", None, 0.0), + (None, "A", None, 0.0), + ("Answer: B", "ignored", {"correct_letter": "B"}, 1.0), + ("Answer: A", "ignored", {"correct_letter": "B"}, 0.0), + ("Answer: A", 0, {"choices": ["Option 1", "Option 2", "Option 3", "Option 4"]}, 1.0), + ("Answer: B", 1, {"choices": ["Option 1", "Option 2", "Option 3", "Option 4"]}, 1.0), + ("Answer: X", "X", {"valid_letters": ["X", "Y", "Z"]}, 1.0), + ("Answer: A", "X", {"valid_letters": ["X", "Y", "Z"]}, 0.0), + ( + "I believe the answer is Paris", + "", + {"choices": ["Paris", "London", "Berlin", "Rome"], "correct_letter": "A"}, + 1.0, + ), + ("Answer: A", "", {"choices": {"A": "Paris", "B": "London"}, "correct_letter": "A"}, 1.0), + ("The answer is Paris", "Paris", {"choices": ["Paris", "London", "Berlin", "Rome"]}, 1.0), + ("Let me think step by step...The answer is A", "A", None, 1.0), + ], + ) + def test_compute_gpqa_reward(self, response, label, metadata, expected): + assert compute_gpqa_reward(response, label, metadata=metadata) == expected diff --git a/tests/fast/rollout/rm_hub/test_math_dapo_utils.py b/tests/fast/rollout/rm_hub/test_math_dapo_utils.py new file mode 100644 index 000000000..56a7f6d1f --- /dev/null +++ b/tests/fast/rollout/rm_hub/test_math_dapo_utils.py @@ -0,0 +1,108 @@ +import pytest + +from miles.rollout.rm_hub.math_dapo_utils import ( + compute_score, + is_correct_minerva, + is_correct_strict_box, + last_boxed_only_string, + normalize_final_answer, + remove_boxed, +) + + +class TestLastBoxedOnlyString: + @pytest.mark.parametrize( + "input_str,expected", + [ + (r"The answer is \boxed{42}", r"\boxed{42}"), + (r"\boxed{x^2}", r"\boxed{x^2}"), + (r"No boxed", None), + (r"Multiple \boxed{1} and \boxed{2}", r"\boxed{2}"), + ], + ) + def test_last_boxed_only_string(self, input_str, expected): + assert last_boxed_only_string(input_str) == expected + + +class TestRemoveBoxed: + @pytest.mark.parametrize( + "input_str,expected", + [ + (r"\boxed{42}", "42"), + (r"\boxed{x + 1}", "x + 1"), + ], + ) + def test_remove_boxed_valid(self, input_str, expected): + assert remove_boxed(input_str) == expected + + def test_remove_boxed_invalid(self): + with pytest.raises(AssertionError): + remove_boxed("not boxed") + + +class TestNormalizeFinalAnswer: + @pytest.mark.parametrize( + "input_str,expected", + [ + ("42", "42"), + (" 42 ", "42"), + (r"\text{hello}", "hello"), + (r"\textbf{bold}", "bold"), + (r"x = 42", "42"), + (r"100 square", "100"), + (r"$50$ dollars", "50"), + (r"\boxed{42}", "42"), + (r"\frac12", r"\frac{1}{2}"), + (r"\sqrt3", r"\sqrt{3}"), + ("1,000", "1000"), + ("<|im_end|>", ""), + ], + ) + def test_normalize_final_answer(self, input_str, expected): + assert normalize_final_answer(input_str) == expected + + +class TestIsCorrectMinerva: + @pytest.mark.parametrize( + "solution,gt,gt_need_extract,expected_correct", + [ + ("Answer: 42", "42", False, True), + ("Answer: 100", "42", False, False), + ("Answer: wrong", "42", False, False), + ("Answer: 42", r"\boxed{42}", True, True), + ], + ) + def test_is_correct_minerva(self, solution, gt, gt_need_extract, expected_correct): + correct, pred = is_correct_minerva(solution, gt, gt_need_extract=gt_need_extract) + assert correct == expected_correct + + +class TestIsCorrectStrictBox: + @pytest.mark.parametrize( + "pred,gt,expected_score,expected_pred", + [ + (r"blah blah \boxed{42}", "42", 1, "42"), + (r"\boxed{wrong}", "42", -1, "wrong"), + ("no box here", "42", -1, None), + ], + ) + def test_is_correct_strict_box(self, pred, gt, expected_score, expected_pred): + score, extracted = is_correct_strict_box(pred, gt) + assert score == expected_score + assert extracted == expected_pred + + +class TestComputeScore: + @pytest.mark.parametrize( + "solution,gt,strict_box,expected_score,expected_acc", + [ + ("Answer: 42", "42", False, 1.0, True), + ("Answer: wrong", "42", False, -1.0, False), + (r"\boxed{42}", "42", True, 1.0, True), + ("x" * 500 + " Answer: 42", "42", False, 1.0, True), + ], + ) + def test_compute_score(self, solution, gt, strict_box, expected_score, expected_acc): + result = compute_score(solution, gt, strict_box_verify=strict_box) + assert result["score"] == expected_score + assert result["acc"] == expected_acc diff --git a/tests/fast/rollout/rm_hub/test_math_utils.py b/tests/fast/rollout/rm_hub/test_math_utils.py new file mode 100644 index 000000000..2423ed4ac --- /dev/null +++ b/tests/fast/rollout/rm_hub/test_math_utils.py @@ -0,0 +1,129 @@ +import pytest + +from miles.rollout.rm_hub.math_utils import ( + _normalize, + extract_answer, + grade_answer_mathd, + grade_answer_sympy, + grade_answer_verl, + last_boxed_only_string, + remove_boxed, +) + + +class TestLastBoxedOnlyString: + @pytest.mark.parametrize( + "input_str,expected", + [ + (r"The answer is \boxed{42}", r"\boxed{42}"), + (r"\boxed{x^2 + 1}", r"\boxed{x^2 + 1}"), + (r"So \boxed{\frac{1}{2}}", r"\boxed{\frac{1}{2}}"), + (r"No boxed here", None), + (r"Multiple \boxed{1} and \boxed{2}", r"\boxed{2}"), + (r"\boxed{nested {braces}}", r"\boxed{nested {braces}}"), + (r"\fbox{fbox content}", r"\fbox{fbox content}"), + ("", None), + ], + ) + def test_last_boxed_only_string(self, input_str, expected): + assert last_boxed_only_string(input_str) == expected + + +class TestRemoveBoxed: + @pytest.mark.parametrize( + "input_str,expected", + [ + (r"\boxed{42}", "42"), + (r"\boxed{x^2 + 1}", "x^2 + 1"), + (r"\boxed{\frac{1}{2}}", r"\frac{1}{2}"), + ("not boxed", None), + ], + ) + def test_remove_boxed(self, input_str, expected): + assert remove_boxed(input_str) == expected + + +class TestExtractAnswer: + @pytest.mark.parametrize( + "input_str,expected", + [ + (r"The answer is \boxed{42}", "42"), + (r"So \boxed{\frac{1}{2}}", r"\frac{1}{2}"), + (r"Multiple \boxed{1} then \boxed{final}", "final"), + (r"No boxed here", None), + ("", None), + ], + ) + def test_extract_answer(self, input_str, expected): + assert extract_answer(input_str) == expected + + +class TestNormalize: + @pytest.mark.parametrize( + "input_str,expected", + [ + ("1,000", "1000"), + (r"\text{hello}", "hello"), + (" 42 ", "42"), + (r"100%", "100"), + (r"\$50", "50"), + ("HELLO", "hello"), + ("1,234,567", "1234567"), + (None, None), + ], + ) + def test_normalize(self, input_str, expected): + assert _normalize(input_str) == expected + + +class TestGradeAnswerMathd: + @pytest.mark.parametrize( + "given,ground_truth,expected", + [ + ("42", "42", True), + (" 42 ", "42", True), + (r"\frac{1}{2}", r"\frac{1}{2}", True), + ("wrong", "42", False), + ("", "42", False), + ], + ) + def test_grade_answer_mathd(self, given, ground_truth, expected): + assert grade_answer_mathd(given, ground_truth) == expected + + +class TestGradeAnswerSympy: + @pytest.mark.parametrize( + "given,ground_truth,expected", + [ + ("42", "42", True), + ("x^2", "x^2", True), + ("1/2", "0.5", True), + (r"\frac{1}{2}", "0.5", True), + ("wrong", "42", False), + ("", "42", False), + ("(1,2)", "(1,2)", True), + ("(1,2,3)", "(1,2)", False), + ("42", None, False), + ], + ) + def test_grade_answer_sympy(self, given, ground_truth, expected): + assert grade_answer_sympy(given, ground_truth) == expected + + +class TestGradeAnswerVerl: + @pytest.mark.parametrize( + "solution,ground_truth,expected", + [ + (r"\boxed{42}", "42", True), + (r"The answer is \boxed{42}", "42", True), + (r"\boxed{1/2}", r"\frac{1}{2}", True), + (r"\boxed{wrong}", "42", False), + ("no boxed", "42", False), + (r"\boxed{42}", r"\boxed{42}", True), + ("", "42", False), + (r"\boxed{42}", "", False), + (r"\boxed{42}", None, False), + ], + ) + def test_grade_answer_verl(self, solution, ground_truth, expected): + assert grade_answer_verl(solution, ground_truth) == expected diff --git a/tests/fast/rollout/rm_hub/test_rm_hub.py b/tests/fast/rollout/rm_hub/test_rm_hub.py new file mode 100644 index 000000000..a3dadbdaf --- /dev/null +++ b/tests/fast/rollout/rm_hub/test_rm_hub.py @@ -0,0 +1,126 @@ +from unittest.mock import MagicMock + +import pytest + +from miles.rollout.rm_hub import async_rm, batched_async_rm +from miles.utils.async_utils import run +from miles.utils.types import Sample + + +@pytest.fixture +def mock_args(): + args = MagicMock() + args.custom_rm_path = None + args.rm_type = None + args.rm_url = None + return args + + +class TestAsyncRm: + @pytest.mark.parametrize( + "rm_type,response,label,expected", + [ + ("math", r"\boxed{42}", "42", 1), + ("math", r"\boxed{wrong}", "42", 0), + ("f1", "hello world", "hello world", 1.0), + ("dapo", "Answer: 42", "42", {"score": 1.0}), + ("deepscaler", r"\boxed{42}", "42", 1), + ("gpqa", "Answer: A", "A", 1.0), + ("boxed_f1", r"Final answer is \boxed{hello world}", "hello world", 1.0), + ], + ) + def test_rm_types(self, mock_args, rm_type, response, label, expected): + mock_args.rm_type = rm_type + sample = Sample(prompt="", response=response, label=label) + reward = run(async_rm(mock_args, sample)) + if isinstance(expected, dict): + for k, v in expected.items(): + assert reward[k] == v + else: + assert reward == expected + + def test_f1_rm_partial(self, mock_args): + mock_args.rm_type = "f1" + sample = Sample(prompt="", response="hello", label="hello world") + reward = run(async_rm(mock_args, sample)) + assert 0 < reward < 1 + + def test_random_rm(self, mock_args): + mock_args.rm_type = "random" + sample = Sample(prompt="", response="anything", label="anything") + reward = run(async_rm(mock_args, sample)) + assert reward in [0, 1] + + def test_rm_type_from_metadata(self, mock_args): + mock_args.rm_type = None + sample = Sample(prompt="", response=r"\boxed{42}", label="42", metadata={"rm_type": "math"}) + reward = run(async_rm(mock_args, sample)) + assert reward == 1 + + @pytest.mark.parametrize( + "rm_type,match", + [ + ("unknown_type", "not implemented"), + ("", "not specified"), + ], + ) + def test_invalid_rm_type_raises(self, mock_args, rm_type, match): + mock_args.rm_type = rm_type + sample = Sample(prompt="", response="test", label="test") + with pytest.raises(NotImplementedError, match=match): + run(async_rm(mock_args, sample)) + + +class TestBatchedAsyncRm: + @pytest.mark.parametrize( + "rm_type,samples_data,expected", + [ + ( + "math", + [(r"\boxed{42}", "42"), (r"\boxed{100}", "100"), (r"\boxed{wrong}", "42")], + [1, 1, 0], + ), + ( + "f1", + [("hello world", "hello world"), ("different", "something else")], + [1.0, 0], + ), + ], + ) + def test_batched_rm(self, mock_args, rm_type, samples_data, expected): + mock_args.rm_type = rm_type + samples = [Sample(prompt="", response=r, label=label) for r, label in samples_data] + rewards = run(batched_async_rm(mock_args, samples)) + assert rewards == expected + + def test_inplace_set_reward_field(self, mock_args): + mock_args.rm_type = "math" + samples = [ + Sample(prompt="", response=r"\boxed{42}", label="42"), + Sample(prompt="", response=r"\boxed{100}", label="100"), + ] + result = run(batched_async_rm(mock_args, samples, inplace_set_reward_field=True)) + assert result is None + assert samples[0].reward == 1 + assert samples[1].reward == 1 + + def test_inplace_raises_on_existing_reward(self, mock_args): + mock_args.rm_type = "math" + samples = [Sample(prompt="", response=r"\boxed{42}", label="42", reward=0.5)] + with pytest.raises(AssertionError, match="Overriding"): + run(batched_async_rm(mock_args, samples, inplace_set_reward_field=True)) + + def test_empty_samples(self, mock_args): + mock_args.rm_type = "math" + rewards = run(batched_async_rm(mock_args, [])) + assert rewards == [] + + def test_mixed_rm_types_via_metadata(self, mock_args): + mock_args.rm_type = None + samples = [ + Sample(prompt="", response=r"\boxed{42}", label="42", metadata={"rm_type": "math"}), + Sample(prompt="", response="hello", label="hello", metadata={"rm_type": "f1"}), + ] + rewards = run(batched_async_rm(mock_args, samples)) + assert rewards[0] == 1 + assert rewards[1] == 1.0 diff --git a/tests/fast/router/__init__.py b/tests/fast/router/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/fast/router/test_router.py b/tests/fast/router/test_router.py new file mode 100644 index 000000000..7c645fe30 --- /dev/null +++ b/tests/fast/router/test_router.py @@ -0,0 +1,204 @@ +import asyncio +from argparse import Namespace + +import pytest +import requests + +from miles.router.router import MilesRouter +from miles.utils.http_utils import find_available_port +from miles.utils.test_utils.mock_sglang_server import MockSGLangServer, default_process_fn +from miles.utils.test_utils.uvicorn_thread_server import UvicornThreadServer + + +def make_router_args(router_port: int, **overrides) -> Namespace: + defaults = dict( + sglang_router_ip="127.0.0.1", + sglang_router_port=router_port, + rollout_health_check_interval=1.0, + miles_router_health_check_failure_threshold=3, + miles_router_max_connections=100, + miles_router_timeout=None, + miles_router_middleware_paths=[], + ) + defaults.update(overrides) + return Namespace(**defaults) + + +def create_mock_worker(start_port: int = 30000) -> MockSGLangServer: + port = find_available_port(start_port) + return MockSGLangServer( + model_name="Qwen/Qwen3-0.6B", + process_fn=default_process_fn, + host="127.0.0.1", + port=port, + latency=0.0, + ) + + +class RouterEnv: + def __init__(self, router: MilesRouter, server: UvicornThreadServer): + self.router = router + self.server = server + + @property + def url(self) -> str: + return self.server.url + + +@pytest.fixture +def router_env(): + args = make_router_args(find_available_port(20000)) + router = MilesRouter(args, verbose=False) + server = UvicornThreadServer(router.app, host=args.sglang_router_ip, port=args.sglang_router_port) + server.start() + yield RouterEnv(router, server) + server.stop() + + +@pytest.fixture +def mock_worker(): + server = create_mock_worker() + server.start() + yield server + server.stop() + + +@pytest.fixture +def mock_worker_factory(): + servers = [] + + def _create(): + start_port = 30000 + len(servers) * 100 + server = create_mock_worker(start_port) + server.start() + servers.append(server) + return server + + yield _create + for s in servers: + s.stop() + + +@pytest.fixture +def router_factory(): + def _create(**overrides) -> MilesRouter: + args = make_router_args(find_available_port(20000), **overrides) + return MilesRouter(args, verbose=False) + + return _create + + +class TestWorkerManagement: + def test_add_worker_via_query_param(self, router_env: RouterEnv): + worker_url = "http://127.0.0.1:30001" + r = requests.post(f"{router_env.url}/add_worker", params={"url": worker_url}, timeout=5.0) + r.raise_for_status() + + assert r.json()["status"] == "success" + assert worker_url in router_env.router.worker_request_counts + assert router_env.router.worker_request_counts[worker_url] == 0 + + def test_add_worker_via_body(self, router_env: RouterEnv): + worker_url = "http://127.0.0.1:30002" + r = requests.post(f"{router_env.url}/add_worker", json={"url": worker_url}, timeout=5.0) + r.raise_for_status() + + assert r.json()["status"] == "success" + assert worker_url in router_env.router.worker_request_counts + + def test_add_worker_duplicate(self, router_env: RouterEnv): + worker_url = "http://127.0.0.1:30003" + requests.post(f"{router_env.url}/add_worker", params={"url": worker_url}, timeout=5.0).raise_for_status() + requests.post(f"{router_env.url}/add_worker", params={"url": worker_url}, timeout=5.0).raise_for_status() + + assert len(router_env.router.worker_request_counts) == 1 + assert worker_url in router_env.router.worker_request_counts + + def test_add_worker_missing_url(self, router_env: RouterEnv): + r = requests.post(f"{router_env.url}/add_worker", json={}, timeout=5.0) + assert r.status_code == 400 + assert "error" in r.json() + + def test_list_workers(self, router_env: RouterEnv): + worker_urls = ["http://127.0.0.1:30001", "http://127.0.0.1:30002"] + for url in worker_urls: + requests.post(f"{router_env.url}/add_worker", params={"url": url}, timeout=5.0) + + r = requests.get(f"{router_env.url}/list_workers", timeout=5.0) + r.raise_for_status() + assert set(r.json()["urls"]) == set(worker_urls) + + +class TestLoadBalancing: + def test_use_url_selects_min_load(self, router_factory): + router = router_factory() + router.worker_request_counts = {"http://w1:8000": 5, "http://w2:8000": 2, "http://w3:8000": 8} + + selected = router._use_url() + assert selected == "http://w2:8000" + assert router.worker_request_counts["http://w2:8000"] == 3 + + def test_use_url_excludes_dead_workers(self, router_factory): + router = router_factory() + router.worker_request_counts = {"http://w1:8000": 5, "http://w2:8000": 1, "http://w3:8000": 3} + router.dead_workers = {"http://w2:8000"} + + selected = router._use_url() + assert selected == "http://w3:8000" + assert router.worker_request_counts["http://w3:8000"] == 4 + + def test_use_url_raises_when_all_dead(self, router_factory): + router = router_factory() + router.worker_request_counts = {"http://w1:8000": 0} + router.dead_workers = {"http://w1:8000"} + + with pytest.raises(RuntimeError, match="No healthy workers"): + router._use_url() + + +# TODO: extract main body inside `_health_check_loop`, then can test that function +class TestHealthCheck: + def test_check_worker_health_success(self, router_factory, mock_worker: MockSGLangServer): + router = router_factory() + url, healthy = asyncio.run(router._check_worker_health(mock_worker.url)) + assert url == mock_worker.url + assert healthy is True + + def test_check_worker_health_failure(self, router_factory): + router = router_factory() + url, healthy = asyncio.run(router._check_worker_health("http://127.0.0.1:59999")) + assert url == "http://127.0.0.1:59999" + assert healthy is False + + +class TestProxyIntegration: + def test_proxy_forwards_request(self, router_env: RouterEnv, mock_worker: MockSGLangServer): + requests.post(f"{router_env.url}/add_worker", params={"url": mock_worker.url}, timeout=5.0).raise_for_status() + + payload = {"input_ids": [1, 2, 3], "return_logprob": True} + r = requests.post(f"{router_env.url}/generate", json=payload, timeout=10.0) + r.raise_for_status() + + assert "text" in r.json() + assert len(mock_worker.request_log) == 1 + assert mock_worker.request_log[0] == payload + + def test_proxy_multi_worker(self, router_env: RouterEnv, mock_worker_factory): + worker1, worker2 = mock_worker_factory(), mock_worker_factory() + requests.post(f"{router_env.url}/add_worker", params={"url": worker1.url}, timeout=5.0) + requests.post(f"{router_env.url}/add_worker", params={"url": worker2.url}, timeout=5.0) + + payload = {"input_ids": [1, 2, 3], "return_logprob": True} + for _ in range(4): + requests.post(f"{router_env.url}/generate", json=payload, timeout=10.0).raise_for_status() + + all_requests = worker1.request_log + worker2.request_log + assert len(all_requests) == 4 + assert all(req == payload for req in all_requests) + + def test_proxy_health_endpoint(self, router_env: RouterEnv, mock_worker: MockSGLangServer): + requests.post(f"{router_env.url}/add_worker", params={"url": mock_worker.url}, timeout=5.0) + + r = requests.get(f"{router_env.url}/health", timeout=5.0) + r.raise_for_status() + assert r.json()["status"] == "ok" diff --git a/tests/fast/router/test_sessions.py b/tests/fast/router/test_sessions.py new file mode 100644 index 000000000..5c6edafe2 --- /dev/null +++ b/tests/fast/router/test_sessions.py @@ -0,0 +1,195 @@ +from types import SimpleNamespace + +import pytest +import requests + +from miles.router.router import MilesRouter +from miles.router.sessions import SessionManager, SessionRecord +from miles.utils.http_utils import find_available_port +from miles.utils.test_utils.mock_sglang_server import ProcessResult, with_mock_server +from miles.utils.test_utils.uvicorn_thread_server import UvicornThreadServer + + +class TestSessionManager: + def test_create_session(self): + manager = SessionManager() + session_id = manager.create_session() + assert session_id is not None + assert len(session_id) == 32 + assert session_id in manager.sessions + assert manager.sessions[session_id] == [] + + def test_get_session_exists(self): + manager = SessionManager() + session_id = manager.create_session() + records = manager.get_session(session_id) + assert records == [] + + def test_get_session_not_exists(self): + manager = SessionManager() + records = manager.get_session("nonexistent") + assert records is None + + def test_delete_session_exists(self): + manager = SessionManager() + session_id = manager.create_session() + records = manager.delete_session(session_id) + assert records == [] + assert session_id not in manager.sessions + + def test_delete_session_not_exists(self): + manager = SessionManager() + with pytest.raises(AssertionError): + manager.delete_session("nonexistent") + + def test_add_record(self): + manager = SessionManager() + session_id = manager.create_session() + record = SessionRecord( + timestamp=1234567890.0, + method="POST", + path="generate", + request={"prompt": "hello"}, + response={"text": "world"}, + status_code=200, + ) + manager.add_record(session_id, record) + assert len(manager.sessions[session_id]) == 1 + assert manager.sessions[session_id][0] == record + + def test_add_record_nonexistent_session(self): + manager = SessionManager() + record = SessionRecord( + timestamp=1234567890.0, + method="POST", + path="generate", + request={}, + response={}, + status_code=200, + ) + with pytest.raises(AssertionError): + manager.add_record("nonexistent", record) + + +@pytest.fixture(scope="class") +def router_url(): + def process_fn(prompt: str) -> ProcessResult: + return ProcessResult(text=f"echo: {prompt}", finish_reason="stop") + + with with_mock_server(process_fn=process_fn) as backend: + args = SimpleNamespace( + miles_router_max_connections=10, + miles_router_timeout=30, + miles_router_middleware_paths=[], + rollout_health_check_interval=60, + miles_router_health_check_failure_threshold=3, + hf_checkpoint="Qwen/Qwen3-0.6B", + ) + router = MilesRouter(args) + + port = find_available_port(31000) + server = UvicornThreadServer(router.app, host="127.0.0.1", port=port) + server.start() + + url = f"http://127.0.0.1:{port}" + requests.post(f"{url}/add_worker", json={"url": backend.url}) + + try: + yield url + finally: + server.stop() + + +class TestSessionRoutes: + def test_create_session(self, router_url): + response = requests.post(f"{router_url}/sessions") + assert response.status_code == 200 + data = response.json() + assert "session_id" in data + assert len(data["session_id"]) == 32 + + def test_get_session(self, router_url): + session_id = requests.post(f"{router_url}/sessions").json()["session_id"] + + get_resp = requests.get(f"{router_url}/sessions/{session_id}") + assert get_resp.status_code == 200 + data = get_resp.json() + assert data["session_id"] == session_id + assert data["records"] == [] + + def test_get_session_not_found(self, router_url): + response = requests.get(f"{router_url}/sessions/nonexistent") + assert response.status_code == 404 + assert response.json()["error"] == "session not found" + + def test_get_with_records(self, router_url): + session_id = requests.post(f"{router_url}/sessions").json()["session_id"] + + requests.post( + f"{router_url}/sessions/{session_id}/generate", + json={"input_ids": [1, 2, 3], "sampling_params": {}, "return_logprob": True}, + ) + + get_resp = requests.get(f"{router_url}/sessions/{session_id}") + assert get_resp.status_code == 200 + data = get_resp.json() + assert data["session_id"] == session_id + assert len(data["records"]) == 1 + + def test_delete_session(self, router_url): + session_id = requests.post(f"{router_url}/sessions").json()["session_id"] + + delete_resp = requests.delete(f"{router_url}/sessions/{session_id}") + assert delete_resp.status_code == 204 + assert delete_resp.text == "" + + assert requests.delete(f"{router_url}/sessions/{session_id}").status_code == 404 + + def test_delete_session_not_found(self, router_url): + response = requests.delete(f"{router_url}/sessions/nonexistent") + assert response.status_code == 404 + assert response.json()["error"] == "session not found" + + +class TestSessionProxy: + def test_proxy_session_not_found(self, router_url): + response = requests.post(f"{router_url}/sessions/nonexistent/generate", json={}) + assert response.status_code == 404 + assert response.json()["error"] == "session not found" + + def test_proxy_records_request_response(self, router_url): + session_id = requests.post(f"{router_url}/sessions").json()["session_id"] + + resp = requests.post( + f"{router_url}/sessions/{session_id}/generate", + json={"input_ids": [1, 2, 3], "sampling_params": {}, "return_logprob": True}, + ) + assert resp.status_code == 200 + assert "text" in resp.json() + + get_resp = requests.get(f"{router_url}/sessions/{session_id}") + records = get_resp.json()["records"] + assert len(records) == 1 + assert records[0]["method"] == "POST" + assert records[0]["path"] == "generate" + assert records[0]["request"]["input_ids"] == [1, 2, 3] + assert "text" in records[0]["response"] + + delete_resp = requests.delete(f"{router_url}/sessions/{session_id}") + assert delete_resp.status_code == 204 + + def test_proxy_accumulates_records(self, router_url): + session_id = requests.post(f"{router_url}/sessions").json()["session_id"] + + for _ in range(3): + requests.post( + f"{router_url}/sessions/{session_id}/generate", + json={"input_ids": [1], "sampling_params": {}, "return_logprob": True}, + ) + + get_resp = requests.get(f"{router_url}/sessions/{session_id}") + records = get_resp.json()["records"] + assert len(records) == 3 + + delete_resp = requests.delete(f"{router_url}/sessions/{session_id}") + assert delete_resp.status_code == 204 diff --git a/tests/fast/utils/__init__.py b/tests/fast/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/fast/utils/test_arguments.py b/tests/fast/utils/test_arguments.py new file mode 100644 index 000000000..9bd1a620d --- /dev/null +++ b/tests/fast/utils/test_arguments.py @@ -0,0 +1,58 @@ +import argparse +import sys +from unittest.mock import patch + +import pytest + +from miles.utils.arguments import get_miles_extra_args_provider +from miles.utils.misc import function_registry + +PATH_ARGS = ["--rollout-function-path", "--custom-generate-function-path"] +REQUIRED_ARGS = ["--rollout-batch-size", "64"] + + +def make_class_with_add_arguments(): + class MyFn: + @classmethod + def add_arguments(cls, parser): + parser.add_argument("--my-custom-arg", type=int, default=42) + + return MyFn + + +def make_function_with_add_arguments(): + def my_fn(): + pass + + my_fn.add_arguments = lambda parser: parser.add_argument("--my-custom-arg", type=int, default=42) + return my_fn + + +def make_function_without_add_arguments(): + def my_fn(): + pass + + return my_fn + + +@pytest.mark.parametrize("path_arg", PATH_ARGS) +class TestAddArgumentsSupport: + + @pytest.mark.parametrize("fn_factory", [make_class_with_add_arguments, make_function_with_add_arguments]) + def test_add_arguments_is_called_and_arg_is_parsed(self, path_arg, fn_factory): + fn = fn_factory() + with function_registry.temporary("test:fn", fn), patch.object( + sys, "argv", ["test", path_arg, "test:fn", "--my-custom-arg", "100"] + REQUIRED_ARGS + ): + parser = argparse.ArgumentParser() + get_miles_extra_args_provider()(parser) + args, _ = parser.parse_known_args() + assert args.my_custom_arg == 100 + + def test_skips_function_without_add_arguments(self, path_arg): + fn = make_function_without_add_arguments() + with function_registry.temporary("test:fn", fn), patch.object( + sys, "argv", ["test", path_arg, "test:fn"] + REQUIRED_ARGS + ): + parser = argparse.ArgumentParser() + get_miles_extra_args_provider()(parser) diff --git a/tests/fast/utils/test_mask_utils.py b/tests/fast/utils/test_mask_utils.py new file mode 100644 index 000000000..f54304b96 --- /dev/null +++ b/tests/fast/utils/test_mask_utils.py @@ -0,0 +1,99 @@ +from transformers import AutoTokenizer + +from miles.utils.mask_utils import MultiTurnLossMaskGenerator + + +def test_loss_mask_qwen3_simple(model_name: str = "Qwen/Qwen3-8B"): + tokenizer = AutoTokenizer.from_pretrained(model_name) + mask_generator = MultiTurnLossMaskGenerator(tokenizer, tokenizer_type="qwen3") + messages = [ + {"role": "system", "content": "SYSTEM MESSAGE FOR TESTING ONLY"}, + {"role": "user", "content": "USER CONTENT FOR TESTING ONLY"}, + {"role": "assistant", "content": "ASSISTANT RESPONSE FOR TESTING ONLY"}, + ] + all_token_ids, all_loss_masks = mask_generator.gen_multi_turn_loss_mask_qwen3(messages) + assert len(all_token_ids) == len(all_loss_masks), f"{len(all_token_ids)} != {len(all_loss_masks)}" + selected_texts = mask_generator.get_text_from_loss_mask(all_token_ids, all_loss_masks) + assert len(selected_texts) == 1, f"Expected 1 text, got {len(selected_texts)}" + + print(f"==== Single Turn Test {model_name} ====") + print("text = ", [tokenizer.decode(all_token_ids)]) + print("token_ids = ", all_token_ids) + print("loss_mask = ", all_loss_masks) + print("selected_texts = ", selected_texts) + + +def test_loss_mask_qwen3_tools(model_name: str = "Qwen/Qwen3-8B"): + tokenizer = AutoTokenizer.from_pretrained(model_name) + mask_generator = MultiTurnLossMaskGenerator(tokenizer, tokenizer_type="qwen3") + messages = [ + {"role": "system", "content": "SYSTEM MESSAGE FOR TESTING ONLY"}, + {"role": "user", "content": "USER CONTENT FOR TESTING ONLY"}, + { + "role": "assistant", + "content": "I WILL CALL terminal", + "tool_calls": [ + {"function": {"name": "terminal", "arguments": {"command": "ls"}}, "id": "call_0", "type": "function"}, + {"function": {"name": "terminal", "arguments": {"command": "ls"}}, "id": "call_0", "type": "function"}, + ], + }, + {"role": "tool", "name": "terminal", "content": "LICENSE README.md README_zh.md"}, + {"role": "tool", "name": "terminal", "content": "LICENSE README.md README_zh.md"}, + {"role": "assistant", "content": "ASSISTANT RESPONSE FOR TESTING ONLY"}, + ] + tools = [ + { + "type": "function", + "function": { + "name": "terminal", + "description": "Perform operations from the terminal.", + "parameters": { + "type": "object", + "properties": { + "command": { + "type": "string", + "description": "The bash command to execute as `bash -c `", + }, + "description": { + "type": "string", + "description": "Brief description of the command for the user.", + }, + }, + "required": ["command"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "read_file", + "description": "Read the content of a file given its path.", + "parameters": { + "type": "object", + "properties": { + "file_path": { + "type": "string", + "description": "The absolute path to the file to be read.", + } + }, + "required": ["file_path"], + }, + }, + }, + ] + + all_token_ids, all_loss_masks = mask_generator.gen_multi_turn_loss_mask_qwen3(messages, tools) + assert len(all_token_ids) == len(all_loss_masks), f"{len(all_token_ids)} != {len(all_loss_masks)}" + selected_texts = mask_generator.get_text_from_loss_mask(all_token_ids, all_loss_masks) + assert len(selected_texts) == 2, f"Expected 2 texts, got {len(selected_texts)}" + + print(f"==== Multi-turn with Tools Test {model_name} ====") + print("text = ", [tokenizer.decode(all_token_ids)]) + print("token_ids = ", all_token_ids) + print("loss_mask = ", all_loss_masks) + print("selected_texts = ", selected_texts) + + +if __name__ == "__main__": + test_loss_mask_qwen3_simple("Qwen/Qwen3-Coder-30B-A3B-Instruct") + test_loss_mask_qwen3_tools("Qwen/Qwen3-Coder-30B-A3B-Instruct") diff --git a/tests/fast/utils/test_misc.py b/tests/fast/utils/test_misc.py new file mode 100644 index 000000000..810c2b67c --- /dev/null +++ b/tests/fast/utils/test_misc.py @@ -0,0 +1,59 @@ +import os + +import pytest + +from miles.utils.misc import FunctionRegistry, function_registry, load_function + + +def _fn_a(): + return "a" + + +def _fn_b(): + return "b" + + +class TestFunctionRegistry: + def test_register_and_get(self): + registry = FunctionRegistry() + with registry.temporary("my_fn", _fn_a): + assert registry.get("my_fn") is _fn_a + + def test_register_duplicate_raises(self): + registry = FunctionRegistry() + with registry.temporary("my_fn", _fn_a): + with pytest.raises(AssertionError): + with registry.temporary("my_fn", _fn_b): + pass + + def test_unregister(self): + registry = FunctionRegistry() + with registry.temporary("my_fn", _fn_a): + assert registry.get("my_fn") is _fn_a + assert registry.get("my_fn") is None + + def test_temporary_cleanup_on_exception(self): + registry = FunctionRegistry() + with pytest.raises(RuntimeError): + with registry.temporary("temp_fn", _fn_a): + raise RuntimeError("test") + assert registry.get("temp_fn") is None + + +class TestLoadFunction: + def test_load_from_module(self): + import os.path + + assert load_function("os.path.join") is os.path.join + + def test_load_none_returns_none(self): + assert load_function(None) is None + + def test_load_from_registry(self): + with function_registry.temporary("test:my_fn", _fn_a): + assert load_function("test:my_fn") is _fn_a + + def test_registry_takes_precedence(self): + with function_registry.temporary("os.path.join", _fn_b): + assert load_function("os.path.join") is _fn_b + assert load_function("os.path.join") is os.path.join diff --git a/tests/fast/utils/test_utils/__init__.py b/tests/fast/utils/test_utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/fast/utils/test_utils/test_mock_sglang_server.py b/tests/fast/utils/test_utils/test_mock_sglang_server.py new file mode 100644 index 000000000..6633678da --- /dev/null +++ b/tests/fast/utils/test_utils/test_mock_sglang_server.py @@ -0,0 +1,409 @@ +import asyncio +import concurrent.futures +import time + +import pytest +import requests + +from miles.utils.test_utils.mock_sglang_server import ( + Counter, + ProcessResult, + ProcessResultMetaInfo, + default_process_fn, + with_mock_server, +) +from miles.utils.test_utils.mock_tools import SAMPLE_TOOLS, TwoTurnStub + + +def expected_logprobs(tokenizer, text: str) -> list[dict]: + output_ids = tokenizer.encode(text, add_special_tokens=False) + return [{"token": tokenizer.convert_ids_to_tokens(tid), "logprob": -i / 128} for i, tid in enumerate(output_ids)] + + +@pytest.fixture(scope="module") +def mock_server(): + with with_mock_server() as server: + yield server + + +class TestProcessResultMetaInfo: + def test_to_dict_empty(self): + assert ProcessResultMetaInfo().to_dict() == {} + + def test_to_dict_single_field(self): + assert ProcessResultMetaInfo(weight_version="v1").to_dict() == {"weight_version": "v1"} + + def test_to_dict_partial_fields(self): + assert ProcessResultMetaInfo(weight_version="v1", spec_accept_token_num=10).to_dict() == { + "weight_version": "v1", + "spec_accept_token_num": 10, + } + + def test_to_dict_all_fields(self): + assert ProcessResultMetaInfo( + weight_version="v1", + routed_experts="abc", + spec_accept_token_num=10, + spec_draft_token_num=15, + spec_verify_ct=3, + ).to_dict() == { + "weight_version": "v1", + "routed_experts": "abc", + "spec_accept_token_num": 10, + "spec_draft_token_num": 15, + "spec_verify_ct": 3, + } + + +class TestDefaultProcessFn: + def test_math_question(self): + assert default_process_fn("What is 1+5?") == ProcessResult(text="\\boxed{6}", finish_reason="stop") + assert default_process_fn("What is 1+10?") == ProcessResult(text="\\boxed{11}", finish_reason="stop") + + def test_unknown_question(self): + assert default_process_fn("Hello") == ProcessResult(text="I don't understand.", finish_reason="stop") + + +class TestCounter: + def test_tracks_max(self): + counter = Counter() + assert counter.max_value == 0 + + with counter.track(): + assert counter.max_value == 1 + with counter.track(): + assert counter.max_value == 2 + + counter.reset() + assert counter.max_value == 0 + + def test_concurrent_tasks(self): + counter = Counter() + + async def task(): + with counter.track(): + await asyncio.sleep(0.1) + + async def run_all(): + await asyncio.gather(task(), task(), task()) + + asyncio.run(run_all()) + assert counter.max_value == 3 + + +class TestMockServerBasic: + def test_start_stop(self, mock_server): + assert mock_server.port > 0 + assert f"http://{mock_server.host}:{mock_server.port}" == mock_server.url + + def test_request_log_and_reset_stats(self, mock_server): + mock_server.reset_stats() + assert len(mock_server.request_log) == 0 + + payload = {"input_ids": [1, 2, 3], "sampling_params": {"temperature": 0.5}, "return_logprob": True} + requests.post(f"{mock_server.url}/generate", json=payload, timeout=5.0) + assert len(mock_server.request_log) == 1 + assert mock_server.request_log[0] == payload + + mock_server.reset_stats() + assert len(mock_server.request_log) == 0 + assert mock_server.max_concurrent == 0 + + @pytest.mark.parametrize("latency,min_time,max_time", [(0.0, 0.0, 0.3), (0.5, 0.5, 1.0)]) + def test_latency(self, latency, min_time, max_time): + with with_mock_server(latency=latency) as server: + start = time.time() + requests.post(f"{server.url}/generate", json={"input_ids": [1], "sampling_params": {}}, timeout=5.0) + elapsed = time.time() - start + assert min_time <= elapsed < max_time + + def test_max_concurrent_with_latency(self): + with with_mock_server(latency=0.1) as server: + + def send_request(): + requests.post(f"{server.url}/generate", json={"input_ids": [1], "sampling_params": {}}, timeout=5.0) + + with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor: + futures = [executor.submit(send_request) for _ in range(3)] + concurrent.futures.wait(futures) + + assert server.max_concurrent == 3 + + def test_health_endpoint(self, mock_server): + response = requests.get(f"{mock_server.url}/health", timeout=5.0) + assert response.status_code == 200 + assert response.json() == {"status": "ok"} + + def test_abort_request_endpoint(self, mock_server): + response = requests.post(f"{mock_server.url}/abort_request", json={}, timeout=5.0) + assert response.status_code == 200 + assert response.json() == {"status": "ok"} + + +class TestGenerateEndpoint: + def test_basic(self, mock_server): + prompt = "What is 1+7?" + input_ids = mock_server.tokenizer.encode(prompt, add_special_tokens=False) + assert input_ids == [3838, 374, 220, 16, 10, 22, 30] + + response = requests.post( + f"{mock_server.url}/generate", + json={ + "input_ids": input_ids, + "sampling_params": {"temperature": 0.7, "max_new_tokens": 10}, + "return_logprob": True, + }, + timeout=5.0, + ) + assert response.status_code == 200 + assert response.json() == { + "text": "\\boxed{8}", + "meta_info": { + "finish_reason": {"type": "stop"}, + "prompt_tokens": len(input_ids), + "cached_tokens": 0, + "completion_tokens": 5, + "output_token_logprobs": [ + [-0.0, 59], + [-0.0078125, 79075], + [-0.015625, 90], + [-0.0234375, 23], + [-0.03125, 92], + ], + }, + } + + def test_with_meta_info(self): + def process_fn(_: str) -> ProcessResult: + return ProcessResult( + text="ok", + finish_reason="stop", + cached_tokens=5, + meta_info=ProcessResultMetaInfo( + weight_version="v2.0", + routed_experts="encoded_data", + spec_accept_token_num=10, + spec_draft_token_num=15, + spec_verify_ct=3, + ), + ) + + with with_mock_server(process_fn=process_fn) as server: + response = requests.post( + f"{server.url}/generate", + json={"input_ids": [1, 2, 3], "sampling_params": {}, "return_logprob": True}, + timeout=5.0, + ) + + assert response.json() == { + "text": "ok", + "meta_info": { + "finish_reason": {"type": "stop"}, + "prompt_tokens": 3, + "cached_tokens": 5, + "completion_tokens": 1, + "output_token_logprobs": [[-0.0, 562]], + "weight_version": "v2.0", + "routed_experts": "encoded_data", + "spec_accept_token_num": 10, + "spec_draft_token_num": 15, + "spec_verify_ct": 3, + }, + } + + def test_finish_reason_length(self): + def process_fn(_: str) -> ProcessResult: + return ProcessResult(text="truncated output", finish_reason="length") + + with with_mock_server(process_fn=process_fn) as server: + response = requests.post( + f"{server.url}/generate", + json={"input_ids": [1, 2, 3], "sampling_params": {}, "return_logprob": True}, + timeout=5.0, + ) + data = response.json() + + finish_reason = data["meta_info"]["finish_reason"] + assert finish_reason["type"] == "length" + assert finish_reason["length"] == data["meta_info"]["completion_tokens"] + + +class TestChatCompletionsEndpoint: + def test_basic(self, mock_server): + response = requests.post( + f"{mock_server.url}/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "What is 1+5?"}], + }, + timeout=5.0, + ) + assert response.status_code == 200 + data = response.json() + + assert data["id"].startswith("chatcmpl-") + assert isinstance(data["created"], int) + assert data == { + "id": data["id"], + "object": "chat.completion", + "created": data["created"], + "model": "mock-model", + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": "\\boxed{6}", "tool_calls": None}, + "logprobs": {"content": expected_logprobs(mock_server.tokenizer, "\\boxed{6}")}, + "finish_reason": "stop", + } + ], + } + + def test_with_tool_calls(self): + tool_call_response = 'Let me check for you.\n\n{"name": "get_year", "arguments": {}}\n' + + def process_fn(_: str) -> ProcessResult: + return ProcessResult(text=tool_call_response, finish_reason="stop") + + with with_mock_server(process_fn=process_fn) as server: + response = requests.post( + f"{server.url}/v1/chat/completions", + json={ + "model": "test", + "messages": [{"role": "user", "content": "What year is it?"}], + "tools": SAMPLE_TOOLS, + }, + timeout=5.0, + ) + data = response.json() + + assert data["choices"][0] == { + "index": 0, + "message": { + "role": "assistant", + "content": "Let me check for you.", + "tool_calls": [ + {"id": "call00000", "type": "function", "function": {"name": "get_year", "arguments": "{}"}} + ], + }, + "logprobs": {"content": expected_logprobs(server.tokenizer, tool_call_response)}, + "finish_reason": "tool_calls", + } + + def test_with_tools_but_no_tool_call(self): + response_text = "The weather is sunny today." + + def process_fn(_: str) -> ProcessResult: + return ProcessResult(text=response_text, finish_reason="stop") + + with with_mock_server(process_fn=process_fn) as server: + response = requests.post( + f"{server.url}/v1/chat/completions", + json={ + "model": "test", + "messages": [{"role": "user", "content": "What's the weather?"}], + "tools": SAMPLE_TOOLS, + }, + timeout=5.0, + ) + data = response.json() + + assert data["choices"][0] == { + "index": 0, + "message": {"role": "assistant", "content": response_text, "tool_calls": None}, + "logprobs": {"content": expected_logprobs(server.tokenizer, response_text)}, + "finish_reason": "stop", + } + + def test_with_multiple_tool_calls(self): + multi_tool_response = ( + "I will get year and temperature.\n" + '\n{"name": "get_year", "arguments": {}}\n\n' + '\n{"name": "get_temperature", "arguments": {"location": "Shanghai"}}\n' + ) + + def process_fn(_: str) -> ProcessResult: + return ProcessResult(text=multi_tool_response, finish_reason="stop") + + with with_mock_server(process_fn=process_fn) as server: + response = requests.post( + f"{server.url}/v1/chat/completions", + json={ + "model": "test", + "messages": [{"role": "user", "content": "What year and temperature?"}], + "tools": SAMPLE_TOOLS, + }, + timeout=5.0, + ) + data = response.json() + + assert data["choices"][0] == { + "index": 0, + "message": { + "role": "assistant", + "content": "I will get year and temperature.", + "tool_calls": [ + {"id": "call00000", "type": "function", "function": {"name": "get_year", "arguments": "{}"}}, + { + "id": "call00001", + "type": "function", + "function": {"name": "get_temperature", "arguments": '{"location": "Shanghai"}'}, + }, + ], + }, + "logprobs": {"content": expected_logprobs(server.tokenizer, multi_tool_response)}, + "finish_reason": "tool_calls", + } + + +class TestMultiTurnToolCallProcessFn: + @pytest.mark.parametrize( + "prompt,expected_response", + [ + pytest.param(TwoTurnStub.FIRST_PROMPT, TwoTurnStub.FIRST_RESPONSE, id="first_turn"), + pytest.param(TwoTurnStub.SECOND_PROMPT, TwoTurnStub.SECOND_RESPONSE, id="second_turn"), + ], + ) + def test_generate_endpoint(self, prompt, expected_response): + with with_mock_server(process_fn=TwoTurnStub.process_fn) as server: + input_ids = server.tokenizer.encode(prompt, add_special_tokens=False) + response = requests.post( + f"{server.url}/generate", + json={"input_ids": input_ids, "sampling_params": {}, "return_logprob": True}, + timeout=5.0, + ) + assert response.status_code == 200 + data = response.json() + assert data["text"] == expected_response + assert data["meta_info"]["finish_reason"] == {"type": "stop"} + + @pytest.mark.parametrize( + "messages,expected_content,expected_tool_calls,expected_finish_reason", + [ + pytest.param( + TwoTurnStub.OPENAI_MESSAGES_FIRST_TURN, + TwoTurnStub.FIRST_RESPONSE_CONTENT, + TwoTurnStub.FIRST_TOOL_CALLS_OPENAI_FORMAT, + "tool_calls", + id="first_turn", + ), + pytest.param( + TwoTurnStub.OPENAI_MESSAGES_SECOND_TURN_FROM_CLIENT, + TwoTurnStub.SECOND_RESPONSE, + None, + "stop", + id="second_turn", + ), + ], + ) + def test_chat_completions_endpoint(self, messages, expected_content, expected_tool_calls, expected_finish_reason): + with with_mock_server(process_fn=TwoTurnStub.process_fn) as server: + response = requests.post( + f"{server.url}/v1/chat/completions", + json={"model": "test", "messages": messages, "tools": SAMPLE_TOOLS}, + timeout=5.0, + ) + assert response.status_code == 200 + data = response.json() + assert data["choices"][0]["message"]["content"] == expected_content + assert data["choices"][0]["message"]["tool_calls"] == expected_tool_calls + assert data["choices"][0]["finish_reason"] == expected_finish_reason diff --git a/tests/fast/utils/test_utils/test_mock_tools.py b/tests/fast/utils/test_utils/test_mock_tools.py new file mode 100644 index 000000000..3f2116ec0 --- /dev/null +++ b/tests/fast/utils/test_utils/test_mock_tools.py @@ -0,0 +1,111 @@ +import asyncio + +import pytest +from pydantic import TypeAdapter +from sglang.srt.entrypoints.openai.protocol import Tool +from sglang.srt.function_call.core_types import ToolCallItem +from sglang.srt.function_call.function_call_parser import FunctionCallParser + +from miles.utils.test_utils.mock_tools import SAMPLE_TOOLS, TwoTurnStub, execute_tool_call + + +class TestExecuteToolCall: + def test_execute_get_year(self): + result = asyncio.run(execute_tool_call("get_year", {})) + assert result == '{"year": 2026}' + + def test_execute_get_temperature(self): + result = asyncio.run(execute_tool_call("get_temperature", {"location": "Mars"})) + assert result == '{"temperature": -60}' + + +class TestApplyChatTemplateWithTools: + EXPECTED_PROMPT_WITHOUT_TOOLS = ( + "<|im_start|>user\n" "What's the weather in Paris?<|im_end|>\n" "<|im_start|>assistant\n" + ) + + EXPECTED_PROMPT_WITH_TOOLS = ( + "<|im_start|>system\n" + "# Tools\n\n" + "You may call one or more functions to assist with the user query.\n\n" + "You are provided with function signatures within XML tags:\n" + "\n" + '{"type": "function", "function": {"name": "get_year", "description": "Get current year", "parameters": {"type": "object", "properties": {}, "required": []}}}\n' + '{"type": "function", "function": {"name": "get_temperature", "description": "Get temperature for a location", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}}}\n' + "\n\n" + "For each function call, return a json object with function name and arguments within XML tags:\n" + "\n" + '{"name": , "arguments": }\n' + "<|im_end|>\n" + "<|im_start|>user\n" + "What's the weather in Paris?<|im_end|>\n" + "<|im_start|>assistant\n" + ) + + @pytest.mark.parametrize( + "tools,expected", + [ + pytest.param(None, EXPECTED_PROMPT_WITHOUT_TOOLS, id="without_tools"), + pytest.param(SAMPLE_TOOLS, EXPECTED_PROMPT_WITH_TOOLS, id="with_tools"), + ], + ) + def test_apply_chat_template(self, tools, expected): + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B", trust_remote_code=True) + messages = [{"role": "user", "content": "What's the weather in Paris?"}] + + prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, tools=tools) + + assert prompt == expected + + +class TestSGLangFunctionCallParser: + """Test to demonstrate and ensure SGLang function call parser have features we need without breaking changes.""" + + @pytest.mark.parametrize( + "model_output,expected", + [ + pytest.param( + 'Let me check for you.\n\n{"name": "get_year", "arguments": {}}\n', + ( + "Let me check for you.", + [ToolCallItem(tool_index=-1, name="get_year", parameters="{}")], + ), + id="single_tool_call", + ), + pytest.param( + "I will get year and temperature.\n" + '\n{"name": "get_year", "arguments": {}}\n\n' + '\n{"name": "get_temperature", "arguments": {"location": "Shanghai"}}\n', + ( + "I will get year and temperature.", + [ + ToolCallItem(tool_index=-1, name="get_year", parameters="{}"), + ToolCallItem(tool_index=-1, name="get_temperature", parameters='{"location": "Shanghai"}'), + ], + ), + id="multi_tool_calls", + ), + pytest.param( + "The weather is sunny today.", + ("The weather is sunny today.", []), + id="no_tool_call", + ), + pytest.param( + TwoTurnStub.FIRST_RESPONSE, + ( + "Let me get the year and temperature first.", + [ + ToolCallItem(tool_index=-1, name="get_year", parameters="{}"), + ToolCallItem(tool_index=-1, name="get_temperature", parameters='{"location": "Mars"}'), + ], + ), + id="multi_turn_first_response", + ), + ], + ) + def test_parse_non_stream(self, model_output, expected): + tools = TypeAdapter(list[Tool]).validate_python(SAMPLE_TOOLS) + parser = FunctionCallParser(tools=tools, tool_call_parser="qwen25") + assert parser.parse_non_stream(model_output) == expected diff --git a/tests/test_chunked_gae.py b/tests/test_chunked_gae.py new file mode 100644 index 000000000..6640df8c2 --- /dev/null +++ b/tests/test_chunked_gae.py @@ -0,0 +1,63 @@ +import time +import pytest +import torch + +from miles.utils.ppo_utils import chunked_gae, vanilla_gae + + +@pytest.mark.parametrize( + "B,T", + [ + (16, 4096), + (32, 8192), + (256, 128 * 1024), + ], +) +@pytest.mark.parametrize("chunk_size", [64, 128, 256]) +def test_gae_parallel_matches_serial(B, T, chunk_size): + """ + Test that chunked_gae (parallel-scan) matches vanilla_gae (batch-serial) + under various shapes, chunk sizes and dtypes. + """ + device = "cuda" if torch.cuda.is_available() else "cpu" + torch.manual_seed(0) + + rewards = torch.randn(B, T, device=device, dtype=torch.float32) + values = torch.randn(B, T, device=device, dtype=torch.float32) + + gamma, lam = 0.99, 0.95 + + # ---------- Serial ---------- + if device == "cuda": + torch.cuda.synchronize() + t0 = time.time() + adv_s, ret_s = vanilla_gae(rewards, values, gamma, lam) + if device == "cuda": + torch.cuda.synchronize() + t1 = time.time() + serial_time = t1 - t0 + + # ---------- Parallel-scan ---------- + if device == "cuda": + torch.cuda.synchronize() + t0 = time.time() + adv_p, ret_p = chunked_gae(rewards, values, gamma, lam, chunk_size=chunk_size) + if device == "cuda": + torch.cuda.synchronize() + t1 = time.time() + parallel_time = t1 - t0 + + # ---------- Accuracy ---------- + adv_err = (adv_s - adv_p).abs().max().item() + ret_err = (ret_s - ret_p).abs().max().item() + + atol = 1e-5 + assert adv_err < atol, f"adv error too large: {adv_err}" + assert ret_err < atol, f"ret error too large: {ret_err}" + + # ---------- logging ---------- + print(f"\n[GAE Test] B={B}, T={T}, chunk={chunk_size}") + print(f" Serial : {serial_time:.6f} s") + print(f" Parallel : {parallel_time:.6f} s") + print(f" Speedup : x{serial_time / parallel_time:.2f}") + print(f" Max diff adv={adv_err:.3e}, ret={ret_err:.3e}") diff --git a/tests/test_external_rollout.py b/tests/test_external_rollout.py index c5c0838c5..9b6e69c29 100644 --- a/tests/test_external_rollout.py +++ b/tests/test_external_rollout.py @@ -126,6 +126,7 @@ def execute(): num_gpus_per_node=NUM_GPUS, megatron_model_type=MODEL_TYPE, before_ray_job_submit=_launch_background, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) diff --git a/tests/test_fsdp_import.py b/tests/test_fsdp_import.py new file mode 100644 index 000000000..66b6861ed --- /dev/null +++ b/tests/test_fsdp_import.py @@ -0,0 +1,9 @@ +import pytest + + +def test_fsdp_import(): + try: + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + except ImportError: + pytest.skip("FSDP not available in this environment") + assert FSDP is not None diff --git a/tests/test_fused_experts_backward.py b/tests/test_fused_experts_backward.py new file mode 100644 index 000000000..e2a94897b --- /dev/null +++ b/tests/test_fused_experts_backward.py @@ -0,0 +1,593 @@ +""" +Test script to compare Triton backward implementation with Python backward implementation. + +This test compares: +1. Triton implementation (from fused_experts.py) - uses invoke_fused_moe_backward_kernel +2. Python reference implementation (defined in this file) - uses pure PyTorch operations +""" + +import pytest +import torch + +# ============================================================================ +# Python Reference Implementation (Pure PyTorch) +# ============================================================================ + + +class GateUpProjFunctionPython(torch.autograd.Function): + @staticmethod + def forward( + ctx, + hidden_states: torch.Tensor, + w1: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + ): + num_tokens, D_in = hidden_states.shape + E, N, K = w1.shape + assert D_in == K, f"hidden_states dim {D_in} != w1 dim {K}" + + topk = topk_ids.shape[1] + + # Output: (num_tokens * topk, N) + intermediate_cache1 = torch.empty( + (num_tokens * topk, N), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + + # Python implementation: iterate over tokens and their topk experts + # For each token t and expert k: + # intermediate_cache1[t*topk + k] = hidden_states[t] @ w1[expert_id].T + for t in range(num_tokens): + for k in range(topk): + expert_id = topk_ids[t, k].item() + x_t = hidden_states[t] # shape: (D_in,) + W1_e = w1[expert_id] # shape: (N, K) + intermediate_cache1[t * topk + k] = x_t @ W1_e.T + + ctx.save_for_backward(hidden_states, w1, topk_weights, topk_ids) + ctx.num_tokens = num_tokens + ctx.topk = topk + + return intermediate_cache1 + + @staticmethod + def backward(ctx, grad_output): + """ + Backward pass for GateUpProjFunction - Pure Python implementation. + + Forward: output = input @ w1 (without topk_weight multiplication) + Backward: + - grad_hidden_states = grad_output @ w1 + - grad_w1 = grad_output.T @ input (note: transposed) + - grad_topk_weights = zeros (not needed in this stage) + + Args: + grad_output: shape (num_tokens * topk, N) + + Returns: + (grad_hidden_states, grad_w1, grad_topk_weights, None) + """ + hidden_states, w1, topk_weights, topk_ids = ctx.saved_tensors + topk = ctx.topk + + num_tokens, D_in = hidden_states.shape + E, N, _ = w1.shape + CHUNK_SIZE = 64 * 1024 + + # Initialize gradient tensors + grad_hidden_states = torch.zeros_like(hidden_states) + # Use float32 for grad_w1 accumulation to avoid bfloat16 precision loss + grad_w1 = torch.zeros(w1.shape, dtype=torch.float32, device=w1.device) + # GateUpProj stage doesn't compute topk_weights gradient + grad_topk_weights = torch.zeros_like(topk_weights) + + # Process in chunks to match forward pass + for chunk in range((num_tokens // CHUNK_SIZE) + 1): + begin_chunk_idx, end_chunk_idx = ( + chunk * CHUNK_SIZE, + min((chunk + 1) * CHUNK_SIZE, num_tokens), + ) + + curr_num_tokens = end_chunk_idx - begin_chunk_idx + if curr_num_tokens == 0: + continue + + curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx] + curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] + curr_grad_output = grad_output[begin_chunk_idx * topk : end_chunk_idx * topk] + + # 1. Calculate grad_hidden_states: grad_output @ w1 + # For each token t and expert k: + # grad_hidden_states[t] += grad_output[t*topk+k] @ w1[expert_id] + for t in range(curr_num_tokens): + for k in range(topk): + expert_id = curr_topk_ids[t, k].item() + grad_y_tk = curr_grad_output[t * topk + k] # shape: (N,) + W1_e = w1[expert_id] # shape: (N, D_in) + # grad_x: (N,) @ (N, D_in) -> (D_in,) + grad_hidden_states[begin_chunk_idx + t] += grad_y_tk @ W1_e + + # 2. Calculate grad_w1: input.T @ grad_output + # For each token t and expert k: + # grad_w1[expert_id] += input[t].T @ grad_output[t*topk+k] + # Which is: grad_w1[expert_id] += outer(grad_output[t*topk+k], input[t]) + for t in range(curr_num_tokens): + for k in range(topk): + expert_id = curr_topk_ids[t, k].item() + x_t = curr_hidden_states[t] # shape: (D_in,) + grad_y_tk = curr_grad_output[t * topk + k] # shape: (N,) + # grad_W1: outer(grad_y_tk, x_t) -> (N, D_in) + # Accumulate in float32 + grad_w1[expert_id] += torch.outer(grad_y_tk, x_t).to(torch.float32) + + # Convert grad_w1 back to original dtype (bfloat16) + grad_w1 = grad_w1.to(hidden_states.dtype) + + return grad_hidden_states, grad_w1, grad_topk_weights, None + + +class DownProjFunctionPython(torch.autograd.Function): + @staticmethod + def forward( + ctx, + intermediate_cache2: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + ): + total_tokens, intermediate_size = intermediate_cache2.shape + topk = topk_ids.shape[1] + num_tokens = total_tokens // topk + E, hidden_size, K = w2.shape + assert intermediate_size == K, f"intermediate_cache2 dim {intermediate_size} != w2 dim {K}" + + # Output: (num_tokens, topk, hidden_size) + intermediate_cache3 = torch.empty( + (num_tokens, topk, hidden_size), + device=intermediate_cache2.device, + dtype=intermediate_cache2.dtype, + ) + + # Python implementation: iterate over tokens and their topk experts + # For each token t and expert k: + # intermediate_cache3[t, k] = topk_weights[t, k] * (intermediate_cache2[t*topk+k] @ w2[expert_id].T) + for t in range(num_tokens): + for k in range(topk): + expert_id = topk_ids[t, k].item() + x_tk = intermediate_cache2[t * topk + k] # shape: (intermediate_size,) + W2_e = w2[expert_id] # shape: (hidden_size, intermediate_size) + weight_tk = topk_weights[t, k] # scalar + + intermediate_cache3[t, k] = weight_tk * (x_tk @ W2_e.T) + + ctx.save_for_backward(intermediate_cache2, w2, topk_weights, topk_ids) + ctx.num_tokens = num_tokens + ctx.topk = topk + + return intermediate_cache3 + + @staticmethod + def backward(ctx, grad_output): + """ + Backward pass for DownProjFunction - Pure Python implementation. + + Forward: output = topk_weights * (input @ w2) (with topk_weight multiplication) + Backward: + - grad_intermediate_cache2 = topk_weights * (grad_output @ w2) + - grad_w2 = topk_weights * (grad_output.T @ intermediate_cache2) + - grad_topk_weights = dot(grad_output, forward_output_before_weighting) + + Args: + grad_output: shape (num_tokens, topk, hidden_size) + + Returns: + (grad_intermediate_cache2, grad_w2, grad_topk_weights, None) + """ + intermediate_cache2, w2, topk_weights, topk_ids = ctx.saved_tensors + num_tokens = ctx.num_tokens + topk = ctx.topk + + E, hidden_size, intermediate_size = w2.shape + CHUNK_SIZE = 64 * 1024 + + # Initialize gradient tensors + grad_intermediate_cache2 = torch.zeros_like(intermediate_cache2) + # Use float32 for grad_w2 accumulation to avoid bfloat16 precision loss + grad_w2 = torch.zeros(w2.shape, dtype=torch.float32, device=w2.device) + # Compute grad_topk_weights in DownProjFunction backward + grad_topk_weights = torch.zeros_like(topk_weights) + + # Process in chunks to match forward pass + for chunk in range((num_tokens // CHUNK_SIZE) + 1): + begin_chunk_idx, end_chunk_idx = ( + chunk * CHUNK_SIZE, + min((chunk + 1) * CHUNK_SIZE, num_tokens), + ) + + curr_num_tokens = end_chunk_idx - begin_chunk_idx + if curr_num_tokens == 0: + continue + + curr_intermediate_cache2 = intermediate_cache2[begin_chunk_idx * topk : end_chunk_idx * topk] + curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] + curr_grad_output = grad_output[begin_chunk_idx:end_chunk_idx] + curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] + + # 1. Calculate grad_intermediate_cache2: topk_weights * (grad_output @ w2) + for t in range(curr_num_tokens): + for k in range(topk): + expert_id = curr_topk_ids[t, k].item() + grad_y_tk = curr_grad_output[t, k] # shape: (hidden_size,) + W2_e = w2[expert_id] # shape: (hidden_size, intermediate_size) + weight_tk = curr_topk_weights[t, k] # scalar + + grad_intermediate_cache2[(begin_chunk_idx + t) * topk + k] = weight_tk * (grad_y_tk @ W2_e) + + # 2. Calculate grad_w2: topk_weights * (grad_output.T @ intermediate_cache2) + for t in range(curr_num_tokens): + for k in range(topk): + expert_id = curr_topk_ids[t, k].item() + grad_y_tk = curr_grad_output[t, k] # shape: (hidden_size,) + x_tk = curr_intermediate_cache2[t * topk + k] # shape: (intermediate_size,) + weight_tk = curr_topk_weights[t, k] # scalar + + # Accumulate in float32 + grad_w2[expert_id] += (weight_tk * torch.outer(grad_y_tk, x_tk)).to(torch.float32) + + # 3. Calculate grad_topk_weights: dot(grad_output, forward_output_before_weighting) + for t in range(curr_num_tokens): + for k in range(topk): + expert_id = curr_topk_ids[t, k].item() + grad_y_tk = curr_grad_output[t, k] # shape: (hidden_size,) + x_tk = curr_intermediate_cache2[t * topk + k] # shape: (intermediate_size,) + W2_e = w2[expert_id] # shape: (hidden_size, intermediate_size) + + # Compute forward output before weighting + forward_output_unweighted = x_tk @ W2_e.T # shape: (hidden_size,) + + # grad_topk_weights: dot product + grad_topk_weights[begin_chunk_idx + t, k] += torch.sum(grad_y_tk * forward_output_unweighted) + + # Convert grad_w2 back to original dtype (bfloat16) + grad_w2 = grad_w2.to(intermediate_cache2.dtype) + + return grad_intermediate_cache2, grad_w2, grad_topk_weights, None + + +# ============================================================================ +# Import Triton Implementation +# ============================================================================ + +from miles.backends.fsdp_utils.kernels.fused_experts import DownProjFunction as DownProjFunctionTriton +from miles.backends.fsdp_utils.kernels.fused_experts import GateUpProjFunction as GateUpProjFunctionTriton + +# ============================================================================ +# Test Fixtures and Utilities +# ============================================================================ + + +@pytest.fixture +def setup_moe_params(): + """Setup MOE parameters for testing.""" + torch.manual_seed(42) + + # Small parameters for easier debugging + num_tokens = 64 + hidden_size = 128 + intermediate_size = 256 + num_experts = 4 + topk = 2 + + device = "cuda" if torch.cuda.is_available() else "cpu" + dtype = torch.bfloat16 + + # Create input tensors with random values for better testing + hidden_states = torch.randn(num_tokens, hidden_size, device=device, dtype=dtype) + + # Create expert weights + w1 = torch.randn(num_experts, intermediate_size * 2, hidden_size, device=device, dtype=dtype) + w2 = torch.randn(num_experts, hidden_size, intermediate_size, device=device, dtype=dtype) + + # Create router outputs + topk_weights = torch.rand(num_tokens, topk, device=device, dtype=dtype) + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) # normalize + + # Random expert selection + topk_ids = torch.stack([torch.randperm(num_experts, device=device)[:topk] for _ in range(num_tokens)], dim=0).to( + torch.int32 + ) + + return { + "hidden_states": hidden_states, + "w1": w1, + "w2": w2, + "topk_weights": topk_weights, + "topk_ids": topk_ids, + "device": device, + "dtype": dtype, + } + + +# ============================================================================ +# Test Cases +# ============================================================================ + + +class TestGateUpProjBackward: + """Test GateUpProjFunction backward pass comparison.""" + + def test_forward_consistency(self, setup_moe_params): + """Test that Triton and Python implementations produce same forward output.""" + params = setup_moe_params + + # Python implementation + out_python = GateUpProjFunctionPython.apply( + params["hidden_states"].clone(), + params["w1"].clone(), + params["topk_weights"].clone(), + params["topk_ids"].clone(), + ) + + # Triton implementation + out_triton = GateUpProjFunctionTriton.apply( + params["hidden_states"].clone(), + params["w1"].clone(), + params["topk_weights"].clone(), + params["topk_ids"].clone(), + ) + + # Check outputs are close + torch.testing.assert_close(out_python, out_triton, rtol=1, atol=1) + print("โœ“ GateUpProjFunction forward test passed") + + def test_backward_consistency(self, setup_moe_params): + """Test that Triton and Python implementations produce same gradients.""" + params = setup_moe_params + + # Prepare inputs with requires_grad + hidden_states_python = params["hidden_states"].clone().requires_grad_(True) + w1_python = params["w1"].clone().requires_grad_(True) + topk_weights_python = params["topk_weights"].clone().requires_grad_(True) + topk_ids_python = params["topk_ids"].clone() + + hidden_states_triton = params["hidden_states"].clone().requires_grad_(True) + w1_triton = params["w1"].clone().requires_grad_(True) + topk_weights_triton = params["topk_weights"].clone().requires_grad_(True) + topk_ids_triton = params["topk_ids"].clone() + + # Python implementation + out_python = GateUpProjFunctionPython.apply( + hidden_states_python, + w1_python, + topk_weights_python, + topk_ids_python, + ) + + # Triton implementation + out_triton = GateUpProjFunctionTriton.apply( + hidden_states_triton, + w1_triton, + topk_weights_triton, + topk_ids_triton, + ) + + # Create gradient for backward + grad_output = torch.randn_like(out_python) + + # Backward pass + out_python.backward(grad_output.clone()) + out_triton.backward(grad_output.clone()) + + # Check hidden_states gradients + print("\n" + "=" * 80) + print("GateUpProjFunction Backward - hidden_states gradients:") + print("=" * 80) + if hidden_states_python.grad is not None and hidden_states_triton.grad is not None: + diff = hidden_states_python.grad - hidden_states_triton.grad + max_diff = torch.max(torch.abs(diff)) + print(f"Max absolute difference: {max_diff:.6f}") + torch.testing.assert_close(hidden_states_python.grad, hidden_states_triton.grad, rtol=1, atol=1) + print("โœ“ hidden_states gradient matches") + print("=" * 80 + "\n") + + # Check w1 gradients + print("\n" + "=" * 80) + print("GateUpProjFunction Backward - w1 gradients:") + print("=" * 80) + if w1_python.grad is not None and w1_triton.grad is not None: + diff = w1_python.grad - w1_triton.grad + max_diff = torch.max(torch.abs(diff)) + print(f"Max absolute difference: {max_diff:.6f}") + torch.testing.assert_close(w1_python.grad, w1_triton.grad, rtol=1, atol=1) + print("โœ“ w1 gradient matches") + print("=" * 80 + "\n") + + print("โœ“ GateUpProjFunction backward test passed") + + +class TestDownProjBackward: + """Test DownProjFunction backward pass comparison.""" + + def test_forward_consistency(self, setup_moe_params): + """Test that Triton and Python implementations produce same forward output.""" + params = setup_moe_params + + # Create intermediate input (after SiluAndMul) + num_tokens = params["hidden_states"].shape[0] + topk = params["topk_ids"].shape[1] + intermediate_size = params["w2"].shape[2] + intermediate_cache2 = torch.randn( + num_tokens * topk, intermediate_size, device=params["device"], dtype=params["dtype"] + ) + + # Python implementation + out_python = DownProjFunctionPython.apply( + intermediate_cache2.clone(), + params["w2"].clone(), + params["topk_weights"].clone(), + params["topk_ids"].clone(), + ) + + # Triton implementation + out_triton = DownProjFunctionTriton.apply( + intermediate_cache2.clone(), + params["w2"].clone(), + params["topk_weights"].clone(), + params["topk_ids"].clone(), + ) + + # Check outputs are close + torch.testing.assert_close(out_python, out_triton, rtol=1, atol=1) + print("โœ“ DownProjFunction forward test passed") + + def test_backward_consistency(self, setup_moe_params): + """Test that Triton and Python implementations produce same gradients.""" + params = setup_moe_params + + # Create intermediate input + num_tokens = params["hidden_states"].shape[0] + topk = params["topk_ids"].shape[1] + intermediate_size = params["w2"].shape[2] + + intermediate_cache2_base = torch.randn( + num_tokens * topk, intermediate_size, device=params["device"], dtype=params["dtype"] + ) + + intermediate_cache2_python = intermediate_cache2_base.clone().requires_grad_(True) + intermediate_cache2_triton = intermediate_cache2_base.clone().requires_grad_(True) + + w2_python = params["w2"].clone().requires_grad_(True) + w2_triton = params["w2"].clone().requires_grad_(True) + + topk_weights_python = params["topk_weights"].clone().requires_grad_(True) + topk_weights_triton = params["topk_weights"].clone().requires_grad_(True) + + # Python implementation + out_python = DownProjFunctionPython.apply( + intermediate_cache2_python, + w2_python, + topk_weights_python, + params["topk_ids"], + ) + + # Triton implementation + out_triton = DownProjFunctionTriton.apply( + intermediate_cache2_triton, + w2_triton, + topk_weights_triton, + params["topk_ids"], + ) + + # Create gradient for backward + grad_output = torch.randn_like(out_python) + + # Backward pass + out_python.backward(grad_output.clone()) + out_triton.backward(grad_output.clone()) + + # Check intermediate_cache2 gradients + print("\n" + "=" * 80) + print("DownProjFunction Backward - intermediate_cache2 gradients:") + print("=" * 80) + if intermediate_cache2_python.grad is not None and intermediate_cache2_triton.grad is not None: + diff = intermediate_cache2_python.grad - intermediate_cache2_triton.grad + max_diff = torch.max(torch.abs(diff)) + print(f"Max absolute difference: {max_diff:.6f}") + torch.testing.assert_close( + intermediate_cache2_python.grad, intermediate_cache2_triton.grad, rtol=1, atol=1 + ) + print("โœ“ intermediate_cache2 gradient matches") + print("=" * 80 + "\n") + + # Check topk_weights gradients + print("\n" + "=" * 80) + print("DownProjFunction Backward - topk_weights gradients:") + print("=" * 80) + if topk_weights_python.grad is not None and topk_weights_triton.grad is not None: + diff = topk_weights_python.grad - topk_weights_triton.grad + max_diff = torch.max(torch.abs(diff)) + print(f"Max absolute difference: {max_diff:.6f}") + torch.testing.assert_close(topk_weights_python.grad, topk_weights_triton.grad, rtol=1, atol=1) + print("โœ“ topk_weights gradient matches") + print("=" * 80 + "\n") + + # Check w2 gradients + print("\n" + "=" * 80) + print("DownProjFunction Backward - w2 gradients:") + print("=" * 80) + if w2_python.grad is not None and w2_triton.grad is not None: + diff = w2_python.grad - w2_triton.grad + max_diff = torch.max(torch.abs(diff)) + print(f"Max absolute difference: {max_diff:.6f}") + torch.testing.assert_close(w2_python.grad, w2_triton.grad, rtol=1, atol=1) + print("โœ“ w2 gradient matches") + print("=" * 80 + "\n") + + print("โœ“ DownProjFunction backward test passed") + + +# ============================================================================ +# Main Test Runner +# ============================================================================ + + +def run_all_tests(): + """Run all tests.""" + print("=" * 80) + print("Running Fused Experts Backward Tests") + print("Testing: Triton Implementation vs Python Reference") + print("=" * 80) + + if not torch.cuda.is_available(): + print("WARNING: CUDA not available, skipping tests") + return + + # Setup parameters + torch.manual_seed(42) + params_dict = {} + + # Small parameters for testing + num_tokens = 64 + hidden_size = 128 + intermediate_size = 256 + num_experts = 4 + topk = 2 + + device = "cuda" + dtype = torch.bfloat16 + + # Create input tensors + params_dict["hidden_states"] = torch.randn(num_tokens, hidden_size, device=device, dtype=dtype) + params_dict["w1"] = torch.randn(num_experts, intermediate_size * 2, hidden_size, device=device, dtype=dtype) + params_dict["w2"] = torch.randn(num_experts, hidden_size, intermediate_size, device=device, dtype=dtype) + params_dict["topk_weights"] = torch.rand(num_tokens, topk, device=device, dtype=dtype) + params_dict["topk_weights"] = params_dict["topk_weights"] / params_dict["topk_weights"].sum(dim=-1, keepdim=True) + params_dict["topk_ids"] = torch.stack( + [torch.randperm(num_experts, device=device)[:topk] for _ in range(num_tokens)], dim=0 + ).to(torch.int32) + params_dict["device"] = device + params_dict["dtype"] = dtype + + print("\n" + "=" * 80) + print("Testing GateUpProjFunction Backward") + print("=" * 80) + test_gate_up = TestGateUpProjBackward() + test_gate_up.test_forward_consistency(params_dict) + test_gate_up.test_backward_consistency(params_dict) + + print("\n" + "=" * 80) + print("Testing DownProjFunction Backward") + print("=" * 80) + test_down = TestDownProjBackward() + test_down.test_forward_consistency(params_dict) + test_down.test_backward_consistency(params_dict) + + print("\n" + "=" * 80) + print("All Backward Tests Passed! โœ“") + print("=" * 80) + + +if __name__ == "__main__": + run_all_tests() diff --git a/tests/test_gspo.sh b/tests/test_gspo.sh new file mode 100644 index 000000000..6e915ca65 --- /dev/null +++ b/tests/test_gspo.sh @@ -0,0 +1,79 @@ +#!/bin/bash + +# for rerun the task +pkill -9 sglang +sleep 3 +ray stop --force +pkill -9 ray +pkill -9 python +sleep 3 +pkill -9 ray +pkill -9 python + +set -ex + +# will prevent ray from buffering stdout/stderr +export PYTHONBUFFERED=16 + +CKPT_ARGS=( + --hf-checkpoint /root/Qwen3-0.6B +) + +ROLLOUT_ARGS=( + --prompt-data /root/dapo-math-17k/dapo-math-17k.jsonl + --input-key prompt + --label-key label + --apply-chat-template + --rollout-shuffle + --rm-type deepscaler + --num-rollout 2 + --rollout-batch-size 4 + --n-samples-per-prompt 4 + --rollout-max-response-len 8192 + --rollout-temperature 0.8 + + --global-batch-size 16 +) + +GSPO_ARGS=( + --advantage-estimator gspo + #--use-kl-loss + --kl-loss-coef 0.00 + --kl-loss-type low_var_kl + --kl-coef 0.00 + --entropy-coef 0.00 + --eps-clip 3.5e-4 +) + +OPTIMIZER_ARGS=( + --optimizer adam + --lr 1e-6 + --lr-decay-style constant + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.98 +) + +SGLANG_ARGS=( + --rollout-num-gpus-per-engine 1 +) + +# launch the master node of ray in container +ray start --head --node-ip-address 127.0.0.1 --num-gpus 4 --disable-usage-stats + +ray job submit --address="http://127.0.0.1:8265" \ + --runtime-env-json='{ + "env_vars": { + "no_proxy": "localhost,127.0.0.1,0.0.0.0,${MASTER_ADDR}" + } + }' \ + -- python3 train.py \ + --actor-num-nodes 1 \ + --actor-num-gpus-per-node 4 \ + --colocate \ + --train-backend fsdp \ + ${CKPT_ARGS[@]} \ + ${ROLLOUT_ARGS[@]} \ + ${OPTIMIZER_ARGS[@]} \ + ${GSPO_ARGS[@]} \ + ${SGLANG_ARGS[@]} diff --git a/tests/test_mimo_7B_mtp_only_grad.py b/tests/test_mimo_7B_mtp_only_grad.py new file mode 100644 index 000000000..d90a2d7a7 --- /dev/null +++ b/tests/test_mimo_7B_mtp_only_grad.py @@ -0,0 +1,147 @@ +"""End-to-end test for MTP-only gradient verification. + +This test verifies that when MTP training is enabled and all outputs are truncated +(due to very short max response length), only MTP parameters receive non-zero +gradients while all other model parameters have zero gradients. + +This validates that the MTP loss computation correctly isolates gradient flow +to only the MTP layers when the main model loss is zero (due to truncation). +""" + +import os + +import miles.utils.external_utils.command_utils as U + + +MODEL_NAME = "MiMo-7B-RL" +MODEL_TYPE = "mimo-7B-rl" +NUM_GPUS = 8 + + +def prepare(): + """Download model and convert checkpoint with MTP layers.""" + U.exec_command("mkdir -p /root/models /root/datasets") + U.exec_command(f"hf download XiaomiMiMo/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") + U.hf_download_dataset("zhuzilin/dapo-math-17k") + + # Convert checkpoint with MTP layers enabled + U.convert_checkpoint( + model_name=MODEL_NAME, + megatron_model_type=MODEL_TYPE, + num_gpus_per_node=NUM_GPUS, + extra_args=" --mtp-num-layers 1", + dir_dst="/root/models", + ) + + +def execute(): + """Run training with MTP enabled and very short output length to cause truncation.""" + ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME}/ " f"--ref-load /root/models/{MODEL_NAME}_torch_dist " + + # Use very short rollout-max-response-len to ensure all outputs are truncated + # This should result in zero loss for the main model, leaving only MTP loss + rollout_args = ( + "--prompt-data /root/datasets/dapo-math-17k/dapo-math-17k.jsonl " + "--input-key prompt " + "--label-key label " + "--apply-chat-template " + "--rollout-shuffle " + "--rm-type deepscaler " + "--num-rollout 1 " + "--rollout-batch-size 4 " + "--n-samples-per-prompt 2 " + # Very short max response length to cause all outputs to be truncated + "--rollout-max-response-len 128 " + "--rollout-temperature 0.8 " + "--global-batch-size 8 " + ) + + perf_args = ( + "--tensor-model-parallel-size 2 " + "--sequence-parallel " + "--pipeline-model-parallel-size 1 " + "--context-parallel-size 1 " + "--recompute-granularity full " + "--recompute-method uniform " + "--recompute-num-layers 1 " + "--use-dynamic-batch-size " + "--max-tokens-per-gpu 4096 " + ) + + grpo_args = ( + "--advantage-estimator grpo " + "--kl-loss-coef 0.00 " + "--kl-loss-type low_var_kl " + "--kl-coef 0.00 " + "--entropy-coef 0.00 " + "--eps-clip 0.2 " + "--eps-clip-high 0.28 " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + ) + + sglang_args = ( + "--rollout-num-gpus-per-engine 2 " + "--rollout-num-gpus 8 " + "--sglang-mem-fraction-static 0.8 " + "--sglang-enable-metrics " + "--sglang-speculative-algorithm EAGLE " + "--sglang-speculative-num-steps 2 " + "--sglang-speculative-eagle-topk 1 " + "--sglang-speculative-num-draft-tokens 3 " + ) + + # Enable MTP training with loss scaling + mtp_args = "--mtp-num-layers 1 " "--enable-mtp-training " "--mtp-loss-scaling-factor 0.2 " + + ci_args = ( + "--ci-test " + "--ci-disable-kl-checker " + # MTP grad check is automatically triggered when ci_test and enable_mtp_training are both set + ) + + misc_args = ( + "--attention-dropout 0.0 " + "--hidden-dropout 0.0 " + "--accumulate-allreduce-grads-in-fp32 " + "--attention-softmax-in-fp32 " + "--attention-backend flash " + "--actor-num-nodes 1 " + "--actor-num-gpus-per-node 8 " + "--colocate " + ) + + train_args = ( + f"{ckpt_args} " + f"{rollout_args} " + f"{optimizer_args} " + f"{grpo_args} " + f"{U.get_default_wandb_args(__file__)} " + f"{perf_args} " + f"{sglang_args} " + f"{mtp_args} " + f"{ci_args} " + f"{misc_args} " + ) + + U.execute_train( + train_args=train_args, + num_gpus_per_node=NUM_GPUS, + megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, + ) + + +if __name__ == "__main__": + prepare() + # Remove proxy settings that might interfere with local operations + for key in ["http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"]: + os.environ.pop(key, None) + execute() diff --git a/tests/test_moonlight_16B_A3B.py b/tests/test_moonlight_16B_A3B.py new file mode 100644 index 000000000..c35943ec1 --- /dev/null +++ b/tests/test_moonlight_16B_A3B.py @@ -0,0 +1,124 @@ +import os +import miles.utils.external_utils.command_utils as U + +ENABLE_EVAL = bool(int(os.environ.get("MILES_TEST_ENABLE_EVAL", "1"))) +TIGHT_HOST_MEMORY = bool(int(os.environ.get("MILES_TEST_TIGHT_HOST_MEMORY", "1"))) + +MODEL_NAME = "Moonlight-16B-A3B-Instruct" +MODEL_TYPE = "moonlight" +NUM_GPUS = 8 + + +def prepare(): + U.exec_command("mkdir -p /root/models /root/datasets") + U.exec_command( + "hf download moonshotai/Moonlight-16B-A3B-Instruct --local-dir /root/models/Moonlight-16B-A3B-Instruct" + ) + U.hf_download_dataset("zhuzilin/dapo-math-17k") + U.hf_download_dataset("zhuzilin/aime-2024") + + U.convert_checkpoint(model_name=MODEL_NAME, megatron_model_type=MODEL_TYPE, num_gpus_per_node=NUM_GPUS) + + +def execute(): + ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME} " f"--ref-load /root/{MODEL_NAME}_torch_dist " + + rollout_args = ( + "--prompt-data /root/datasets/dapo-math-17k/dapo-math-17k.jsonl " + "--input-key prompt " + "--label-key label " + "--apply-chat-template " + "--rollout-shuffle " + "--rm-type math " + "--num-rollout 3 " + "--rollout-batch-size 8 " + "--n-samples-per-prompt 8 " + "--rollout-max-response-len 4096 " + "--rollout-temperature 1 " + "--global-batch-size 32 " + ) + + eval_args = ( + f"{'--eval-interval 20 ' if ENABLE_EVAL else ''}" + "--eval-prompt-data aime /root/datasets/aime-2024/aime-2024.jsonl " + "--n-samples-per-eval-prompt 1 " + "--eval-max-response-len 4096 " + "--eval-top-k 1 " + ) + + perf_args = ( + "--tensor-model-parallel-size 2 " + "--sequence-parallel " + "--pipeline-model-parallel-size 1 " + "--context-parallel-size 2 " + "--expert-model-parallel-size 8 " + "--expert-tensor-parallel-size 1 " + "--recompute-granularity full " + "--recompute-method uniform " + "--recompute-num-layers 1 " + "--use-dynamic-batch-size " + f"--max-tokens-per-gpu {2048 if TIGHT_HOST_MEMORY else 2048} " + ) + + grpo_args = ( + "--advantage-estimator gspo " + f"{'' if TIGHT_HOST_MEMORY else '--use-kl-loss '}" + "--kl-loss-coef 0.00 " + "--kl-loss-type low_var_kl " + "--kl-coef 0.00 " + "--entropy-coef 0.00 " + "--eps-clip 4e-4 " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + ) + + sglang_args = ( + "--rollout-num-gpus-per-engine 2 " "--sglang-mem-fraction-static 0.8 " "--sglang-max-running-requests 512 " + ) + + ci_args = "--ci-test " + + misc_args = ( + "--attention-dropout 0.0 " + "--hidden-dropout 0.0 " + "--accumulate-allreduce-grads-in-fp32 " + "--attention-softmax-in-fp32 " + "--attention-backend flash " + "--actor-num-nodes 1 " + "--actor-num-gpus-per-node 8 " + "--colocate " + ) + + train_args = ( + f"{ckpt_args} " + f"{rollout_args} " + f"{optimizer_args} " + f"{grpo_args} " + f"{U.get_default_wandb_args(__file__)} " + f"{perf_args} " + f"{eval_args} " + f"{sglang_args} " + f"{ci_args} " + f"{misc_args} " + ) + + U.execute_train( + train_args=train_args, + num_gpus_per_node=NUM_GPUS, + megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, + ) + + +if __name__ == "__main__": + prepare() + for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): + os.environ.pop(proxy_var, None) + execute() diff --git a/tests/test_moonlight_16B_A3B_r3.py b/tests/test_moonlight_16B_A3B_r3.py new file mode 100644 index 000000000..cdb898c19 --- /dev/null +++ b/tests/test_moonlight_16B_A3B_r3.py @@ -0,0 +1,125 @@ +import os +import miles.utils.external_utils.command_utils as U + +ENABLE_EVAL = bool(int(os.environ.get("MILES_TEST_ENABLE_EVAL", "1"))) +TIGHT_HOST_MEMORY = bool(int(os.environ.get("MILES_TEST_TIGHT_HOST_MEMORY", "1"))) + +MODEL_NAME = "Moonlight-16B-A3B-Instruct" +MODEL_TYPE = "moonlight" +NUM_GPUS = 8 + + +def prepare(): + U.exec_command("mkdir -p /root/models /root/datasets") + U.exec_command( + "hf download moonshotai/Moonlight-16B-A3B-Instruct --local-dir /root/models/Moonlight-16B-A3B-Instruct" + ) + U.hf_download_dataset("zhuzilin/dapo-math-17k") + U.hf_download_dataset("zhuzilin/aime-2024") + + U.convert_checkpoint(model_name=MODEL_NAME, megatron_model_type=MODEL_TYPE, num_gpus_per_node=NUM_GPUS) + + +def execute(): + ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME} " f"--ref-load /root/{MODEL_NAME}_torch_dist " + + rollout_args = ( + "--prompt-data /root/datasets/dapo-math-17k/dapo-math-17k.jsonl " + "--input-key prompt " + "--label-key label " + "--apply-chat-template " + "--rollout-shuffle " + "--rm-type math " + "--num-rollout 3 " + "--rollout-batch-size 8 " + "--n-samples-per-prompt 8 " + "--rollout-max-response-len 4096 " + "--rollout-temperature 1 " + "--global-batch-size 32 " + ) + + eval_args = ( + f"{'--eval-interval 20 ' if ENABLE_EVAL else ''}" + "--eval-prompt-data aime /root/datasets/aime-2024/aime-2024.jsonl " + "--n-samples-per-eval-prompt 1 " + "--eval-max-response-len 4096 " + "--eval-top-k 1 " + ) + + perf_args = ( + "--tensor-model-parallel-size 2 " + "--sequence-parallel " + "--pipeline-model-parallel-size 1 " + "--context-parallel-size 2 " + "--expert-model-parallel-size 8 " + "--expert-tensor-parallel-size 1 " + "--recompute-granularity full " + "--recompute-method uniform " + "--recompute-num-layers 1 " + "--use-dynamic-batch-size " + f"--max-tokens-per-gpu {2048 if TIGHT_HOST_MEMORY else 2048} " + ) + + grpo_args = ( + "--advantage-estimator gspo " + f"{'' if TIGHT_HOST_MEMORY else '--use-kl-loss '}" + "--kl-loss-coef 0.00 " + "--kl-loss-type low_var_kl " + "--kl-coef 0.00 " + "--entropy-coef 0.00 " + "--eps-clip 4e-4 " + "--use-rollout-routing-replay " + "--use-miles-router " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + ) + + sglang_args = ( + "--rollout-num-gpus-per-engine 2 " "--sglang-mem-fraction-static 0.8 " "--sglang-max-running-requests 512 " + ) + + ci_args = "--ci-test " + + misc_args = ( + "--attention-dropout 0.0 " + "--hidden-dropout 0.0 " + "--accumulate-allreduce-grads-in-fp32 " + "--attention-softmax-in-fp32 " + "--attention-backend flash " + "--actor-num-nodes 1 " + "--actor-num-gpus-per-node 8 " + "--colocate " + ) + + train_args = ( + f"{ckpt_args} " + f"{rollout_args} " + f"{optimizer_args} " + f"{grpo_args} " + f"{U.get_default_wandb_args(__file__)} " + f"{perf_args} " + f"{eval_args} " + f"{sglang_args} " + f"{ci_args} " + f"{misc_args} " + ) + + U.execute_train( + train_args=train_args, + num_gpus_per_node=NUM_GPUS, + megatron_model_type=MODEL_TYPE, + ) + + +if __name__ == "__main__": + prepare() + for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): + os.environ.pop(proxy_var, None) + execute() diff --git a/tests/test_quick_start_glm4_9B.py b/tests/test_quick_start_glm4_9B.py index f18888c22..ae3c383ae 100644 --- a/tests/test_quick_start_glm4_9B.py +++ b/tests/test_quick_start_glm4_9B.py @@ -1,3 +1,4 @@ +import os import miles.utils.external_utils.command_utils as U ENABLE_EVAL = U.get_bool_env_var("MILES_TEST_ENABLE_EVAL", "1") @@ -114,10 +115,13 @@ def execute(): train_args=train_args, num_gpus_per_node=NUM_GPUS, megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) if __name__ == "__main__": # TODO also use typer prepare() + for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): + os.environ.pop(proxy_var, None) execute() diff --git a/tests/test_qwen2.5_0.5B_gsm8k.py b/tests/test_qwen2.5_0.5B_gsm8k.py index 6302aadb6..4d7f034f6 100644 --- a/tests/test_qwen2.5_0.5B_gsm8k.py +++ b/tests/test_qwen2.5_0.5B_gsm8k.py @@ -1,3 +1,4 @@ +import os import miles.utils.external_utils.command_utils as U @@ -119,9 +120,12 @@ def execute(): train_args=train_args, num_gpus_per_node=NUM_GPUS, megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) if __name__ == "__main__": prepare() + for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): + os.environ.pop(proxy_var, None) execute() diff --git a/tests/test_qwen2.5_0.5B_gsm8k_async.py b/tests/test_qwen2.5_0.5B_gsm8k_async.py index 1c55ccb20..32b60f593 100644 --- a/tests/test_qwen2.5_0.5B_gsm8k_async.py +++ b/tests/test_qwen2.5_0.5B_gsm8k_async.py @@ -1,3 +1,4 @@ +import os import miles.utils.external_utils.command_utils as U FEW_GPU = U.get_bool_env_var("MILES_TEST_FEW_GPU", "1") @@ -119,9 +120,12 @@ def execute(): num_gpus_per_node=NUM_GPUS, megatron_model_type=MODEL_TYPE, train_script="train_async.py", + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) if __name__ == "__main__": prepare() + for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): + os.environ.pop(proxy_var, None) execute() diff --git a/tests/test_qwen2.5_0.5B_gsm8k_async_short.py b/tests/test_qwen2.5_0.5B_gsm8k_async_short.py new file mode 100644 index 000000000..b1954a4e8 --- /dev/null +++ b/tests/test_qwen2.5_0.5B_gsm8k_async_short.py @@ -0,0 +1,129 @@ +import os +import miles.utils.external_utils.command_utils as U + +TIGHT_DEVICE_MEMORY = U.get_bool_env_var("MILES_TEST_TIGHT_DEVICE_MEMORY", "1") + +MODEL_NAME = "Qwen2.5-0.5B-Instruct" +MODEL_TYPE = "qwen2.5-0.5B" +NUM_GPUS = 4 + + +def prepare(): + U.exec_command("mkdir -p /root/models /root/datasets") + U.exec_command(f"huggingface-cli download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") + U.hf_download_dataset("zhuzilin/gsm8k") + + +def execute(): + ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME}/ " f"--ref-load /root/models/{MODEL_NAME}/ " + + rollout_args = ( + "--prompt-data /root/datasets/gsm8k/train.parquet " + "--input-key messages " + "--label-key label " + "--apply-chat-template " + "--rollout-shuffle " + "--rm-type math " + "--num-rollout 3 " + "--rollout-batch-size 8 " + "--n-samples-per-prompt 4 " + "--rollout-max-response-len 1024 " + "--rollout-temperature 0.8 " + "--over-sampling-batch-size 16 " + "--dynamic-sampling-filter-path miles.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std " + "--global-batch-size 32 " + ) + + eval_args = ( + "--eval-interval 8 " + "--eval-prompt-data gsm8k /root/datasets/gsm8k/test.parquet " + "--n-samples-per-eval-prompt 1 " + "--eval-max-response-len 1024 " + "--eval-top-k 1 " + ) + + perf_args = ( + "--tensor-model-parallel-size 1 " + "--sequence-parallel " + "--pipeline-model-parallel-size 1 " + "--context-parallel-size 1 " + "--expert-model-parallel-size 1 " + "--expert-tensor-parallel-size 1 " + "--use-dynamic-batch-size " + "--max-tokens-per-gpu 9216 " + ) + + grpo_args = ( + "--advantage-estimator grpo " + "--use-kl-loss " + "--kl-loss-coef 0.00 " + "--kl-loss-type low_var_kl " + "--entropy-coef 0.00 " + "--eps-clip 0.2 " + "--eps-clip-high 0.28 " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + ) + + sglang_args = ( + "--rollout-num-gpus-per-engine 1 " + f"--sglang-mem-fraction-static {0.55 if TIGHT_DEVICE_MEMORY else 0.65} " + "--sglang-enable-metrics " + ) + + ci_args = "--ci-test " + + fault_tolerance_args = ( + "--use-fault-tolerance " + "--rollout-health-check-interval 5 " + "--rollout-health-check-timeout 10 " + "--rollout-health-check-first-wait 0 " + ) + + misc_args = ( + "--attention-dropout 0.0 " + "--hidden-dropout 0.0 " + "--accumulate-allreduce-grads-in-fp32 " + "--attention-softmax-in-fp32 " + "--attention-backend flash " + "--actor-num-nodes 1 " + "--actor-num-gpus-per-node 1 " + "--rollout-num-gpus 3 " + "--megatron-to-hf-mode bridge " + ) + + train_args = ( + f"{ckpt_args} " + f"{rollout_args} " + f"{optimizer_args} " + f"{grpo_args} " + f"{U.get_default_wandb_args(__file__)} " + f"{perf_args} " + f"{eval_args} " + f"{sglang_args} " + f"{ci_args} " + f"{fault_tolerance_args} " + f"{misc_args} " + ) + + U.execute_train( + train_args=train_args, + num_gpus_per_node=NUM_GPUS, + megatron_model_type=MODEL_TYPE, + train_script="train_async.py", + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, + ) + + +if __name__ == "__main__": + prepare() + for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): + os.environ.pop(proxy_var, None) + execute() diff --git a/tests/test_qwen2.5_0.5B_gsm8k_short.py b/tests/test_qwen2.5_0.5B_gsm8k_short.py new file mode 100644 index 000000000..86e21eac8 --- /dev/null +++ b/tests/test_qwen2.5_0.5B_gsm8k_short.py @@ -0,0 +1,128 @@ +import os +import miles.utils.external_utils.command_utils as U + +TIGHT_DEVICE_MEMORY = U.get_bool_env_var("MILES_TEST_TIGHT_DEVICE_MEMORY", "1") + +MODEL_NAME = "Qwen2.5-0.5B-Instruct" +MODEL_TYPE = "qwen2.5-0.5B" +NUM_GPUS = 4 + + +def prepare(): + U.exec_command("mkdir -p /root/models /root/datasets") + U.exec_command(f"huggingface-cli download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") + U.hf_download_dataset("zhuzilin/gsm8k") + + +def execute(): + ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME}/ " f"--ref-load /root/models/{MODEL_NAME}/ " + + rollout_args = ( + "--prompt-data /root/datasets/gsm8k/train.parquet " + "--input-key messages " + "--label-key label " + "--apply-chat-template " + "--rollout-shuffle " + "--rm-type math " + "--num-rollout 3 " + "--rollout-batch-size 8 " + "--n-samples-per-prompt 4 " + "--rollout-max-response-len 1024 " + "--rollout-temperature 0.8 " + "--over-sampling-batch-size 16 " + "--dynamic-sampling-filter-path miles.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std " + "--global-batch-size 32 " + ) + + eval_args = ( + "--eval-interval 20 " + "--eval-prompt-data gsm8k /root/datasets/gsm8k/test.parquet " + "--n-samples-per-eval-prompt 1 " + "--eval-max-response-len 1024 " + "--eval-top-k 1 " + ) + + perf_args = ( + "--tensor-model-parallel-size 1 " + "--sequence-parallel " + "--pipeline-model-parallel-size 1 " + "--context-parallel-size 1 " + "--expert-model-parallel-size 1 " + "--expert-tensor-parallel-size 1 " + "--use-dynamic-batch-size " + "--max-tokens-per-gpu 9216 " + ) + + grpo_args = ( + "--advantage-estimator grpo " + "--use-kl-loss " + "--kl-loss-coef 0.00 " + "--kl-loss-type low_var_kl " + "--entropy-coef 0.00 " + "--eps-clip 0.2 " + "--eps-clip-high 0.28 " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + ) + + sglang_args = ( + "--rollout-num-gpus-per-engine 1 " + f"--sglang-mem-fraction-static {0.6 if TIGHT_DEVICE_MEMORY else 0.7} " + "--sglang-enable-metrics " + ) + + ci_args = "--ci-test " + + fault_tolerance_args = ( + "--use-fault-tolerance " + "--rollout-health-check-interval 5 " + "--rollout-health-check-timeout 10 " + "--rollout-health-check-first-wait 0 " + ) + + misc_args = ( + "--attention-dropout 0.0 " + "--hidden-dropout 0.0 " + "--accumulate-allreduce-grads-in-fp32 " + "--attention-softmax-in-fp32 " + "--attention-backend flash " + "--actor-num-nodes 1 " + "--actor-num-gpus-per-node 4 " + "--colocate " + "--megatron-to-hf-mode bridge " + ) + + train_args = ( + f"{ckpt_args} " + f"{rollout_args} " + f"{optimizer_args} " + f"{grpo_args} " + f"{U.get_default_wandb_args(__file__)} " + f"{perf_args} " + f"{eval_args} " + f"{sglang_args} " + f"{ci_args} " + f"{fault_tolerance_args} " + f"{misc_args} " + ) + + U.execute_train( + train_args=train_args, + num_gpus_per_node=NUM_GPUS, + megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, + ) + + +if __name__ == "__main__": + prepare() + for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): + os.environ.pop(proxy_var, None) + execute() diff --git a/tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py b/tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py index 6967f9145..3d4768e42 100644 --- a/tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py +++ b/tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py @@ -1,3 +1,4 @@ +import os import miles.utils.external_utils.command_utils as U MODEL_NAME = "Qwen3-0.6B" @@ -92,9 +93,12 @@ def execute(): train_args=train_args, num_gpus_per_node=2, megatron_model_type=None, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) if __name__ == "__main__": prepare() + for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): + os.environ.pop(proxy_var, None) execute() diff --git a/tests/test_qwen3_0.6B_fsdp_distributed.py b/tests/test_qwen3_0.6B_fsdp_distributed.py index b3eb416b3..fcd777288 100644 --- a/tests/test_qwen3_0.6B_fsdp_distributed.py +++ b/tests/test_qwen3_0.6B_fsdp_distributed.py @@ -1,3 +1,4 @@ +import os import miles.utils.external_utils.command_utils as U MODEL_NAME = "Qwen3-0.6B" @@ -94,9 +95,12 @@ def execute(): num_gpus_per_node=2 if FEW_GPU else 4, megatron_model_type=None, train_script="train_async.py", + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) if __name__ == "__main__": prepare() + for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): + os.environ.pop(proxy_var, None) execute() diff --git a/tests/test_qwen3_0.6B_megatron_fsdp_align.py b/tests/test_qwen3_0.6B_megatron_fsdp_align.py new file mode 100644 index 000000000..b89a2f283 --- /dev/null +++ b/tests/test_qwen3_0.6B_megatron_fsdp_align.py @@ -0,0 +1,155 @@ +import os + +import miles.utils.external_utils.command_utils as U + +MODEL_NAME = "Qwen3-0.6B" +MODEL_TYPE = "qwen3-0.6B" +NUM_GPUS = 4 +CP_SIZE = 1 +MEGATRON_TP_SIZE = 1 +MEGATRON_PP_SIZE = 1 + + +def prepare(): + U.exec_command("mkdir -p /root/models /root/datasets") + U.exec_command(f"hf download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") + U.hf_download_dataset("zhuzilin/dapo-math-17k") + + U.convert_checkpoint( + model_name=MODEL_NAME, + megatron_model_type=MODEL_TYPE, + num_gpus_per_node=NUM_GPUS, + dir_dst="/root/models", + ) + + +def execute(): + ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME}/" + + rollout_args = ( + "--prompt-data /root/datasets/dapo-math-17k/dapo-math-17k.jsonl " + "--input-key prompt " + "--label-key label " + "--apply-chat-template " + "--rollout-shuffle " + "--rm-type deepscaler " + "--num-rollout 1 " + "--rollout-batch-size 8 " + "--n-samples-per-prompt 8 " + "--rollout-max-response-len 8192 " + "--rollout-temperature 1 " + "--global-batch-size 64 " + "--use-dynamic-batch-size " + "--max-tokens-per-gpu 8192 " + ) + + ppo_args = ( + "--advantage-estimator grpo " + "--kl-loss-coef 0.00 " + "--kl-loss-type k1 " + "--kl-coef 0.00 " + "--entropy-coef 0.00 " + "--eps-clip 4e-4 " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + ) + + sglang_args = ( + "--rollout-num-gpus-per-engine 1 " "--sglang-chunked-prefill-size 4096 " "--sglang-mem-fraction-static 0.75 " + ) + + ci_args = "--ci-test " + + misc_args = "--actor-num-nodes 1 " "--colocate " f"--actor-num-gpus-per-node {NUM_GPUS} " + + train_args = ( + f"{ckpt_args} " + f"{rollout_args} " + f"{optimizer_args} " + f"{ppo_args} " + f"{U.get_default_wandb_args(__file__)} " + f"{sglang_args} " + f"{ci_args} " + f"{misc_args} " + ) + + debug_data_path = "test_rollout_data_megatron_fsdp_align.pt" + grad_norm_path = "grad_norm_fsdp.pt" + + fsdp_args = ( + "--train-backend fsdp " + "--attn-implementation flash_attention_2 " + "--gradient-checkpointing " + f"--context-parallel-size {CP_SIZE} " + f"--update-weight-buffer-size {512 * 1024 * 1024} " + """--train-env-vars '{"PYTORCH_CUDA_ALLOC_CONF":"expandable_segments:True"}' """ + ) + + try: + U.execute_train( + train_args=train_args + (f"{fsdp_args}" f"--save-debug-rollout-data {debug_data_path} "), + num_gpus_per_node=NUM_GPUS, + megatron_model_type=None, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, + ) + + U.execute_train( + train_args=train_args + + ( + f"{fsdp_args}" + f"--load-debug-rollout-data {debug_data_path} " + f"--ci-save-grad-norm {grad_norm_path} " + "--debug-train-only " + ), + num_gpus_per_node=NUM_GPUS, + megatron_model_type=None, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, + ) + + U.execute_train( + train_args=train_args + + ( + f"--ref-load /root/models/{MODEL_NAME}_torch_dist " + f"--tensor-model-parallel-size {MEGATRON_TP_SIZE} " + "--sequence-parallel " + f"--pipeline-model-parallel-size {MEGATRON_PP_SIZE} " + f"--context-parallel-size {CP_SIZE} " + "--expert-model-parallel-size 1 " + "--expert-tensor-parallel-size 1 " + "--recompute-granularity full " + "--recompute-method uniform " + "--recompute-num-layers 1 " + "--train-memory-margin-bytes 3221225472 " + f"--load-debug-rollout-data {debug_data_path} " + f"--ci-load-grad-norm {grad_norm_path} " + "--attention-dropout 0.0 " + "--hidden-dropout 0.0 " + "--accumulate-allreduce-grads-in-fp32 " + "--attention-softmax-in-fp32 " + "--attention-backend flash " + "--debug-train-only " + ), + num_gpus_per_node=NUM_GPUS, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, + megatron_model_type=MODEL_TYPE, + ) + + finally: + if os.path.exists(grad_norm_path): + os.remove(grad_norm_path) + if os.path.exists(debug_data_path): + os.remove(debug_data_path) + + +if __name__ == "__main__": + prepare() + for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): + os.environ.pop(proxy_var, None) + execute() diff --git a/tests/test_qwen3_0.6B_parallel_check.py b/tests/test_qwen3_0.6B_parallel_check.py new file mode 100644 index 000000000..d0ad283d1 --- /dev/null +++ b/tests/test_qwen3_0.6B_parallel_check.py @@ -0,0 +1,138 @@ +import os + +import miles.utils.external_utils.command_utils as U + + +ENABLE_EVAL = bool(int(os.environ.get("MILES_TEST_ENABLE_EVAL", "1"))) +TIGHT_HOST_MEMORY = bool(int(os.environ.get("MILES_TEST_TIGHT_HOST_MEMORY", "1"))) + +MODEL_NAME = "Qwen3-0.6B" +MODEL_TYPE = "qwen3-0.6B" +NUM_GPUS = 8 + + +def prepare(): + U.exec_command("mkdir -p /root/models /root/datasets") + U.exec_command(f"hf download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") + U.hf_download_dataset("zhuzilin/dapo-math-17k") + + U.convert_checkpoint( + model_name=MODEL_NAME, megatron_model_type=MODEL_TYPE, num_gpus_per_node=NUM_GPUS, dir_dst="/root/models" + ) + + +def execute(): + ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME}/ " f"--ref-load /root/models/{MODEL_NAME}_torch_dist " + + rollout_args = ( + "--prompt-data /root/datasets/dapo-math-17k/dapo-math-17k.jsonl " + "--input-key prompt " + "--label-key label " + "--apply-chat-template " + "--rollout-shuffle " + "--rm-type deepscaler " + "--num-rollout 1 " + "--rollout-batch-size 4 " + "--n-samples-per-prompt 8 " + "--rollout-max-response-len 8192 " + "--rollout-temperature 0.8 " + "--global-batch-size 32 " + ) + + ppo_args = ( + "--advantage-estimator grpo " + "--kl-loss-coef 0.00 " + "--kl-loss-type k1 " + "--kl-coef 0.00 " + "--entropy-coef 0.00 " + "--eps-clip 4e-4 " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + ) + + sglang_args = "--rollout-num-gpus-per-engine 2 " "--rollout-num-gpus 8 " "--sglang-mem-fraction-static 0.8 " + + ci_args = "--ci-test " + + misc_args = ( + # default dropout in megatron is 0.1 + "--attention-dropout 0.0 " + "--hidden-dropout 0.0 " + # should be good for model performance + "--accumulate-allreduce-grads-in-fp32 " + "--attention-softmax-in-fp32 " + # need to comment this when using model with MLA + "--attention-backend flash " + "--actor-num-nodes 1 " + "--colocate " + ) + + train_args = ( + f"{ckpt_args} " + f"{rollout_args} " + f"{optimizer_args} " + f"{ppo_args} " + f"{U.get_default_wandb_args(__file__)} " + f"{sglang_args} " + f"{ci_args} " + f"{misc_args} " + ) + + for i in range(2): + U.execute_train( + train_args=train_args + + ( + f"--save-debug-rollout-data data-{i}.pt " + f"--ci-save-grad-norm grad_norms-{i}.pt " + f"--actor-num-gpus-per-node {NUM_GPUS} " + ), + num_gpus_per_node=NUM_GPUS, + megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, + ) + # 8 GPU CPU 1 + for num_gpus in [8, 4, 2]: + remaining_gpus = num_gpus + for tp_size in [1, 2, 4, 8]: + remaining_gpus /= tp_size + for pp_size in [1, 2, 4]: + if remaining_gpus < pp_size: + continue + remaining_gpus /= pp_size + for cp_size in [1, 2, 4, 8]: + if remaining_gpus < cp_size: + continue + args = train_args + ( + f"--load-debug-rollout-data data-{i}.pt " + f"--ci-load-grad-norm grad_norms-{i}.pt " + f"--context-parallel-size {cp_size} " + f"--tensor-model-parallel-size {tp_size} " + f"--pipeline-model-parallel-size {pp_size} " + "--sequence-parallel " + f"--actor-num-gpus-per-node {num_gpus} " + "--use-dynamic-batch-size " + "--max-tokens-per-gpu 8192 " + ) + + U.execute_train( + train_args=args, + num_gpus_per_node=num_gpus, + megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, + ) + train_args += "--calculate-per-token-loss " + + +if __name__ == "__main__": + # TODO also use typer + prepare() + for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): + os.environ.pop(proxy_var, None) + execute() diff --git a/tests/test_qwen3_30B_A3B.py b/tests/test_qwen3_30B_A3B.py index 6b5f6b889..b30eeed8e 100644 --- a/tests/test_qwen3_30B_A3B.py +++ b/tests/test_qwen3_30B_A3B.py @@ -5,6 +5,8 @@ ENABLE_EVAL = bool(int(os.environ.get("MILES_TEST_ENABLE_EVAL", "1"))) TIGHT_HOST_MEMORY = bool(int(os.environ.get("MILES_TEST_TIGHT_HOST_MEMORY", "1"))) +USE_DEEPEP = bool(int(os.environ.get("MILES_TEST_USE_DEEPEP", "1"))) +USE_FP8_ROLLOUT = bool(int(os.environ.get("MILES_TEST_USE_FP8_ROLLOUT", "1"))) MODEL_NAME = "Qwen3-30B-A3B" MODEL_TYPE = "qwen3-30B-A3B" @@ -22,7 +24,10 @@ def prepare(): def execute(): - ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME}-FP8 " f"--ref-load /root/{MODEL_NAME}_torch_dist " + if USE_FP8_ROLLOUT: + ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME}-FP8 " f"--ref-load /root/{MODEL_NAME}_torch_dist " + else: + ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME} " f"--ref-load /root/{MODEL_NAME}_torch_dist " rollout_args = ( "--prompt-data /root/datasets/dapo-math-17k/dapo-math-17k.jsonl " @@ -89,13 +94,13 @@ def execute(): sglang_args = ( "--rollout-num-gpus-per-engine 8 " "--sglang-mem-fraction-static 0.8 " - "--sglang-moe-a2a-backend deepep " - "--sglang-deepep-mode auto " "--sglang-max-running-requests 512 " - "--sglang-disable-radix-cache " "--sglang-enable-metrics " ) + if USE_DEEPEP: + sglang_args += "--sglang-moe-a2a-backend deepep --sglang-deepep-mode auto " + ci_args = "--ci-test " misc_args = ( @@ -107,13 +112,16 @@ def execute(): "--attention-softmax-in-fp32 " # need to comment this when using model with MLA "--attention-backend flash " - "--moe-token-dispatcher-type flex " - "--moe-enable-deepep " "--actor-num-nodes 1 " "--actor-num-gpus-per-node 8 " "--colocate " ) + if USE_DEEPEP: + misc_args += "--moe-token-dispatcher-type flex --moe-enable-deepep " + else: + misc_args += "--moe-token-dispatcher-type alltoall " + train_args = ( f"{ckpt_args} " f"{rollout_args} " @@ -131,10 +139,13 @@ def execute(): train_args=train_args, num_gpus_per_node=NUM_GPUS, megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) if __name__ == "__main__": # TODO also use typer prepare() + for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): + os.environ.pop(proxy_var, None) execute() diff --git a/tests/test_qwen3_30B_A3B_r3.py b/tests/test_qwen3_30B_A3B_r3.py new file mode 100644 index 000000000..5a5b968aa --- /dev/null +++ b/tests/test_qwen3_30B_A3B_r3.py @@ -0,0 +1,151 @@ +import os + +import miles.utils.external_utils.command_utils as U + + +ENABLE_EVAL = bool(int(os.environ.get("MILES_TEST_ENABLE_EVAL", "1"))) +TIGHT_HOST_MEMORY = bool(int(os.environ.get("MILES_TEST_TIGHT_HOST_MEMORY", "1"))) +USE_DEEPEP = bool(int(os.environ.get("MILES_TEST_USE_DEEPEP", "1"))) +USE_FP8_ROLLOUT = bool(int(os.environ.get("MILES_TEST_USE_FP8_ROLLOUT", "1"))) + +MODEL_NAME = "Qwen3-30B-A3B" +MODEL_TYPE = "qwen3-30B-A3B" +NUM_GPUS = 8 + + +def prepare(): + U.exec_command("mkdir -p /root/models /root/datasets") + U.exec_command("hf download Qwen/Qwen3-30B-A3B --local-dir /root/models/Qwen3-30B-A3B") + U.exec_command("hf download Qwen/Qwen3-30B-A3B-FP8 --local-dir /root/models/Qwen3-30B-A3B-FP8") + U.hf_download_dataset("zhuzilin/dapo-math-17k") + U.hf_download_dataset("zhuzilin/aime-2024") + + U.convert_checkpoint(model_name=MODEL_NAME, megatron_model_type=MODEL_TYPE, num_gpus_per_node=NUM_GPUS) + + +def execute(): + if USE_FP8_ROLLOUT: + ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME}-FP8 " f"--ref-load /root/{MODEL_NAME}_torch_dist " + else: + ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME} " f"--ref-load /root/{MODEL_NAME}_torch_dist " + + rollout_args = ( + "--prompt-data /root/datasets/dapo-math-17k/dapo-math-17k.jsonl " + "--input-key prompt " + "--label-key label " + "--apply-chat-template " + "--rollout-shuffle " + "--rm-type deepscaler " + "--num-rollout 3 " + "--rollout-batch-size 8 " + "--n-samples-per-prompt 8 " + "--rollout-max-response-len 8192 " + "--rollout-temperature 1 " + "--global-batch-size 32 " + "--balance-data " + ) + + eval_args = ( + f"{'--eval-interval 20 ' if ENABLE_EVAL else ''}" + "--eval-prompt-data aime24 /root/datasets/aime-2024/aime-2024.jsonl " + "--n-samples-per-eval-prompt 1 " + "--eval-max-response-len 16384 " + "--eval-top-k 1 " + ) + + perf_args = ( + "--tensor-model-parallel-size 4 " + "--sequence-parallel " + "--pipeline-model-parallel-size 1 " + "--context-parallel-size 2 " + "--expert-model-parallel-size 8 " + "--expert-tensor-parallel-size 1 " + "--recompute-granularity full " + "--recompute-method uniform " + "--recompute-num-layers 1 " + "--use-dynamic-batch-size " + f"--max-tokens-per-gpu {2048 if TIGHT_HOST_MEMORY else 16384} " + ) + + grpo_args = ( + "--advantage-estimator gspo " + f"{'' if TIGHT_HOST_MEMORY else '--use-kl-loss '}" + "--kl-loss-coef 0.00 " + "--kl-loss-type low_var_kl " + "--kl-coef 0.00 " + "--entropy-coef 0.00 " + "--eps-clip 4e-4 " + "--use-tis " + "--use-rollout-routing-replay " + "--use-miles-router " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + "--optimizer-cpu-offload " + "--overlap-cpu-optimizer-d2h-h2d " + "--use-precision-aware-optimizer " + ) + + sglang_args = ( + "--rollout-num-gpus-per-engine 8 " + "--sglang-mem-fraction-static 0.8 " + "--sglang-max-running-requests 512 " + "--sglang-enable-metrics " + ) + + if USE_DEEPEP: + sglang_args += "--sglang-moe-a2a-backend deepep --sglang-deepep-mode auto " + + ci_args = "--ci-test " + + misc_args = ( + # default dropout in megatron is 0.1 + "--attention-dropout 0.0 " + "--hidden-dropout 0.0 " + # should be good for model performance + "--accumulate-allreduce-grads-in-fp32 " + "--attention-softmax-in-fp32 " + # need to comment this when using model with MLA + "--attention-backend flash " + "--actor-num-nodes 1 " + "--actor-num-gpus-per-node 8 " + "--colocate " + ) + + if USE_DEEPEP: + misc_args += "--moe-token-dispatcher-type flex --moe-enable-deepep " + else: + misc_args += "--moe-token-dispatcher-type alltoall " + + train_args = ( + f"{ckpt_args} " + f"{rollout_args} " + f"{optimizer_args} " + f"{grpo_args} " + f"{U.get_default_wandb_args(__file__)} " + f"{perf_args} " + f"{eval_args} " + f"{sglang_args} " + f"{ci_args} " + f"{misc_args} " + ) + + U.execute_train( + train_args=train_args, + num_gpus_per_node=NUM_GPUS, + megatron_model_type=MODEL_TYPE, + ) + + +if __name__ == "__main__": + # TODO also use typer + prepare() + for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): + os.environ.pop(proxy_var, None) + execute() diff --git a/tests/test_qwen3_4B_ckpt.py b/tests/test_qwen3_4B_ckpt.py new file mode 100644 index 000000000..0df4492e1 --- /dev/null +++ b/tests/test_qwen3_4B_ckpt.py @@ -0,0 +1,138 @@ +import os +from argparse import ArgumentParser + +import miles.utils.external_utils.command_utils as U + + +ENABLE_EVAL = bool(int(os.environ.get("MILES_TEST_ENABLE_EVAL", "1"))) +TIGHT_HOST_MEMORY = bool(int(os.environ.get("MILES_TEST_TIGHT_HOST_MEMORY", "1"))) + +MODEL_NAME = "Qwen3-4B" +MODEL_TYPE = "qwen3-4B" +NUM_GPUS = 8 + + +parser = ArgumentParser() +parser.add_argument("--async-save", action="store_true", help="Whether to test async save/load.") + + +def prepare(): + U.exec_command("mkdir -p /root/models /root/datasets") + U.exec_command(f"hf download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") + U.exec_command(f"rm -rf /root/models/{MODEL_NAME}_miles") + U.hf_download_dataset("zhuzilin/dapo-math-17k") + U.hf_download_dataset("zhuzilin/aime-2024") + + U.convert_checkpoint( + model_name=MODEL_NAME, megatron_model_type=MODEL_TYPE, num_gpus_per_node=NUM_GPUS, dir_dst="/root/models" + ) + + +def execute(mode: str = ""): + ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME}/ " f"--ref-load /root/models/{MODEL_NAME}_torch_dist " + if mode == "save": + ckpt_args += f"--save /root/models/{MODEL_NAME}_miles " + ckpt_args += "--save-interval 2 " + elif mode == "async_save": + ckpt_args += f"--save /root/models/{MODEL_NAME}_miles " + ckpt_args += "--save-interval 2 " + ckpt_args += "--async-save " + elif mode == "load": + ckpt_args += f"--load /root/models/{MODEL_NAME}_miles " + ckpt_args += "--ckpt-step 1 " + + rollout_args = ( + "--prompt-data /root/datasets/dapo-math-17k/dapo-math-17k.jsonl " + "--input-key prompt " + "--label-key label " + "--apply-chat-template " + "--rollout-shuffle " + "--rm-type deepscaler " + "--num-rollout 3 " + "--rollout-batch-size 4 " + "--n-samples-per-prompt 8 " + "--rollout-max-response-len 1024 " + "--rollout-temperature 0.8 " + "--global-batch-size 32 " + "--balance-data " + ) + + perf_args = ( + "--tensor-model-parallel-size 2 " + "--sequence-parallel " + "--pipeline-model-parallel-size 1 " + "--context-parallel-size 2 " + "--recompute-granularity full " + "--recompute-method uniform " + "--recompute-num-layers 1 " + "--use-dynamic-batch-size " + f"--max-tokens-per-gpu {2048 if TIGHT_HOST_MEMORY else 16384} " + ) + + ppo_args = ( + "--advantage-estimator grpo " + "--kl-loss-coef 0.00 " + "--kl-loss-type k1 " + "--kl-coef 0.00 " + "--entropy-coef 0.00 " + "--eps-clip 0.2 " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + "--optimizer-cpu-offload " + "--overlap-cpu-optimizer-d2h-h2d " + "--use-precision-aware-optimizer " + ) + + sglang_args = "--rollout-num-gpus-per-engine 2 --sglang-mem-fraction-static 0.8 --sglang-cuda-graph-bs 1 2 4 8 16 " + + ci_args = "--ci-test " + + misc_args = ( + # default dropout in megatron is 0.1 + "--attention-dropout 0.0 " + "--hidden-dropout 0.0 " + # should be good for model performance + "--accumulate-allreduce-grads-in-fp32 " + "--attention-softmax-in-fp32 " + # need to comment this when using model with MLA + "--attention-backend flash " + "--actor-num-nodes 1 " + "--actor-num-gpus-per-node 8 " + "--colocate " + ) + + train_args = ( + f"{ckpt_args} " + f"{rollout_args} " + f"{optimizer_args} " + f"{ppo_args} " + f"{U.get_default_wandb_args(__file__)} " + f"{perf_args} " + f"{sglang_args} " + f"{ci_args} " + f"{misc_args} " + ) + + U.execute_train( + train_args=train_args, + num_gpus_per_node=NUM_GPUS, + megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, + ) + + +if __name__ == "__main__": + args = parser.parse_args() + # TODO also use typer + prepare() + for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): + os.environ.pop(proxy_var, None) + execute("save" if not args.async_save else "async_save") + execute("load") diff --git a/tests/test_qwen3_4B_fsdp_true_on_policy.py b/tests/test_qwen3_4B_fsdp_true_on_policy.py new file mode 100644 index 000000000..03ba4094e --- /dev/null +++ b/tests/test_qwen3_4B_fsdp_true_on_policy.py @@ -0,0 +1,113 @@ +import os +import miles.utils.external_utils.command_utils as U + +ENABLE_EVAL = bool(int(os.environ.get("MILES_TEST_ENABLE_EVAL", "1"))) +NUM_GPUS = 2 + +MODEL_NAME = "Qwen3-4B" + + +def prepare(): + U.exec_command("mkdir -p /root/models /root/datasets") + U.exec_command(f"hf download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") + U.hf_download_dataset("zhuzilin/dapo-math-17k") + U.hf_download_dataset("zhuzilin/aime-2024") + + +def execute(): + ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME} " + + rollout_args = ( + "--prompt-data /root/datasets/dapo-math-17k/dapo-math-17k.jsonl " + "--input-key prompt " + "--label-key label " + "--apply-chat-template " + "--rollout-shuffle " + "--rm-type math " + "--num-rollout 3 " + "--rollout-batch-size 8 " + "--n-samples-per-prompt 8 " + "--rollout-max-response-len 4096 " + "--rollout-temperature 1 " + "--global-batch-size 32 " + ) + + eval_args = ( + f"{'--eval-interval 20 ' if ENABLE_EVAL else ''}" + "--eval-prompt-data aime /root/datasets/aime-2024/aime-2024.jsonl " + "--n-samples-per-eval-prompt 1 " + "--eval-max-response-len 4096 " + "--eval-top-p 0.7 " + ) + + fsdp_args = "--train-backend fsdp " "--update-weight-buffer-size 536870912 " + + grpo_args = ( + "--advantage-estimator grpo " + "--kl-loss-coef 0.00 " + "--kl-loss-type low_var_kl " + "--kl-coef 0.00 " + "--entropy-coef 0.00 " + "--eps-clip 0.2 " + "--eps-clip-high 0.28 " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + ) + + sglang_args = ( + "--rollout-num-gpus-per-engine 1 " + "--sglang-decode-log-interval 1000 " + "--sglang-enable-metrics " + "--sglang-enable-deterministic-inference " + "--sglang-rl-on-policy-target fsdp " + "--sglang-attention-backend fa3 " + "--attn-implementation flash_attention_3 " + "--deterministic-mode " + "--true-on-policy-mode " + ) + + ci_args = "--ci-test " + + misc_args = "--actor-num-nodes 1 " f"--actor-num-gpus-per-node {NUM_GPUS} " "--colocate " + + train_args = ( + f"{ckpt_args} " + f"{rollout_args} " + f"{optimizer_args} " + f"{grpo_args} " + f"{U.get_default_wandb_args(__file__)} " + f"{fsdp_args} " + f"{eval_args} " + f"{sglang_args} " + f"{ci_args} " + f"{misc_args} " + ) + + extra_env_vars = { + "NCCL_ALGO": "allreduce:tree", + "NVTE_ALLOW_NONDETERMINISTIC_ALGO": "0", + "CUBLAS_WORKSPACE_CONFIG": ":4096:8", + "CUDA_DEVICE_MAX_CONNECTIONS": "1", + "MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1", + } + + U.execute_train( + train_args=train_args, + num_gpus_per_node=NUM_GPUS, + megatron_model_type=None, + extra_env_vars=extra_env_vars, + ) + + +if __name__ == "__main__": + prepare() + for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): + os.environ.pop(proxy_var, None) + execute() diff --git a/tests/test_qwen3_4B_ppo.py b/tests/test_qwen3_4B_ppo.py new file mode 100644 index 000000000..d4c1ac273 --- /dev/null +++ b/tests/test_qwen3_4B_ppo.py @@ -0,0 +1,134 @@ +import os + +import miles.utils.external_utils.command_utils as U + + +ENABLE_EVAL = bool(int(os.environ.get("MILES_TEST_ENABLE_EVAL", "1"))) +TIGHT_HOST_MEMORY = bool(int(os.environ.get("MILES_TEST_TIGHT_HOST_MEMORY", "1"))) + +MODEL_NAME = "Qwen3-4B" +MODEL_TYPE = "qwen3-4B" +NUM_GPUS = 8 + + +def prepare(): + U.exec_command("mkdir -p /root/models /root/datasets") + U.exec_command("hf download Qwen/Qwen3-4B --local-dir /root/models/Qwen3-4B") + U.hf_download_dataset("zhuzilin/dapo-math-17k") + U.hf_download_dataset("zhuzilin/aime-2024") + + U.convert_checkpoint(model_name=MODEL_NAME, megatron_model_type=MODEL_TYPE, num_gpus_per_node=NUM_GPUS) + + +def execute(): + ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME}/ " f"--ref-load /root/{MODEL_NAME}_torch_dist " + + rollout_args = ( + "--prompt-data /root/datasets/dapo-math-17k/dapo-math-17k.jsonl " + "--input-key prompt " + "--label-key label " + "--apply-chat-template " + "--rollout-shuffle " + "--rm-type deepscaler " + "--num-rollout 3 " + "--rollout-batch-size 8 " + "--n-samples-per-prompt 8 " + "--rollout-max-response-len 8192 " + "--rollout-temperature 0.8 " + "--global-batch-size 32 " + "--balance-data " + ) + + eval_args = ( + f"{'--eval-interval 20 ' if ENABLE_EVAL else ''}" + "--eval-prompt-data aime24 /root/datasets/aime-2024/aime-2024.jsonl " + "--n-samples-per-eval-prompt 1 " + "--eval-max-response-len 16384 " + "--eval-top-k 1 " + ) + + perf_args = ( + "--tensor-model-parallel-size 2 " + "--sequence-parallel " + "--pipeline-model-parallel-size 1 " + "--context-parallel-size 2 " + "--recompute-granularity full " + "--recompute-method uniform " + "--recompute-num-layers 1 " + "--use-dynamic-batch-size " + f"--max-tokens-per-gpu {2048 if TIGHT_HOST_MEMORY else 16384} " + ) + + ppo_args = ( + "--advantage-estimator ppo " + f"{'' if TIGHT_HOST_MEMORY else '--use-kl-loss '}" + "--kl-loss-coef 0.00 " + "--kl-loss-type k1 " + "--kl-coef 0.00 " + "--entropy-coef 0.00 " + "--eps-clip 4e-4 " + "--num-critic-only-steps 1 " + "--normalize-advantages " + "--critic-lr 1e-5 " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + ) + + sglang_args = ( + "--rollout-num-gpus-per-engine 2 " + "--rollout-num-gpus 8 " + "--sglang-mem-fraction-static 0.8 " + "--sglang-max-running-requests 512 " + "--sglang-enable-metrics " + ) + + ci_args = "--ci-test " + + misc_args = ( + # default dropout in megatron is 0.1 + "--attention-dropout 0.0 " + "--hidden-dropout 0.0 " + # should be good for model performance + "--accumulate-allreduce-grads-in-fp32 " + "--attention-softmax-in-fp32 " + # need to comment this when using model with MLA + "--attention-backend flash " + "--actor-num-nodes 1 " + "--actor-num-gpus-per-node 4 " + "--colocate " + ) + + train_args = ( + f"{ckpt_args} " + f"{rollout_args} " + f"{optimizer_args} " + f"{ppo_args} " + f"{U.get_default_wandb_args(__file__)} " + f"{perf_args} " + f"{eval_args} " + f"{sglang_args} " + f"{ci_args} " + f"{misc_args} " + ) + + U.execute_train( + train_args=train_args, + num_gpus_per_node=NUM_GPUS, + megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, + ) + + +if __name__ == "__main__": + # TODO also use typer + prepare() + for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): + os.environ.pop(proxy_var, None) + execute() diff --git a/tests/test_qwen3_vl_4B_fsdp.py b/tests/test_qwen3_vl_4B_fsdp.py new file mode 100644 index 000000000..bc4ef3293 --- /dev/null +++ b/tests/test_qwen3_vl_4B_fsdp.py @@ -0,0 +1,112 @@ +import os +import miles.utils.external_utils.command_utils as U + +ENABLE_EVAL = bool(int(os.environ.get("MILES_TEST_ENABLE_EVAL", "1"))) +NUM_GPUS = 8 + +MODEL_NAME = "Qwen3-VL-4B-Instruct" +DATASET_NAME = "chenhegu/geo3k_imgurl" + + +def prepare(): + U.exec_command("mkdir -p /root/models /root/datasets") + U.exec_command(f"hf download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") + U.hf_download_dataset(DATASET_NAME) + + +def execute(): + ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME} " + + rollout_args = ( + "--prompt-data /root/datasets/geo3k_imgurl/train.parquet " + "--input-key problem " + "--label-key answer " + "--apply-chat-template " + "--rollout-shuffle " + "--rm-type math " + "--num-rollout 3 " + "--rollout-batch-size 8 " + "--n-samples-per-prompt 8 " + "--rollout-max-response-len 4096 " + "--rollout-temperature 1 " + "--global-batch-size 32 " + ) + + # multimodal keys required for vlm datasets + multimodal_args = '--multimodal-keys \'{"image": "images"}\' ' + + eval_args = ( + f"{'--eval-interval 20 ' if ENABLE_EVAL else ''}" + "--eval-prompt-data geo3k /root/datasets/geo3k_imgurl/test.parquet " + "--n-samples-per-eval-prompt 1 " + "--eval-max-response-len 4096 " + ) + + fsdp_args = "--train-backend fsdp " "--gradient-checkpointing " "--update-weight-buffer-size 536870912 " + + grpo_args = ( + "--advantage-estimator grpo " + "--kl-loss-coef 0.00 " + "--kl-loss-type low_var_kl " + "--kl-coef 0.00 " + "--entropy-coef 0.00 " + "--eps-clip 0.2 " + "--eps-clip-high 0.28 " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + ) + + sglang_args = ( + "--rollout-num-gpus-per-engine 1 " + "--sglang-mem-fraction-static 0.6 " + "--sglang-decode-log-interval 1000 " + "--sglang-enable-metrics " + "--sglang-attention-backend fa3 " + "--attn-implementation flash_attention_3 " + ) + + ci_args = "--ci-test " + + misc_args = "--actor-num-nodes 1 " f"--actor-num-gpus-per-node {NUM_GPUS} " "--colocate " + + train_args = ( + f"{ckpt_args} " + f"{rollout_args} " + f"{multimodal_args} " + f"{optimizer_args} " + f"{grpo_args} " + f"{U.get_default_wandb_args(__file__)} " + f"{fsdp_args} " + f"{eval_args} " + f"{sglang_args} " + f"{ci_args} " + f"{misc_args} " + ) + + extra_env_vars = { + "CUDA_DEVICE_MAX_CONNECTIONS": "1", + "MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1", + } + + U.execute_train( + train_args=train_args, + num_gpus_per_node=NUM_GPUS, + megatron_model_type=None, + extra_env_vars=extra_env_vars, + ) + + +if __name__ == "__main__": + prepare() + os.environ.pop("http_proxy", None) + os.environ.pop("https_proxy", None) + os.environ.pop("HTTP_PROXY", None) + os.environ.pop("HTTPS_PROXY", None) + execute() diff --git a/tools/convert_fsdp_to_hf.py b/tools/convert_fsdp_to_hf.py new file mode 100644 index 000000000..142730c81 --- /dev/null +++ b/tools/convert_fsdp_to_hf.py @@ -0,0 +1,178 @@ +import argparse +import os +import pickle +import shutil +import time + +import torch +import torch.distributed.checkpoint as dist_cp +from transformers import AutoConfig, AutoModelForCausalLM +from typing_extensions import override + + +class UnpicklerWrapper(pickle.Unpickler): + @override + def find_class(self, mod_name, name): + class DummyClass: + def __init__(self, *args, **kwargs): + pass + + if mod_name.startswith("megatron") or mod_name.startswith("glm"): + return DummyClass + return super().find_class(mod_name, name) + + +class WrappedStorageReader(dist_cp.FileSystemReader): + @override + def read_metadata(self): + path = self.fs.concat_path(self.path, ".metadata") + with self.fs.create_stream(path, "rb") as metadata_file: + metadata = UnpicklerWrapper(metadata_file).load() + if getattr(metadata, "storage_meta", None) is None: + metadata.storage_meta = dist_cp.StorageMeta() + metadata.storage_meta.load_id = self.load_id + if metadata.planner_data is None: + metadata.planner_data = {} + return metadata + + +class EmptyStateDictLoadPlanner(dist_cp.default_planner.DefaultLoadPlanner): + @override + def set_up_planner( + self, + state_dict: dist_cp.metadata.STATE_DICT_TYPE, + metadata: dist_cp.metadata.Metadata | None = None, + is_coordinator: bool = False, + ) -> None: + for k, v in metadata.state_dict_metadata.items(): + if "optimizer" in k: + continue + print(f"find {k} in torch_dist ckpt") + if isinstance(v, dist_cp.metadata.TensorStorageMetadata): + v = torch.empty(v.size, dtype=v.properties.dtype) # type: ignore[assignment] + state_dict[k] = v + super().set_up_planner(state_dict, metadata, is_coordinator) + + +def _detect_model_dir(input_dir: str) -> str: + model_dir = os.path.join(input_dir, "model") + return model_dir if os.path.isdir(model_dir) else input_dir + + +def _load_fsdp_state_dict(input_dir: str) -> dict[str, torch.Tensor]: + state_dict: dict[str, torch.Tensor] = {} + dist_cp.state_dict_loader._load_state_dict( + state_dict, + storage_reader=WrappedStorageReader(input_dir), + planner=EmptyStateDictLoadPlanner(), + no_dist=True, + ) + return state_dict + + +def _get_candidate_prefixes(keys: list[str]) -> list[str]: + predefined = [ + "model_state.model.", + "model_state.", + "model.", + "module.", + "", + ] + + detected: set[str] = set() + for key in keys: + for prefix in predefined: + if prefix and key.startswith(prefix): + detected.add(prefix) + + # Always keep empty string as a fall back option for exact match. + detected.add("") + # Preserve predefined order while keeping only detected prefixes. + return [p for p in predefined if p in detected] + + +def _strip_best_prefix(keys: list[str], target_keys: set[str]) -> tuple[str, int]: + best_prefix = "" + best_match = -1 + + for prefix in _get_candidate_prefixes(keys): + mapped_keys = {k.removeprefix(prefix) for k in keys} + match_count = len(mapped_keys & target_keys) + if match_count > best_match: + best_match = match_count + best_prefix = prefix + + return best_prefix, best_match + + +def _convert_fsdp_to_hf( + origin_hf_dir: str, + input_dir: str, + output_dir: str, +) -> None: + print(f"loading FSDP model from {input_dir}") + t = time.time() + state_dict = _load_fsdp_state_dict(input_dir) + print(f"FSDP model loaded in {time.time()-t:.2f} sec.") + + tensor_items = {k: v for k, v in state_dict.items() if isinstance(v, torch.Tensor)} + + config = AutoConfig.from_pretrained(origin_hf_dir, trust_remote_code=True) + hf_model = AutoModelForCausalLM.from_config(config) + target_keys = set(hf_model.state_dict().keys()) + + best_prefix, best_match = _strip_best_prefix(list(tensor_items.keys()), target_keys) + total_keys = len(tensor_items) + + print(f"Using prefix '{best_prefix}' for key mapping. " f"Matched {best_match}/{total_keys} parameter keys.") + + model_state = {k.removeprefix(best_prefix): v for k, v in tensor_items.items()} + + if not model_state: + raise ValueError( + "No model weights found in checkpoint. " + "Please pass the checkpoint directory (e.g. iter_xxx or iter_xxx/model)." + ) + + missing, unexpected = hf_model.load_state_dict(model_state, strict=False) + print(f"Missing keys: {missing}\nUnexpected keys: {unexpected}") + + os.makedirs(output_dir, exist_ok=True) + hf_model.save_pretrained(output_dir, safe_serialization=True) + print(f"Model weights saved to {output_dir}") + + +def copy_assets(origin_hf_dir: str, output_dir: str) -> None: + for filename in os.listdir(origin_hf_dir): + if filename == "model.safetensors.index.json" or filename.endswith(".safetensors"): + continue + origin_filename = os.path.join(origin_hf_dir, filename) + if not os.path.isfile(origin_filename): + print(f"Skip {filename}, not a file.") + continue + src, dst = origin_filename, os.path.join(output_dir, filename) + print(f"copy from {src} to {dst}") + shutil.copy(src, dst) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--input-dir", type=str, required=True) + parser.add_argument("--output-dir", type=str, required=True) + parser.add_argument( + "--origin-hf-dir", + type=str, + required=True, + help="The original Hugging Face model directory to load config/tokenizer assets.", + ) + parser.add_argument( + "-f", "--force", action="store_true", help="Force overwrite the output directory if it exists." + ) + args = parser.parse_args() + + if os.path.exists(args.output_dir) and not args.force: + raise ValueError(f"Output directory {args.output_dir} already exists. Use --force to overwrite it.") + + model_dir = _detect_model_dir(args.input_dir) + _convert_fsdp_to_hf(args.origin_hf_dir, model_dir, args.output_dir) + copy_assets(args.origin_hf_dir, args.output_dir) diff --git a/tools/convert_hf_to_fp8.py b/tools/convert_hf_to_fp8.py index ee48582e2..7754e7dea 100644 --- a/tools/convert_hf_to_fp8.py +++ b/tools/convert_hf_to_fp8.py @@ -65,7 +65,7 @@ def block_fp8(weight, block_size): .to(torch.float8_e4m3fn) ) qweight = qweight[:shape_0, :shape_1].clone().detach() - scale = scale.squeeze() + scale = scale.reshape(n_tiles, k_tiles) return qweight, scale @@ -101,12 +101,15 @@ def __init__(self): self.weight_map = {} self.param_count = 0 self.modules_to_not_convert = [] + self.has_dsa_layers = False def add_result(self, filename, q_weights, module_names): with self.lock: for k, v in q_weights.items(): self.weight_map[k] = filename self.param_count += len(v) + if "indexer" in k: + self.has_dsa_layers = True self.modules_to_not_convert.extend(module_names) @@ -133,6 +136,7 @@ def process_file(input_path, output_path, filename, strategy, block_size, result and "norm" not in key and "lm_head" not in key and "eh_proj" not in key + and "weights_proj" not in key ): qw, s = quant_fp8(weights[key], strategy, block_size) q_weights[key] = qw @@ -181,6 +185,8 @@ def convert_fp8(input_path, output_path, strategy, block_size=None, max_workers= } if block_size: quantization_config["weight_block_size"] = block_size + if result_collector.has_dsa_layers: + quantization_config["scale_fmt"] = "ue8m0" if len(result_collector.modules_to_not_convert) > 0: quantization_config["modules_to_not_convert"] = list(set(result_collector.modules_to_not_convert)) else: diff --git a/tools/convert_hf_to_hf_int4.py b/tools/convert_hf_to_hf_int4.py new file mode 100644 index 000000000..ba76a987f --- /dev/null +++ b/tools/convert_hf_to_hf_int4.py @@ -0,0 +1,103 @@ +#!/usr/bin/env python +import argparse +import os +import random + +import torch +from datasets import Dataset, load_dataset +from llmcompressor import oneshot +from llmcompressor.modifiers.quantization.gptq import GPTQModifier +from transformers import AutoModelForCausalLM, AutoTokenizer + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--input-dir", type=str, required=True, help="local BF16 path") + parser.add_argument("--output-dir", type=str, required=True) + parser.add_argument("--data-dir", type=str, required=True, help="dataset path") + parser.add_argument("--quant-type", type=str, choices=["W4A16", "W8A16"], default="W4A16") + parser.add_argument("--num-calibration-samples", type=int, default=256, help="sample nums") + parser.add_argument("--max-sequence-length", type=int, default=2048) + parser.add_argument("--dampening-frac", type=float, default=0.01) + parser.add_argument("--trust-remote-code", action="store_true") + parser.add_argument("--quant-group-size", type=int, default=32, help="GPTQ Group Size") + return parser.parse_args() + + +def get_calibration_dataset(tokenizer, num_samples, seq_len, local_data_path): + + train_file = os.path.join(local_data_path, "train-00000-of-00001.parquet") + + if not os.path.exists(train_file): + print(f"can't find the localpath: {train_file}") + exit(1) + + try: + ds_raw = load_dataset("parquet", data_files={"train": train_file}, split="train") + except Exception as e: + print(f"load Parquet file failed: {e}") + exit(1) + + text_stream = "".join(ds_raw["text"]) + encoded = tokenizer(text_stream, return_tensors="pt").input_ids[0] + + data_list = [] + for _ in range(num_samples): + i = random.randint(0, encoded.shape[0] - seq_len - 1) + chunk = encoded[i : i + seq_len] + + data_list.append({"input_ids": chunk.tolist(), "attention_mask": torch.ones_like(chunk).tolist()}) + + ds_hf = Dataset.from_list(data_list) + return ds_hf + + +def main(): + args = parse_args() + + tokenizer = AutoTokenizer.from_pretrained(args.model_id, trust_remote_code=args.trust_remote_code) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + ds_hf = get_calibration_dataset( + tokenizer, args.num_calibration_samples, args.max_sequence_length, args.local_data_path + ) + + model = AutoModelForCausalLM.from_pretrained( + args.model_id, + device_map="auto", + torch_dtype=torch.bfloat16, + trust_remote_code=args.trust_remote_code, + low_cpu_mem_usage=True, + ) + + ignore_patterns = [ + "re:.*lm_head.*", + "re:.*norm.*", + "re:.*embed.*", + "re:.*self_attn.*", + "re:.*shared_experts.*", + "re:.*mlp\\.(gate|up|gate_up|down)_proj.*", + ] + + recipe = GPTQModifier( + targets="Linear", + scheme=args.quant_type, + ignore=ignore_patterns, + dampening_frac=args.dampening_frac, + block_size=32, + ) + + oneshot( + model=model, + dataset=ds_hf, # dataset + tokenizer=tokenizer, + recipe=recipe, + output_dir=args.output_dir, + max_seq_length=args.max_sequence_length, + num_calibration_samples=args.num_calibration_samples, + ) + + +if __name__ == "__main__": + main() diff --git a/tools/convert_hf_to_torch_dist.py b/tools/convert_hf_to_torch_dist.py index 10faa6824..d6fddf386 100644 --- a/tools/convert_hf_to_torch_dist.py +++ b/tools/convert_hf_to_torch_dist.py @@ -11,7 +11,7 @@ import miles_plugins.mbridge # noqa: F401 from mbridge import AutoBridge -from miles.backends.megatron_utils import set_default_megatron_args +from miles.backends.megatron_utils.arguments import set_default_megatron_args from miles.backends.megatron_utils.initialize import init from miles.backends.megatron_utils.model_provider import get_model_provider_func from miles.utils.logging_utils import configure_logger @@ -21,6 +21,12 @@ def add_convertion_args(parser): """Add conversion arguments to the parser""" parser.add_argument("--hf-checkpoint", type=str, required=True, help="HuggingFace model path") + parser.add_argument( + "--megatron-to-hf-mode", + choices=["raw", "bridge"], + default="raw", + help="The method to convert megatron weights to hugging face weights for SGLang.", + ) try: parser.add_argument("--padded-vocab-size", type=int, default=None) except Exception: diff --git a/train.py b/train.py index 9fb480eda..745dcbed6 100644 --- a/train.py +++ b/train.py @@ -1,10 +1,4 @@ import ray -from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS - -try: - from sglang.srt.constants import GPU_MEMORY_TYPE_CUDA_GRAPH -except ImportError: - GPU_MEMORY_TYPE_CUDA_GRAPH = None from miles.ray.placement_group import create_placement_groups, create_rollout_manager, create_training_models from miles.utils.arguments import parse_args @@ -27,7 +21,7 @@ def train(args): actor_model, critic_model = create_training_models(args, pgs, rollout_manager) if args.offload_rollout: - ray.get(rollout_manager.onload.remote(tags=[GPU_MEMORY_TYPE_WEIGHTS])) + ray.get(rollout_manager.onload_weights.remote()) # always update weight first so that sglang has the loaded weights from training. actor_model.update_weights() @@ -36,9 +30,7 @@ def train(args): ray.get(rollout_manager.check_weights.remote(action="compare")) if args.offload_rollout: - if GPU_MEMORY_TYPE_CUDA_GRAPH is not None: - ray.get(rollout_manager.onload.remote(tags=[GPU_MEMORY_TYPE_CUDA_GRAPH])) - ray.get(rollout_manager.onload.remote(tags=[GPU_MEMORY_TYPE_KV_CACHE])) + ray.get(rollout_manager.onload_kv.remote()) # special case for eval-only if args.num_rollout == 0 and args.eval_interval is not None: @@ -55,14 +47,24 @@ def offload_train(): else: actor_model.clear_memory() - def onload_rollout(): - if args.offload_rollout: - ray.get(rollout_manager.onload.remote(tags=[GPU_MEMORY_TYPE_WEIGHTS])) + def save(rollout_id): + if (not args.use_critic) or (rollout_id >= args.num_critic_only_steps): + actor_model.save_model( + rollout_id, + force_sync=rollout_id == args.num_rollout - 1, + ) + if args.use_critic: + critic_model.save_model( + rollout_id, + force_sync=rollout_id == args.num_rollout - 1, + ) + if args.rollout_global_dataset: + ray.get(rollout_manager.save.remote(rollout_id)) # train loop. # note that for async training, one can change the position of the sync operation(ray.get). for rollout_id in range(args.start_rollout_id, args.num_rollout): - if args.eval_interval is not None and rollout_id == 0: + if args.eval_interval is not None and rollout_id == 0 and not args.skip_eval_before_train: ray.get(rollout_manager.eval.remote(rollout_id)) rollout_data_ref = ray.get(rollout_manager.generate.remote(rollout_id)) @@ -78,22 +80,15 @@ def onload_rollout(): else: ray.get(actor_model.async_train(rollout_id, rollout_data_ref)) - if should_run_periodic_action(rollout_id, args.save_interval, num_rollout_per_epoch): - if (not args.use_critic) or (rollout_id >= args.num_critic_only_steps): - actor_model.save_model(rollout_id) - if args.use_critic: - critic_model.save_model(rollout_id) - if args.rollout_global_dataset: - ray.get(rollout_manager.save.remote(rollout_id)) + if should_run_periodic_action(rollout_id, args.save_interval, num_rollout_per_epoch, args.num_rollout): + save(rollout_id) offload_train() - onload_rollout() + if args.offload_rollout: + ray.get(rollout_manager.onload_weights.remote()) actor_model.update_weights() - if args.offload_rollout: - if GPU_MEMORY_TYPE_CUDA_GRAPH is not None: - ray.get(rollout_manager.onload.remote(tags=[GPU_MEMORY_TYPE_CUDA_GRAPH])) - ray.get(rollout_manager.onload.remote(tags=[GPU_MEMORY_TYPE_KV_CACHE])) + ray.get(rollout_manager.onload_kv.remote()) if should_run_periodic_action(rollout_id, args.eval_interval, num_rollout_per_epoch): ray.get(rollout_manager.eval.remote(rollout_id)) diff --git a/train_async.py b/train_async.py index a43464aaf..bef1d98ab 100644 --- a/train_async.py +++ b/train_async.py @@ -47,10 +47,16 @@ def train(args): else: ray.get(actor_model.async_train(rollout_id, rollout_data_curr_ref)) - if should_run_periodic_action(rollout_id, args.save_interval, num_rollout_per_epoch): - actor_model.save_model(rollout_id) + if should_run_periodic_action(rollout_id, args.save_interval, num_rollout_per_epoch, args.num_rollout): + actor_model.save_model( + rollout_id, + force_sync=rollout_id == args.num_rollout - 1, + ) if args.use_critic: - critic_model.save_model(rollout_id) + critic_model.save_model( + rollout_id, + force_sync=rollout_id == args.num_rollout - 1, + ) if args.rollout_global_dataset: ray.get(rollout_manager.save.remote(rollout_id))