diff --git a/.gemini/config.yaml b/.gemini/config.yaml new file mode 100644 index 000000000..66015ad30 --- /dev/null +++ b/.gemini/config.yaml @@ -0,0 +1,10 @@ +have_fun: false +code_review: + disable: false + comment_severity_threshold: HIGH + max_review_comments: -1 + pull_request_opened: + help: false + summary: false + code_review: true +ignore_patterns: [] diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 000000000..06f8e0c3d --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1,18 @@ +/docs @eric-haibin-lin @zhaochenyang20 @hongpeng-guo +/docs/amd_tutorial @yushengsu-thu +/docs/slang_multiturn @zhaochenyang20 @SwordFaith + +/recipe/dapo @tongyx361 @PeterSH6 +/recipe/spin @zhaochenyang20 +/recipe/sppo @zhaochenyang20 + +/third_party/sglang @zhaochenyang20 @SwordFaith +/third_party/vllm @PeterSH6 @wuxibin89 +/verl/single_controller @zw0610 @wuxibin89 @hongpeng-guo +/verl/trainer @eric-haibin-lin @vermouth1992 @tongyx361 @PeterSH6 +/verl/workers/rollout/vllm_rollout @wuxibin89 @PeterSH6 @chenhaiq +/verl/workers/rollout/sglang_rollout @zhaochenyang20 @SwordFaith @chenhaiq + +/tests/single_controller @zw0610 @wuxibin89 +/tests/trainer @eric-haibin-lin @vermouth1992 @tongyx361 @PeterSH6 +/tests/workers/rollout/vllm_rollout @wuxibin89 @PeterSH6 @chenhaiq diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 6306a82d1..96f6641cc 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -1,46 +1,40 @@ -### Checklist Before Starting - -- [ ] Search for similar PR(s). - ### What does this PR do? -> Add one-line overview of what this PR aims to achieve or accomplish. +> Add **concise** overview of what this PR aims to achieve or accomplish. Reference related GitHub issues and PRs that help with the review. -### High-Level Design - -> Demonstrate the high-level design if this PR is complex. - -### Specific Changes +### Checklist Before Starting -> List the specific changes. +- [ ] Search for similar PRs. Paste at least one query link here: ... +- [ ] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) + - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` + - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` + - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` + - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. + - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` -### API +### Test -> Demonstrate how the API changes if any. +> For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. -### Usage Example +### API and Usage Example -> Provide usage example(s) for easier usage. +> Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python -# Add code snippet or script demonstrating how to use this +# Add code snippet or script demonstrating how to use this ``` -### Test - -> For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluatuion results, etc. +### Design & Code Changes -### Additional Info. - -- **Issue Number**: Fixes issue # or discussion # if any. -- **Training**: [Note which backend this PR will affect: FSDP, Megatron, both, or none] -- **Inference**: [Note which backend this PR will affect: vLLM, SGLang, both, or none] +> Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting -- [ ] Read the [Contribute Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide). -- [ ] Apply [pre-commit checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting). -- [ ] Add `[BREAKING]` to the PR title if it breaks any API. -- [ ] Update the documentation about your changes in the [docs](https://github.com/volcengine/verl/tree/main/docs). -- [ ] New CI unit test(s) are added to cover the code path. -- [ ] Rely on existing unit tests on CI that covers the code path. +> [!IMPORTANT] +> Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. + +- [ ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). +- [ ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` +- [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). +- [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... +- [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). diff --git a/.github/workflows/README.md b/.github/workflows/README.md new file mode 100644 index 000000000..097250672 --- /dev/null +++ b/.github/workflows/README.md @@ -0,0 +1,69 @@ +### Adding a New Workflow + +When adding a new workflow for continuous integration (CI), you have two runner options: a fixed runner or a machine from the vemlp. + +- **Fixed Runner**: To use a fixed runner, specify it in your workflow using the `runs-on` keyword, like `runs-on: [L20x8]`. +- **Vemlp Runner**: Opting for a Vemlp machine allows you to launch tasks elastically. + +Here is a template to assist you. This template is designed for using Vemlp machines. Currently, for each workflow, you need to create a `setup` and a `cleanup` job. When using this template, the main parts you need to modify are the `IMAGE` environment variable and the specific `job steps`. + +```yaml +name: Your Default Workflow + +on: + push: + branches: + - main + - v0.* + pull_request: + branches: + - main + - v0.* + paths: + - "**/*.py" + - ".github/workflows/template.yml" + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} + +permissions: + contents: read + +env: + IMAGE: "your vemlp image" # e.g. "verl-ci-cn-beijing.cr.volces.com/verlai/verl:app-verl0.4-vllm0.8.5-mcore0.12.1" + DYNAMIC_RUNNER_URL: "https://sd10g3clalm04ug7alq90.apigateway-cn-beijing.volceapi.com/runner" # public veFaas api + +jobs: + setup: + if: github.repository_owner == 'volcengine' + runs-on: ubuntu-latest + outputs: + runner-label: ${{ steps.create-runner.outputs.runner-label }} + task-id: ${{ steps.create-runner.outputs.task-id }} + steps: + - uses: actions/checkout@v4 + - id: create-runner + uses: volcengine/vemlp-github-runner@v1 + with: + mode: "create" + faas-url: "${{ env.DYNAMIC_RUNNER_URL }}" + image: "${{ env.DEFAULT_IMAGE }}" + + your_job: + needs: setup + runs-on: ["${{ needs.setup.outputs.runner-label || 'default-runner' }}"] + steps: + xxxx # your jobs + + cleanup: + runs-on: ubuntu-latest + needs: [setup, your_job] + if: always() + steps: + - id: destroy-runner + uses: volcengine/vemlp-github-runner@v1 + with: + mode: "destroy" + faas-url: "${{ env.DYNAMIC_RUNNER_URL }}" + task-id: "${{ needs.setup.outputs.task-id }}" \ No newline at end of file diff --git a/.github/workflows/check-pr-title.yml b/.github/workflows/check-pr-title.yml new file mode 100644 index 000000000..948ce5e3f --- /dev/null +++ b/.github/workflows/check-pr-title.yml @@ -0,0 +1,58 @@ +# # Tests layout + +# Each folder under tests/ corresponds to a test category for a sub-namespace in verl. For instance: +# - `tests/trainer` for testing functionality related to `verl/trainer` +# - `tests/models` for testing functionality related to `verl/models` +# - ... + +# There are a few folders with `special_` prefix, created for special purposes: +# - `special_distributed`: unit tests that must run with multiple GPUs +# - `special_e2e`: end-to-end tests with training/generation scripts +# - `special_npu`: tests for NPUs +# - `special_sanity`: a suite of quick sanity tests +# - `special_standalone`: a set of test that are designed to run in dedicated environments + +# Accelerators for tests +# - By default tests are run with GPU available, except for the ones under `special_npu`, and any test script whose name ends with `on_cpu.py`. +# - For test scripts with `on_cpu.py` name suffix would be tested on CPU resources in linux environment. + +# # Workflow layout + +# All CI tests are configured by yaml files in `.github/workflows/`. Here's an overview of all test configs: +# 1. A list of always triggered CPU sanity tests: `check-pr-title.yml`, `secrets_scan.yml`, `check-pr-title,yml`, `pre-commit.yml`, `doc.yml` +# 2. Some heavy multi-GPU unit tests, such as `model.yml`, `vllm.yml`, `sgl.yml` +# 3. End-to-end tests: `e2e_*.yml` +# 4. Unit tests +# - `cpu_unit_tests.yml`, run pytest on all scripts with file name pattern `tests/**/test_*_on_cpu.py` +# - `gpu_unit_tests.yml`, run pytest on all scripts with file without the `on_cpu.py` suffix. +# - Since cpu/gpu unit tests by default runs all tests under `tests`, please make sure tests are manually excluded in them when +# - new workflow yaml is added to `.github/workflows` +# - new tests are added to workflow mentioned in 2. + + +on: + pull_request: + types: [opened, edited, synchronize] + +jobs: + check-title: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.11' + + - name: Run PR title checker + run: python3 tests/special_sanity/check_pr_title.py + env: + PR_TITLE: ${{ github.event.pull_request.title }} + + - name: Run PR description checker + run: python3 tests/special_sanity/check_pr_description.py + env: + PR_TITLE: ${{ github.event.pull_request.title }} + GITHUB_EVENT_PATH: ${{ github.event_path }} diff --git a/.github/workflows/checkpoint_converter.yml b/.github/workflows/checkpoint_converter.yml index cea6dbf16..906d1231f 100644 --- a/.github/workflows/checkpoint_converter.yml +++ b/.github/workflows/checkpoint_converter.yml @@ -1,3 +1,36 @@ +# # Tests layout + +# Each folder under tests/ corresponds to a test category for a sub-namespace in verl. For instance: +# - `tests/trainer` for testing functionality related to `verl/trainer` +# - `tests/models` for testing functionality related to `verl/models` +# - ... + +# There are a few folders with `special_` prefix, created for special purposes: +# - `special_distributed`: unit tests that must run with multiple GPUs +# - `special_e2e`: end-to-end tests with training/generation scripts +# - `special_npu`: tests for NPUs +# - `special_sanity`: a suite of quick sanity tests +# - `special_standalone`: a set of test that are designed to run in dedicated environments + +# Accelerators for tests +# - By default tests are run with GPU available, except for the ones under `special_npu`, and any test script whose name ends with `on_cpu.py`. +# - For test scripts with `on_cpu.py` name suffix would be tested on CPU resources in linux environment. + +# # Workflow layout + +# All CI tests are configured by yaml files in `.github/workflows/`. Here's an overview of all test configs: +# 1. A list of always triggered CPU sanity tests: `check-pr-title.yml`, `secrets_scan.yml`, `check-pr-title,yml`, `pre-commit.yml`, `doc.yml` +# 2. Some heavy multi-GPU unit tests, such as `model.yml`, `vllm.yml`, `sgl.yml` +# 3. End-to-end tests: `e2e_*.yml` +# 4. Unit tests +# - `cpu_unit_tests.yml`, run pytest on all scripts with file name pattern `tests/**/test_*_on_cpu.py` +# - `gpu_unit_tests.yml`, run pytest on all scripts with file without the `on_cpu.py` suffix. +# - Since cpu/gpu unit tests by default runs all tests under `tests`, please make sure tests are manually excluded in them when +# - new workflow yaml is added to `.github/workflows` +# - new tests are added to workflow mentioned in 2. + + + name: checkpoint_converter # latest version: Megatron-LM core_r0.11.0 https://github.com/NVIDIA/Megatron-LM/tree/core_r0.11.0 @@ -27,7 +60,7 @@ on: - ".github/workflows/checkpoint_converter.yml" - ".github/workflows/e2e_ppo_trainer_megatron.yml" - "examples/data_preprocess/gsm8k.py" - - "tests/e2e/run_ppo_trainer_megatron.sh" + - "tests/special_e2e/run_ppo_trainer_megatron.sh" - "verl/trainer/main_ppo.py" - "verl/trainer/config/ppo_megatron_trainer.yaml" @@ -51,7 +84,7 @@ jobs: NO_PROXY: "localhost,127.0.0.1" HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable container: - image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6-mcore0.12.0-te2.3 + image: verlai/verl:app-verl0.4-sglang0.4.6.post5-vllm0.8.5-mcore0.12.1 options: --gpus all --shm-size=10g steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 @@ -63,11 +96,11 @@ jobs: - name: Running Huggingface to Megatron dist_ckpt converter (Qwen/Qwen2.5-0.5B) run: | ray stop --force - python scripts/converter_hf_to_mcore.py --hf_model_path=${HOME}/models/Qwen/Qwen2.5-0.5B --output_path checkpoints/Qwen/Qwen2.5-0.5B + python scripts/converter_hf_to_mcore.py --hf_model_path=${HOME}/models/Qwen/Qwen2.5-0.5B --output_path checkpoints/Qwen/Qwen2.5-0.5B --test - name: Running Huggingface to Megatron dist_ckpt converter (deepseek-ai/deepseek-coder-1.3b-instruct) run: | ray stop --force - python scripts/converter_hf_to_mcore.py --hf_model_path=${HOME}/models/deepseek-ai/deepseek-coder-1.3b-instruct --output_path checkpoints/deepseek-ai/deepseek-coder-1.3b-instruct + python scripts/converter_hf_to_mcore.py --hf_model_path=${HOME}/models/deepseek-ai/deepseek-coder-1.3b-instruct --output_path checkpoints/deepseek-ai/deepseek-coder-1.3b-instruct --test - name: Clean up run: | rm -rf checkpoints @@ -81,7 +114,7 @@ jobs: HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable HF_ENDPOINT: "https://hf-mirror.com" container: - image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6-mcore0.12.0-te2.3 + image: verlai/verl:app-verl0.4-sglang0.4.6.post5-vllm0.8.5-mcore0.12.1 options: --gpus all --shm-size=10g steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 @@ -98,6 +131,10 @@ jobs: run: | ray stop --force python scripts/converter_hf_to_mcore.py --hf_model_path=${HOME}/models/Qwen/Qwen1.5-MoE-A2.7B-Chat --output_path checkpoints/Qwen/Qwen1.5-MoE-A2.7B-Chat --use_cpu_initialization + - name: Running distributed Huggingface to Megatron dist_ckpt CPU converter (Qwen/Qwen1.5-MoE-A2.7B-Chat) + run: | + ray stop --force + torchrun --nproc_per_node 8 --nnodes 1 scripts/converter_hf_to_mcore.py --hf_model_path=${HOME}/models/Qwen/Qwen1.5-MoE-A2.7B-Chat --output_path checkpoints/Qwen/Qwen1.5-MoE-A2.7B-Chat_dist --use_cpu_initialization - name: clean up run: | rm -rf checkpoints diff --git a/.github/workflows/cpu_unit_tests.yml b/.github/workflows/cpu_unit_tests.yml new file mode 100644 index 000000000..8b873065b --- /dev/null +++ b/.github/workflows/cpu_unit_tests.yml @@ -0,0 +1,83 @@ +# # Tests layout + +# Each folder under tests/ corresponds to a test category for a sub-namespace in verl. For instance: +# - `tests/trainer` for testing functionality related to `verl/trainer` +# - `tests/models` for testing functionality related to `verl/models` +# - ... + +# There are a few folders with `special_` prefix, created for special purposes: +# - `special_distributed`: unit tests that must run with multiple GPUs +# - `special_e2e`: end-to-end tests with training/generation scripts +# - `special_npu`: tests for NPUs +# - `special_sanity`: a suite of quick sanity tests +# - `special_standalone`: a set of test that are designed to run in dedicated environments + +# Accelerators for tests +# - By default tests are run with GPU available, except for the ones under `special_npu`, and any test script whose name ends with `on_cpu.py`. +# - For test scripts with `on_cpu.py` name suffix would be tested on CPU resources in linux environment. + +# # Workflow layout + +# All CI tests are configured by yaml files in `.github/workflows/`. Here's an overview of all test configs: +# 1. A list of always triggered CPU sanity tests: `check-pr-title.yml`, `secrets_scan.yml`, `check-pr-title,yml`, `pre-commit.yml`, `doc.yml` +# 2. Some heavy multi-GPU unit tests, such as `model.yml`, `vllm.yml`, `sgl.yml` +# 3. End-to-end tests: `e2e_*.yml` +# 4. Unit tests +# - `cpu_unit_tests.yml`, run pytest on all scripts with file name pattern `tests/**/test_*_on_cpu.py` +# - `gpu_unit_tests.yml`, run pytest on all scripts with file without the `on_cpu.py` suffix. +# - Since cpu/gpu unit tests by default runs all tests under `tests`, please make sure tests are manually excluded in them when +# - new workflow yaml is added to `.github/workflows` +# - new tests are added to workflow mentioned in 2. + + +name: cpu_unit_tests + +on: + # Trigger the workflow on push or pull request, + # but only for the main branch + push: + branches: + - main + - v0.* + pull_request: + branches: + - main + - v0.* + paths: + - "**/*.py" + - .github/workflows/cpu_unit_tests.yml + - "!recipe/**/*.py" + +# Cancel jobs on the same ref if a new one is triggered +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} + +# Declare permissions just read content. +permissions: + contents: read + +jobs: + cpu_unit_tests: + runs-on: ubuntu-latest + timeout-minutes: 10 # Increase this timeout value as needed + strategy: + matrix: + python-version: ["3.10"] + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 + with: + python-version: ${{ matrix.python-version }} + - name: Install the current repository + run: | + pip install -e .[test,prime,geo] + pip install --upgrade "ray>=2.40.0" pillow + - name: Running CPU unit tests + run: | + [ ! -d "$HOME/verl-data" ] && huggingface-cli download verl-team/gsm8k-v0.4.1 --repo-type dataset --local-dir ~/verl-data/gsm8k + python3 examples/data_preprocess/geo3k.py + echo '[pytest]' > pytest.ini + echo 'python_files = *_on_cpu.py' >> pytest.ini + pytest -s -x --asyncio-mode=auto tests/ diff --git a/.github/workflows/dataset.yml b/.github/workflows/dataset.yml deleted file mode 100644 index 5e9fa4136..000000000 --- a/.github/workflows/dataset.yml +++ /dev/null @@ -1,58 +0,0 @@ -name: dataset - -on: - # Trigger the workflow on push or pull request, - # but only for the main branch - push: - branches: - - main - - v0.* - pull_request: - branches: - - main - paths: - - "verl/utils/**/*.py" - - .github/workflows/dataset.yml - -# Cancel jobs on the same ref if a new one is triggered -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} - -# Declare permissions just read content. -permissions: - contents: read - -jobs: - ray: - runs-on: [L20x8] - timeout-minutes: 10 # Increase this timeout value as needed - env: - HTTP_PROXY: ${{ secrets.PROXY_HTTP }} - HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} - NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" - HF_ENDPOINT: "https://hf-mirror.com" - HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable - container: - image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6-mcore0.12.0-te2.3 - options: --gpus all --shm-size=10g - steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - fetch-depth: 0 - - name: Install the current repository - run: | - pip install -e .[test] - pip install --upgrade "ray>=2.40.0" - pip install cupy-cuda12x - - name: Running dataset tests - run: | - [ ! -d "$HOME/verl-data" ] && git clone --depth 1 https://github.com/eric-haibin-lin/verl-data ~/verl-data - python3 examples/data_preprocess/geo3k.py - pytest -s -x tests/utils/gpu_tests/dataset/test_rl_dataset.py - pytest -s -x tests/utils/gpu_tests/dataset/test_sft_dataset.py - # pytest -s -x tests/utils/gpu_tests/dataset/test_rm_dataset.py - - name: Running ray test using cupy (move it to L20 when dockerfile ready) - run: | - cd tests/ray_gpu - pytest -s -x test_rvdz.py diff --git a/.github/workflows/disabled/e2e_prime.yml b/.github/workflows/disabled/e2e_prime.yml index 61c7e86cf..b7d4f4e98 100644 --- a/.github/workflows/disabled/e2e_prime.yml +++ b/.github/workflows/disabled/e2e_prime.yml @@ -25,7 +25,7 @@ on: # Entrypoints - ".github/workflows/e2e_prime.yml" - "examples/data_preprocess/gsm8k.py" - - "tests/e2e/run_prime.sh" + - "tests/special_e2e/run_prime.sh" # Cancel jobs on the same ref if a new one is triggered concurrency: @@ -47,7 +47,7 @@ jobs: HF_ENDPOINT: "https://hf-mirror.com" HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable container: - image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6-mcore0.12.0-te2.3 + image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6.post5-mcore0.12.0-te2.3 options: --gpus all --shm-size=10g steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 @@ -63,4 +63,4 @@ jobs: - name: Running GSM8K E2E with prime alg run: | ray stop --force - bash tests/e2e/run_prime.sh + bash tests/special_e2e/run_prime.sh diff --git a/.github/workflows/doc.yml b/.github/workflows/doc.yml index 7faee7e8d..55eaa2eaa 100644 --- a/.github/workflows/doc.yml +++ b/.github/workflows/doc.yml @@ -1,3 +1,35 @@ +# # Tests layout + +# Each folder under tests/ corresponds to a test category for a sub-namespace in verl. For instance: +# - `tests/trainer` for testing functionality related to `verl/trainer` +# - `tests/models` for testing functionality related to `verl/models` +# - ... + +# There are a few folders with `special_` prefix, created for special purposes: +# - `special_distributed`: unit tests that must run with multiple GPUs +# - `special_e2e`: end-to-end tests with training/generation scripts +# - `special_npu`: tests for NPUs +# - `special_sanity`: a suite of quick sanity tests +# - `special_standalone`: a set of test that are designed to run in dedicated environments + +# Accelerators for tests +# - By default tests are run with GPU available, except for the ones under `special_npu`, and any test script whose name ends with `on_cpu.py`. +# - For test scripts with `on_cpu.py` name suffix would be tested on CPU resources in linux environment. + +# # Workflow layout + +# All CI tests are configured by yaml files in `.github/workflows/`. Here's an overview of all test configs: +# 1. A list of always triggered CPU sanity tests: `check-pr-title.yml`, `secrets_scan.yml`, `check-pr-title,yml`, `pre-commit.yml`, `doc.yml` +# 2. Some heavy multi-GPU unit tests, such as `model.yml`, `vllm.yml`, `sgl.yml` +# 3. End-to-end tests: `e2e_*.yml` +# 4. Unit tests +# - `cpu_unit_tests.yml`, run pytest on all scripts with file name pattern `tests/**/test_*_on_cpu.py` +# - `gpu_unit_tests.yml`, run pytest on all scripts with file without the `on_cpu.py` suffix. +# - Since cpu/gpu unit tests by default runs all tests under `tests`, please make sure tests are manually excluded in them when +# - new workflow yaml is added to `.github/workflows` +# - new tests are added to workflow mentioned in 2. + + name: doc_test on: @@ -54,3 +86,15 @@ jobs: echo "🚨 Sphinx doc build contained ERRORs - see _build/sphinx.log" exit 1 fi + if grep -q "WARNING: document isn't included in any toctree" _build/sphinx.log; then + echo "🚨 Sphinx doc build contained WARNING. Please include newly added docs in index.rst. See _build/sphinx.log for details" + exit 1 + fi + if grep -q "WARNING: Inline emphasis" _build/sphinx.log; then + echo "🚨 Sphinx doc build contained WARNING. Please check inline emphasis is correct. See _build/sphinx.log for details" + exit 1 + fi + if grep -q "WARNING: Definition list ends without a blank line" _build/sphinx.log; then + echo "🚨 Sphinx doc build contained WARNING. Please check if the indentation is correct. See _build/sphinx.log for details" + exit 1 + fi diff --git a/.github/workflows/e2e_ascend.yml b/.github/workflows/e2e_ascend.yml index a489b0fa7..c66d77235 100644 --- a/.github/workflows/e2e_ascend.yml +++ b/.github/workflows/e2e_ascend.yml @@ -1,3 +1,35 @@ +# # Tests layout + +# Each folder under tests/ corresponds to a test category for a sub-namespace in verl. For instance: +# - `tests/trainer` for testing functionality related to `verl/trainer` +# - `tests/models` for testing functionality related to `verl/models` +# - ... + +# There are a few folders with `special_` prefix, created for special purposes: +# - `special_distributed`: unit tests that must run with multiple GPUs +# - `special_e2e`: end-to-end tests with training/generation scripts +# - `special_npu`: tests for NPUs +# - `special_sanity`: a suite of quick sanity tests +# - `special_standalone`: a set of test that are designed to run in dedicated environments + +# Accelerators for tests +# - By default tests are run with GPU available, except for the ones under `special_npu`, and any test script whose name ends with `on_cpu.py`. +# - For test scripts with `on_cpu.py` name suffix would be tested on CPU resources in linux environment. + +# # Workflow layout + +# All CI tests are configured by yaml files in `.github/workflows/`. Here's an overview of all test configs: +# 1. A list of always triggered CPU sanity tests: `check-pr-title.yml`, `secrets_scan.yml`, `check-pr-title,yml`, `pre-commit.yml`, `doc.yml` +# 2. Some heavy multi-GPU unit tests, such as `model.yml`, `vllm.yml`, `sgl.yml` +# 3. End-to-end tests: `e2e_*.yml` +# 4. Unit tests +# - `cpu_unit_tests.yml`, run pytest on all scripts with file name pattern `tests/**/test_*_on_cpu.py` +# - `gpu_unit_tests.yml`, run pytest on all scripts with file without the `on_cpu.py` suffix. +# - Since cpu/gpu unit tests by default runs all tests under `tests`, please make sure tests are manually excluded in them when +# - new workflow yaml is added to `.github/workflows` +# - new tests are added to workflow mentioned in 2. + + name: e2e_ascend on: @@ -12,7 +44,21 @@ on: - main paths: - "**/*.py" - - .github/workflows/e2e_ascend.yml + - "requirements-npu.txt" + # Other entrypoints + - "!examples/**" + - "!tests/**" + - "!verl/trainer/main_*.py" + - "!verl/trainer/fsdp_sft_trainer.py" + # Recipes + - "!recipe/**" + # Entrypoints + - ".github/workflows/e2e_ascend.yml" + - "examples/data_preprocess/gsm8k.py" + - "examples/data_preprocess/geo3k.py" + - "tests/special_e2e/ppo_trainer" + - "verl/trainer/main_ppo.py" + - "verl/trainer/config/ppo_trainer.yaml" # Cancel jobs on the same ref if a new one is triggered concurrency: @@ -28,7 +74,7 @@ jobs: runs-on: [self-hosted, npu-0] timeout-minutes: 30 # Increase this timeout value as needed container: - image: crispig/verl_npu:cann8.1rc1-py3.10-torch2.5.1-vllm0.7.3-250603 + image: crispig/verl_npu:cann8.1rc1-py3.10-torch2.5.1-vllm-ascend0.7.3.post1-250616 volumes: - /usr/local/dcmi:/usr/local/dcmi - /usr/local/bin/npu-smi:/usr/local/bin/npu-smi @@ -44,7 +90,7 @@ jobs: --device /dev/hisi_hdc --network host --privileged - --shm-size 2g + --shm-size 16g env: HTTP_PROXY: ${{ secrets.PROXY_HTTP }} HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} @@ -63,17 +109,34 @@ jobs: pip3 install hf_transfer peft pip3 install -r requirements-npu.txt pip install -e . + - name: Install torchviison + run: | + pip install torchvision==0.20.1+cpu --index-url https://download.pytorch.org/whl/cpu - name: Prepare gsm8k dataset run: | ray stop --force python3 examples/data_preprocess/gsm8k.py - - name: Running gsm8k e2e training tests with LoRA on ASCEND NPU + - name: Prepare geo3k dataset run: | ray stop --force - bash tests/e2e/sft/run_sft.sh + python3 examples/data_preprocess/geo3k.py + - name: Running gsm8k e2e training tests with peft sft on ASCEND NPU + run: | + ray stop --force + bash tests/special_npu/run_qwen2_5_05b_sft_peft_sp2.sh rm -rf $HOME/ckpts - name: Running gsm8k e2e training tests with GRPO on ASCEND NPU run: | ray stop --force - bash tests/npu/run_qwen2_5_05b_grpo.sh - rm -rf $HOME/ckpts \ No newline at end of file + bash tests/special_npu/run_qwen2_5_05b_grpo.sh + rm -rf $HOME/ckpts + - name: Running geo3k e2e training tests with GRPO on ASCEND NPU + run: | + ray stop --force + bash tests/special_npu/run_qwen2_5_vl_3b_npu.sh + rm -rf $HOME/ckpts + - name: Running gsm8k e2e training tests with DAPO on ASCEND NPU + run: | + ray stop --force + bash tests/special_npu/run_qwen2_5_05b_dapo.sh + rm -rf $HOME/ckpts diff --git a/.github/workflows/e2e_dapo.yml b/.github/workflows/e2e_dapo.yml index 784e2a071..446d2c108 100644 --- a/.github/workflows/e2e_dapo.yml +++ b/.github/workflows/e2e_dapo.yml @@ -1,12 +1,57 @@ +# # Tests layout + +# Each folder under tests/ corresponds to a test category for a sub-namespace in verl. For instance: +# - `tests/trainer` for testing functionality related to `verl/trainer` +# - `tests/models` for testing functionality related to `verl/models` +# - ... + +# There are a few folders with `special_` prefix, created for special purposes: +# - `special_distributed`: unit tests that must run with multiple GPUs +# - `special_e2e`: end-to-end tests with training/generation scripts +# - `special_npu`: tests for NPUs +# - `special_sanity`: a suite of quick sanity tests +# - `special_standalone`: a set of test that are designed to run in dedicated environments + +# Accelerators for tests +# - By default tests are run with GPU available, except for the ones under `special_npu`, and any test script whose name ends with `on_cpu.py`. +# - For test scripts with `on_cpu.py` name suffix would be tested on CPU resources in linux environment. + +# # Workflow layout + +# All CI tests are configured by yaml files in `.github/workflows/`. Here's an overview of all test configs: +# 1. A list of always triggered CPU sanity tests: `check-pr-title.yml`, `secrets_scan.yml`, `check-pr-title,yml`, `pre-commit.yml`, `doc.yml` +# 2. Some heavy multi-GPU unit tests, such as `model.yml`, `vllm.yml`, `sgl.yml` +# 3. End-to-end tests: `e2e_*.yml` +# 4. Unit tests +# - `cpu_unit_tests.yml`, run pytest on all scripts with file name pattern `tests/**/test_*_on_cpu.py` +# - `gpu_unit_tests.yml`, run pytest on all scripts with file without the `on_cpu.py` suffix. +# - Since cpu/gpu unit tests by default runs all tests under `tests`, please make sure tests are manually excluded in them when +# - new workflow yaml is added to `.github/workflows` +# - new tests are added to workflow mentioned in 2. + + name: e2e_dapo on: # Trigger the workflow on push or pull request, # but only for the main branch + # For push, for now only anti-patterns are specified so it is more conservative + # and achieves higher coverage. push: branches: - main - v0.* + paths: + - "verl/*.py" + # Other entrypoints + - "!examples/*trainer*" + - "!tests/**" + - "!verl/trainer/main_*.py" + - "!verl/trainer/fsdp_sft_trainer.py" + # Megatron + - "!verl/workers/**/megatron_*.py" + - "!recipe/**" + - "recipe/dapo" pull_request: branches: - main @@ -27,7 +72,7 @@ on: # Entrypoints - ".github/workflows/e2e_dapo.yml" - "examples/data_preprocess/gsm8k.py" - - "tests/e2e/run_dapo.sh" + - "tests/special_e2e/run_dapo.sh" # Cancel jobs on the same ref if a new one is triggered concurrency: @@ -49,7 +94,7 @@ jobs: HF_ENDPOINT: "https://hf-mirror.com" HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable container: - image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6-mcore0.12.0-te2.3 + image: verlai/verl:app-verl0.4-vllm0.8.5-mcore0.12.1 options: --gpus all --shm-size=10g steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 @@ -64,4 +109,4 @@ jobs: - name: Running the E2E test with the DAPO algorithm run: | ray stop --force - bash tests/e2e/run_dapo.sh + bash tests/special_e2e/run_dapo.sh diff --git a/.github/workflows/e2e_eval_aime24.yml b/.github/workflows/e2e_eval_aime24.yml index 63532d7ce..9728be7c2 100644 --- a/.github/workflows/e2e_eval_aime24.yml +++ b/.github/workflows/e2e_eval_aime24.yml @@ -1,18 +1,68 @@ +# # Tests layout + +# Each folder under tests/ corresponds to a test category for a sub-namespace in verl. For instance: +# - `tests/trainer` for testing functionality related to `verl/trainer` +# - `tests/models` for testing functionality related to `verl/models` +# - ... + +# There are a few folders with `special_` prefix, created for special purposes: +# - `special_distributed`: unit tests that must run with multiple GPUs +# - `special_e2e`: end-to-end tests with training/generation scripts +# - `special_npu`: tests for NPUs +# - `special_sanity`: a suite of quick sanity tests +# - `special_standalone`: a set of test that are designed to run in dedicated environments + +# Accelerators for tests +# - By default tests are run with GPU available, except for the ones under `special_npu`, and any test script whose name ends with `on_cpu.py`. +# - For test scripts with `on_cpu.py` name suffix would be tested on CPU resources in linux environment. + +# # Workflow layout + +# All CI tests are configured by yaml files in `.github/workflows/`. Here's an overview of all test configs: +# 1. A list of always triggered CPU sanity tests: `check-pr-title.yml`, `secrets_scan.yml`, `check-pr-title,yml`, `pre-commit.yml`, `doc.yml` +# 2. Some heavy multi-GPU unit tests, such as `model.yml`, `vllm.yml`, `sgl.yml` +# 3. End-to-end tests: `e2e_*.yml` +# 4. Unit tests +# - `cpu_unit_tests.yml`, run pytest on all scripts with file name pattern `tests/**/test_*_on_cpu.py` +# - `gpu_unit_tests.yml`, run pytest on all scripts with file without the `on_cpu.py` suffix. +# - Since cpu/gpu unit tests by default runs all tests under `tests`, please make sure tests are manually excluded in them when +# - new workflow yaml is added to `.github/workflows` +# - new tests are added to workflow mentioned in 2. + + name: e2e_eval_aime24 on: # Trigger the workflow on push or pull request, # but only for the main branch + # For push, for now only anti-patterns are specified so it is more conservative + # and achieves higher coverage. push: branches: - main - v0.* + paths: + - "**/*.py" + # Other entrypoints + - "!*.md" + - "!docker/**" + - "!docs/**" + - "!examples/**" + - "!tests/**" + - "!verl/trainer/main_*.py" + - "!verl/trainer/fsdp_sft_trainer.py" + - "!recipe/**" + - "recipe/r1" + - "!recipe/r1/README.md" pull_request: branches: - main paths: - "**/*.py" # Other entrypoints + - "!*.md" + - "!docker/**" + - "!docs/**" - "!examples/**" - "!tests/**" - "!verl/trainer/main_*.py" @@ -24,7 +74,7 @@ on: - "!recipe/**" # Entrypoints - ".github/workflows/e2e_eval_aime24.yml" - - "tests/e2e/run_r1_distill_qwen_aime24_eval.sh" + - "tests/special_e2e/run_r1_distill_qwen_aime24_eval.sh" - "verl/trainer/main_generation.py" - "verl/trainer/config/generation.yaml" @@ -37,9 +87,29 @@ concurrency: permissions: contents: read +env: + IMAGE: "verl-ci-cn-beijing.cr.volces.com/verlai/verl:app-verl0.4-vllm0.8.5-mcore0.12.1" + DYNAMIC_RUNNER_ENDPOINT: "https://sd10g3clalm04ug7alq90.apigateway-cn-beijing.volceapi.com/runner" + jobs: + setup: + if: github.repository_owner == 'volcengine' + runs-on: ubuntu-latest + outputs: + runner-label: ${{ steps.create-runner.outputs.runner-label }} + mlp-task-id: ${{ steps.create-runner.outputs.mlp-task-id }} + steps: + - uses: actions/checkout@v4 + - id: create-runner + uses: volcengine/vemlp-github-runner@v1 + with: + mode: "create" + faas-url: "${{ env.DYNAMIC_RUNNER_ENDPOINT }}" + mlp-image: "${{ env.IMAGE }}" + e2e_eval_aime24: - runs-on: [L20x8] + needs: setup + runs-on: ["${{ needs.setup.outputs.runner-label || 'L20x8' }}"] timeout-minutes: 40 # Increase this timeout value as needed env: HTTP_PROXY: ${{ secrets.PROXY_HTTP }} @@ -47,9 +117,6 @@ jobs: NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" HF_ENDPOINT: "https://hf-mirror.com" HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable - container: - image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6-mcore0.12.0-te2.3 - options: --gpus all --shm-size=10g steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: @@ -65,4 +132,16 @@ jobs: - name: Running generation and evaluation in AIME 2024 run: | ray stop --force - bash tests/e2e/run_r1_distill_qwen_aime24_eval.sh + bash tests/special_e2e/run_r1_distill_qwen_aime24_eval.sh + + cleanup: + runs-on: ubuntu-latest + needs: [setup, e2e_eval_aime24] + if: always() + steps: + - id: destroy-runner + uses: volcengine/vemlp-github-runner@v1 + with: + mode: "destroy" + faas-url: "${{ env.DYNAMIC_RUNNER_ENDPOINT }}" + mlp-task-id: "${{ needs.setup.outputs.mlp-task-id }}" \ No newline at end of file diff --git a/.github/workflows/e2e_genrm_remote.yml b/.github/workflows/e2e_genrm_remote.yml new file mode 100644 index 000000000..8f06a6cd0 --- /dev/null +++ b/.github/workflows/e2e_genrm_remote.yml @@ -0,0 +1,105 @@ +# # Tests layout + +# Each folder under tests/ corresponds to a test category for a sub-namespace in verl. For instance: +# - `tests/trainer` for testing functionality related to `verl/trainer` +# - `tests/models` for testing functionality related to `verl/models` +# - ... + +# There are a few folders with `special_` prefix, created for special purposes: +# - `special_distributed`: unit tests that must run with multiple GPUs +# - `special_e2e`: end-to-end tests with training/generation scripts +# - `special_npu`: tests for NPUs +# - `special_sanity`: a suite of quick sanity tests +# - `special_standalone`: a set of test that are designed to run in dedicated environments + +# Accelerators for tests +# - By default tests are run with GPU available, except for the ones under `special_npu`, and any test script whose name ends with `on_cpu.py`. +# - For test scripts with `on_cpu.py` name suffix would be tested on CPU resources in linux environment. + +# # Workflow layout + +# All CI tests are configured by yaml files in `.github/workflows/`. Here's an overview of all test configs: +# 1. A list of always triggered CPU sanity tests: `check-pr-title.yml`, `secrets_scan.yml`, `check-pr-title,yml`, `pre-commit.yml`, `doc.yml` +# 2. Some heavy multi-GPU unit tests, such as `model.yml`, `vllm.yml`, `sgl.yml` +# 3. End-to-end tests: `e2e_*.yml` +# 4. Unit tests +# - `cpu_unit_tests.yml`, run pytest on all scripts with file name pattern `tests/**/test_*_on_cpu.py` +# - `gpu_unit_tests.yml`, run pytest on all scripts with file without the `on_cpu.py` suffix. +# - Since cpu/gpu unit tests by default runs all tests under `tests`, please make sure tests are manually excluded in them when +# - new workflow yaml is added to `.github/workflows` +# - new tests are added to workflow mentioned in 2. + + +name: e2e_genrm_remote + +on: + # Trigger the workflow on push or pull request, + # but only for the main branch + push: + branches: + - main + - v0.* + paths: + - "**/*.py" + - "tests/**" + - "!recipe/**" + - "recipe/genrm_remote" + pull_request: + branches: + - main + - v0.* + paths: + - "**/*.py" + # Other entrypoints + - "!examples/**" + - "!tests/**" + - "!verl/trainer/main_*.py" + - "!verl/trainer/fsdp_sft_trainer.py" + # Other recipes + - "!recipe/**" + # Megatron + - "!verl/workers/**/megatron_*.py" + # Home + - "recipe/genrm_remote" + - "!recipe/genrm_remote/README.md" + # Entrypoints + - ".github/workflows/e2e_genrm_remote.yml" + - "examples/data_preprocess/gsm8k.py" + - "tests/special_e2e/run_genrm_remote.sh" + +# Cancel jobs on the same ref if a new one is triggered +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} + +# Declare permissions just read content. +permissions: + contents: read + +jobs: + e2e_genrm_remote: + runs-on: [L20x8] + timeout-minutes: 40 # Increase this timeout value as needed + env: + HTTP_PROXY: ${{ secrets.PROXY_HTTP }} + HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} + NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" + HF_ENDPOINT: "https://hf-mirror.com" + HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable + container: + image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6.post5-mcore0.12.0-te2.3 + options: --gpus all --shm-size=10g + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + fetch-depth: 0 + - name: Install the current repository + run: | + pip3 install --no-deps -e .[test,gpu] + - name: Prepare GSM8K dataset + run: | + python3 examples/data_preprocess/gsm8k.py + - name: Running the E2E test with the Generative Reward Model + run: | + ray stop --force + bash tests/special_e2e/run_genrm_remote.sh diff --git a/.github/workflows/e2e_ppo_trainer.yml b/.github/workflows/e2e_ppo_trainer.yml index 057d3d0ca..c67dce301 100644 --- a/.github/workflows/e2e_ppo_trainer.yml +++ b/.github/workflows/e2e_ppo_trainer.yml @@ -3,10 +3,21 @@ name: e2e_ppo_trainer on: # Trigger the workflow on push or pull request, # but only for the main branch + # For push, for now only anti-patterns are specified so it is more conservative + # and achieves higher coverage. push: branches: - main - v0.* + paths: + - "**/*.py" + # Other entrypoints + - "!verl/trainer/fsdp_sft_trainer.py" + # Recipes + - "!recipe/**" + # Megatron + - "!verl/workers/**/megatron_*.py" + pull_request: branches: - main @@ -14,10 +25,14 @@ on: paths: - "**/*.py" # Other entrypoints + - "!**/*.md" + - "!docker/**" - "!examples/**" - "!tests/**" - "!verl/trainer/main_*.py" - "!verl/trainer/fsdp_sft_trainer.py" + # Docs + - "!docs/**" # Recipes - "!recipe/**" # Megatron @@ -26,7 +41,7 @@ on: - ".github/workflows/e2e_ppo_trainer.yml" - "examples/data_preprocess/gsm8k.py" - "examples/data_preprocess/geo3k.py" - - "tests/e2e/ppo_trainer" + - "tests/special_e2e/ppo_trainer" - "verl/trainer/main_ppo.py" - "verl/trainer/config/ppo_trainer.yaml" @@ -51,6 +66,9 @@ jobs: uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: ${{ matrix.python-version }} + - name: Install the current repository + run: | + pip install -e . - name: Set ruff --output-format=github run: | sed -i 's/--output-format=full/--output-format=github/' .pre-commit-config.yaml @@ -61,7 +79,7 @@ jobs: e2e_ppo_trainer_vllm: runs-on: [L20x8] - timeout-minutes: 40 # Increase this timeout value as needed + timeout-minutes: 60 # Increase this timeout value as needed env: HTTP_PROXY: ${{ secrets.PROXY_HTTP }} HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} @@ -69,7 +87,7 @@ jobs: HF_ENDPOINT: "https://hf-mirror.com" HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable container: - image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6-mcore0.12.0-te2.3 + image: verlai/verl:app-verl0.4-vllm0.8.5-mcore0.12.1 options: --gpus all --shm-size=10g steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 @@ -82,85 +100,112 @@ jobs: run: | ray stop --force python3 examples/data_preprocess/gsm8k.py + # HF sanity + - name: Running GSM8K E2E training tests on 1 L20 GPU with hf for santiy + run: | + ray stop --force + bash tests/special_e2e/ppo_trainer/run_single_gpu.sh # Function RM - name: Running GSM8K E2E training tests on 8 L20 GPUs with rmpad using function rm with validation and saving (FSDP_SIZE=8) run: | ray stop --force - VAL_BEFORE_TRAIN=True TEST_FREQ=1 SAVE_FREQ=1 SAVE_HF_MODEL=True VERL_EXP_NAME="qwen2.5-0.5b-function-reward-minimal-fsdp8" bash tests/e2e/ppo_trainer/run_function_reward.sh + VAL_BEFORE_TRAIN=True TEST_FREQ=1 SAVE_FREQ=1 SAVE_HF_MODEL=True VERL_EXP_NAME="qwen2.5-0.5b-function-reward-minimal-fsdp-size8" bash tests/special_e2e/ppo_trainer/run_function_reward.sh - name: Running GSM8K E2E training tests on 8 L20 GPUs with rmpad using function rm after resuming run: | ray stop --force - RESUME_MODE=auto VERL_EXP_NAME="qwen2.5-0.5b-function-reward-minimal-fsdp8" bash tests/e2e/ppo_trainer/run_function_reward.sh + RESUME_MODE=auto VERL_EXP_NAME="qwen2.5-0.5b-function-reward-minimal-fsdp-size8" bash tests/special_e2e/ppo_trainer/run_function_reward.sh - name: Test merging FSDP checkpoints (Qwen Actor) run: | - exp_name="qwen2.5-0.5b-function-reward-minimal-fsdp8" - python scripts/model_merger.py test --backend fsdp --local_dir checkpoints/verl-test/${exp_name}/global_step_1/actor --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/actor/huggingface + exp_name="qwen2.5-0.5b-function-reward-minimal-fsdp-size8" + python -m verl.model_merger test --backend fsdp --local_dir checkpoints/verl-test/${exp_name}/global_step_1/actor --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/actor/huggingface - name: Running GSM8K E2E training tests on 8 L20 GPUs with rmpad using function rm with validation and saving (DDP_SIZE=2, FSDP_SIZE=4) run: | ray stop --force - VAL_BEFORE_TRAIN=True TEST_FREQ=1 SAVE_FREQ=1 SAVE_HF_MODEL=True FSDP_SIZE=4 VERL_EXP_NAME="qwen2.5-0.5b-function-reward-minimal-ddp2-fsdp4" bash tests/e2e/ppo_trainer/run_function_reward.sh + VAL_BEFORE_TRAIN=True TEST_FREQ=1 SAVE_FREQ=1 SAVE_HF_MODEL=True FSDP_SIZE=4 VERL_EXP_NAME="qwen2.5-0.5b-function-reward-minimal-ddp-size2-fsdp-size4" bash tests/special_e2e/ppo_trainer/run_function_reward.sh - name: Test merging DDP+FSDP checkpoints (Qwen Actor) run: | - exp_name="qwen2.5-0.5b-function-reward-minimal-ddp2-fsdp4" - python scripts/model_merger.py test --backend fsdp --local_dir checkpoints/verl-test/${exp_name}/global_step_1/actor --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/actor/huggingface + exp_name="qwen2.5-0.5b-function-reward-minimal-ddp-size2-fsdp-size4" + python -m verl.model_merger test --backend fsdp --local_dir checkpoints/verl-test/${exp_name}/global_step_1/actor --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/actor/huggingface + - name: Running GSM8K E2E training tests on 8 L20 GPUs with rmpad using function rm with validation and saving (FSDP2) + run: | + ray stop --force + VAL_BEFORE_TRAIN=True TEST_FREQ=1 SAVE_FREQ=1 SAVE_HF_MODEL=True VERL_EXP_NAME="qwen2.5-0.5b-function-reward-minimal-fsdp2-size8" STRATEGY=fsdp2 bash tests/special_e2e/ppo_trainer/run_function_reward.sh + - name: Test merging FSDP2 checkpoints (Qwen Actor) + run: | + exp_name="qwen2.5-0.5b-function-reward-minimal-fsdp2-size8" + python -m verl.model_merger test --backend fsdp --local_dir checkpoints/verl-test/${exp_name}/global_step_1/actor --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/actor/huggingface - name: Running GSM8K E2E without rmpad using function rm run: | ray stop --force - RM_PAD=False bash tests/e2e/ppo_trainer/run_function_reward.sh + RM_PAD=False bash tests/special_e2e/ppo_trainer/run_function_reward.sh - name: Running GSM8K E2E training tests on 8 L20 GPUs with rmpad using function rm (GRPO) run: | ray stop --force - ADV_ESTIMATOR=grpo USE_KL=True bash tests/e2e/ppo_trainer/run_function_reward.sh + ADV_ESTIMATOR=grpo USE_KL=True bash tests/special_e2e/ppo_trainer/run_function_reward.sh - name: Running GSM8K E2E training tests on 8 L20 GPUs with rmpad using function rm (ReMax) run: | ray stop --force - ADV_ESTIMATOR=remax USE_KL=True bash tests/e2e/ppo_trainer/run_function_reward.sh + ADV_ESTIMATOR=remax USE_KL=True bash tests/special_e2e/ppo_trainer/run_function_reward.sh - name: Running GSM8K E2E training tests on 8 L20 GPUs with rmpad using customized reward function run: | ray stop --force - CUSTOM_REWARD_FN=True bash tests/e2e/ppo_trainer/run_function_reward.sh + CUSTOM_REWARD_FN=True bash tests/special_e2e/ppo_trainer/run_function_reward.sh - name: Running GSM8K E2E training tests on 8 L20 GPUs with rmpad using function rm with in-reward kl and kl loss run: | ray stop --force - USE_KL=True bash tests/e2e/ppo_trainer/run_function_reward.sh + USE_KL=True bash tests/special_e2e/ppo_trainer/run_function_reward.sh # LoRA tests - name: Running GSM8K E2E training tests on 8 L20 GPUs with grpo lora using function rm with use_shm run: | ray stop --force - ADV_ESTIMATOR=grpo USE_SHM=True LORA_RANK=32 LOAD_FORMAT=safetensors bash tests/e2e/ppo_trainer/run_function_reward.sh + ADV_ESTIMATOR=grpo USE_SHM=True LORA_RANK=32 LOAD_FORMAT=safetensors bash tests/special_e2e/ppo_trainer/run_function_reward.sh - name: Running GSM8K E2E training tests on 8 L20 GPUs with grpo lora using function rm with use_shm and layered_summon run: | ray stop --force - ADV_ESTIMATOR=grpo USE_SHM=True LORA_RANK=32 LOAD_FORMAT=safetensors LAYERED_SUMMON=True bash tests/e2e/ppo_trainer/run_function_reward.sh + ADV_ESTIMATOR=grpo USE_SHM=True LORA_RANK=32 LOAD_FORMAT=safetensors LAYERED_SUMMON=True TOTAL_TRAIN_STEPS=1 SAVE_FREQ=1 FSDP_SIZE=4 VERL_EXP_NAME="qwen2.5-0.5b-function-reward-minimal" bash tests/special_e2e/ppo_trainer/run_function_reward.sh + - name: Test GRPO LoRA checkpoints merging function + run: | + export EXP_NAME="qwen2.5-0.5b-function-reward-minimal" + ls checkpoints/verl-test/${EXP_NAME}/global_step_1/actor + cat checkpoints/verl-test/${EXP_NAME}/global_step_1/actor/huggingface/config.json + python3 -m verl.model_merger merge --backend fsdp --local_dir checkpoints/verl-test/${EXP_NAME}/global_step_1/actor/ --target_dir checkpoints/verl-test/${EXP_NAME}/global_step_1/actor/huggingface - name: Running GSM8K E2E training tests on 8 L20 GPUs with grpo lora using function rm with use_shm and layered_summon with fsdp2 run: | ray stop --force - ADV_ESTIMATOR=grpo USE_SHM=True LORA_RANK=32 LOAD_FORMAT=safetensors LAYERED_SUMMON=True STRATEGY=fsdp2 bash tests/e2e/ppo_trainer/run_function_reward.sh + ADV_ESTIMATOR=grpo USE_SHM=True LORA_RANK=32 LOAD_FORMAT=safetensors LAYERED_SUMMON=True STRATEGY=fsdp2 bash tests/special_e2e/ppo_trainer/run_function_reward.sh # Model RM - name: Running GRPO GSM8K E2E training tests with FSDP on 8 L20 GPUs (DeepSeek) run: | ray stop --force - MODEL_ID=deepseek-ai/deepseek-coder-1.3b-instruct bash tests/e2e/ppo_trainer/run_function_reward.sh + MODEL_ID=deepseek-ai/deepseek-coder-1.3b-instruct bash tests/special_e2e/ppo_trainer/run_function_reward.sh - name: Running GSM8K E2E with rmpad using model rm run: | ray stop --force - bash tests/e2e/ppo_trainer/run_model_reward.sh + bash tests/special_e2e/ppo_trainer/run_model_reward.sh - name: Running GSM8K E2E without rmpad using model rm run: | ray stop --force - RM_PAD=False bash tests/e2e/ppo_trainer/run_model_reward.sh + RM_PAD=False bash tests/special_e2e/ppo_trainer/run_model_reward.sh - name: Running GSM8K E2E with rmpad using model rm and ulysses sp=2 run: | ray stop --force - SP_SIZE=2 bash tests/e2e/ppo_trainer/run_model_reward.sh + SP_SIZE=2 bash tests/special_e2e/ppo_trainer/run_model_reward.sh - name: Running GSM8K E2E with rmpad using model rm and dynamic batch size run: | ray stop --force - SEQ_BALANCE=True bash tests/e2e/ppo_trainer/run_model_reward.sh + SEQ_BALANCE=True bash tests/special_e2e/ppo_trainer/run_model_reward.sh - name: Running GSM8K E2E with rmpad using model rm with Liger Kernel enabled run: | ray stop --force - LIGER=True bash tests/e2e/ppo_trainer/run_model_reward.sh + LIGER=True bash tests/special_e2e/ppo_trainer/run_model_reward.sh + - name: Running GSM8K E2E with rmpad using model rm with Fused Kernel enabled + run: | + ray stop --force + FUSED_KERNELS=True bash tests/special_e2e/ppo_trainer/run_model_reward.sh + - name: Running GSM8K E2E with rmpad using model rm with Fused Kernel enabled + run: | + ray stop --force + FUSED_KERNEL=True FUSED_KERNEL_BACKEND=triton bash tests/special_e2e/ppo_trainer/run_model_reward.sh e2e_ppo_trainer_vllm_vlm: runs-on: [L20x8] @@ -173,7 +218,7 @@ jobs: HF_ENDPOINT: "https://hf-mirror.com" HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable container: - image: hiyouga/verl:ngc-th2.6.0-cu126-vllm0.8.3-flashinfer0.2.2-cxx11abi0 + image: verlai/verl:app-verl0.4-vllm0.8.5-mcore0.12.1 options: --gpus all --shm-size=50g # Visual dataloader requires large memory steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 @@ -181,13 +226,33 @@ jobs: fetch-depth: 0 - name: Install the current repository run: | - pip3 install -e .[test,geo,vllm] + pip3 install -e .[test,gpu,vllm,geo,trl] + pip install "transformers[hf_xet]<4.53.0" # Fix for transformers 4.53.0 # Geo3k - - name: Prepare Geo3k dataset + - name: Prepare GEO3K dataset run: | ray stop --force python3 examples/data_preprocess/geo3k.py - - name: Running Geo3k VLM E2E training tests on 8 L20 GPUs with rmpad using function rm + - name: Running GEO3K VLM GRPO E2E training tests on 8 L20 GPUs with rmpad using function rm + run: | + ray stop --force + TRAIN_FILES=$HOME/data/geo3k/train.parquet VAL_FILES=$HOME/data/geo3k/test.parquet \ + MAX_PROMPT_LEN=1536 MAX_RESPONSE_LEN=1536 \ + MODEL_ID=Qwen/Qwen2-VL-2B-Instruct \ + ADV_ESTIMATOR=grpo RM_PAD=True USE_KL=True ENABLE_CHUNKED_PREFILL=False \ + SP_SIZE=2 \ + bash tests/special_e2e/ppo_trainer/run_function_reward.sh + + - name: Running GEO3K VLM PPO E2E training tests on 8 L20 GPUs with rmpad using function rm + run: | + ray stop --force + TRAIN_FILES=$HOME/data/geo3k/train.parquet VAL_FILES=$HOME/data/geo3k/test.parquet \ + MAX_PROMPT_LEN=1536 MAX_RESPONSE_LEN=1536 \ + MODEL_ID=Qwen/Qwen2-VL-2B-Instruct \ + ADV_ESTIMATOR=gae RM_PAD=True USE_KL=True ENABLE_CHUNKED_PREFILL=False \ + SP_SIZE=2 \ + bash tests/special_e2e/ppo_trainer/run_function_reward.sh + - name: Running GEO3K VLM GRPO E2E lora training tests on 8 L20 GPUs with rmpad using function rm run: | ray stop --force TRAIN_FILES=$HOME/data/geo3k/train.parquet VAL_FILES=$HOME/data/geo3k/test.parquet \ @@ -195,7 +260,9 @@ jobs: MODEL_ID=Qwen/Qwen2-VL-2B-Instruct \ ADV_ESTIMATOR=grpo RM_PAD=True USE_KL=True ENABLE_CHUNKED_PREFILL=False \ SP_SIZE=2 \ - bash tests/e2e/ppo_trainer/run_function_reward.sh + LORA_RANK=32 LORA_EXCLUDE=".*visual.*" \ + bash tests/special_e2e/ppo_trainer/run_function_reward.sh + e2e_ppo_trainer_sglang: runs-on: [L20x8] @@ -208,7 +275,7 @@ jobs: HF_ENDPOINT: "https://hf-mirror.com" HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable container: - image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6.post5-mcore0.12.0-te2.3 + image: verlai/verl:app-verl0.4-sglang0.4.6.post5-vllm0.8.5-mcore0.12.1 options: --gpus all --shm-size=10g steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 @@ -224,7 +291,17 @@ jobs: - name: Running GSM8K E2E training tests on 8 L20 GPUs with rmpad using function rm and save ckpt run: | ray stop --force - ENGINE=sglang bash tests/e2e/ppo_trainer/run_function_reward.sh + ENGINE=sglang bash tests/special_e2e/ppo_trainer/run_function_reward.sh + - name: Running GSM8K E2E training tests on sglang async + run: | + ray stop --force + TOTAL_TRAIN_STEPS=2 ENGINE=sglang ROLLOUT_MODE=async bash tests/special_e2e/ppo_trainer/run_function_reward.sh + - name: Running GSM8K E2E training tests on vllm async + run: | + ray stop --force + export VLLM_USE_V1=1 + ray start --head + TOTAL_TRAIN_STEPS=2 ENGINE=vllm ROLLOUT_MODE=async bash tests/special_e2e/ppo_trainer/run_function_reward.sh e2e_ppo_trainer_sglang_multiturn_with_tool: runs-on: [L20x8] @@ -237,7 +314,7 @@ jobs: HF_ENDPOINT: "https://hf-mirror.com" HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable container: - image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6.post5-mcore0.12.0-te2.3 + image: verlai/verl:app-verl0.4-sglang0.4.6.post5-vllm0.8.5-mcore0.12.1 options: --gpus all --shm-size=10g steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 @@ -253,16 +330,16 @@ jobs: - name: Running GSM8K with tool E2E training tests on 8 L20 GPUs with rmpad using function rm and save ckpt with sglang run: | ray stop --force - bash tests/e2e/run_gsm8k_fsdp_sgl_multiturn_w_tool.sh + bash tests/special_e2e/run_gsm8k_fsdp_sgl_multiturn_w_tool.sh - name: Running GSM8K with tool E2E training tests with FSDP2 run: | ray stop --force - FSDP_STRATEGY=fsdp2 bash tests/e2e/run_gsm8k_fsdp_sgl_multiturn_w_tool.sh + FSDP_STRATEGY=fsdp2 bash tests/special_e2e/run_gsm8k_fsdp_sgl_multiturn_w_tool.sh e2e_ppo_trainer_sglang_vlm: runs-on: [L20x8] needs: pre_commit_for_ppo - timeout-minutes: 40 # Increase this timeout value as needed + timeout-minutes: 60 # Increase this timeout value as needed env: HTTP_PROXY: ${{ secrets.PROXY_HTTP }} HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} @@ -270,7 +347,7 @@ jobs: HF_ENDPOINT: "https://hf-mirror.com" HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable container: - image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6.post5-mcore0.12.0-te2.3 + image: verlai/verl:app-verl0.4-sglang0.4.6.post5-vllm0.8.5-mcore0.12.1 options: --gpus all --shm-size=50g # Visual dataloader requires large memory steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 @@ -280,11 +357,11 @@ jobs: run: | pip3 install -e .[test,geo,gpu,sglang] # Geo3k - - name: Prepare Geo3k dataset + - name: Prepare GEO3K dataset run: | ray stop --force python3 examples/data_preprocess/geo3k.py - - name: Running Geo3k VLM E2E training tests on 8 L20 GPUs with rmpad using function rm + - name: Running GEO3K VLM E2E training tests on 8 L20 GPUs with rmpad using function rm run: | ray stop --force TRAIN_FILES=$HOME/data/geo3k/train.parquet VAL_FILES=$HOME/data/geo3k/test.parquet \ @@ -293,45 +370,30 @@ jobs: ADV_ESTIMATOR=grpo RM_PAD=True USE_KL=True ENABLE_CHUNKED_PREFILL=False \ ENGINE=sglang GPU_MEMORY_UTILIZATION=0.6 ACTOR_FSDP_PARAM_OFFLOAD=True \ ACTOR_FSDP_OPTIMIZER_OFFLOAD=True REF_FSDP_PARAM_OFFLOAD=True \ - bash tests/e2e/ppo_trainer/run_function_reward.sh - - e2e_ppo_trainer_fused_kernels_vllm: - runs-on: [L20x8] - needs: pre_commit_for_ppo - timeout-minutes: 40 # Increase this timeout value as needed - env: - HTTP_PROXY: ${{ secrets.PROXY_HTTP }} - HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} - NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" - HF_ENDPOINT: "https://hf-mirror.com" - HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable - container: - image: hiyouga/verl:ngc-th2.6.0-cu126-vllm0.8.3-flashinfer0.2.2-cxx11abi0 - options: --gpus all --shm-size=50g # Visual dataloader requires large memory - steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - fetch-depth: 0 - - name: Install the current repository - run: | - pip3 install -e .[test,geo,vllm] - # Geo3k - - name: Prepare Geo3k dataset + bash tests/special_e2e/ppo_trainer/run_function_reward.sh + - name: Running GEO3K VLM E2E with rmpad using torch fused kernel (Qwen2.5-VL) run: | ray stop --force - python3 examples/data_preprocess/geo3k.py - - name: Running Geo3k VLM E2E with rmpad using fused kernel (Qwen2.5-VL) + FUSED_KERNELS=True TRAIN_FILES=$HOME/data/geo3k/train.parquet VAL_FILES=$HOME/data/geo3k/test.parquet \ + MAX_PROMPT_LEN=1536 MAX_RESPONSE_LEN=1536 \ + MODEL_ID=Qwen/Qwen2.5-VL-3B-Instruct \ + ADV_ESTIMATOR=grpo RM_PAD=True USE_KL=True ENABLE_CHUNKED_PREFILL=False \ + ENGINE=sglang GPU_MEMORY_UTILIZATION=0.6 ACTOR_FSDP_PARAM_OFFLOAD=True \ + ACTOR_FSDP_OPTIMIZER_OFFLOAD=True REF_FSDP_PARAM_OFFLOAD=True \ + bash tests/special_e2e/ppo_trainer/run_function_reward.sh + - name: Running GEO3K VLM E2E with rmpad using triton fused kernel (Qwen2.5-VL) run: | ray stop --force - FUSED_KERNELS=True TRAIN_FILES=$HOME/data/geo3k/train.parquet VAL_FILES=$HOME/data/geo3k/test.parquet \ + FUSED_KERNELS=True FUSED_KERNEL_BACKEND=triton \ + TRAIN_FILES=$HOME/data/geo3k/train.parquet VAL_FILES=$HOME/data/geo3k/test.parquet \ MAX_PROMPT_LEN=1536 MAX_RESPONSE_LEN=1536 \ MODEL_ID=Qwen/Qwen2.5-VL-3B-Instruct \ ADV_ESTIMATOR=grpo RM_PAD=True USE_KL=True ENABLE_CHUNKED_PREFILL=False \ - GPU_MEMORY_UTILIZATION=0.6 ACTOR_FSDP_PARAM_OFFLOAD=True \ + ENGINE=sglang GPU_MEMORY_UTILIZATION=0.6 ACTOR_FSDP_PARAM_OFFLOAD=True \ ACTOR_FSDP_OPTIMIZER_OFFLOAD=True REF_FSDP_PARAM_OFFLOAD=True \ - bash tests/e2e/ppo_trainer/run_function_reward.sh + bash tests/special_e2e/ppo_trainer/run_function_reward.sh - e2e_ppo_trainer_fused_kernels_sglang: + e2e_ppo_trainer_sglang_vlm_multiturn_with_tool: runs-on: [L20x8] needs: pre_commit_for_ppo timeout-minutes: 40 # Increase this timeout value as needed @@ -342,8 +404,8 @@ jobs: HF_ENDPOINT: "https://hf-mirror.com" HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable container: - image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6.post5-mcore0.12.0-te2.3 - options: --gpus all --shm-size=50g # Visual dataloader requires large memory + image: verlai/verl:app-verl0.4-sglang0.4.6.post5-vllm0.8.5-mcore0.12.1 + options: --gpus all --shm-size=10g steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: @@ -351,17 +413,15 @@ jobs: - name: Install the current repository run: | pip3 install -e .[test,geo,gpu,sglang] - - name: Prepare Geo3k dataset + - name: Prepare geo3k dataset with tool run: | ray stop --force - python3 examples/data_preprocess/geo3k.py - - name: Running Geo3k VLM E2E with rmpad using fused kernel (Qwen2.5-VL) + python3 examples/data_preprocess/geo3k_multiturn_w_tool.py --local_dir $HOME/data/geo3k_verl_sgl_multi_turn_preprocessed + - name: Running GEO3K with tool E2E training tests on 8 L20 GPUs with rmpad using function rm and save ckpt with sglang run: | ray stop --force - FUSED_KERNELS=True TRAIN_FILES=$HOME/data/geo3k/train.parquet VAL_FILES=$HOME/data/geo3k/test.parquet \ - MAX_PROMPT_LEN=1536 MAX_RESPONSE_LEN=1536 \ - MODEL_ID=Qwen/Qwen2.5-VL-3B-Instruct \ - ADV_ESTIMATOR=grpo RM_PAD=True USE_KL=True ENABLE_CHUNKED_PREFILL=False \ - ENGINE=sglang GPU_MEMORY_UTILIZATION=0.6 ACTOR_FSDP_PARAM_OFFLOAD=True \ - ACTOR_FSDP_OPTIMIZER_OFFLOAD=True REF_FSDP_PARAM_OFFLOAD=True \ - bash tests/e2e/ppo_trainer/run_function_reward.sh \ No newline at end of file + bash tests/special_e2e/run_geo3k_fsdp_sgl_multiturn_w_tool.sh + - name: Running GEO3K with tool E2E training tests with FSDP2 + run: | + ray stop --force + FSDP_STRATEGY=fsdp2 bash tests/special_e2e/run_geo3k_fsdp_sgl_multiturn_w_tool.sh diff --git a/.github/workflows/e2e_ppo_trainer_megatron.yml b/.github/workflows/e2e_ppo_trainer_megatron.yml deleted file mode 100644 index 8fc466ec7..000000000 --- a/.github/workflows/e2e_ppo_trainer_megatron.yml +++ /dev/null @@ -1,267 +0,0 @@ -name: e2e_ppo_trainer_megatron -# latest version: Megatron-LM core_r0.11.0 https://github.com/NVIDIA/Megatron-LM/tree/core_r0.11.0 - -on: - # Trigger the workflow on push or pull request, - # but only for the main branch - push: - branches: - - main - - v0.* - pull_request: - branches: - - main - - v0.* - paths: - - "**/*.py" - # Other entrypoints - - "!examples/**" - - "!tests/**" - - "!verl/trainer/main_*.py" - - "!verl/trainer/fsdp_sft_trainer.py" - # Recipes - - "!recipe/**" - # FSDP - - "!verl/workers/**/*dp_*.py" - # Entrypoints - - ".github/workflows/e2e_ppo_trainer_megatron.yml" - - "examples/data_preprocess/gsm8k.py" - - "tests/e2e/run_ppo_trainer_megatron.sh" - - "verl/trainer/main_ppo.py" - - "verl/trainer/config/ppo_megatron_trainer.yaml" - -# Cancel jobs on the same ref if a new one is triggered -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} - -# Declare permissions just read content. -permissions: - contents: read - -jobs: - e2e_ppo_trainer_megatron-deepseek: - runs-on: [L20x8] - timeout-minutes: 60 # Increase this timeout value as needed - env: - HTTP_PROXY: ${{ secrets.PROXY_HTTP }} - HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} - NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" - HF_ENDPOINT: "https://hf-mirror.com" - HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable - container: - image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6.post5-mcore0.12.0-te2.3 - options: --gpus all --shm-size=10g - steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - fetch-depth: 0 - - name: Install the current repository - run: | - pip3 install --no-deps -e .[test] - - name: Prepare GSM8K dataset - run: | - python3 examples/data_preprocess/gsm8k.py - - name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (DeepSeek) - run: | - ray stop --force - ALL_OFFLOAD=True SAVE_FREQ=1 MODEL_ID=deepseek-ai/deepseek-coder-1.3b-instruct bash tests/e2e/run_ppo_trainer_megatron.sh - - name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (DeepSeek) - run: | - ray stop --force - RESUME_MODE=auto MODEL_ID=deepseek-ai/deepseek-coder-1.3b-instruct TOTAL_TRAIN_STEPS=2 bash tests/e2e/run_ppo_trainer_megatron.sh - - name: Test Megatron checkpoints merging function (DeepSeek Actor and Critic) - run: | - exp_name="deepseek-coder-1.3b-instruct-megatron-gsm8k-minimal" - python scripts/model_merger.py test --backend megatron --local_dir checkpoints/verl-test/${exp_name}/global_step_1/actor --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/actor/huggingface - python scripts/model_merger.py test --backend megatron --is-value-model --local_dir checkpoints/verl-test/${exp_name}/global_step_1/critic --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/critic/huggingface - - name: Running GRPO GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (Deepseek) - run: | - ray stop --force - ADV_ESTIMATOR=grpo USE_DYNAMIC_BSZ=False MODEL_ID=deepseek-ai/deepseek-coder-1.3b-instruct bash tests/e2e/run_ppo_trainer_megatron.sh - - name: clean up - run: | - rm -rf checkpoints - e2e_ppo_trainer_megatron-qwen3: - runs-on: [L20x8] - timeout-minutes: 60 # Increase this timeout value as needed - env: - HTTP_PROXY: ${{ secrets.PROXY_HTTP }} - HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} - NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" - HF_ENDPOINT: "https://hf-mirror.com" - HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable - container: - image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6.post5-mcore0.12.0-te2.3 - options: --gpus all --shm-size=10g - steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - fetch-depth: 0 - - name: Install the current repository - run: | - pip3 install --no-deps -e .[test] - - name: Prepare GSM8K dataset - run: | - python3 examples/data_preprocess/gsm8k.py - - name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (Qwen3) with validation and saving - run: | - ray stop --force - ALL_OFFLOAD=True VAL_BEFORE_TRAIN=True TEST_FREQ=1 SAVE_FREQ=1 MODEL_ID=Qwen/Qwen3-0.6B bash tests/e2e/run_ppo_trainer_megatron.sh - - name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (Qwen3) after resuming - run: | - ray stop --force - RESUME_MODE=auto MODEL_ID=Qwen/Qwen3-0.6B bash tests/e2e/run_ppo_trainer_megatron.sh - - name: Test Megatron checkpoints merging function (Qwen3 Actor and Critic) - run: | - exp_name="qwen3-0.6b-megatron-gsm8k-minimal" - python scripts/model_merger.py test --backend megatron --tie-word-embedding --local_dir checkpoints/verl-test/${exp_name}/global_step_1/actor --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/actor/huggingface - python scripts/model_merger.py test --backend megatron --is-value-model --local_dir checkpoints/verl-test/${exp_name}/global_step_1/critic --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/critic/huggingface - - name: Running GRPO GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (Qwen3) - run: | - ray stop --force - ADV_ESTIMATOR=grpo USE_DYNAMIC_BSZ=False MODEL_ID=Qwen/Qwen3-0.6B bash tests/e2e/run_ppo_trainer_megatron.sh - - name: clean up - run: | - rm -rf checkpoints - e2e_ppo_trainer_megatron-different-train-infer-tp-qwen-tie-embedding: - runs-on: [L20x8] - timeout-minutes: 60 # Increase this timeout value as needed - env: - HTTP_PROXY: ${{ secrets.PROXY_HTTP }} - HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} - NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" - HF_ENDPOINT: "https://hf-mirror.com" - HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable - container: - image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6.post5-mcore0.12.0-te2.3 - options: --gpus all --shm-size=10g - steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - fetch-depth: 0 - - name: Install the current repository - run: | - pip3 install --no-deps -e .[test] - - name: Prepare GSM8K dataset - run: | - python3 examples/data_preprocess/gsm8k.py - - name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with tie-embedding Megatron (Qwen) with train tp > infer tp - run: | - ray stop --force - VAL_BEFORE_TRAIN=True TEST_FREQ=1 SAVE_FREQ=1 TRAIN_TP=2 INFER_TP=1 MODEL_ID=Qwen/Qwen2.5-1.5B bash tests/e2e/run_ppo_trainer_megatron.sh - - name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (Qwen) with train tp < infer tp - run: | - ray stop --force - VAL_BEFORE_TRAIN=True TEST_FREQ=1 SAVE_FREQ=1 TRAIN_TP=1 INFER_TP=2 MODEL_ID=Qwen/Qwen2.5-1.5B bash tests/e2e/run_ppo_trainer_megatron.sh - - name: clean up - run: | - rm -rf checkpoints - e2e_ppo_trainer_megatron-qwen-override-transformer-config: - runs-on: [L20x8] - timeout-minutes: 60 # Increase this timeout value as needed - env: - HTTP_PROXY: ${{ secrets.PROXY_HTTP }} - HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} - NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" - HF_ENDPOINT: "https://hf-mirror.com" - HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable - container: - image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6.post5-mcore0.12.0-te2.3 - options: --gpus all --shm-size=10g - steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - fetch-depth: 0 - - name: Install the current repository - run: | - pip3 install --no-deps -e .[test] - - name: Prepare GSM8K dataset - run: | - python3 examples/data_preprocess/gsm8k.py - - name: Prepare dist_ckpt of Qwen2.5-0.5B, uneven layer distribution only supports dist_ckpt - run: | - python3 scripts/converter_hf_to_mcore.py --hf_model_path ${HOME}/models/Qwen/Qwen2.5-0.5B --output_path checkpoints/verl-test/qwen2.5-0.5b-megatron - - name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (Qwen) - run: | - ray stop --force - SAVE_FREQ=1 COMMON_PP=4 COMMON_VPP=null COMMON_CP=1 SKIP_SAVE_HF_MODEL=1 bash tests/e2e/run_ppo_trainer_megatron.sh +actor_rollout_ref.actor.megatron.override_transformer_config.num_layers_in_first_pipeline_stage=8 +actor_rollout_ref.actor.megatron.override_transformer_config.num_layers_in_last_pipeline_stage=4 actor_rollout_ref.actor.megatron.use_dist_checkpointing=true actor_rollout_ref.actor.megatron.dist_checkpointing_path=checkpoints/verl-test/qwen2.5-0.5b-megatron actor_rollout_ref.ref.megatron.use_dist_checkpointing=true actor_rollout_ref.ref.megatron.dist_checkpointing_path=checkpoints/verl-test/qwen2.5-0.5b-megatron critic.megatron.use_dist_checkpointing=true critic.megatron.dist_checkpointing_path=checkpoints/verl-test/qwen2.5-0.5b-megatron reward_model.megatron.use_dist_checkpointing=true reward_model.megatron.dist_checkpointing_path=checkpoints/verl-test/qwen2.5-0.5b-megatron - cp -r checkpoints checkpoints-dut - SAVE_FREQ=1 COMMON_PP=4 COMMON_VPP=null COMMON_CP=1 bash tests/e2e/run_ppo_trainer_megatron.sh - - name: Test Megatron checkpoints merging function (Qwen Actor and Critic) - run: | - exp_name="qwen2.5-0.5b-megatron-gsm8k-minimal" - python scripts/model_merger.py test --backend megatron --tie-word-embedding --local_dir checkpoints-dut/verl-test/${exp_name}/global_step_1/actor --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/actor/huggingface - python scripts/model_merger.py test --backend megatron --is-value-model --local_dir checkpoints-dut/verl-test/${exp_name}/global_step_1/critic --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/critic/huggingface - - name: clean up - run: | - rm -rf checkpoints - e2e_ppo_trainer_megatron-deepseek-override-transformer-config: - runs-on: [L20x8] - timeout-minutes: 60 # Increase this timeout value as needed - env: - HTTP_PROXY: ${{ secrets.PROXY_HTTP }} - HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} - NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" - HF_ENDPOINT: "https://hf-mirror.com" - HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable - container: - image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6.post5-mcore0.12.0-te2.3 - options: --gpus all --shm-size=10g - steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - fetch-depth: 0 - - name: Install the current repository - run: | - pip3 install --no-deps -e .[test] - - name: Prepare GSM8K dataset - run: | - python3 examples/data_preprocess/gsm8k.py - - name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (DeepSeek) - run: | - ray stop --force - SAVE_FREQ=1 MODEL_ID=deepseek-ai/deepseek-coder-1.3b-instruct COMMON_PP=2 COMMON_VPP=null bash tests/e2e/run_ppo_trainer_megatron.sh +actor_rollout_ref.actor.megatron.override_transformer_config.account_for_embedding_in_pipeline_split=true +actor_rollout_ref.actor.megatron.override_transformer_config.account_for_loss_in_pipeline_split=true - - name: Test Megatron checkpoints merging function (DeepSeek Actor and Critic) - run: | - exp_name="deepseek-coder-1.3b-instruct-megatron-gsm8k-minimal" - python scripts/model_merger.py test --backend megatron --local_dir checkpoints/verl-test/${exp_name}/global_step_1/actor --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/actor/huggingface - python scripts/model_merger.py test --backend megatron --is-value-model --local_dir checkpoints/verl-test/${exp_name}/global_step_1/critic --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/critic/huggingface - - name: clean up - run: | - rm -rf checkpoints - e2e_ppo_trainer_megatron-moe-expert-parallel: - runs-on: [L20x8] - timeout-minutes: 60 # Increase this timeout value as needed - env: - HTTP_PROXY: ${{ secrets.PROXY_HTTP }} - HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} - NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" - HF_ENDPOINT: "https://hf-mirror.com" - HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable - container: - image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6-mcore0.12.0-te2.3 - options: --gpus all --shm-size=10g - steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - fetch-depth: 0 - - name: Install the current repository - run: | - pip3 install --no-deps -e .[test] - - name: Prepare GSM8K dataset - run: | - python3 examples/data_preprocess/gsm8k.py - - name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (DeepSeek) - run: | - ray stop --force - ADV_ESTIMATOR=grpo USE_DUMMY_MODEL=True DUMMY_MODEL_CONFIG_PATH=tests/e2e/ppo_trainer/expert_parallel/qwen2moe_minimal.json \ - PPO_MAX_TOKEN_LEN=512 FWD_MAX_TOKEN_LEN=512 \ - MAX_PROMPT_LENGTH=256 MAX_RESPONSE_LENGTH=256 \ - MODEL_ID=Qwen/Qwen1.5-MoE-A2.7B-Chat \ - COMMON_PP=2 COMMON_VPP=null COMMON_CP=1 COMMON_TP=4 COMMON_EP=4 COMMON_ETP=1 INFER_TP=8 \ - USE_DIST_CKPT=True ALL_OFFLOAD=True SKIP_SAVE_HF_MODEL=1 bash tests/e2e/run_ppo_trainer_megatron.sh - - name: clean up - run: | - rm -rf checkpoints - diff --git a/.github/workflows/e2e_ppo_trainer_megatron_sglang.yml b/.github/workflows/e2e_ppo_trainer_megatron_sglang.yml new file mode 100644 index 000000000..988a78113 --- /dev/null +++ b/.github/workflows/e2e_ppo_trainer_megatron_sglang.yml @@ -0,0 +1,366 @@ +# # Tests layout + +# Each folder under tests/ corresponds to a test category for a sub-namespace in verl. For instance: +# - `tests/trainer` for testing functionality related to `verl/trainer` +# - `tests/models` for testing functionality related to `verl/models` +# - ... + +# There are a few folders with `special_` prefix, created for special purposes: +# - `special_distributed`: unit tests that must run with multiple GPUs +# - `special_e2e`: end-to-end tests with training/generation scripts +# - `special_npu`: tests for NPUs +# - `special_sanity`: a suite of quick sanity tests +# - `special_standalone`: a set of test that are designed to run in dedicated environments + +# Accelerators for tests +# - By default tests are run with GPU available, except for the ones under `special_npu`, and any test script whose name ends with `on_cpu.py`. +# - For test scripts with `on_cpu.py` name suffix would be tested on CPU resources in linux environment. + +# # Workflow layout + +# All CI tests are configured by yaml files in `.github/workflows/`. Here's an overview of all test configs: +# 1. A list of always triggered CPU sanity tests: `check-pr-title.yml`, `secrets_scan.yml`, `check-pr-title,yml`, `pre-commit.yml`, `doc.yml` +# 2. Some heavy multi-GPU unit tests, such as `model.yml`, `vllm.yml`, `sgl.yml` +# 3. End-to-end tests: `e2e_*.yml` +# 4. Unit tests +# - `cpu_unit_tests.yml`, run pytest on all scripts with file name pattern `tests/**/test_*_on_cpu.py` +# - `gpu_unit_tests.yml`, run pytest on all scripts with file without the `on_cpu.py` suffix. +# - Since cpu/gpu unit tests by default runs all tests under `tests`, please make sure tests are manually excluded in them when +# - new workflow yaml is added to `.github/workflows` +# - new tests are added to workflow mentioned in 2. + +name: e2e_ppo_trainer_megatron_sglang + +on: + # Trigger the workflow on push or pull request, + # but only for the main branch. + # For push, for now only anti-patterns are specified so it is more conservative + # and achieves higher coverage. + push: + branches: + - main + - v0.* + paths: + - "**/*.py" + # Other entrypoints + - "!verl/trainer/fsdp_sft_trainer.py" + # Recipes + - "!recipe/**" + # FSDP + - "!verl/workers/**/*dp_*.py" + pull_request: + branches: + - main + - v0.* + paths: + - "**/*.py" + # Other entrypoints + - "!docker/**" + # Docs + - "!**/*.md" + - "!docs/**" + - "!examples/**" + - "!tests/**" + - "!verl/trainer/main_*.py" + - "!verl/trainer/fsdp_sft_trainer.py" + # Recipes + - "!recipe/**" + # FSDP + - "!verl/workers/**/*dp_*.py" + # Entrypoints + - ".github/workflows/e2e_ppo_trainer_megatron_sglang.yml" + - "examples/data_preprocess/gsm8k.py" + - "examples/data_preprocess/geo3k.py" + - "tests/special_e2e/run_ppo_trainer_megatron.sh" + - "verl/trainer/main_ppo.py" + - "verl/trainer/config/ppo_megatron_trainer.yaml" + +# Cancel jobs on the same ref if a new one is triggered +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} + +# Declare permissions just read content. +permissions: + contents: read + +env: + IMAGE: "verl-ci-cn-beijing.cr.volces.com/verlai/verl:app-verl0.4-sglang0.4.6.post5-vllm0.8.5-mcore0.12.1" + DYNAMIC_RUNNER_ENDPOINT: "https://sd10g3clalm04ug7alq90.apigateway-cn-beijing.volceapi.com/runner" + +jobs: + setup: + if: github.repository_owner == 'volcengine' + runs-on: ubuntu-latest + outputs: + runner-label: ${{ steps.create-runner.outputs.runner-label }} + mlp-task-id: ${{ steps.create-runner.outputs.mlp-task-id }} + steps: + - uses: actions/checkout@v4 + - id: create-runner + uses: volcengine/vemlp-github-runner@v1 + with: + mode: "create" + faas-url: "${{ env.DYNAMIC_RUNNER_ENDPOINT }}" + mlp-image: "${{ env.IMAGE }}" + + e2e_ppo_trainer_megatron-deepseek: + needs: setup + runs-on: ["${{ needs.setup.outputs.runner-label || 'L20x8' }}"] + timeout-minutes: 60 # Increase this timeout value as needed + env: + HTTP_PROXY: ${{ secrets.PROXY_HTTP }} + HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} + NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" + HF_ENDPOINT: "https://hf-mirror.com" + HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + fetch-depth: 0 + - name: Install the current repository + run: | + pip3 install --no-deps -e .[test] + - name: Prepare GSM8K dataset + run: | + python3 examples/data_preprocess/gsm8k.py + - name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (DeepSeek) + run: | + ray stop --force + ENGINE=sglang ALL_OFFLOAD=True SAVE_FREQ=1 MODEL_ID=deepseek-ai/deepseek-coder-1.3b-instruct bash tests/special_e2e/run_ppo_trainer_megatron.sh + - name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (DeepSeek) + run: | + ray stop --force + export VLLM_USE_V1=1 + ray start --head + ENGINE=sglang MODE=async RESUME_MODE=auto MODEL_ID=deepseek-ai/deepseek-coder-1.3b-instruct TOTAL_TRAIN_STEPS=2 bash tests/special_e2e/run_ppo_trainer_megatron.sh + - name: Test Megatron checkpoints merging function (DeepSeek Actor and Critic) + run: | + exp_name="deepseek-coder-1.3b-instruct-megatron-gsm8k-minimal" + python -m verl.model_merger test --backend megatron --local_dir checkpoints/verl-test/${exp_name}/global_step_1/actor --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/actor/huggingface + python -m verl.model_merger test --backend megatron --is-value-model --local_dir checkpoints/verl-test/${exp_name}/global_step_1/critic --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/critic/huggingface + - name: Running GRPO GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (Deepseek) + run: | + ray stop --force + ENGINE=sglang ADV_ESTIMATOR=grpo USE_DYNAMIC_BSZ=False MODEL_ID=deepseek-ai/deepseek-coder-1.3b-instruct bash tests/special_e2e/run_ppo_trainer_megatron.sh + - name: clean up + run: | + rm -rf checkpoints + e2e_ppo_trainer_megatron-qwen3: + needs: setup + runs-on: ["${{ needs.setup.outputs.runner-label || 'L20x8' }}"] + timeout-minutes: 60 # Increase this timeout value as needed + env: + HTTP_PROXY: ${{ secrets.PROXY_HTTP }} + HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} + NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" + HF_ENDPOINT: "https://hf-mirror.com" + HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + fetch-depth: 0 + - name: Install the current repository + run: | + pip3 install --no-deps -e .[test] + - name: Prepare GSM8K dataset + run: | + python3 examples/data_preprocess/gsm8k.py + - name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (Qwen3) with validation and saving + run: | + ray stop --force + ENGINE=sglang ALL_OFFLOAD=True VAL_BEFORE_TRAIN=True TEST_FREQ=1 SAVE_FREQ=1 MODEL_ID=Qwen/Qwen3-0.6B bash tests/special_e2e/run_ppo_trainer_megatron.sh + - name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (Qwen3) testing learning rate scheduler + run: | + ray stop --force + ENGINE=sglang LR_WARMUP_STEPS=1 TOTAL_TRAIN_STEPS=2 MODEL_ID=Qwen/Qwen3-0.6B bash tests/special_e2e/run_ppo_trainer_megatron.sh + + - name: Test Megatron checkpoints merging function (Qwen3 Actor and Critic) + run: | + exp_name="qwen3-0.6b-megatron-gsm8k-minimal" + python -m verl.model_merger test --backend megatron --tie-word-embedding --local_dir checkpoints/verl-test/${exp_name}/global_step_1/actor --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/actor/huggingface + python -m verl.model_merger test --backend megatron --is-value-model --local_dir checkpoints/verl-test/${exp_name}/global_step_1/critic --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/critic/huggingface + - name: clean up + run: | + rm -rf checkpoints + e2e_ppo_trainer_megatron-different-train-infer-tp-qwen-tie-embedding: + needs: setup + runs-on: ["${{ needs.setup.outputs.runner-label || 'L20x8' }}"] + timeout-minutes: 60 # Increase this timeout value as needed + env: + HTTP_PROXY: ${{ secrets.PROXY_HTTP }} + HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} + NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" + HF_ENDPOINT: "https://hf-mirror.com" + HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + fetch-depth: 0 + - name: Install the current repository + run: | + pip3 install --no-deps -e .[test] + - name: Prepare GSM8K dataset + run: | + python3 examples/data_preprocess/gsm8k.py + - name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with tie-embedding Megatron (Qwen) with train tp > infer tp + run: | + ray stop --force + ENGINE=sglang VAL_BEFORE_TRAIN=True TEST_FREQ=1 SAVE_FREQ=1 TRAIN_TP=2 INFER_TP=1 MODEL_ID=Qwen/Qwen2.5-1.5B bash tests/special_e2e/run_ppo_trainer_megatron.sh + - name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (Qwen) with train tp < infer tp + run: | + ray stop --force + ENGINE=sglang VAL_BEFORE_TRAIN=True TEST_FREQ=1 SAVE_FREQ=1 TRAIN_TP=1 INFER_TP=2 MODEL_ID=Qwen/Qwen2.5-1.5B bash tests/special_e2e/run_ppo_trainer_megatron.sh + - name: clean up + run: | + rm -rf checkpoints + e2e_ppo_trainer_megatron-qwen-override-transformer-config: + needs: setup + runs-on: ["${{ needs.setup.outputs.runner-label || 'L20x8' }}"] + timeout-minutes: 60 # Increase this timeout value as needed + env: + HTTP_PROXY: ${{ secrets.PROXY_HTTP }} + HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} + NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" + HF_ENDPOINT: "https://hf-mirror.com" + HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + fetch-depth: 0 + - name: Install the current repository + run: | + pip3 install --no-deps -e .[test] + - name: Prepare GSM8K dataset + run: | + python3 examples/data_preprocess/gsm8k.py + - name: Prepare dist_ckpt of Qwen2.5-0.5B, uneven layer distribution only supports dist_ckpt + run: | + python3 scripts/converter_hf_to_mcore.py --hf_model_path ${HOME}/models/Qwen/Qwen2.5-0.5B --output_path checkpoints/verl-test/qwen2.5-0.5b-megatron + - name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (Qwen) + run: | + ray stop --force + ENGINE=sglang SAVE_FREQ=1 COMMON_PP=4 COMMON_VPP=null COMMON_CP=1 SKIP_SAVE_HF_MODEL=1 bash tests/special_e2e/run_ppo_trainer_megatron.sh +actor_rollout_ref.actor.megatron.override_transformer_config.num_layers_in_first_pipeline_stage=8 +actor_rollout_ref.actor.megatron.override_transformer_config.num_layers_in_last_pipeline_stage=4 actor_rollout_ref.actor.megatron.use_dist_checkpointing=true actor_rollout_ref.actor.megatron.dist_checkpointing_path=checkpoints/verl-test/qwen2.5-0.5b-megatron actor_rollout_ref.ref.megatron.use_dist_checkpointing=true actor_rollout_ref.ref.megatron.dist_checkpointing_path=checkpoints/verl-test/qwen2.5-0.5b-megatron critic.megatron.use_dist_checkpointing=true critic.megatron.dist_checkpointing_path=checkpoints/verl-test/qwen2.5-0.5b-megatron reward_model.megatron.use_dist_checkpointing=true reward_model.megatron.dist_checkpointing_path=checkpoints/verl-test/qwen2.5-0.5b-megatron + cp -r checkpoints checkpoints-dut + ENGINE=sglang SAVE_FREQ=1 COMMON_PP=4 COMMON_VPP=null COMMON_CP=1 bash tests/special_e2e/run_ppo_trainer_megatron.sh + - name: Test Megatron checkpoints merging function (Qwen Actor and Critic) + run: | + exp_name="qwen2.5-0.5b-megatron-gsm8k-minimal" + python -m verl.model_merger test --backend megatron --tie-word-embedding --local_dir checkpoints-dut/verl-test/${exp_name}/global_step_1/actor --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/actor/huggingface + python -m verl.model_merger test --backend megatron --is-value-model --local_dir checkpoints-dut/verl-test/${exp_name}/global_step_1/critic --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/critic/huggingface + - name: clean up + run: | + rm -rf checkpoints + e2e_ppo_trainer_megatron-deepseek-override-transformer-config: + needs: setup + runs-on: ["${{ needs.setup.outputs.runner-label || 'L20x8' }}"] + timeout-minutes: 60 # Increase this timeout value as needed + env: + HTTP_PROXY: ${{ secrets.PROXY_HTTP }} + HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} + NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" + HF_ENDPOINT: "https://hf-mirror.com" + HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + fetch-depth: 0 + - name: Install the current repository + run: | + pip3 install --no-deps -e .[test] + - name: Prepare GSM8K dataset + run: | + python3 examples/data_preprocess/gsm8k.py + - name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (DeepSeek) + run: | + ray stop --force + ENGINE=sglang SAVE_FREQ=1 MODEL_ID=deepseek-ai/deepseek-coder-1.3b-instruct COMMON_PP=2 COMMON_VPP=null bash tests/special_e2e/run_ppo_trainer_megatron.sh +actor_rollout_ref.actor.megatron.override_transformer_config.account_for_embedding_in_pipeline_split=true +actor_rollout_ref.actor.megatron.override_transformer_config.account_for_loss_in_pipeline_split=true + - name: Test Megatron checkpoints merging function (DeepSeek Actor and Critic) + run: | + exp_name="deepseek-coder-1.3b-instruct-megatron-gsm8k-minimal" + python -m verl.model_merger test --backend megatron --local_dir checkpoints/verl-test/${exp_name}/global_step_1/actor --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/actor/huggingface + python -m verl.model_merger test --backend megatron --is-value-model --local_dir checkpoints/verl-test/${exp_name}/global_step_1/critic --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/critic/huggingface + - name: clean up + run: | + rm -rf checkpoints + e2e_ppo_trainer_megatron-moe-expert-parallel: + needs: setup + runs-on: ["${{ needs.setup.outputs.runner-label || 'L20x8' }}"] + timeout-minutes: 60 # Increase this timeout value as needed + env: + HTTP_PROXY: ${{ secrets.PROXY_HTTP }} + HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} + NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" + HF_ENDPOINT: "https://hf-mirror.com" + HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + fetch-depth: 0 + - name: Install the current repository + run: | + pip3 install --no-deps -e .[test] + - name: Prepare GSM8K dataset + run: | + python3 examples/data_preprocess/gsm8k.py + - name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (DeepSeek) + run: | + ray stop --force + ADV_ESTIMATOR=grpo USE_DUMMY_MODEL=True DUMMY_MODEL_CONFIG_PATH=tests/special_e2e/ppo_trainer/expert_parallel/qwen2moe_minimal.json \ + PPO_MAX_TOKEN_LEN=512 FWD_MAX_TOKEN_LEN=512 \ + MAX_PROMPT_LENGTH=256 MAX_RESPONSE_LENGTH=256 \ + MODEL_ID=Qwen/Qwen1.5-MoE-A2.7B-Chat \ + ENGINE=sglang COMMON_PP=2 COMMON_VPP=null COMMON_CP=1 COMMON_TP=4 COMMON_EP=4 COMMON_ETP=1 INFER_TP=8 \ + USE_DIST_CKPT=True ALL_OFFLOAD=True SKIP_SAVE_HF_MODEL=1 bash tests/special_e2e/run_ppo_trainer_megatron.sh + - name: clean up + run: | + rm -rf checkpoints + e2e_ppo_trainer_megatron-qwen2_5vl-3b: + needs: setup + runs-on: ["${{ needs.setup.outputs.runner-label || 'L20x8' }}"] + timeout-minutes: 60 # Increase this timeout value as needed + env: + HTTP_PROXY: ${{ secrets.PROXY_HTTP }} + HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} + NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" + HF_ENDPOINT: "https://hf-mirror.com" + HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + fetch-depth: 0 + - name: Install the current repository + run: | + pip3 install --no-deps -e .[test] + - name: Prepare Geo3k dataset + run: | + python3 examples/data_preprocess/geo3k.py + - name: Prepare dist_ckpt of Qwen2.5-VL-3B, only supports dist_ckpt + run: | + python3 scripts/converter_hf_to_mcore.py --hf_model_path ${HOME}/models/Qwen/Qwen2.5-VL-3B-Instruct --output_path checkpoints/verl-test/qwen2.5-vl-3b-megatron + - name: Running Geo3k E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (Qwen) + run: | + ray stop --force + ENGINE=sglang TRAIN_FILES=${HOME}/data/geo3k/train.parquet VAL_FILES=${HOME}/data/geo3k/test.parquet MAX_PROMPT_LENGTH=1024 MAX_RESPONSE_LENGTH=2048 MODEL_ID=Qwen/Qwen2.5-VL-3B-Instruct ADV_ESTIMATOR=grpo USE_DYNAMIC_BSZ=False SKIP_SAVE_HF_MODEL=1 COMMON_PP=4 COMMON_VPP=null COMMON_CP=1 COMMON_TP=2 USE_DIST_CKPT=true DIST_CKPT_PATH=checkpoints/verl-test/qwen2.5-vl-3b-megatron bash tests/special_e2e/run_ppo_trainer_megatron.sh + - name: clean up + run: | + rm -rf checkpoints + + cleanup: + runs-on: ubuntu-latest + needs: [setup, + e2e_ppo_trainer_megatron-deepseek, + e2e_ppo_trainer_megatron-qwen3, + e2e_ppo_trainer_megatron-different-train-infer-tp-qwen-tie-embedding, + e2e_ppo_trainer_megatron-qwen-override-transformer-config, + e2e_ppo_trainer_megatron-deepseek-override-transformer-config, + e2e_ppo_trainer_megatron-moe-expert-parallel, + e2e_ppo_trainer_megatron-qwen2_5vl-3b] + if: always() + steps: + - id: destroy-runner + uses: volcengine/vemlp-github-runner@v1 + with: + mode: "destroy" + faas-url: "${{ env.DYNAMIC_RUNNER_ENDPOINT }}" + mlp-task-id: "${{ needs.setup.outputs.mlp-task-id }}" \ No newline at end of file diff --git a/.github/workflows/e2e_ppo_trainer_megatron_vllm.yml b/.github/workflows/e2e_ppo_trainer_megatron_vllm.yml new file mode 100644 index 000000000..b89e890cd --- /dev/null +++ b/.github/workflows/e2e_ppo_trainer_megatron_vllm.yml @@ -0,0 +1,372 @@ +# # Tests layout + +# Each folder under tests/ corresponds to a test category for a sub-namespace in verl. For instance: +# - `tests/trainer` for testing functionality related to `verl/trainer` +# - `tests/models` for testing functionality related to `verl/models` +# - ... + +# There are a few folders with `special_` prefix, created for special purposes: +# - `special_distributed`: unit tests that must run with multiple GPUs +# - `special_e2e`: end-to-end tests with training/generation scripts +# - `special_npu`: tests for NPUs +# - `special_sanity`: a suite of quick sanity tests +# - `special_standalone`: a set of test that are designed to run in dedicated environments + +# Accelerators for tests +# - By default tests are run with GPU available, except for the ones under `special_npu`, and any test script whose name ends with `on_cpu.py`. +# - For test scripts with `on_cpu.py` name suffix would be tested on CPU resources in linux environment. + +# # Workflow layout + +# All CI tests are configured by yaml files in `.github/workflows/`. Here's an overview of all test configs: +# 1. A list of always triggered CPU sanity tests: `check-pr-title.yml`, `secrets_scan.yml`, `check-pr-title,yml`, `pre-commit.yml`, `doc.yml` +# 2. Some heavy multi-GPU unit tests, such as `model.yml`, `vllm.yml`, `sgl.yml` +# 3. End-to-end tests: `e2e_*.yml` +# 4. Unit tests +# - `cpu_unit_tests.yml`, run pytest on all scripts with file name pattern `tests/**/test_*_on_cpu.py` +# - `gpu_unit_tests.yml`, run pytest on all scripts with file without the `on_cpu.py` suffix. +# - Since cpu/gpu unit tests by default runs all tests under `tests`, please make sure tests are manually excluded in them when +# - new workflow yaml is added to `.github/workflows` +# - new tests are added to workflow mentioned in 2. + +name: e2e_ppo_trainer_megatron_vllm + +on: + # Trigger the workflow on push or pull request, + # but only for the main branch. + # For push, for now only anti-patterns are specified so it is more conservative + # and achieves higher coverage. + push: + branches: + - main + - v0.* + paths: + - "**/*.py" + # Other entrypoints + - "!verl/trainer/fsdp_sft_trainer.py" + # Recipes + - "!recipe/**" + # FSDP + - "!verl/workers/**/*dp_*.py" + pull_request: + branches: + - main + - v0.* + paths: + - "**/*.py" + # Other entrypoints + - "!docker/**" + # Docs + - "!**/*.md" + - "!docs/**" + - "!examples/**" + - "!tests/**" + - "!verl/trainer/main_*.py" + - "!verl/trainer/fsdp_sft_trainer.py" + # Recipes + - "!recipe/**" + # FSDP + - "!verl/workers/**/*dp_*.py" + # Entrypoints + - ".github/workflows/e2e_ppo_trainer_megatron_vllm.yml" + - "examples/data_preprocess/gsm8k.py" + - "examples/data_preprocess/geo3k.py" + - "tests/special_e2e/run_ppo_trainer_megatron.sh" + - "verl/trainer/main_ppo.py" + - "verl/trainer/config/ppo_megatron_trainer.yaml" + +# Cancel jobs on the same ref if a new one is triggered +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} + +# Declare permissions just read content. +permissions: + contents: read + +env: + IMAGE: "verl-ci-cn-beijing.cr.volces.com/verlai/verl:app-verl0.4-vllm0.8.5-mcore0.12.1" + DYNAMIC_RUNNER_ENDPOINT: "https://sd10g3clalm04ug7alq90.apigateway-cn-beijing.volceapi.com/runner" + +jobs: + setup: + if: github.repository_owner == 'volcengine' + runs-on: ubuntu-latest + outputs: + runner-label: ${{ steps.create-runner.outputs.runner-label }} + mlp-task-id: ${{ steps.create-runner.outputs.mlp-task-id }} + steps: + - uses: actions/checkout@v4 + - id: create-runner + uses: volcengine/vemlp-github-runner@v1 + with: + mode: "create" + faas-url: "${{ env.DYNAMIC_RUNNER_ENDPOINT }}" + mlp-image: "${{ env.IMAGE }}" + + e2e_ppo_trainer_megatron-deepseek: + needs: setup + runs-on: ["${{ needs.setup.outputs.runner-label || 'L20x8' }}"] + timeout-minutes: 60 # Increase this timeout value as needed + env: + HTTP_PROXY: ${{ secrets.PROXY_HTTP }} + HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} + NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" + HF_ENDPOINT: "https://hf-mirror.com" + HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + fetch-depth: 0 + - name: Install the current repository + run: | + pip3 install --no-deps -e .[test] + - name: Prepare GSM8K dataset + run: | + python3 examples/data_preprocess/gsm8k.py + - name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (DeepSeek) + run: | + ray stop --force + ALL_OFFLOAD=True SAVE_FREQ=1 MODEL_ID=deepseek-ai/deepseek-coder-1.3b-instruct bash tests/special_e2e/run_ppo_trainer_megatron.sh + - name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (DeepSeek) + run: | + ray stop --force + export VLLM_USE_V1=1 + ray start --head + MODE=async USE_FUSED_KERNELS=True RESUME_MODE=auto MODEL_ID=deepseek-ai/deepseek-coder-1.3b-instruct TOTAL_TRAIN_STEPS=2 bash tests/special_e2e/run_ppo_trainer_megatron.sh + - name: Test Megatron checkpoints merging function (DeepSeek Actor and Critic) + run: | + exp_name="deepseek-coder-1.3b-instruct-megatron-gsm8k-minimal" + python -m verl.model_merger test --backend megatron --local_dir checkpoints/verl-test/${exp_name}/global_step_1/actor --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/actor/huggingface + python -m verl.model_merger test --backend megatron --is-value-model --local_dir checkpoints/verl-test/${exp_name}/global_step_1/critic --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/critic/huggingface + - name: Test Megatron distributed checkpoints merging function (DeepSeek) + run: | + exp_name="deepseek-coder-1.3b-instruct-megatron-gsm8k-minimal" + torchrun --nproc_per_node 4 --nnodes 1 -m verl.model_merger merge --backend megatron --local_dir checkpoints/verl-test/${exp_name}/global_step_1/actor --target_dir checkpoints/verl-test/${exp_name}/global_step_1/actor/hf_model + - name: Running GRPO GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (Deepseek) + run: | + ray stop --force + ADV_ESTIMATOR=grpo USE_DYNAMIC_BSZ=False MODEL_ID=deepseek-ai/deepseek-coder-1.3b-instruct bash tests/special_e2e/run_ppo_trainer_megatron.sh + - name: clean up + run: | + rm -rf checkpoints + e2e_ppo_trainer_megatron-qwen3: + needs: setup + runs-on: ["${{ needs.setup.outputs.runner-label || 'L20x8' }}"] + timeout-minutes: 60 # Increase this timeout value as needed + env: + HTTP_PROXY: ${{ secrets.PROXY_HTTP }} + HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} + NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" + HF_ENDPOINT: "https://hf-mirror.com" + HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + fetch-depth: 0 + - name: Install the current repository + run: | + pip3 install --no-deps -e .[test] + - name: Prepare GSM8K dataset + run: | + python3 examples/data_preprocess/gsm8k.py + - name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (Qwen3) with validation and saving + run: | + ray stop --force + ALL_OFFLOAD=True VAL_BEFORE_TRAIN=True TEST_FREQ=1 SAVE_FREQ=1 MODEL_ID=Qwen/Qwen3-0.6B bash tests/special_e2e/run_ppo_trainer_megatron.sh + - name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (Qwen3) testing learning rate scheduler + run: | + ray stop --force + LR_WARMUP_STEPS=1 TOTAL_TRAIN_STEPS=2 MODEL_ID=Qwen/Qwen3-0.6B bash tests/special_e2e/run_ppo_trainer_megatron.sh + + - name: Test Megatron checkpoints merging function (Qwen3 Actor and Critic) + run: | + exp_name="qwen3-0.6b-megatron-gsm8k-minimal" + python -m verl.model_merger test --backend megatron --tie-word-embedding --local_dir checkpoints/verl-test/${exp_name}/global_step_1/actor --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/actor/huggingface + python -m verl.model_merger test --backend megatron --is-value-model --local_dir checkpoints/verl-test/${exp_name}/global_step_1/critic --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/critic/huggingface + - name: clean up + run: | + rm -rf checkpoints + e2e_ppo_trainer_megatron-different-train-infer-tp-qwen-tie-embedding: + needs: setup + runs-on: ["${{ needs.setup.outputs.runner-label || 'L20x8' }}"] + timeout-minutes: 60 # Increase this timeout value as needed + env: + HTTP_PROXY: ${{ secrets.PROXY_HTTP }} + HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} + NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" + HF_ENDPOINT: "https://hf-mirror.com" + HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + fetch-depth: 0 + - name: Install the current repository + run: | + pip3 install --no-deps -e .[test] + - name: Prepare GSM8K dataset + run: | + python3 examples/data_preprocess/gsm8k.py + - name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with tie-embedding Megatron (Qwen) with train tp > infer tp + run: | + ray stop --force + VAL_BEFORE_TRAIN=True TEST_FREQ=1 SAVE_FREQ=1 TRAIN_TP=2 INFER_TP=1 MODEL_ID=Qwen/Qwen2.5-1.5B bash tests/special_e2e/run_ppo_trainer_megatron.sh + - name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (Qwen) with train tp < infer tp + run: | + ray stop --force + VAL_BEFORE_TRAIN=True TEST_FREQ=1 SAVE_FREQ=1 TRAIN_TP=1 INFER_TP=2 MODEL_ID=Qwen/Qwen2.5-1.5B bash tests/special_e2e/run_ppo_trainer_megatron.sh + - name: clean up + run: | + rm -rf checkpoints + e2e_ppo_trainer_megatron-qwen-override-transformer-config: + needs: setup + runs-on: ["${{ needs.setup.outputs.runner-label || 'L20x8' }}"] + timeout-minutes: 60 # Increase this timeout value as needed + env: + HTTP_PROXY: ${{ secrets.PROXY_HTTP }} + HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} + NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" + HF_ENDPOINT: "https://hf-mirror.com" + HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + fetch-depth: 0 + - name: Install the current repository + run: | + pip3 install --no-deps -e .[test] + - name: Prepare GSM8K dataset + run: | + python3 examples/data_preprocess/gsm8k.py + - name: Prepare dist_ckpt of Qwen2.5-0.5B, uneven layer distribution only supports dist_ckpt + run: | + python3 scripts/converter_hf_to_mcore.py --hf_model_path ${HOME}/models/Qwen/Qwen2.5-0.5B --output_path checkpoints/verl-test/qwen2.5-0.5b-megatron + - name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (Qwen) + run: | + ray stop --force + SAVE_FREQ=1 COMMON_PP=4 COMMON_VPP=null COMMON_CP=1 SKIP_SAVE_HF_MODEL=1 bash tests/special_e2e/run_ppo_trainer_megatron.sh +actor_rollout_ref.actor.megatron.override_transformer_config.num_layers_in_first_pipeline_stage=8 +actor_rollout_ref.actor.megatron.override_transformer_config.num_layers_in_last_pipeline_stage=4 actor_rollout_ref.actor.megatron.use_dist_checkpointing=true actor_rollout_ref.actor.megatron.dist_checkpointing_path=checkpoints/verl-test/qwen2.5-0.5b-megatron actor_rollout_ref.ref.megatron.use_dist_checkpointing=true actor_rollout_ref.ref.megatron.dist_checkpointing_path=checkpoints/verl-test/qwen2.5-0.5b-megatron critic.megatron.use_dist_checkpointing=true critic.megatron.dist_checkpointing_path=checkpoints/verl-test/qwen2.5-0.5b-megatron reward_model.megatron.use_dist_checkpointing=true reward_model.megatron.dist_checkpointing_path=checkpoints/verl-test/qwen2.5-0.5b-megatron + cp -r checkpoints checkpoints-dut + SAVE_FREQ=1 COMMON_PP=4 COMMON_VPP=null COMMON_CP=1 bash tests/special_e2e/run_ppo_trainer_megatron.sh + - name: Test Megatron checkpoints merging function (Qwen Actor and Critic) + run: | + exp_name="qwen2.5-0.5b-megatron-gsm8k-minimal" + python -m verl.model_merger test --backend megatron --tie-word-embedding --local_dir checkpoints-dut/verl-test/${exp_name}/global_step_1/actor --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/actor/huggingface + python -m verl.model_merger test --backend megatron --is-value-model --local_dir checkpoints-dut/verl-test/${exp_name}/global_step_1/critic --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/critic/huggingface + - name: clean up + run: | + rm -rf checkpoints + e2e_ppo_trainer_megatron-deepseek-override-transformer-config: + needs: setup + runs-on: ["${{ needs.setup.outputs.runner-label || 'L20x8' }}"] + timeout-minutes: 60 # Increase this timeout value as needed + env: + HTTP_PROXY: ${{ secrets.PROXY_HTTP }} + HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} + NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" + HF_ENDPOINT: "https://hf-mirror.com" + HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + fetch-depth: 0 + - name: Install the current repository + run: | + pip3 install --no-deps -e .[test] + - name: Prepare GSM8K dataset + run: | + python3 examples/data_preprocess/gsm8k.py + - name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (DeepSeek) + run: | + ray stop --force + SAVE_FREQ=1 MODEL_ID=deepseek-ai/deepseek-coder-1.3b-instruct COMMON_PP=2 COMMON_VPP=null bash tests/special_e2e/run_ppo_trainer_megatron.sh +actor_rollout_ref.actor.megatron.override_transformer_config.account_for_embedding_in_pipeline_split=true +actor_rollout_ref.actor.megatron.override_transformer_config.account_for_loss_in_pipeline_split=true + - name: Test Megatron checkpoints merging function (DeepSeek Actor and Critic) + run: | + exp_name="deepseek-coder-1.3b-instruct-megatron-gsm8k-minimal" + python -m verl.model_merger test --backend megatron --local_dir checkpoints/verl-test/${exp_name}/global_step_1/actor --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/actor/huggingface + python -m verl.model_merger test --backend megatron --is-value-model --local_dir checkpoints/verl-test/${exp_name}/global_step_1/critic --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/critic/huggingface + - name: clean up + run: | + rm -rf checkpoints + e2e_ppo_trainer_megatron-moe-expert-parallel: + needs: setup + runs-on: ["${{ needs.setup.outputs.runner-label || 'L20x8' }}"] + timeout-minutes: 60 # Increase this timeout value as needed + env: + HTTP_PROXY: ${{ secrets.PROXY_HTTP }} + HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} + NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" + HF_ENDPOINT: "https://hf-mirror.com" + HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + fetch-depth: 0 + - name: Install the current repository + run: | + pip3 install --no-deps -e .[test] + pip3 install mbridge + - name: Prepare GSM8K dataset + run: | + python3 examples/data_preprocess/gsm8k.py + - name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (DeepSeek) + run: | + ray stop --force + ADV_ESTIMATOR=grpo USE_DUMMY_MODEL=True DUMMY_MODEL_CONFIG_PATH=tests/special_e2e/ppo_trainer/expert_parallel/qwen2moe_minimal.json \ + PPO_MAX_TOKEN_LEN=512 FWD_MAX_TOKEN_LEN=512 \ + MAX_PROMPT_LENGTH=256 MAX_RESPONSE_LENGTH=256 \ + MODEL_ID=Qwen/Qwen1.5-MoE-A2.7B-Chat USE_MBRIDGE=True \ + COMMON_PP=2 COMMON_VPP=null COMMON_CP=1 COMMON_TP=4 COMMON_EP=4 COMMON_ETP=1 INFER_TP=8 \ + USE_DIST_CKPT=True ALL_OFFLOAD=True SKIP_SAVE_HF_MODEL=1 bash tests/special_e2e/run_ppo_trainer_megatron.sh + - name: clean up + run: | + rm -rf checkpoints + e2e_ppo_trainer_megatron-qwen2_5vl-3b: + needs: setup + runs-on: ["${{ needs.setup.outputs.runner-label || 'L20x8' }}"] + timeout-minutes: 60 # Increase this timeout value as needed + env: + HTTP_PROXY: ${{ secrets.PROXY_HTTP }} + HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} + NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" + HF_ENDPOINT: "https://hf-mirror.com" + HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + fetch-depth: 0 + - name: Install the current repository + run: | + pip3 install --no-deps -e .[test] + pip3 install "transformers[hf_xet]<4.52.0" + - name: Prepare Geo3k dataset + run: | + python3 examples/data_preprocess/geo3k.py + - name: Prepare dist_ckpt of Qwen2.5-VL-3B, only supports dist_ckpt + run: | + python3 scripts/converter_hf_to_mcore.py --hf_model_path ${HOME}/models/Qwen/Qwen2.5-VL-3B-Instruct --output_path checkpoints/verl-test/qwen2.5-vl-3b-megatron + - name: Running Geo3k E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (Qwen) + run: | + ray stop --force + TRAIN_FILES=${HOME}/data/geo3k/train.parquet VAL_FILES=${HOME}/data/geo3k/test.parquet MAX_PROMPT_LENGTH=1024 MAX_RESPONSE_LENGTH=2048 MODEL_ID=Qwen/Qwen2.5-VL-3B-Instruct ADV_ESTIMATOR=grpo USE_DYNAMIC_BSZ=False SKIP_SAVE_HF_MODEL=1 COMMON_PP=4 COMMON_VPP=null COMMON_CP=1 COMMON_TP=2 USE_DIST_CKPT=true DIST_CKPT_PATH=checkpoints/verl-test/qwen2.5-vl-3b-megatron bash tests/special_e2e/run_ppo_trainer_megatron.sh + - name: clean up + run: | + rm -rf checkpoints + + cleanup: + runs-on: ubuntu-latest + needs: [setup, + e2e_ppo_trainer_megatron-deepseek, + e2e_ppo_trainer_megatron-qwen3, + e2e_ppo_trainer_megatron-different-train-infer-tp-qwen-tie-embedding, + e2e_ppo_trainer_megatron-qwen-override-transformer-config, + e2e_ppo_trainer_megatron-deepseek-override-transformer-config, + e2e_ppo_trainer_megatron-moe-expert-parallel, + e2e_ppo_trainer_megatron-qwen2_5vl-3b] + if: always() + steps: + - id: destroy-runner + uses: volcengine/vemlp-github-runner@v1 + with: + mode: "destroy" + faas-url: "${{ env.DYNAMIC_RUNNER_ENDPOINT }}" + mlp-task-id: "${{ needs.setup.outputs.mlp-task-id }}" \ No newline at end of file diff --git a/.github/workflows/e2e_sft.yml b/.github/workflows/e2e_sft.yml index b657dae4f..6f6fcc574 100644 --- a/.github/workflows/e2e_sft.yml +++ b/.github/workflows/e2e_sft.yml @@ -1,3 +1,34 @@ +# # Tests layout + +# Each folder under tests/ corresponds to a test category for a sub-namespace in verl. For instance: +# - `tests/trainer` for testing functionality related to `verl/trainer` +# - `tests/models` for testing functionality related to `verl/models` +# - ... + +# There are a few folders with `special_` prefix, created for special purposes: +# - `special_distributed`: unit tests that must run with multiple GPUs +# - `special_e2e`: end-to-end tests with training/generation scripts +# - `special_npu`: tests for NPUs +# - `special_sanity`: a suite of quick sanity tests +# - `special_standalone`: a set of test that are designed to run in dedicated environments + +# Accelerators for tests +# - By default tests are run with GPU available, except for the ones under `special_npu`, and any test script whose name ends with `on_cpu.py`. +# - For test scripts with `on_cpu.py` name suffix would be tested on CPU resources in linux environment. + +# # Workflow layout + +# All CI tests are configured by yaml files in `.github/workflows/`. Here's an overview of all test configs: +# 1. A list of always triggered CPU sanity tests: `check-pr-title.yml`, `secrets_scan.yml`, `check-pr-title,yml`, `pre-commit.yml`, `doc.yml` +# 2. Some heavy multi-GPU unit tests, such as `model.yml`, `vllm.yml`, `sgl.yml` +# 3. End-to-end tests: `e2e_*.yml` +# 4. Unit tests +# - `cpu_unit_tests.yml`, run pytest on all scripts with file name pattern `tests/**/test_*_on_cpu.py` +# - `gpu_unit_tests.yml`, run pytest on all scripts with file without the `on_cpu.py` suffix. +# - Since cpu/gpu unit tests by default runs all tests under `tests`, please make sure tests are manually excluded in them when +# - new workflow yaml is added to `.github/workflows` +# - new tests are added to workflow mentioned in 2. + name: e2e_sft on: @@ -25,7 +56,7 @@ on: # Entrypoints - ".github/workflows/e2e_sft.yml" - "examples/data_preprocess/gsm8k.py" - - "tests/e2e/sft" + - "tests/special_e2e/sft" - "verl/trainer/fsdp_sft_trainer.py" - "verl/trainer/config/sft_trainer.yaml" @@ -38,9 +69,28 @@ concurrency: permissions: contents: read +env: + IMAGE: "verl-ci-cn-beijing.cr.volces.com/verlai/verl:base-verl0.4-cu124-cudnn9.8-torch2.6-fa2.7.4" + DYNAMIC_RUNNER_ENDPOINT: "https://sd10g3clalm04ug7alq90.apigateway-cn-beijing.volceapi.com/runner" + jobs: + setup: + if: github.repository_owner == 'volcengine' + runs-on: ubuntu-latest + outputs: + runner-label: ${{ steps.create-runner.outputs.runner-label }} + mlp-task-id: ${{ steps.create-runner.outputs.mlp-task-id }} + steps: + - uses: actions/checkout@v4 + - id: create-runner + uses: volcengine/vemlp-github-runner@v1 + with: + mode: "create" + faas-url: "${{ env.DYNAMIC_RUNNER_ENDPOINT }}" + mlp-image: "${{ env.IMAGE }}" e2e_sft: - runs-on: [L20x8] + needs: setup + runs-on: ["${{ needs.setup.outputs.runner-label || 'L20x8' }}"] timeout-minutes: 20 # Increase this timeout value as needed env: HTTP_PROXY: ${{ secrets.PROXY_HTTP }} @@ -48,9 +98,6 @@ jobs: NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" HF_ENDPOINT: "https://hf-mirror.com" HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable - container: - image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6-mcore0.12.0-te2.3 - options: --gpus all --shm-size=10g steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: @@ -66,25 +113,37 @@ jobs: - name: Running GSM8K E2E training tests on 8 L20 GPUs with rmpad using function rm run: | ray stop --force - bash tests/e2e/sft/run_sft.sh + bash tests/special_e2e/sft/run_sft.sh - name: Running GSM8K E2E training tests on 8 L20 GPUs w/o rmpad using function rm run: | ray stop --force - RM_PAD=False bash tests/e2e/sft/run_sft.sh + RM_PAD=False bash tests/special_e2e/sft/run_sft.sh - name: Running GSM8K E2E training tests on 8 L20 GPUs with sequence parallism run: | ray stop --force - SP_SIZE=2 bash tests/e2e/sft/run_sft.sh + SP_SIZE=2 bash tests/special_e2e/sft/run_sft.sh - name: Check loss difference between sequence parallel vs. default implementation run: | ray stop --force - ENTRYPOINT="tests/e2e/sft/test_sp_loss_match.py" SP_SIZE=2 bash tests/e2e/sft/run_sft.sh + ENTRYPOINT="tests/special_e2e/sft/test_sp_loss_match.py" SP_SIZE=2 bash tests/special_e2e/sft/run_sft.sh - name: Running GSM8K E2E training tests on 8 L20 GPUs with sequence parallism and liger run: | ray stop --force - SP_SIZE=2 LIGER=True bash tests/e2e/sft/run_sft.sh + SP_SIZE=2 LIGER=True bash tests/special_e2e/sft/run_sft.sh - name: Running GSM8K E2E training tests with LoRA run: | ray stop --force - LORA_RANK=32 bash tests/e2e/sft/run_sft.sh + LORA_RANK=32 bash tests/special_e2e/sft/run_sft.sh # TODO: multiturn + + cleanup: + runs-on: ubuntu-latest + needs: [setup, e2e_sft] + if: always() + steps: + - id: destroy-runner + uses: volcengine/vemlp-github-runner@v1 + with: + mode: "destroy" + faas-url: "${{ env.DYNAMIC_RUNNER_ENDPOINT }}" + mlp-task-id: "${{ needs.setup.outputs.mlp-task-id }}" diff --git a/.github/workflows/e2e_spin.yml b/.github/workflows/e2e_spin.yml index 0ec51115f..cff18df0e 100644 --- a/.github/workflows/e2e_spin.yml +++ b/.github/workflows/e2e_spin.yml @@ -7,6 +7,24 @@ on: branches: - main - v0.* + paths: + - "**/*.py" + # Other entrypoints + - "!examples/**" + - "!tests/**" + - "!verl/trainer/main_*.py" + - "!verl/trainer/fsdp_sft_trainer.py" + # Other recipes + - "!recipe/**" + # Megatron + - "!verl/workers/**/megatron_*.py" + # Home + - "recipe/spin" + # Entrypoints + - ".github/workflows/e2e_spin.yml" + - "examples/data_preprocess/gsm8k.py" + - "tests/special_e2e/run_spin.sh" + - "!examples" pull_request: branches: - main @@ -27,13 +45,18 @@ on: # Entrypoints - ".github/workflows/e2e_spin.yml" - "examples/data_preprocess/gsm8k.py" - - "tests/e2e/run_spin.sh" + - "tests/special_e2e/run_spin.sh" - "!examples" # Declare permissions just read content. permissions: contents: read +# Cancel jobs on the same ref if a new one is triggered +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} + jobs: e2e_spin: runs-on: [L20x8] @@ -45,7 +68,7 @@ jobs: HF_ENDPOINT: "https://hf-mirror.com" HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable container: - image: ocss884/verl-sglang:ngc-th2.6.0-cu126-sglang0.4.5.post3 + image: verlai/verl:app-verl0.4-sglang0.4.6.post5-vllm0.8.5-mcore0.12.1 options: --gpus all --shm-size=10g steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 @@ -63,4 +86,4 @@ jobs: - name: Running the E2E test with the spin algorithm run: | ray stop --force - bash tests/e2e/run_spin.sh + bash tests/special_e2e/run_spin.sh diff --git a/.github/workflows/e2e_sppo.yml b/.github/workflows/e2e_sppo.yml index cc871a0a1..4f8cc1b2e 100644 --- a/.github/workflows/e2e_sppo.yml +++ b/.github/workflows/e2e_sppo.yml @@ -7,13 +7,30 @@ on: branches: - main - v0.* + paths: + - "**/*.py" + # Other entrypoints + - "!examples/**" + - "!tests/**" + - "!verl/trainer/main_*.py" + - "!verl/trainer/fsdp_sft_trainer.py" + # Other recipes + - "!recipe/**" + # Megatron + - "!verl/workers/**/megatron_*.py" + # Home + - "recipe/sppo" + # Entrypoints + - ".github/workflows/e2e_sppo.yml" + - "examples/data_preprocess/gsm8k.py" + - "tests/special_e2e/run_sppo.sh" pull_request: branches: - main - v0.* paths: - "**/*.py" - # Other entrypoints + # Other entrypoints - "!examples/**" - "!tests/**" - "!verl/trainer/main_*.py" @@ -27,12 +44,17 @@ on: # Entrypoints - ".github/workflows/e2e_sppo.yml" - "examples/data_preprocess/gsm8k.py" - - "tests/e2e/run_sppo.sh" + - "tests/special_e2e/run_sppo.sh" # Declare permissions just read content. permissions: contents: read +# Cancel jobs on the same ref if a new one is triggered +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} + jobs: e2e_sppo: runs-on: [L20x8] @@ -44,7 +66,7 @@ jobs: HF_ENDPOINT: "https://hf-mirror.com" HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable container: - image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6.post5-mcore0.12.0-te2.3 + image: verlai/verl:app-verl0.4-sglang0.4.6.post5-vllm0.8.5-mcore0.12.1 options: --gpus all --shm-size=10g steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 @@ -62,4 +84,4 @@ jobs: - name: Running the E2E test with the SPPO algorithm run: | ray stop --force - bash tests/e2e/run_sppo.sh + bash tests/special_e2e/run_sppo.sh diff --git a/.github/workflows/gpu_unit_tests.yml b/.github/workflows/gpu_unit_tests.yml new file mode 100644 index 000000000..24c0fe0b6 --- /dev/null +++ b/.github/workflows/gpu_unit_tests.yml @@ -0,0 +1,100 @@ +# # Tests layout + +# Each folder under tests/ corresponds to a test category for a sub-namespace in verl. For instance: +# - `tests/trainer` for testing functionality related to `verl/trainer` +# - `tests/models` for testing functionality related to `verl/models` +# - ... + +# There are a few folders with `special_` prefix, created for special purposes: +# - `special_distributed`: unit tests that must run with multiple GPUs +# - `special_e2e`: end-to-end tests with training/generation scripts +# - `special_npu`: tests for NPUs +# - `special_sanity`: a suite of quick sanity tests +# - `special_standalone`: a set of test that are designed to run in dedicated environments + +# Accelerators for tests +# - By default tests are run with GPU available, except for the ones under `special_npu`, and any test script whose name ends with `on_cpu.py`. +# - For test scripts with `on_cpu.py` name suffix would be tested on CPU resources in linux environment. + +# # Workflow layout + +# All CI tests are configured by yaml files in `.github/workflows/`. Here's an overview of all test configs: +# 1. A list of always triggered CPU sanity tests: `check-pr-title.yml`, `secrets_scan.yml`, `check-pr-title,yml`, `pre-commit.yml`, `doc.yml` +# 2. Some heavy multi-GPU unit tests, such as `model.yml`, `vllm.yml`, `sgl.yml` +# 3. End-to-end tests: `e2e_*.yml` +# 4. Unit tests +# - `cpu_unit_tests.yml`, run pytest on all scripts with file name pattern `tests/**/test_*_on_cpu.py` +# - `gpu_unit_tests.yml`, run pytest on all scripts with file without the `on_cpu.py` suffix. +# - Since cpu/gpu unit tests by default runs all tests under `tests`, please make sure tests are manually excluded in them when +# - new workflow yaml is added to `.github/workflows` +# - new tests are added to workflow mentioned in 2. + +name: GPU unit tests + +on: + # Trigger the workflow on push or pull request, + # but only for the main branch + push: + branches: + - main + - v0.4.x + paths: + - "**/*.py" + - .github/workflows/gpu_unit_tests.yml + pull_request: + branches: + - main + - v0.4.x + paths: + # The order that you define paths patterns matters: + # A matching negative pattern (prefixed with !) after a positive match will exclude the path. + # A matching positive pattern after a negative match will include the path again. + - "**/*.py" + # Other entrypoints + - "!examples/**" + - "!verl/trainer/main_*.py" + - "!verl/trainer/fsdp_sft_trainer.py" + - "!recipe/**" + # Entrypoints + - .github/workflows/gpu_unit_tests.yml + - "tests/**test_*.py" + # Ignore CPU tests + - "!tests/*_on_cpu.py" + +# Cancel jobs on the same ref if a new one is triggered +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} + +# Declare permissions just read content. +permissions: + contents: read + +jobs: + gpu_unit_tests: + runs-on: [L20x8] + timeout-minutes: 40 # Increase this timeout value as needed + env: + HTTP_PROXY: ${{ secrets.PROXY_HTTP }} + HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} + NO_PROXY: "localhost,127.0.0.1" + HF_HUB_ENABLE_HF_TRANSFER: 1 + container: + image: verlai/verl:app-verl0.4-sglang0.4.6.post5-vllm0.8.5-mcore0.12.1 + options: --gpus all --shm-size=10g + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + fetch-depth: 0 + - name: Install the current repository + run: | + pip3 install hf_transfer + pip3 install --no-deps -e .[test] + pip3 install --upgrade "ray>=2.40.0" + pip3 install cupy-cuda12x + - name: Run all GPU unit tests + run: | + pytest -s -x --ignore-glob="*test_linear_cross_entropy_tp.py" --ignore-glob='*on_cpu.py' --ignore-glob="*test_vllm*" --ignore-glob="*_sglang*" --ignore-glob="*_hf_rollout*" --ignore-glob="tests/models/" --ignore-glob='tests/special*' --ignore-glob="tests/experimental" tests/ + - name: Testing LinearCrossEntropyTP Correctness, Computation Time and Memory Consumption + run: | + LOW_MEMORY=True torchrun --standalone --nnodes=1 --nproc-per-node=8 tests/utils/test_linear_cross_entropy_tp.py diff --git a/.github/workflows/kernels.yml b/.github/workflows/kernels.yml deleted file mode 100644 index 0a6f9163d..000000000 --- a/.github/workflows/kernels.yml +++ /dev/null @@ -1,62 +0,0 @@ -name: kernels -# latest version: Megatron-LM core_r0.11.0 https://github.com/NVIDIA/Megatron-LM/tree/core_r0.11.0 - -on: - # Trigger the workflow on push or pull request, - # but only for the main branch - push: - branches: - - main - - v0.2.x - paths: - - "**/*.py" - - .github/workflows/kernels.yml - pull_request: - branches: - - main - - v0.2.x - paths: - - "**/*.py" - # Other entrypoints - - "!examples/**" - - "!tests/**" - - "!verl/trainer/main_*.py" - - "!verl/trainer/fsdp_sft_trainer.py" - # Recipes - - "!recipe/**" - # Entrypoints - - .github/workflows/kernels.yml - - "tests/kernels/*" - -# Cancel jobs on the same ref if a new one is triggered -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} - -# Declare permissions just read content. -permissions: - contents: read - -jobs: - e2e_gsm8k_megatron: - runs-on: [L20x8] - timeout-minutes: 40 # Increase this timeout value as needed - env: - HTTP_PROXY: ${{ secrets.PROXY_HTTP }} - HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} - NO_PROXY: "localhost,127.0.0.1" - HF_HUB_ENABLE_HF_TRANSFER: 1 - container: - image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6-mcore0.12.0-te2.3 - options: --gpus all --shm-size=10g - steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - fetch-depth: 0 - - name: Install the current repository - run: | - pip3 install hf_transfer - pip3 install --no-deps -e .[test] - - name: Testing LinearCrossEntropy Correction, Computation Time and Memory Consumption - run: | - python3 tests/kernels/test_linear_cross_entropy.py \ No newline at end of file diff --git a/.github/workflows/model.yml b/.github/workflows/model.yml index c554f1071..88563d8b3 100644 --- a/.github/workflows/model.yml +++ b/.github/workflows/model.yml @@ -1,3 +1,35 @@ +# # Tests layout + +# Each folder under tests/ corresponds to a test category for a sub-namespace in verl. For instance: +# - `tests/trainer` for testing functionality related to `verl/trainer` +# - `tests/models` for testing functionality related to `verl/models` +# - ... + +# There are a few folders with `special_` prefix, created for special purposes: +# - `special_distributed`: unit tests that must run with multiple GPUs +# - `special_e2e`: end-to-end tests with training/generation scripts +# - `special_npu`: tests for NPUs +# - `special_sanity`: a suite of quick sanity tests +# - `special_standalone`: a set of test that are designed to run in dedicated environments + +# Accelerators for tests +# - By default tests are run with GPU available, except for the ones under `special_npu`, and any test script whose name ends with `on_cpu.py`. +# - For test scripts with `on_cpu.py` name suffix would be tested on CPU resources in linux environment. + +# # Workflow layout + +# All CI tests are configured by yaml files in `.github/workflows/`. Here's an overview of all test configs: +# 1. A list of always triggered CPU sanity tests: `check-pr-title.yml`, `secrets_scan.yml`, `check-pr-title,yml`, `pre-commit.yml`, `doc.yml` +# 2. Some heavy multi-GPU unit tests, such as `model.yml`, `vllm.yml`, `sgl.yml` +# 3. End-to-end tests: `e2e_*.yml` +# 4. Unit tests +# - `cpu_unit_tests.yml`, run pytest on all scripts with file name pattern `tests/**/test_*_on_cpu.py` +# - `gpu_unit_tests.yml`, run pytest on all scripts with file without the `on_cpu.py` suffix. +# - Since cpu/gpu unit tests by default runs all tests under `tests`, please make sure tests are manually excluded in them when +# - new workflow yaml is added to `.github/workflows` +# - new tests are added to workflow mentioned in 2. +# name: Check PR Title + name: model_rmpad on: @@ -15,14 +47,19 @@ on: - "verl/**/*.py" # Entrypoints - ".github/workflows/model.yml" - - "tests/utils/gpu_tests/checkpoint/test_fsdp_ckpt.py" - - "tests/models/test_transformers_ulysses.py" - - "tests/distributed/run_all.sh" + - "tests/special_distributed/test_fsdp_ckpt.py" + - "tests/models/**" + - "tests/special_distributed/run_all.sh" # Declare permissions just read content. permissions: contents: read +# Cancel jobs on the same ref if a new one is triggered +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} + jobs: model_rmpad: runs-on: [L20x8] @@ -34,7 +71,7 @@ jobs: HF_ENDPOINT: "https://hf-mirror.com" HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable container: - image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6-mcore0.12.0-te2.3 + image: verlai/verl:app-verl0.4-sglang0.4.6.post5-vllm0.8.5-mcore0.12.1 options: --gpus all --shm-size=10g steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 @@ -49,11 +86,10 @@ jobs: pytest -s tests/models/test_transformer.py - name: Running rmpad model tests on 8 L20 GPUs + latest flash_attn run: | - pip3 install --upgrade flash_attn --no-build-isolation pytest -s tests/models/test_transformer.py - name: Running FSDP rmpad model tests on 8 L20 GPUs + latest flash_attn run: | - STRATEGY=fsdp torchrun --nproc_per_node=8 tests/utils/gpu_tests/checkpoint/test_fsdp_ckpt.py + STRATEGY=fsdp torchrun --nproc_per_node=8 tests/special_distributed/test_fsdp_ckpt.py - name: Running transformers ulysses tests on 8 L20 GPUs + latest transformers run: | torchrun --nproc_per_node=8 -m pytest tests/models/test_transformers_ulysses.py @@ -79,7 +115,7 @@ jobs: torchrun --nproc_per_node=8 -m pytest tests/models/test_transformers_ulysses.py - name: Run distributed test run: | - bash tests/distributed/run_all.sh + bash tests/special_distributed/run_all.sh # TODO: Move this back to model_rmpad once FSDP2 is stable. # NOTE: List as an independent job to make rerun easier. @@ -93,7 +129,7 @@ jobs: HF_ENDPOINT: "https://hf-mirror.com" HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable container: - image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6-mcore0.12.0-te2.3 + image: verlai/verl:base-verl0.4-cu124-cudnn9.8-torch2.6-fa2.7.4 options: --gpus all --shm-size=10g steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 @@ -105,5 +141,4 @@ jobs: pip3 install --upgrade transformers - name: Running FSDP2 rmpad model tests on 8 L20 GPUs + latest flash_attn run: | - pip3 install --upgrade flash_attn --no-build-isolation - STRATEGY=fsdp2 torchrun --nproc_per_node=8 tests/utils/gpu_tests/checkpoint/test_fsdp_ckpt.py + STRATEGY=fsdp2 torchrun --nproc_per_node=8 tests/special_distributed/test_fsdp_ckpt.py diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index 3e3b3e527..80cfa0945 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -25,10 +25,12 @@ jobs: uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: ${{ matrix.python-version }} + - name: Install the current repository + run: | + pip install -e . - name: Set ruff --output-format=github run: | sed -i 's/--output-format=full/--output-format=github/' .pre-commit-config.yaml git add .pre-commit-config.yaml + # Check "--all-files" by default - uses: pre-commit/action@v3.0.1 - with: - extra_args: "" # Overriding default "--all-files" diff --git a/.github/workflows/ray_cpu_test.yml b/.github/workflows/ray_cpu_test.yml deleted file mode 100644 index 1c68df0c8..000000000 --- a/.github/workflows/ray_cpu_test.yml +++ /dev/null @@ -1,48 +0,0 @@ -name: ray_cpu - -on: - # Trigger the workflow on push or pull request, - # but only for the main branch - push: - branches: - - main - - v0.* - pull_request: - branches: - - main - - v0.* - paths: - - "verl/single_controller/*.py" - - .github/workflows/ray_cpu_test.yml - - "!recipe/**/*.py" - -# Cancel jobs on the same ref if a new one is triggered -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} - -# Declare permissions just read content. -permissions: - contents: read - -jobs: - ray_cpu: - runs-on: ubuntu-latest - timeout-minutes: 10 # Increase this timeout value as needed - strategy: - matrix: - python-version: ["3.10"] - steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 - with: - python-version: ${{ matrix.python-version }} - - name: Install the current repository - run: | - pip install -e .[test] - pip install --upgrade "ray>=2.40.0" - - name: Running ray tests that can be tested on CPU machines - run: | - cd tests/ray_cpu - pytest -s -x --ignore=test_check_worker_alive.py . diff --git a/.github/workflows/ray_gpu_test.yml b/.github/workflows/ray_gpu_test.yml deleted file mode 100644 index 5143965ea..000000000 --- a/.github/workflows/ray_gpu_test.yml +++ /dev/null @@ -1,52 +0,0 @@ -name: ray_gpu - -on: - # Trigger the workflow on push or pull request, - # but only for the main branch - push: - branches: - - main - - v0.* - pull_request: - branches: - - main - - v0.* - paths: - - "verl/single_controller/*.py" - - .github/workflows/ray_gpu_test.yml - - "!recipe/**/*.py" - -# Cancel jobs on the same ref if a new one is triggered -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} - -# Declare permissions just read content. -permissions: - contents: read - -jobs: - ray_gpu: - runs-on: [L20x8] - timeout-minutes: 10 # Increase this timeout value as needed - env: - HTTP_PROXY: ${{ secrets.PROXY_HTTP }} - HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} - NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" - HF_ENDPOINT: "https://hf-mirror.com" - HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable - container: - image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6-mcore0.12.0-te2.3 - options: --gpus all --shm-size=10g - steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - fetch-depth: 0 - - name: Install the current repository - run: | - pip install -e .[test] - pip install --upgrade "ray>=2.40.0" - - name: Running ray tests that need 8 GPUs - run: | - cd tests/ray_gpu - pytest -s -x --ignore=test_rvdz.py . diff --git a/.github/workflows/sandbox.yml b/.github/workflows/sandbox.yml deleted file mode 100644 index 23d7b3ed8..000000000 --- a/.github/workflows/sandbox.yml +++ /dev/null @@ -1,51 +0,0 @@ -name: sandbox - -on: - # Trigger the workflow on push or pull request, - # but only for the main branch - push: - branches: - - main - - v0.* - pull_request: - branches: - - main - - v0.* - paths: - - "**/*.py" - - .github/workflows/sandbox.yml - -# Cancel jobs on the same ref if a new one is triggered -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} - -# Declare permissions just read content. -permissions: - contents: read - -jobs: - sandbox: - runs-on: [L20x8] - timeout-minutes: 10 # Increase this timeout value as needed - env: - HTTP_PROXY: ${{ secrets.PROXY_HTTP }} - HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} - NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" - HF_ENDPOINT: "https://hf-mirror.com" - HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable - container: - image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6-mcore0.12.0-te2.3 - options: --gpus all --shm-size=10g - steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - fetch-depth: 0 - - name: Install the current repository - run: | - pip3 install -e .[test,prime] - pip3 install vllm==0.5.4 - - name: Running sandbox tests on 8 L20 GPUs - run: | - cd tests/sandbox - pytest -s -x . diff --git a/.github/workflows/sanity.yml b/.github/workflows/sanity.yml index 75b349c90..0478d4d36 100644 --- a/.github/workflows/sanity.yml +++ b/.github/workflows/sanity.yml @@ -1,3 +1,35 @@ +# # Tests layout + +# Each folder under tests/ corresponds to a test category for a sub-namespace in verl. For instance: +# - `tests/trainer` for testing functionality related to `verl/trainer` +# - `tests/models` for testing functionality related to `verl/models` +# - ... + +# There are a few folders with `special_` prefix, created for special purposes: +# - `special_distributed`: unit tests that must run with multiple GPUs +# - `special_e2e`: end-to-end tests with training/generation scripts +# - `special_npu`: tests for NPUs +# - `special_sanity`: a suite of quick sanity tests +# - `special_standalone`: a set of test that are designed to run in dedicated environments + +# Accelerators for tests +# - By default tests are run with GPU available, except for the ones under `special_npu`, and any test script whose name ends with `on_cpu.py`. +# - For test scripts with `on_cpu.py` name suffix would be tested on CPU resources in linux environment. + +# # Workflow layout + +# All CI tests are configured by yaml files in `.github/workflows/`. Here's an overview of all test configs: +# 1. A list of always triggered CPU sanity tests: `check-pr-title.yml`, `secrets_scan.yml`, `check-pr-title,yml`, `pre-commit.yml`, `doc.yml` +# 2. Some heavy multi-GPU unit tests, such as `model.yml`, `vllm.yml`, `sgl.yml` +# 3. End-to-end tests: `e2e_*.yml` +# 4. Unit tests +# - `cpu_unit_tests.yml`, run pytest on all scripts with file name pattern `tests/**/test_*_on_cpu.py` +# - `gpu_unit_tests.yml`, run pytest on all scripts with file without the `on_cpu.py` suffix. +# - Since cpu/gpu unit tests by default runs all tests under `tests`, please make sure tests are manually excluded in them when +# - new workflow yaml is added to `.github/workflows` +# - new tests are added to workflow mentioned in 2. +# name: Check PR Title + name: sanity on: @@ -14,6 +46,7 @@ on: paths: - "**/*.py" - .github/workflows/sanity.yml + - "tests/special_sanity/**" # Cancel jobs on the same ref if a new one is triggered concurrency: @@ -42,13 +75,25 @@ jobs: pip install -e .[test] - name: Run sanity test run: | - pytest -s -x tests/sanity + pytest -s -x tests/special_sanity - name: Run license test run: | - python3 tests/sanity/check_license.py --directory . + python3 tests/special_sanity/check_license.py --directory . - name: Assert naming convention run: | if grep -rIn --exclude-dir=.git --exclude-dir=.github --exclude-dir=venv --exclude-dir=__pycache__ 'veRL' .; then echo "Please use verl instead of veRL in the codebase" exit 1 fi + - name: Validate test folder structure + run: python3 tests/special_sanity/validate_structure.py + - name: Assert documentation requirement for functions + run: python3 tests/special_sanity/validate_imported_docs.py + - name: Assert device api usage in verl/recipe + run: python3 tests/special_sanity/check_device_api_usage.py --directory ./recipe + - name: Assert device api usage in verl/verl + run: python3 tests/special_sanity/check_device_api_usage.py --directory ./verl + - name: Assert documentation time info + run: python3 tests/special_sanity/check_docs_time_info.py + - name: Check docstrings for specified files + run: python3 tests/special_sanity/check_docstrings.py diff --git a/.github/workflows/sgl.yml b/.github/workflows/sgl.yml index 59ae19cc9..5999e7c43 100644 --- a/.github/workflows/sgl.yml +++ b/.github/workflows/sgl.yml @@ -1,3 +1,34 @@ +# # Tests layout + +# Each folder under tests/ corresponds to a test category for a sub-namespace in verl. For instance: +# - `tests/trainer` for testing functionality related to `verl/trainer` +# - `tests/models` for testing functionality related to `verl/models` +# - ... + +# There are a few folders with `special_` prefix, created for special purposes: +# - `special_distributed`: unit tests that must run with multiple GPUs +# - `special_e2e`: end-to-end tests with training/generation scripts +# - `special_npu`: tests for NPUs +# - `special_sanity`: a suite of quick sanity tests +# - `special_standalone`: a set of test that are designed to run in dedicated environments + +# Accelerators for tests +# - By default tests are run with GPU available, except for the ones under `special_npu`, and any test script whose name ends with `on_cpu.py`. +# - For test scripts with `on_cpu.py` name suffix would be tested on CPU resources in linux environment. + +# # Workflow layout + +# All CI tests are configured by yaml files in `.github/workflows/`. Here's an overview of all test configs: +# 1. A list of always triggered CPU sanity tests: `check-pr-title.yml`, `secrets_scan.yml`, `check-pr-title,yml`, `pre-commit.yml`, `doc.yml` +# 2. Some heavy multi-GPU unit tests, such as `model.yml`, `vllm.yml`, `sgl.yml` +# 3. End-to-end tests: `e2e_*.yml` +# 4. Unit tests +# - `cpu_unit_tests.yml`, run pytest on all scripts with file name pattern `tests/**/test_*_on_cpu.py` +# - `gpu_unit_tests.yml`, run pytest on all scripts with file without the `on_cpu.py` suffix. +# - Since cpu/gpu unit tests by default runs all tests under `tests`, please make sure tests are manually excluded in them when +# - new workflow yaml is added to `.github/workflows` +# - new tests are added to workflow mentioned in 2. + name: sgl on: @@ -34,6 +65,7 @@ on: - ".github/workflows/sgl.yml" - "tests/rollout/*sglang*" - "tests/rollout/async_rollout_utils.py" + - "tests/workers/rollout/*interaction*" # Cancel jobs on the same ref if a new one is triggered concurrency: @@ -56,7 +88,7 @@ jobs: HF_HUB_ENABLE_HF_TRANSFER: 1 SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK: "True" container: - image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6.post5-mcore0.12.0-te2.3 + image: verlai/verl:app-verl0.4-sglang0.4.6.post5-vllm0.8.5-mcore0.12.1 options: --gpus all --shm-size=10g steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 @@ -64,16 +96,26 @@ jobs: fetch-depth: 0 - name: Install the current repository run: | - pip3 install hf_transfer + pip3 install hf_transfer fastmcp pip3 install -e .[test,gpu,sglang] --no-deps - name: Download Model to Use run: | huggingface-cli download 'Qwen/Qwen2-7B-Instruct' + huggingface-cli download 'Qwen/Qwen2.5-0.5B' + huggingface-cli download Qwen/Qwen2.5-1.5B-Instruct export HF_HUB_OFFLINE=1 - name: Test the latest SGLang run: | cd tests/workers/rollout torchrun --nnodes=1 --nproc_per_node=2 $(which pytest) -s test_sglang_spmd.py + - name: Test the latest SGLang Rollout async with interaction + run: | + cd tests/workers/rollout + torchrun --nnodes=1 --nproc_per_node=2 $(which pytest) -s test_sglang_async_rollout_w_interaction.py + - name: Test the latest SGLang Multi Interaction + run: | + cd tests/workers/rollout + torchrun --nnodes=1 --nproc_per_node=2 $(which pytest) -s test_sglang_multi_interaction.py - name: Test the latest SGLang Rollout async with tool run: | cd tests/workers/rollout @@ -85,4 +127,16 @@ jobs: - name: Test the latest SGLang Rollout async with search tool run: | cd tests/workers/rollout - pytest -s test_sglang_async_rollout_search_tools.py \ No newline at end of file + pytest -s test_sglang_async_rollout_search_tools.py + - name: Test the latest SGLang Rollout async with mcp search tool + run: | + cd tests/workers/rollout + pytest -s test_sglang_async_rollout_mcp_tools.py + - name: Test the latest SGLang Rollout async with agent loop + run: | + ROLLOUT_NAME=sglang pytest -svvv tests/experimental/agent_loop/test_basic_agent_loop.py + # Note(haibin.lin): for any new test, please update gpu_unit_tests.yaml to avoid repeated tests + - name: Test the latest SGLang Rollout async with multimodal delta + run: | + cd tests/workers/rollout + pytest -s test_sglang_async_rollout_multimodal_delta.py \ No newline at end of file diff --git a/.github/workflows/type-coverage-check.yml b/.github/workflows/type-coverage-check.yml new file mode 100644 index 000000000..3010736b1 --- /dev/null +++ b/.github/workflows/type-coverage-check.yml @@ -0,0 +1,29 @@ +name: Type Annotation and Docstring Coverage + +on: + pull_request: + paths: + - '**/*.py' + +jobs: + type-coverage-check: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 # 🚨 Important: fetch full history so `origin/main` is available + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.10' + + - name: Install dependencies + run: | + pip install gitpython + pip install -e .[sglang] + - name: Run type annotation coverage check + run: | + python3 tests/special_sanity/type_coverage_check.py + - name: Run docstring coverage check + run: | + python3 tests/special_sanity/check_api_docs.py verl diff --git a/.github/workflows/utils_cpu_test.yml b/.github/workflows/utils_cpu_test.yml deleted file mode 100644 index e3ec220d0..000000000 --- a/.github/workflows/utils_cpu_test.yml +++ /dev/null @@ -1,55 +0,0 @@ -name: utils_cpu_test - -on: - # Trigger the workflow on push or pull request, - # but only for the main branch - push: - branches: - - main - - v0.* - pull_request: - branches: - - main - - v0.* - paths: - - "**/*.py" - - .github/workflows/utils_cpu_test.yml - - "!recipe/**/*.py" - -# Cancel jobs on the same ref if a new one is triggered -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} - -# Declare permissions just read content. -permissions: - contents: read - -jobs: - utils_cpu_test: - runs-on: ubuntu-latest - timeout-minutes: 10 # Increase this timeout value as needed - strategy: - matrix: - python-version: ["3.10"] - steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 - with: - python-version: ${{ matrix.python-version }} - - name: Install the current repository - run: | - pip install -e .[test] - - name: Running test protocol.py - run: | - cd tests - pytest -s -x test_protocol.py - - name: running utils cpu tests - run: | - cd tests/utils/cpu_tests - pytest -s -x . - - name: Running trainer tests - run: | - cd tests/trainer - pytest -s -x . diff --git a/.github/workflows/utils_gpu_test.yml b/.github/workflows/utils_gpu_test.yml deleted file mode 100644 index fa6879729..000000000 --- a/.github/workflows/utils_gpu_test.yml +++ /dev/null @@ -1,47 +0,0 @@ -name: utils_gpu_test - -on: - # Trigger the workflow on push or pull request, - # but only for the main branch - push: - branches: - - main - - v0.* - pull_request: - branches: - - main - - v0.* - paths: - - "**/*.py" - - .github/workflows/utils_gpu_test.yml - - "!recipe/**/*.py" - -# Cancel jobs on the same ref if a new one is triggered -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} - -# Declare permissions just read content. -permissions: - contents: read - -jobs: - utils_gpu_test: - runs-on: [L20x8] - timeout-minutes: 20 # Increase this timeout value as needed - container: - image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6-mcore0.12.0-te2.3 - options: --gpus all --shm-size=10g - steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 - with: - python-version: ${{ matrix.python-version }} - - name: Install the current repository - run: | - pip install -e .[test] - - name: Running utils gpu tests - run: | - cd tests/utils/gpu_tests - pytest -s -x --ignore=dataset/ --ignore=checkpoint/ . \ No newline at end of file diff --git a/.github/workflows/vllm.yml b/.github/workflows/vllm.yml index e6cf582e0..0b8ddc6ec 100644 --- a/.github/workflows/vllm.yml +++ b/.github/workflows/vllm.yml @@ -1,3 +1,34 @@ +# # Tests layout + +# Each folder under tests/ corresponds to a test category for a sub-namespace in verl. For instance: +# - `tests/trainer` for testing functionality related to `verl/trainer` +# - `tests/models` for testing functionality related to `verl/models` +# - ... + +# There are a few folders with `special_` prefix, created for special purposes: +# - `special_distributed`: unit tests that must run with multiple GPUs +# - `special_e2e`: end-to-end tests with training/generation scripts +# - `special_npu`: tests for NPUs +# - `special_sanity`: a suite of quick sanity tests +# - `special_standalone`: a set of test that are designed to run in dedicated environments + +# Accelerators for tests +# - By default tests are run with GPU available, except for the ones under `special_npu`, and any test script whose name ends with `on_cpu.py`. +# - For test scripts with `on_cpu.py` name suffix would be tested on CPU resources in linux environment. + +# # Workflow layout + +# All CI tests are configured by yaml files in `.github/workflows/`. Here's an overview of all test configs: +# 1. A list of always triggered CPU sanity tests: `check-pr-title.yml`, `secrets_scan.yml`, `check-pr-title,yml`, `pre-commit.yml`, `doc.yml` +# 2. Some heavy multi-GPU unit tests, such as `model.yml`, `vllm.yml`, `sgl.yml` +# 3. End-to-end tests: `e2e_*.yml` +# 4. Unit tests +# - `cpu_unit_tests.yml`, run pytest on all scripts with file name pattern `tests/**/test_*_on_cpu.py` +# - `gpu_unit_tests.yml`, run pytest on all scripts with file without the `on_cpu.py` suffix. +# - Since cpu/gpu unit tests by default runs all tests under `tests`, please make sure tests are manually excluded in them when +# - new workflow yaml is added to `.github/workflows` +# - new tests are added to workflow mentioned in 2. + name: vllm on: @@ -28,8 +59,8 @@ on: - "!**/*sglang*" # Entrypoints - ".github/workflows/vllm.yml" - - "tests/e2e/generation" - - "tests/rollout" + - "tests/special_e2e/generation" + - "tests/workers/rollout" - "verl/trainer/main_generation.py" - "verl/trainer/config/generation.yaml" @@ -61,37 +92,35 @@ jobs: fetch-depth: 0 - name: Install the current repository run: | - pip3 install -e .[test] - pip3 install vllm==0.5.4 + pip3 install -e .[test,vllm] + pip install tensordict==0.6.2 - name: Download Model to Use run: | huggingface-cli download Qwen/Qwen2.5-0.5B-Instruct + huggingface-cli download Qwen/Qwen2.5-1.5B-Instruct huggingface-cli download 'Qwen/Qwen2-7B-Instruct' huggingface-cli download 'deepseek-ai/deepseek-llm-7b-chat' export HF_HUB_OFFLINE=1 # Disable requests to avoid network errors - - name: Running vllm tests on 8 L20 GPUs - run: | - cd tests/workers/rollout - torchrun --standalone --nnodes=1 --nproc_per_node=8 $(which pytest) -s test_vllm_hf_loader.py - name: Test the latest vLLM run: | - pip3 install --upgrade vllm==0.7.3 - cd tests/workers/rollout - torchrun --standalone --nnodes=1 --nproc_per_node=4 $(which pytest) -s test_vllm_spmd.py + torchrun --standalone --nnodes=1 --nproc_per_node=4 $(which pytest) -s tests/workers/rollout/rollout_vllm/test_vllm_spmd.py + - name: Test the latest vLLM on model with rope scaling + run: | + torchrun --standalone --nnodes=1 --nproc_per_node=4 $(which pytest) -s tests/workers/rollout/rollout_vllm/test_vllm_model_rope_scaling.py - name: Run Qwen 0.5B generation test run: | - cd tests/e2e/generation + cd tests/special_e2e/generation export OUTPUT_PATH="${HOME}/data/gen/qwen_05_gen_test.parquet" MODEL_ID=Qwen/Qwen2.5-0.5B-Instruct NGPUS_PER_NODE=4 GEN_TP=2 bash ./run_gen_qwen05.sh rm -rf "${OUTPUT_PATH}" - name: Run Qwen 0.5B generation test when world_size == 1 run: | - cd tests/e2e/generation + cd tests/special_e2e/generation export OUTPUT_PATH="${HOME}/data/gen/qwen_05_gen_test.parquet" MODEL_ID=Qwen/Qwen2.5-0.5B-Instruct NGPUS_PER_NODE=1 GEN_TP=1 bash ./run_gen_qwen05.sh rm -rf "${OUTPUT_PATH}" - - name: Running multi-turn rollout tests on 8 L20 GPUs + - name: Test the latest vLLM Rollout async with agent loop run: | - pip3 install --upgrade vllm==0.8.3 tensordict==0.7.2 - python3 tests/workers/rollout/test_vllm_multi_turn.py + ROLLOUT_NAME=vllm pytest -svvv tests/experimental/agent_loop/test_basic_agent_loop.py + # Note(haibin.lin): for any new test, please update gpu_unit_tests.yaml to avoid repeated tests diff --git a/.gitignore b/.gitignore index dd3df16f0..2ba4c0b3a 100644 --- a/.gitignore +++ b/.gitignore @@ -139,6 +139,7 @@ outputs !/LICENSE !/pyproject.toml !/setup.py +!/performance_tuning_guide.md # Slurm logs slurm/ @@ -158,3 +159,9 @@ artifacts/ # Analysis analysis/ + +data_preprocess/data_inspection.ipynb +wandb/ +outputs/ +slurm/ + diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 72e099f2e..3cf768614 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,8 +1,16 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: "v0.11.4" + rev: "v0.12.2" hooks: - id: ruff args: ["--fix", "--show-fixes", "--output-format=full"] exclude: ^.*\.(ipynb)$ - id: ruff-format + + - repo: local + hooks: + - id: autogen-trainer-cfg + name: Generate and verify verl/trainer/config/_generated_*.yaml + entry: scripts/generate_trainer_config.sh + language: script + pass_filenames: false \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 000000000..e953f113e --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,89 @@ +# Contributing to verl + +Thank you for considering a contribution to verl! We welcome contributions of any kind - bug fixes, enhancements, documentation improvements, or even just feedback. Whether you're an experienced developer or this is your first open-source project, your help is invaluable. + +Your support can take many forms: +- Report issues or unexpected behaviors. +- Suggest or implement new features. +- Improve or expand documentation. +- Review pull requests and assist other contributors. +- Spread the word: share verl in blog posts, social media, or give the repo a ⭐. + +## Finding Issues to Contribute + +Looking for ways to dive in? Check out these issues: +- [Good first issues](https://github.com/volcengine/verl/issues?q=is%3Aissue%20state%3Aopen%20label%3A%22good%20first%20issue%22) +- [Call for contribution](https://github.com/volcengine/verl/issues?q=is%3Aissue%20state%3Aopen%20label%3A%22call%20for%20contribution%22) +Furthermore, you can learn the development plan and roadmap via [RFC](https://github.com/volcengine/verl/issues?q=is%3Aissue%20state%3Aopen%20label%3ARFC) and [Roadmap](https://github.com/volcengine/verl/issues?q=state%3Aopen%20label%3A%22roadmap%22). + + +## Developing + +- **Python-only**: install verl via `pip install -e .[test,vllm]` or `pip install -e .[test,sglang]` and iterate quickly. For full dependency setup, check out the verl [installation doc](https://verl.readthedocs.io/en/latest/start/install.html). + +## Code Linting and Formatting + +We rely on pre-commit to keep our code consistent. To set it up: + +```bash +pip install pre-commit +pre-commit install +# for staged changes +pre-commit run +# for all files in the repo +pre-commit run --all-files +# run a specific hook with pre-commit +# pre-commit run --all-files --show-diff-on-failure --color=always +pre-commit run --all-files --show-diff-on-failure --color=always ruff +pre-commit run --all-files --show-diff-on-failure --color=always autogen-trainer-cfg +``` + +## Testing + +Our test suites run on GitHub Actions. Check these workflows for details: +- [GPU unit tests](https://github.com/volcengine/verl/blob/main/.github/workflows/gpu_unit_tests.yml) +- [CPU unit tests](https://github.com/volcengine/verl/blob/main/.github/workflows/cpu_unit_tests.yml) +- [vLLM tests](https://github.com/volcengine/verl/blob/main/.github/workflows/vllm.yml) +- [SGLang tests](https://github.com/volcengine/verl/blob/main/.github/workflows/sgl.yml) + +### Adding CI tests + +If possible, please add CI test(s) for your new feature: + +1. Find the most relevant workflow yml file, which usually corresponds to a `hydra` default config (e.g. `ppo_trainer`, `ppo_megatron_trainer`, `sft_trainer`, etc). +2. Add related path patterns to the `paths` section if not already included. +3. Minimize the workload of the test script(s) (see existing scripts for examples). + +## Building the Docs +``` +# Ensure verl is on your PYTHONPATH, e.g.: +pip install -e .[test] + +# Install documentation dependencies +pip install -r requirements-docs.txt + +# Generate HTML docs +make clean +make html + +# Preview locally +python -m http.server -d _build/html/ +``` +Open your browser at http://localhost:8000 to explore the docs. + +## Pull Requests & Code Reviews + +Thanks for submitting a PR! To streamline reviews: +- Follow our Pull Request Template for title format and checklist. +- Adhere to our pre-commit lint rules and ensure all checks pass. +- Update docs for any user-facing changes. +- Add or update tests in the CI workflows, or explain why tests aren't applicable. + +## License + +See the [LICENSE](https://github.com/volcengine/verl/blob/main/LICENSE) file for full details. + +## Thank You + +We appreciate your contributions to verl. Your efforts help make the project stronger and more user-friendly. Happy coding! + diff --git a/README.md b/README.md index 3498e3031..a4c50b69f 100644 --- a/README.md +++ b/README.md @@ -233,4 +233,4 @@ If you find the repo helpful, please cite: doi = {10.48550/arXiv.2506.14965}, url = {https://arxiv.org/abs/2506.14965} } -``` \ No newline at end of file +``` diff --git a/data_preprocess/sample_testset.py b/data_preprocess/sample_testset.py new file mode 100644 index 000000000..cfe32b32e --- /dev/null +++ b/data_preprocess/sample_testset.py @@ -0,0 +1,30 @@ +import pandas as pd + +# File paths +reasoning_gym_file = "/mnt/sharefs/users/haonan.li/data/k2/test_12k_len/logic__reasoning_gym_4.3k.parquet" +synlogic_file = "/mnt/sharefs/users/haonan.li/data/k2/test_12k_len/logic__synlogic_1.4k.parquet" + +# Load datasets +df_reasoning_gym = pd.read_parquet(reasoning_gym_file) +df_synlogic = pd.read_parquet(synlogic_file) + +# Sample reasoning_gym: 5 rows per ability +sampled_reasoning_gym = df_reasoning_gym.groupby('ability').apply( + lambda x: x.sample(min(5, len(x)), random_state=42) +).reset_index(drop=True) + +# Sample synlogic: 10 rows per data_source +sampled_synlogic = df_synlogic.groupby('data_source').apply( + lambda x: x.sample(min(10, len(x)), random_state=42) +).reset_index(drop=True) + +# Create output filenames with exact numbers +reasoning_gym_output = f"/mnt/sharefs/users/haonan.li/data/k2/test_12k_len/logic__reasoning_gym_{len(sampled_reasoning_gym)}.parquet" +synlogic_output = f"/mnt/sharefs/users/haonan.li/data/k2/test_12k_len/logic__synlogic_{len(sampled_synlogic)}.parquet" + +# Save to separate files +sampled_reasoning_gym.to_parquet(reasoning_gym_output, index=False) +sampled_synlogic.to_parquet(synlogic_output, index=False) + +print(f"Sampled {len(sampled_reasoning_gym)} from reasoning_gym -> {reasoning_gym_output}") +print(f"Sampled {len(sampled_synlogic)} from synlogic -> {synlogic_output}") diff --git a/data_preprocess/sft_rlhf/full_hh_rlhf.py b/data_preprocess/sft_rlhf/full_hh_rlhf.py index 10a0aa9d7..4625f2822 100644 --- a/data_preprocess/sft_rlhf/full_hh_rlhf.py +++ b/data_preprocess/sft_rlhf/full_hh_rlhf.py @@ -62,7 +62,7 @@ def generate_rm_dataset(target_hdfs_path_dir, local_dir="~/data/full_hh_rlh/rm") local_dir = os.path.expanduser(local_dir) os.makedirs(local_dir, exist_ok=True) - for dataset, name in zip([train_dataset, test_dataset], ["train", "test"]): + for dataset, name in zip([train_dataset, test_dataset], ["train", "test"], strict=True): output = {"prompt": [], "chosen": [], "rejected": []} for data in tqdm(dataset): # add chosen diff --git a/data_preprocess/sft_rlhf/geo3k.py b/data_preprocess/sft_rlhf/geo3k.py index eb6a388fe..2df225dfc 100644 --- a/data_preprocess/sft_rlhf/geo3k.py +++ b/data_preprocess/sft_rlhf/geo3k.py @@ -38,7 +38,8 @@ instruction_following = ( r"You FIRST think about the reasoning process as an internal monologue and then provide the final answer. " - r"The reasoning process MUST BE enclosed within tags. The final answer MUST BE put in \boxed{}." + r"The reasoning process MUST BE enclosed within tags. " + r"The final answer MUST BE put in \boxed{}." ) # add a row to each data item that represents a unique id diff --git a/data_preprocess/sft_rlhf/geo3k_multiturn_w_tool.py b/data_preprocess/sft_rlhf/geo3k_multiturn_w_tool.py new file mode 100644 index 000000000..6e006910f --- /dev/null +++ b/data_preprocess/sft_rlhf/geo3k_multiturn_w_tool.py @@ -0,0 +1,99 @@ +# Copyright 2023-2025 SGLang Team +# Copyright Amazon.com, Inc. or its affiliates. +# Copyright 2025 Reallm Labs Ltd. or its affiliates +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Preprocess the Geometry3k dataset to parquet format +""" + +import argparse +import os + +import datasets + +from verl.utils.hdfs_io import copy, makedirs + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--local_dir", default="~/data/geo3k_multiturn_w_tool") + parser.add_argument("--hdfs_dir", default=None) + args = parser.parse_args() + data_source = "hiyouga/geometry3k" + dataset = datasets.load_dataset(data_source) + train_dataset = dataset["train"] + test_dataset = dataset["test"] + instruction_following = ( + r"You FIRST think about the reasoning process as an internal monologue and then provide the final answer. " + r"The reasoning process MUST BE enclosed within tags. " + r"The final answer MUST BE put in \boxed{}." + ) + + # add a row to each data item that represents a unique id + def make_map_fn(split): + def process_fn(example, idx): + problem = example.pop("problem") + prompt = problem + " " + instruction_following + answer = example.pop("answer") + images = example.pop("images") + data = { + "data_source": data_source, + "prompt": [ + { + "role": "system", + "content": ( + "You are a math expert. You are given a question and you need to solve it step by step. " + "Reasoning step by step before any tool call. " + "You should use the `calc_geo3k_reward` tool after step by step solving the question, " + "before generate final answer at least once and refine your answer if necessary. " + ), + }, + { + "role": "user", + "content": prompt, + }, + ], + "images": images, + "ability": "math", + "reward_model": {"style": "rule", "ground_truth": answer}, + "extra_info": { + "split": split, + "index": idx, + "answer": answer, + "question": problem, + "need_tools_kwargs": True, + "tools_kwargs": { + "calc_geo3k_reward": { + "create_kwargs": {"ground_truth": answer}, + # "execute_kwargs": {}, + # "calc_reward_kwargs": {}, + # "release_kwargs": {}, + }, + }, + }, + } + return data + + return process_fn + + train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True, num_proc=8) + test_dataset = test_dataset.map(function=make_map_fn("test"), with_indices=True, num_proc=8) + local_dir = args.local_dir + hdfs_dir = args.hdfs_dir + train_dataset.to_parquet(os.path.join(local_dir, "train.parquet")) + test_dataset.to_parquet(os.path.join(local_dir, "test.parquet")) + if hdfs_dir is not None: + makedirs(hdfs_dir) + copy(src=local_dir, dst=hdfs_dir) diff --git a/data_preprocess/sft_rlhf/gsm8k_multiturn_w_interaction.py b/data_preprocess/sft_rlhf/gsm8k_multiturn_w_interaction.py new file mode 100644 index 000000000..718a87460 --- /dev/null +++ b/data_preprocess/sft_rlhf/gsm8k_multiturn_w_interaction.py @@ -0,0 +1,106 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Preprocess the GSM8k dataset to parquet format +""" + +import argparse +import os +import re + +import datasets + +from verl.utils.hdfs_io import copy, makedirs + + +def extract_solution(solution_str): + solution = re.search("#### (\\-?[0-9\\.\\,]+)", solution_str) + assert solution is not None + final_solution = solution.group(0) + final_solution = final_solution.split("#### ")[1].replace(",", "") + return final_solution + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--local_dir", default="~/data/gsm8k") + parser.add_argument("--hdfs_dir", default=None) + + args = parser.parse_args() + + data_source = "openai/gsm8k" + dataset = datasets.load_dataset(data_source, "main") + + train_dataset = dataset["train"] + test_dataset = dataset["test"] + + instruction_following = "Let's think step by step and output the final answer after `####`." + + # add a row to each data item that represents a unique id + def make_map_fn(split): + def process_fn(example, idx): + question_raw = example.pop("question") + + question = question_raw + " " + instruction_following + + answer_raw = example.pop("answer") + solution = extract_solution(answer_raw) + data = { + "data_source": data_source, + "prompt": [ + { + "role": "system", + "content": ( + "You are a math expert. You are given a question and you need to solve it step by step. " + "You should rethinking carefully if user point out your answer is wrong. " + "Put your final answer in the format of `#### `." + ), + }, + { + "role": "user", + "content": question, + }, + ], + "ability": "math", + "reward_model": {"style": "rule", "ground_truth": solution}, + "extra_info": { + "split": split, + "index": idx, + "answer": answer_raw, + "question": question_raw, + "interaction_kwargs": { + "name": "gsm8k", + "query": question, + "ground_truth": solution, + }, + }, + } + return data + + return process_fn + + train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True) + test_dataset = test_dataset.map(function=make_map_fn("test"), with_indices=True) + + local_dir = args.local_dir + hdfs_dir = args.hdfs_dir + + train_dataset.to_parquet(os.path.join(local_dir, "train.parquet")) + test_dataset.to_parquet(os.path.join(local_dir, "test.parquet")) + + if hdfs_dir is not None: + makedirs(hdfs_dir) + copy(src=local_dir, dst=hdfs_dir) diff --git a/data_preprocess/sft_rlhf/gsm8k_multiturn_w_tool.py b/data_preprocess/sft_rlhf/gsm8k_multiturn_w_tool.py index 6328eed04..400d88566 100644 --- a/data_preprocess/sft_rlhf/gsm8k_multiturn_w_tool.py +++ b/data_preprocess/sft_rlhf/gsm8k_multiturn_w_tool.py @@ -92,6 +92,10 @@ def process_fn(example, idx): # "release_kwargs": {}, }, }, + "interaction_kwargs": { + "query": question, + "ground_truth": solution, + }, }, } return data diff --git a/data_preprocess/sft_rlhf/gsm8k_tool_agent_loop.py b/data_preprocess/sft_rlhf/gsm8k_tool_agent_loop.py new file mode 100644 index 000000000..1271518b4 --- /dev/null +++ b/data_preprocess/sft_rlhf/gsm8k_tool_agent_loop.py @@ -0,0 +1,117 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Preprocess the GSM8k dataset to parquet format +""" + +import argparse +import os +import re + +import datasets + +from verl.utils.hdfs_io import copy, makedirs + + +def extract_solution(solution_str): + solution = re.search("#### (\\-?[0-9\\.\\,]+)", solution_str) + assert solution is not None + final_solution = solution.group(0) + final_solution = final_solution.split("#### ")[1].replace(",", "") + return final_solution + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--local_dir", default="~/data/gsm8k") + parser.add_argument("--hdfs_dir", default=None) + + args = parser.parse_args() + + data_source = "openai/gsm8k" + dataset = datasets.load_dataset(data_source, "main") + + train_dataset = dataset["train"] + test_dataset = dataset["test"] + + instruction_following = "Let's think step by step and output the final answer after `####`." + + # add a row to each data item that represents a unique id + def make_map_fn(split): + def process_fn(example, idx): + question_raw = example.pop("question") + + question = question_raw + " " + instruction_following + + answer_raw = example.pop("answer") + solution = extract_solution(answer_raw) + data = { + "data_source": data_source, + "agent_name": "tool_agent", + "prompt": [ + { + "role": "system", + "content": ( + "You are a math expert. You are given a question and you need to solve it step by step. " + "Reasoning step by step before any tool call. " + "You should use the `calc_gsm8k_reward` tool after step by step solving the question, " + "before generate final answer at least once and refine your answer if necessary. " + "Put your final answer in the format of `#### `." + ), + }, + { + "role": "user", + "content": question, + }, + ], + "ability": "math", + "reward_model": {"style": "rule", "ground_truth": solution}, + "extra_info": { + "split": split, + "index": idx, + "answer": answer_raw, + "question": question_raw, + "need_tools_kwargs": True, + "tools_kwargs": { + "calc_gsm8k_reward": { + "create_kwargs": {"ground_truth": solution}, + # "execute_kwargs": {}, + # "calc_reward_kwargs": {}, + # "release_kwargs": {}, + }, + }, + "interaction_kwargs": { + "query": question, + "ground_truth": solution, + }, + }, + } + return data + + return process_fn + + train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True) + test_dataset = test_dataset.map(function=make_map_fn("test"), with_indices=True) + + local_dir = args.local_dir + hdfs_dir = args.hdfs_dir + + train_dataset.to_parquet(os.path.join(local_dir, "train.parquet")) + test_dataset.to_parquet(os.path.join(local_dir, "test.parquet")) + + if hdfs_dir is not None: + makedirs(hdfs_dir) + copy(src=local_dir, dst=hdfs_dir) diff --git a/data_preprocess/sft_rlhf/limo.py b/data_preprocess/sft_rlhf/limo.py index 67088bb1e..656803947 100644 --- a/data_preprocess/sft_rlhf/limo.py +++ b/data_preprocess/sft_rlhf/limo.py @@ -27,25 +27,26 @@ def extract_solution(solution_str): solution = re.search(r"\\boxed\{(.*?)\}", solution_str) assert solution is not None - final_solution = solution.group(1) + final_solution = solution.group(0) + final_solution = final_solution.split("#### ")[1].replace(",", "") return final_solution if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--local_dir', default='~/data/limo') - parser.add_argument('--hdfs_dir', default=None) + parser.add_argument("--local_dir", default="~/data/gsm8k") + parser.add_argument("--hdfs_dir", default=None) args = parser.parse_args() - data_source = 'GAIR/LIMO' + data_source = "openai/gsm8k" - dataset = datasets.load_dataset(data_source) + dataset = datasets.load_dataset(data_source, "main") - train_dataset = dataset['train'] - test_dataset = dataset['train'] + train_dataset = dataset["train"] + test_dataset = dataset["test"] - instruction_following = "Let's think step by step and output the final answer in \\boxed\{\}." + instruction_following = 'Let\'s think step by step and output the final answer after "####".' # add a row to each data item that represents a unique id def make_map_fn(split): @@ -54,8 +55,8 @@ def process_fn(example, idx): question = question_raw + " " + instruction_following - solution = example.pop('solution') - exact_answer = example.pop('answer') + answer_raw = example.pop("answer") + solution = extract_solution(answer_raw) data = { "data_source": data_source, "prompt": [ @@ -65,16 +66,13 @@ def process_fn(example, idx): } ], "ability": "math", - "reward_model": { - "style": "rule", - "ground_truth": exact_answer - }, + "reward_model": {"style": "rule", "ground_truth": solution}, "extra_info": { - 'split': split, - 'index': idx, - 'answer': solution, - "question": question, # TODO: use prompt \\boxed later - } + "split": split, + "index": idx, + "answer": answer_raw, + "question": question_raw, + }, } return data @@ -86,8 +84,8 @@ def process_fn(example, idx): local_dir = args.local_dir hdfs_dir = args.hdfs_dir - train_dataset.to_parquet(os.path.join(local_dir, 'train_w_prompt.parquet')) - test_dataset.to_parquet(os.path.join(local_dir, 'test_w_prompt.parquet')) + train_dataset.to_parquet(os.path.join(local_dir, "train.parquet")) + test_dataset.to_parquet(os.path.join(local_dir, "test.parquet")) if hdfs_dir is not None: makedirs(hdfs_dir) diff --git a/data_preprocess/sft_rlhf/multiturn.py b/data_preprocess/sft_rlhf/multiturn.py index f312cfbf4..4bf0192b8 100644 --- a/data_preprocess/sft_rlhf/multiturn.py +++ b/data_preprocess/sft_rlhf/multiturn.py @@ -51,12 +51,14 @@ def main(): {"role": "user", "content": "Can you explain quantum computing?"}, { "role": "assistant", - "content": "Quantum computing is a type of computing that uses quantum-mechanical phenomena, such as superposition and entanglement, to perform operations on data.", + "content": "Quantum computing is a type of computing that uses quantum-mechanical phenomena, " + "such as superposition and entanglement, to perform operations on data.", }, {"role": "user", "content": "How is it different from classical computing?"}, { "role": "assistant", - "content": "Classical computing uses bits that are either 0 or 1, while quantum computing uses quantum bits or qubits that can exist in multiple states simultaneously due to superposition.", + "content": "Classical computing uses bits that are either 0 or 1, while quantum computing uses " + "quantum bits or qubits that can exist in multiple states simultaneously due to superposition.", }, ] } @@ -70,12 +72,19 @@ def main(): {"role": "user", "content": "Write a simple Python function to calculate factorial."}, { "role": "assistant", - "content": "```python\ndef factorial(n):\n if n == 0 or n == 1:\n return 1\n else:\n return n * factorial(n-1)\n```\n\nThis is a recursive function to calculate the factorial of a number.", + "content": ( + "```python\ndef factorial(n):\n if n == 0 or n == 1:\n return 1\n else:\n " + "return n * factorial(n-1)\n```\n\nThis is a recursive function to calculate the " + "factorial of a number." + ), }, {"role": "user", "content": "Can you make it iterative instead?"}, { "role": "assistant", - "content": "```python\ndef factorial(n):\n result = 1\n for i in range(1, n+1):\n result *= i\n return result\n```\n\nThis is an iterative version of the factorial function.", + "content": ( + "```python\ndef factorial(n):\n result = 1\n for i in range(1, n+1):\n " + "result *= i\n return result\n```\n\nThis is an iterative version of the factorial function." + ), }, ] } diff --git a/data_preprocess/sft_rlhf/preprocess_search_r1_dataset.py b/data_preprocess/sft_rlhf/preprocess_search_r1_dataset.py index a602d0203..a0c10d59b 100644 --- a/data_preprocess/sft_rlhf/preprocess_search_r1_dataset.py +++ b/data_preprocess/sft_rlhf/preprocess_search_r1_dataset.py @@ -71,7 +71,11 @@ def process_single_row(row, current_split_name, row_index): data_source_tagged = "searchR1_" + str(row.get("data_source", "")) # Build tools kwargs structure - tools_kwargs = {"search": {"create_kwargs": {"ground_truth": ground_truth, "question": question, "data_source": data_source_tagged}}} + tools_kwargs = { + "search": { + "create_kwargs": {"ground_truth": ground_truth, "question": question, "data_source": data_source_tagged} + } + } # Build complete extra_info structure extra_info = { @@ -155,8 +159,14 @@ def apply_process_row(row, split_name=split): if __name__ == "__main__": parser = argparse.ArgumentParser(description="Download Search-R1 from HuggingFace, process, and save to Parquet.") - parser.add_argument("--hf_repo_id", default="PeterJinGo/nq_hotpotqa_train", help="HuggingFace dataset repository ID.") - parser.add_argument("--local_dir", default="~/data/searchR1_processed_direct", help="Local directory to save the processed Parquet files.") + parser.add_argument( + "--hf_repo_id", default="PeterJinGo/nq_hotpotqa_train", help="HuggingFace dataset repository ID." + ) + parser.add_argument( + "--local_dir", + default="~/data/searchR1_processed_direct", + help="Local directory to save the processed Parquet files.", + ) parser.add_argument("--hdfs_dir", default=None, help="Optional HDFS directory to copy the Parquet files to.") args = parser.parse_args() diff --git a/data_preprocess/step0_clean_data_column.py b/data_preprocess/step0_clean_data_column.py new file mode 100644 index 000000000..9f31907bb --- /dev/null +++ b/data_preprocess/step0_clean_data_column.py @@ -0,0 +1,37 @@ +folder1 = "/mnt/sharefs/users/haonan.li/data/k2/train_dedup_am_12k_len" +folder2 = "/mnt/sharefs/users/haonan.li/data/k2/train_scored_dedup_am_12k_len" + +from datasets import load_dataset +from pathlib import Path +import os +import pandas as pd + +def post_process(dataset): + # Convert Dataset to pandas DataFrame for easier manipulation + df = dataset.to_pandas() + + # remove column "model_pass_rate" "detailed_scores" "scores" + df = df.drop(columns=["model_pass_rate", "detailed_scores", "scores"]) + # rename column "pass_rate" to "r1_0528_pass_rate" + df = df.rename(columns={"pass_rate": "deepseek_r1_0528_pass_rate"}) + + # Convert back to Dataset + from datasets import Dataset + return Dataset.from_pandas(df) + +if not os.path.exists(folder2): + os.makedirs(folder2) + +for filename in sorted(Path(folder1).glob("*.parquet")): + # Option 1: Save as parquet file (no train split structure) + ds = load_dataset("parquet", data_files=str(filename))['train'] + ds = post_process(ds) + ds.to_parquet(f"{folder2}/{filename.name}") + + # Option 2: Save maintaining original structure with train split + # Uncomment the lines below if you want to maintain the Dataset structure + # full_dataset = load_dataset("parquet", data_files=str(filename)) + # processed_train = post_process(full_dataset['train']) + # from datasets import DatasetDict + # new_dataset = DatasetDict({'train': processed_train}) + # new_dataset.save_to_disk(f"{folder2}/{filename.stem}") diff --git a/data_preprocess/step1_dedup.py b/data_preprocess/step1_dedup.py new file mode 100644 index 000000000..7199bc1cf --- /dev/null +++ b/data_preprocess/step1_dedup.py @@ -0,0 +1,72 @@ +from datasets import load_dataset, Dataset +from pathlib import Path +import os +import pandas as pd + +folder1 = "/mnt/sharefs/users/haonan.li/data/k2/train_scored_12k_len_dedup_label_2/" +folder2 = "/mnt/sharefs/users/haonan.li/data/k2/train_scored_12k_len_dedup_eval_rl_am_ot_3" + +# Create output directory if it doesn't exist +if not os.path.exists(folder2): + os.makedirs(folder2) + +# Define the duplicate columns to check +duplicate_columns = [ + "is_duplicate_within_rl", + "is_duplicate_within_am", + "is_duplicate_within_openthoughts", + "is_duplicate_within_eval" +] + +def post_process(dataset): + # Convert Dataset to pandas DataFrame for easier manipulation + df = dataset.to_pandas() + + # Check which duplicate columns exist in this dataset + available_duplicate_cols = [col for col in duplicate_columns if col in df.columns] + + if available_duplicate_cols: + # Create a mask for rows where ANY of the duplicate columns is True + duplicate_mask = df[available_duplicate_cols].any(axis=1) + + # Filter out rows where any duplicate flag is True + df = df[~duplicate_mask] + + # Reset index to avoid duplicate index column issues + df = df.reset_index(drop=True) + + # Convert back to Dataset + return Dataset.from_pandas(df) + +# Create a list to store row counts +row_counts = [] + +for filename in sorted(Path(folder1).glob("*.parquet")): + target_file = f"{folder2}/{filename.name}" + + # Skip if file already exists in target folder + if os.path.exists(target_file): + print(f"Skipping {filename.name} - already exists in target folder") + continue + + # Load and process dataset + ds = load_dataset("parquet", data_files=str(filename))['train'] + original_rows = len(ds) + ds = post_process(ds) + processed_rows = len(ds) + ds.to_parquet(target_file) + + # Record the row count + row_counts.append(f"{filename.name}: {processed_rows} rows (original: {original_rows})") + print(f"Processed {filename.name}: {processed_rows} rows (filtered from {original_rows} rows based on duplicate flags)") + +# Write row counts to a text file +with open(f"{folder2}/row_counts_after_processing.txt", "w") as f: + f.write("Row counts after processing (filtering duplicate flags):\n") + f.write("=" * 60 + "\n") + for count in row_counts: + f.write(count + "\n") + f.write(f"\nTotal files processed: {len(row_counts)}\n") + +print(f"\nRow counts saved to: {folder2}/row_counts_after_processing.txt") + \ No newline at end of file diff --git a/data_preprocess/step2_rm_flipscore.py b/data_preprocess/step2_rm_flipscore.py new file mode 100644 index 000000000..7aee67568 --- /dev/null +++ b/data_preprocess/step2_rm_flipscore.py @@ -0,0 +1,63 @@ +folder1 = "/mnt/sharefs/users/haonan.li/data/k2/train_scored_12k_len_dedup_eval_rl_am_ot_3//" +folder2 = "/mnt/sharefs/users/haonan.li/data/k2/train_scored_dedup_am_12k_len_rm_flipscore_4/" + +from datasets import load_dataset +from pathlib import Path +import os +import pandas as pd + +def post_process(dataset): + # Convert Dataset to pandas DataFrame for easier manipulation + df = dataset.to_pandas() + + # Remove rows where both "qwen3_30b_pass_rate" and "qwen2.5_7b_pass_rate" exist + # and qwen2.5_7b_pass_rate > qwen3_30b_pass_rate + if "qwen3_30b_pass_rate" in df.columns and "qwen2.5_7b_pass_rate" in df.columns: + # Create mask for rows to keep (opposite of condition to remove) + mask = ~((df["qwen3_30b_pass_rate"].notna()) & + (df["qwen2.5_7b_pass_rate"].notna()) & + (df["qwen2.5_7b_pass_rate"] > df["qwen3_30b_pass_rate"]) & + (df["qwen2.5_7b_pass_rate"] > df["deepseek_r1_0528_pass_rate"])) + df = df[mask] + + # Reset index to avoid duplicate index columns issue + df = df.reset_index(drop=True) + + # Convert back to Dataset + from datasets import Dataset + return Dataset.from_pandas(df) + +if not os.path.exists(folder2): + os.makedirs(folder2) + +# Create a list to store row counts +row_counts = [] + +for filename in sorted(Path(folder1).glob("*.parquet")): + target_file = f"{folder2}/{filename.name}" + + # Skip if file already exists in target folder + if os.path.exists(target_file): + print(f"Skipping {filename.name} - already exists in target folder") + continue + + # Option 1: Save as parquet file (no train split structure) + ds = load_dataset("parquet", data_files=str(filename))['train'] + original_rows = len(ds) + ds = post_process(ds) + processed_rows = len(ds) + ds.to_parquet(target_file) + + # Record the row count + row_counts.append(f"{filename.name}: {processed_rows} rows (original: {original_rows})") + print(f"Processed {filename.name}: {processed_rows} rows (removed {original_rows - processed_rows} rows)") + +# Write row counts to a text file +with open(f"{folder2}/row_counts_after_processing.txt", "w") as f: + f.write("Row counts after processing (removing flipscore rows):\n") + f.write("=" * 60 + "\n") + for count in row_counts: + f.write(count + "\n") + f.write(f"\nTotal files processed: {len(row_counts)}\n") + +print(f"\nRow counts saved to: {folder2}/row_counts_after_processing.txt") diff --git a/data_preprocess/step3_scorefilter.py b/data_preprocess/step3_scorefilter.py new file mode 100644 index 000000000..36537e084 --- /dev/null +++ b/data_preprocess/step3_scorefilter.py @@ -0,0 +1,163 @@ +folder1 = "/mnt/sharefs/users/haonan.li/data/k2/train_scored_dedup_am_12k_len_rm_flipscore_4/" +folder2 = "/mnt/sharefs/users/haonan.li/data/k2/train_scored_dedup_am_12k_len_rm_flipscore_score_method_5_1/" +folder3 = "/mnt/sharefs/users/haonan.li/data/k2/train_scored_dedup_am_12k_len_rm_flipscore_score_method_5_2/" +folder4 = "/mnt/sharefs/users/haonan.li/data/k2/train_scored_dedup_am_12k_len_rm_flipscore_score_method_5_3/" + +# 5_1: remove rows where deepseek_r1_0528_pass_rate is 0 and 1 (diff_range=middle) +# 5_2: remove rows where deepseek_r1_0528_pass_rate is 0, keep at most 50% of final dataset size as rows with deepseek_r1_0528_pass_rate is 1 (diff_range=easy) +# 5_3: keep at most 10% of final dataset size as rows with deepseek_r1_0528_pass_rate is 1 (diff_range=wide) + + +from datasets import load_dataset +from pathlib import Path +import os +import pandas as pd + +if not os.path.exists(folder2): + os.makedirs(folder2) +if not os.path.exists(folder3): + os.makedirs(folder3) +if not os.path.exists(folder4): + os.makedirs(folder4) + +def post_process_method_5_1(dataset): + """ + Method 5_1: remove rows where deepseek_r1_0528_pass_rate is 0 and 1 (diff=middle) + This keeps only rows where deepseek_r1_0528_pass_rate is not 0 and not 1 + """ + # Convert Dataset to pandas DataFrame for easier manipulation + df = dataset.to_pandas() + + # Filter based on deepseek_r1_0528_pass_rate + if 'deepseek_r1_0528_pass_rate' in df.columns: + # Keep rows where deepseek_r1_0528_pass_rate is neither 0 nor 1 + filtered_df = df[(df['deepseek_r1_0528_pass_rate'] != 0) & (df['deepseek_r1_0528_pass_rate'] != 1)] + df = filtered_df + df = df.sample(frac=1, random_state=42).reset_index(drop=True) + # Convert back to Dataset + from datasets import Dataset + return Dataset.from_pandas(df) + + +def post_process_method_5_2(dataset): + """ + Method 5_2: remove rows where deepseek_r1_0528_pass_rate is 0, + keep at most 50% of the final dataset size as rows with deepseek_r1_0528_pass_rate is 1 (diff=easy) + """ + # Convert Dataset to pandas DataFrame for easier manipulation + df = dataset.to_pandas() + + if 'deepseek_r1_0528_pass_rate' in df.columns: + # Remove rows where deepseek_r1_0528_pass_rate is 0 + df_no_zeros = df[df['deepseek_r1_0528_pass_rate'] != 0] + + # Separate rows with pass_rate = 1 and others + rows_with_1 = df_no_zeros[df_no_zeros['deepseek_r1_0528_pass_rate'] == 1] + other_rows = df_no_zeros[df_no_zeros['deepseek_r1_0528_pass_rate'] != 1] + + # Calculate how many rows with pass_rate = 1 we can keep + # We want at most 50% of the final dataset to be rows with pass_rate = 1 + # So: rows_with_1_to_keep + len(other_rows) = final_size + # And: rows_with_1_to_keep <= 0.5 * final_size + # Therefore: rows_with_1_to_keep <= 0.5 * (rows_with_1_to_keep + len(other_rows)) + # Solving: rows_with_1_to_keep <= len(other_rows) + max_rows_with_1 = len(other_rows) # This ensures 50% of final dataset + + if len(rows_with_1) > 0: + sample_size = min(len(rows_with_1), max_rows_with_1) + rows_with_1_sampled = rows_with_1.sample(n=sample_size, random_state=42) + else: + rows_with_1_sampled = rows_with_1 + + # Combine the filtered data + df = pd.concat([other_rows, rows_with_1_sampled], ignore_index=True) + df = df.sample(frac=1, random_state=42).reset_index(drop=True) + + # Convert back to Dataset + from datasets import Dataset + return Dataset.from_pandas(df) + + +def post_process_method_5_3(dataset): + """ + Method 5_3: keep at most 10% of the final dataset size as rows with deepseek_r1_0528_pass_rate is 1 (diff=hard) + """ + # Convert Dataset to pandas DataFrame for easier manipulation + df = dataset.to_pandas() + + if 'deepseek_r1_0528_pass_rate' in df.columns: + # Separate rows with pass_rate = 1 and others + rows_with_1 = df[df['deepseek_r1_0528_pass_rate'] == 1] + other_rows = df[df['deepseek_r1_0528_pass_rate'] != 1] + + # Calculate how many rows with pass_rate = 1 we can keep + # We want at most 10% of the final dataset to be rows with pass_rate = 1 + # So: rows_with_1_to_keep + len(other_rows) = final_size + # And: rows_with_1_to_keep <= 0.1 * final_size + # Therefore: rows_with_1_to_keep <= 0.1 * (rows_with_1_to_keep + len(other_rows)) + # Solving: rows_with_1_to_keep <= len(other_rows) / 9 + max_rows_with_1 = max(1, len(other_rows) // 9) # This ensures 10% of final dataset + + if len(rows_with_1) > 0: + sample_size = min(len(rows_with_1), max_rows_with_1) + rows_with_1_sampled = rows_with_1.sample(n=sample_size, random_state=42) + else: + rows_with_1_sampled = rows_with_1 + + # Combine the filtered data + df = pd.concat([other_rows, rows_with_1_sampled], ignore_index=True) + df = df.sample(frac=1, random_state=42).reset_index(drop=True) + + # Convert back to Dataset + from datasets import Dataset + return Dataset.from_pandas(df) + + +def process_folder(source_folder, target_folder, post_process_method, method_name): + """Process all parquet files in source_folder using the specified post_process_method""" + row_counts = [] + + for filename in sorted(Path(source_folder).glob("*.parquet")): + # # to process a single file + # if filename.name != "ifbench__fixed_85.6k.parquet": + # continue + # target_file = f"{target_folder}/{filename.name}" + + # Skip if file already exists in target folder + if os.path.exists(target_file): + print(f"Skipping {filename.name} - already exists in target folder") + continue + + # Load and process the dataset + ds = load_dataset("parquet", data_files=str(filename))['train'] + original_rows = len(ds) + ds = post_process_method(ds) + processed_rows = len(ds) + ds.to_parquet(target_file) + + # Record the row count + row_counts.append(f"{filename.name}: {processed_rows} rows (original: {original_rows})") + print(f"Processed {filename.name}: {processed_rows} rows (filtered from {original_rows} rows using {method_name})") + + # Write row counts to a text file + with open(f"{target_folder}/row_counts_after_processing.txt", "w") as f: + f.write(f"Row counts after processing using {method_name}:\n") + f.write("=" * 60 + "\n") + for count in row_counts: + f.write(count + "\n") + f.write(f"\nTotal files processed: {len(row_counts)}\n") + + print(f"\nRow counts saved to: {target_folder}/row_counts_after_processing.txt") + return row_counts + +# Process each folder with its corresponding method +print("Processing folder2 with method 5_1 (remove pass_rate 0 and 1)...") +process_folder(folder1, folder2, post_process_method_5_1, "method_5_1") + +print("\nProcessing folder3 with method 5_2 (remove pass_rate 0, keep 50% of pass_rate 1)...") +process_folder(folder1, folder3, post_process_method_5_2, "method_5_2") + +print("\nProcessing folder4 with method 5_3 (keep 10% of pass_rate 1)...") +process_folder(folder1, folder4, post_process_method_5_3, "method_5_3") + +print("\nAll processing completed!") diff --git a/data_preprocess/step4_datamix.py b/data_preprocess/step4_datamix.py new file mode 100644 index 000000000..ea8732699 --- /dev/null +++ b/data_preprocess/step4_datamix.py @@ -0,0 +1,116 @@ +folder1 = "/mnt/sharefs/users/haonan.li/data/k2/train_scored_dedup_am_12k_len_rm_flipscore_score_method_5_1" +folder2 = "/mnt/sharefs/users/haonan.li/data/k2/train_scored_dedup_am_12k_len_rm_flipscore_score_method_5_2" +folder3 = "/mnt/sharefs/users/haonan.li/data/k2/train_scored_dedup_am_12k_len_rm_flipscore_score_method_5_3" + +from datasets import load_dataset +from pathlib import Path +import os +import pandas as pd +import shutil + +def create_datamix_folders(): + """Create new folders with _datamix_6 suffix for each source folder""" + folders = [folder1, folder2, folder3] + target_folders = [] + + for folder in folders: + target_folder = folder + "_datamix_6" + if not os.path.exists(target_folder): + os.makedirs(target_folder) + print(f"Created folder: {target_folder}") + else: + print(f"Folder already exists: {target_folder}") + target_folders.append(target_folder) + + return target_folders + + +def process_all_files(source_folder, target_folder): + """Process specific files with restrictions and copy remaining files""" + # Define math files and files to process with specific percentages + math_files = [ + "math__combined_118.2k.part1.parquet", + "math__combined_118.2k.part2.parquet" + ] + + files_to_process = { + "stem__web_31.7k.parquet": 0.3, + "stem__nemotron_13.3k.parquet": 0.3, + "simulation__codeio_fixed_12.1k.parquet": 0.1, + "logic__reasoning_gym_40.6k.parquet": 0.3, + "logic__synlogic_12.1k.parquet": 0.3, + "ifbench__fixed_85.6k.parquet": 0.1 + } + + # Get all parquet files in source folder + all_files = [f for f in os.listdir(source_folder) if f.endswith('.parquet')] + remaining_files = [f for f in all_files if f not in math_files and f not in files_to_process.keys()] + + + # Copy math files + print(f" Found {len(remaining_files)} remaining files to copy") + for filename in math_files: + source_path = os.path.join(source_folder, filename) + target_path = os.path.join(target_folder, filename) + + if os.path.exists(source_path): + shutil.copy2(source_path, target_path) + print(f" Copied {filename}") + else: + print(f" Warning: {filename} not found in {source_folder}") + + # Check math total rows + math_total_rows = 0 + for math_file in math_files: + source_path = os.path.join(source_folder, math_file) + + if os.path.exists(source_path): + ds = load_dataset("parquet", data_files=source_path)['train'] + rows = len(ds) + math_total_rows += rows + print(f" Total math rows: {math_total_rows}") + + # Process specific files with restrictions + print(" Processing specific files with restrictions...") + for filename, percentage in files_to_process.items(): + source_path = os.path.join(source_folder, filename) + target_path = os.path.join(target_folder, filename) + + if os.path.exists(source_path): + ds = load_dataset("parquet", data_files=source_path)['train'] + original_rows = len(ds) + # Calculate max rows based on percentage of math total rows + max_rows = int(math_total_rows * percentage) + print(f" {filename}: max rows allowed ({percentage*100}% of math): {max_rows}") + + if original_rows > max_rows: + # Sample the required number of rows + ds = ds.select(range(max_rows)) + print(f" {filename}: sampled {max_rows} rows from {original_rows}") + else: + print(f" {filename}: kept all {original_rows} rows (within limit)") + + ds.to_parquet(target_path) + else: + print(f" Warning: {filename} not found in {source_folder}") + + + +def main(): + """Main function to process all folders""" + print("Creating datamix folders...") + target_folders = create_datamix_folders() + + source_folders = [folder1, folder2, folder3] + + for i, (source_folder, target_folder) in enumerate(zip(source_folders, target_folders), 1): + print(f"\nProcessing folder {i}: {source_folder}") + print(f"Target folder: {target_folder}") + + print("Processing all files...") + process_all_files(source_folder, target_folder) + + print(f"Completed processing folder {i}") + +if __name__ == "__main__": + main() diff --git a/data_preprocess/step5_stats.py b/data_preprocess/step5_stats.py new file mode 100644 index 000000000..d227a96ff --- /dev/null +++ b/data_preprocess/step5_stats.py @@ -0,0 +1,112 @@ +import os +from pathlib import Path +from typing import List, Dict + +import pandas as pd +from pyarrow import parquet as pq + +folder = "/mnt/sharefs/users/haonan.li/data/k2" + + +def check_write_permission(folder_path: Path) -> bool: + """Check if we have write permission to the folder.""" + try: + # Try to create a temporary file to test write permission + test_file = folder_path / ".write_test_temp" + test_file.touch() + test_file.unlink() # Clean up the test file + return True + except (PermissionError, OSError): + return False + + +def collect_row_counts_for_folder(folder_path: Path) -> List[Dict[str, int]]: + """Collect row counts for each .parquet file directly under a folder.""" + results: List[Dict[str, int]] = [] + for file_path in sorted(folder_path.glob("*.parquet")): + parquet_file = pq.ParquetFile(str(file_path)) + num_rows = parquet_file.metadata.num_rows + results.append({"filename": file_path.name, "rows": num_rows}) + return results + + +def write_table_report(folder_path: Path, rows_info: List[Dict[str, int]], total_rows: int) -> Path: + """Write a formatted table report in the folder and return its path.""" + if not rows_info: + return folder_path / "data_summary.txt" + + # Create DataFrame with percentage column + df = pd.DataFrame(rows_info).sort_values(by="filename") + folder_total = df['rows'].sum() + df['percentage'] = (df['rows'] / folder_total * 100).round(2) # Percentage within this folder + + # Calculate column widths dynamically + max_filename_len = max(len(row['filename']) for row in rows_info) + filename_width = max(max_filename_len + 2, 50) # At least 50 chars, add padding + rows_width = 12 + percentage_width = 10 + + total_width = filename_width + rows_width + percentage_width + 6 # 6 for spacing + + # Write as formatted text table + output_path = folder_path / "data_summary.txt" + with open(output_path, 'w') as f: + f.write("=" * total_width + "\n") + f.write(f"DATA SUMMARY FOR: {folder_path.name}\n") + f.write("=" * total_width + "\n") + f.write(f"{'Filename':<{filename_width}} {'Rows':<{rows_width}} {'Percentage':<{percentage_width}}\n") + f.write("-" * total_width + "\n") + + for _, row in df.iterrows(): + f.write(f"{row['filename']:<{filename_width}} {row['rows']:>{rows_width-2},} {row['percentage']:>{percentage_width-1}.2f}%\n") + + f.write("-" * total_width + "\n") + f.write(f"{'TOTAL':<{filename_width}} {df['rows'].sum():>{rows_width-2},} {100.00:>{percentage_width-1}.2f}%\n") + f.write("=" * total_width + "\n") + + return output_path + + +def main() -> None: + # For each immediate subfolder under base_folder, create a table report with row counts + base_folder = Path(folder) + if not base_folder.exists(): + raise FileNotFoundError(f"Base folder not found: {base_folder}") + + # First pass: collect all data to calculate total + all_data = {} + total_rows = 0 + processed_folders = 0 + + for entry in sorted(base_folder.iterdir()): + if not entry.is_dir(): + continue + + # Check write permission before processing + if not check_write_permission(entry): + print(f"Skipping {entry} - no write permission") + continue + + print(f"Processing {entry}") + rows_info = collect_row_counts_for_folder(entry) + + # Store data for later processing + all_data[entry] = rows_info + folder_total = sum(row_info["rows"] for row_info in rows_info) + total_rows += folder_total + processed_folders += 1 + print(f" -> {folder_total:,} rows in {len(rows_info)} files") + + # Second pass: write reports with percentage calculations + print(f"\nWriting reports with percentages...") + for entry, rows_info in all_data.items(): + table_path = write_table_report(entry, rows_info, total_rows) + print(f" -> Report written to {table_path}") + + print(f"\n=== SUMMARY ===") + print(f"Processed {processed_folders} folders") + print(f"Total rows across all folders: {total_rows:,}") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/docker/Dockerfile.awsefa b/docker/Dockerfile.extention.awsefa similarity index 87% rename from docker/Dockerfile.awsefa rename to docker/Dockerfile.extention.awsefa index 313999f7f..10be07697 100644 --- a/docker/Dockerfile.awsefa +++ b/docker/Dockerfile.extention.awsefa @@ -1,4 +1,6 @@ -FROM whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6-mcore0.12.0-te2.3 +# Base Image support aws EFA +# Build Image with frameworks based on this +FROM verlai/verl:app-verl0.5-sglang0.4.6.post5-mcore0.12.1 # For aws instances with EFA net interface (Sagemaker AI Pod) # install EFA driver: @@ -48,6 +50,6 @@ ENV OMPI_MCA_pml=^cm,ucx \ NCCL_SOCKET_IFNAME=^docker,lo,veth_def_agent \ FI_EFA_USE_HUGE_PAGE=0 -# docker build -t whatcanyousee/verl:awsefa --label "commit=$(git rev-parse --short HEAD)" . +# docker build -t verl:awsefa --label "commit=$(git rev-parse --short HEAD)" . # on aws: -# docker run --ipc=host --privileged --name verldev --gpus all --network=host --shm-size=1800gb -itd whatcanyousee/verl:awsefa +# docker run --ipc=host --privileged --name verldev --gpus all --network=host --shm-size=1800gb -itd verl:awsefa diff --git a/docker/Dockerfile.ngc.vllm0.8 b/docker/Dockerfile.ngc.vllm0.8 index 6a297f0e7..127839fe7 100644 --- a/docker/Dockerfile.ngc.vllm0.8 +++ b/docker/Dockerfile.ngc.vllm0.8 @@ -72,4 +72,4 @@ RUN pip install --no-cache-dir verl[vllm] -U # Reset pip config RUN pip config unset global.index-url && \ - pip config unset global.extra-index-url \ No newline at end of file + pip config unset global.extra-index-url diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm index abd0b6b61..e8c209cd0 100644 --- a/docker/Dockerfile.rocm +++ b/docker/Dockerfile.rocm @@ -1,37 +1,295 @@ -# Build the docker in the repo dir: -# docker build -f docker/Dockerfile.rocm -t verl-rocm:03.04.2015 . -# docker images # you can find your built docker +# FROM "compute-artifactory.amd.com:5000/rocm-plus-docker/framework/compute-rocm-rel-6.4:94_ubuntu22.04_py3.10_pytorch_release-2.7_575e247" +FROM "rlfoundation.azurecr.io/rocm6.3.4:vllm-0.8.5-numa-patch-ubuntu-22.04" +SHELL ["/bin/bash", "-ceuxo", "pipefail"] -# Support - Traing: fsdp; Inference: vllm -# FROM rocm/vllm:rocm6.2_mi300_ubuntu20.04_py3.9_vllm_0.6.4 -# Support - Traing: fsdp; Inference: vllm, sglang -FROM lmsysorg/sglang:v0.4.6.post5-rocm630 +ENV MAX_JOBS=512 -# Set working directory -# WORKDIR $PWD/app +ENV PATH="/usr/local/python3.12/bin:$PATH" +RUN ln -sf /usr/bin/python3.12 /usr/bin/python && \ + ln -sf /usr/bin/pip3.12 /usr/bin/pip +############################################ +############################################ +RUN apt-get update +RUN apt-get install -y pkg-config liblzma-dev +############################################ +############################################ + + +########################################### +##########Install TransformerEngine######## +########################################### +WORKDIR /workspace/ +# transformer-engine install +# https://github.com/ROCm/TransformerEngine + +RUN rm -rf TransformerEngine +RUN git clone --recursive https://github.com/ROCm/TransformerEngine.git +WORKDIR /workspace/TransformerEngine +RUN git checkout 236178e5 +# git checkout bb061ade +# git checkout 864405c + +ENV NVTE_FRAMEWORK=pytorch +ENV NVTE_ROCM_ARCH=gfx942 +ENV NVTE_USE_HIPBLASLT=1 +ENV NVTE_USE_ROCM=1 + +# export CMAKE_PREFIX_PATH="/opt/rocm:/opt/rocm/hip:/usr/local:/usr:${CMAKE_PREFIX_PATH:-}" +ENV CMAKE_PREFIX_PATH="/opt/rocm:/opt/rocm/hip:/usr/local:/usr" + + +# ENV NVTE_BUILD_MAX_JOBS=$(MAX_JOBS) + +RUN MAX_JOBS=$(MAX_JOBS) pip install . -vvv + +WORKDIR /workspace/ +########################################### +########################################### +########################################### + + + + + +#################################################################################### +################Install vllm - sglang require vllm 0.6.7 dependency################# +#################################################################################### +#### Require vllm 0.6.7 - checkout 113274a0 +WORKDIR /workspace/ +RUN rm -rf vllm +RUN pip uninstall -y vllm +# Refer to here (down-grade vllm to 0.6.3): https://docs.vllm.ai/en/v0.6.3/getting_started/amd-installation.html +RUN git clone https://github.com/ROCm/vllm.git +# git clone https://github.com/vllm-project/vllm.git +WORKDIR /workspace/vllm +RUN git checkout 113274a0 +ENV PYTORCH_ROCM_ARCH="gfx90a;gfx942" +#ENV MAX_JOBS=512 +ENV MAX_JOBS=${MAX_JOBS} +RUN pip install "boto3>=1.26.0" +RUN pip install setuptools_scm +# will add src into py. You can delete the repo +RUN python3 setup.py install +WORKDIR /workspace/ +#################################################################################### +#################################################################################### +#################################################################################### + + + +########################################### +############For hack docker################ +########################################### +RUN pip install setuptools==75.8.0 +########################################### +########################################### +########################################### + + + +########################################### +############build sgalng################### +########################################### # Set environment variables +ENV BASE_DIR=/sgl-workspace +ENV BUILD_TYPE=all +ENV SGL_REPO=https://github.com/sgl-project/sglang +ENV SGL_BRANCH=v0.4.6.post5 +ENV TRITON_REPO=https://github.com/ROCm/triton.git +ENV TRITON_COMMIT=improve_fa_decode_3.0.0 +ENV AITER_REPO=https://github.com/ROCm/aiter.git +ENV AITER_COMMIT=v0.1.2 +# v0.1.2 version - commit id: 9d11f47 +# ENV AITER_COMMIT=9d11f47 + +ENV HIP_FORCE_DEV_KERNARG=1 +ENV HSA_NO_SCRATCH_RECLAIM=1 +ENV SGLANG_SET_CPU_AFFINITY=1 +ENV SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1 +ENV NCCL_MIN_NCHANNELS=112 +ENV MOE_PADDING=1 +ENV VLLM_FP8_PADDING=1 +ENV VLLM_FP8_ACT_PADDING=1 +ENV VLLM_FP8_WEIGHT_PADDING=1 +ENV VLLM_FP8_REDUCE_CONV=1 +ENV TORCHINDUCTOR_MAX_AUTOTUNE=1 +ENV TORCHINDUCTOR_MAX_AUTOTUNE_POINTWISE=1 +ENV HIPCC_COMPILE_FLAGS_APPEND="--offload-arch=gfx942" +ENV AMDGPU_TARGETS=gfx942 +ENV ROCM_ARCH=gfx942 ENV PYTORCH_ROCM_ARCH="gfx90a;gfx942" +# Switch to working directory +WORKDIR /sgl-workspace + +# Clean and create directory +RUN rm -rf /sgl-workspace && mkdir -p /sgl-workspace + +# Clone and build sglang +RUN git clone ${SGL_REPO} \ + && cd sglang \ + && git checkout ${SGL_BRANCH} || echo "Using default branch" \ + && cd sgl-kernel \ + && rm -f pyproject.toml \ + && mv pyproject_rocm.toml pyproject.toml \ + && python setup_rocm.py install \ + && cd .. \ + && if [ "$BUILD_TYPE" = "srt" ]; then \ + python -m pip --no-cache-dir install -e "python[srt_hip]"; \ + else \ + python -m pip --no-cache-dir install -e "python[all_hip]"; \ + fi \ + && cd /sgl-workspace \ + && cp -r /sgl-workspace/sglang /sglang \ + && python -m pip cache purge + +# Install common Python packages +RUN pip install IPython orjson python-multipart torchao pybind11 + +# Rebuild Triton +RUN pip uninstall -y triton || true \ + && git clone ${TRITON_REPO} \ + && cd triton \ + && git checkout ${TRITON_COMMIT} \ + && cd python \ + && python3 setup.py install \ + && cd /sgl-workspace + + +# ENV HIPCC_COMPILE_FLAGS_APPEND="--offload-arch=gfx942 --amdgpu-lower-module-lds-strategy=1" +# ENV HIPCC_COMPILE_FLAGS_APPEND="--offload-arch=gfx942" + +# Build aiter +#version: Commit 9d11f47 + # && git checkout ${AITER_COMMIT} \ +RUN pip uninstall -y aiter || true +RUN git clone ${AITER_REPO} \ + && cd aiter \ + && git checkout ${AITER_COMMIT} \ + && git submodule sync \ + && git submodule update --init --recursive \ + && PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py install \ + && cd /sgl-workspace + # && PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py develop \ + # && PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py develop \ + +# Copy MI300X config +RUN find /sgl-workspace/sglang/python/sglang/srt/layers/quantization/configs/ \ + /sgl-workspace/sglang/python/sglang/srt/layers/moe/fused_moe_triton/configs/ \ + -type f -name '*MI300X*' | \ + xargs -I {} sh -c 'vf_config=$(echo "$1" | sed "s/MI300X/MI300X_VF/"); cp "$1" "$vf_config"' -- {} + +# Environment setup complete. +RUN echo "Environment setup complete." + +WORKDIR /workspace/ +########################################### +########################################### +########################################### + + + + + + +########################################### +###############vllm v0.8.5################# +########################################### +# ENV GITHUB_USERNAME=yushengsu-thu +# ENV GITHUB_MAIL=yushengsu@gmail.com + +# RUN git config --global user.name "${GITHUB_USERNAME}" \ +# && git config --global user.email "${GITHUB_MAIL}" + +WORKDIR /workspace/ + +ENV VLLM_TARGET_DEVICE=rocm +ENV ROCM_PATH=/opt/rocm +ENV SETUPTOOLS_SCM_PRETEND_VERSION=0.8.5.dev + +# Find the repo path in: DockerFile/Dockerfile.rocm_yang +# RUN git clone https://github.com/RLFoundation/vllm-patch.git +RUN pip uninstall -y vllm || true +RUN rm -rf vllm-patch +RUN git clone https://github.com/RLFoundation/vllm-patch.git \ + && cd vllm-patch \ + && git checkout v0.8.5-sleep-numa \ + && rm -rf build/ dist/ *.egg-info \ + && ln -sf /opt/rocm/lib/libamdhip64.so /usr/lib/libamdhip64.so \ + && SETUPTOOLS_SCM_PRETEND_VERSION=0.8.5.dev PYTORCH_ROCM_ARCH="gfx90a;gfx942" MAX_JOBS=${MAX_JOBS} python3 setup.py install + # RUN SETUPTOOLS_SCM_PRETEND_VERSION=0.8.5.dev PYTORCH_ROCM_ARCH="gfx90a;gfx942" MAX_JOBS=${MAX_JOBS} python3 setup.py develop + +WORKDIR /workspace/ +########################################### +########################################### +########################################### + + + + +######################################### +#### Install megatron-core############### +######################################### +RUN pip uninstall -y megatron-core && \ + git clone https://github.com/yushengsu-thu/Megatron-LM-amd_version.git && \ + cd Megatron-LM-amd_version && \ + pip install -vvv -e . && \ + cd /workspace/ +######################################### +######################################### +######################################### + + + + +####################################### +################apex################### +####################################### +WORKDIR /workspace/ +RUN pip uninstall -y apex && \ + git clone https://github.com/ROCm/apex.git && \ + cd apex && \ + python setup.py install && \ + cd /workspace/ +####################################### +####################################### +####################################### + + + + +################################################################################ +###########################Add torch_memory_saver############################### +################################################################################ +# Set environment variables ENV HIPCC_COMPILE_FLAGS_APPEND="--amdgpu-target=gfx90a;gfx942 -D__HIP_PLATFORM_AMD__" ENV CFLAGS="-D__HIP_PLATFORM_AMD__" ENV CXXFLAGS="-D__HIP_PLATFORM_AMD__" +RUN pip install "git+https://github.com/YangWang92/torch_memory_saver_numa.git@numa" +################################################################################ +################################################################################ +################################################################################ + + + +######################################## +######Install ray####################### +######################################## +# need to add this patch: https://github.com/ray-project/ray/pull/53531/files +RUN pip uninstall ray -y +RUN pip install "ray[data,train,tune,serve]>=2.47.0" +######################################## +######################################## +######################################## -# Install vllm -RUN pip uninstall -y vllm && \ - rm -rf vllm && \ - git clone -b v0.6.3 https://github.com/vllm-project/vllm.git && \ - cd vllm && \ - MAX_JOBS=$(nproc) python3 setup.py install && \ - cd .. && \ - rm -rf vllm -# Copy the entire project directory -COPY . . -# Install dependencies -RUN pip install "tensordict<0.6" --no-deps && \ +########################################## +#######Install other dependencies######### +########################################## +RUN pip install "tensordict==0.6.2" --no-deps && \ pip install accelerate \ codetiming \ datasets \ @@ -43,13 +301,21 @@ RUN pip install "tensordict<0.6" --no-deps && \ peft \ "pyarrow>=15.0.0" \ pylatexenc \ - "ray[data,train,tune,serve]<2.45.0" \ torchdata \ - transformers \ wandb \ orjson \ - pybind11 && \ - pip install -e . --no-deps + pybind11 + +WORKDIR /workspace/ +RUN git clone https://github.com/volcengine/verl.git && \ + cd verl && \ + pip install -e . +########################################## +########################################## +########################################## + + + +WORKDIR /workspace/ -# Install torch_memory_saver -RUN pip install git+https://github.com/ExtremeViscent/torch_memory_saver.git --no-deps \ No newline at end of file +CMD ["/usr/bin/bash"] diff --git a/docker/Dockerfile.rocm_verl-0.3.0.post1 b/docker/Dockerfile.rocm_verl-0.3.0.post1 new file mode 100644 index 000000000..185096d9d --- /dev/null +++ b/docker/Dockerfile.rocm_verl-0.3.0.post1 @@ -0,0 +1,58 @@ +# Build the docker in the repo dir: +# docker build -f docker/Dockerfile.rocm -t verl-rocm:03.04.2015 . +# docker images # you can find your built docker + + +# Support - Traing: fsdp; Inference: vllm +# FROM rocm/vllm:rocm6.2_mi300_ubuntu20.04_py3.9_vllm_0.6.4 +# Support - Traing: fsdp; Inference: vllm, sglang +FROM lmsysorg/sglang:v0.4.6.post5-rocm630 + +# Set working directory +# WORKDIR $PWD/app + +# Set environment variables +ENV PYTORCH_ROCM_ARCH="gfx90a;gfx942" + +ENV HIPCC_COMPILE_FLAGS_APPEND="--amdgpu-target=gfx90a;gfx942 -D__HIP_PLATFORM_AMD__" +ENV CFLAGS="-D__HIP_PLATFORM_AMD__" +ENV CXXFLAGS="-D__HIP_PLATFORM_AMD__" + +# Install vllm +RUN pip uninstall -y vllm && \ + rm -rf vllm && \ + git clone -b v0.6.3 https://github.com/vllm-project/vllm.git && \ + cd vllm && \ + MAX_JOBS=$(nproc) python3 setup.py install && \ + cd .. && \ + rm -rf vllm + +# Copy the entire project directory +COPY . . + +# Install dependencies +RUN pip install "tensordict==0.6.2" --no-deps && \ + pip install accelerate \ + codetiming \ + datasets \ + dill \ + hydra-core \ + liger-kernel \ + numpy \ + pandas \ + peft \ + "pyarrow>=15.0.0" \ + pylatexenc \ + "ray[data,train,tune,serve]<2.45.0" \ + torchdata \ + transformers \ + wandb \ + orjson \ + pybind11 + +RUN git clone https://github.com/volcengine/verl.git && \ + cd verl && \ + pip install -e . + +# Install torch_memory_saver +RUN pip install git+https://github.com/ExtremeViscent/torch_memory_saver.git --no-deps diff --git a/docker/Dockerfile.rocm_verl-0.4.1 b/docker/Dockerfile.rocm_verl-0.4.1 new file mode 100644 index 000000000..b6d30521b --- /dev/null +++ b/docker/Dockerfile.rocm_verl-0.4.1 @@ -0,0 +1,322 @@ +# FROM "compute-artifactory.amd.com:5000/rocm-plus-docker/framework/compute-rocm-rel-6.4:94_ubuntu22.04_py3.10_pytorch_release-2.7_575e247" +FROM "rlfoundation.azurecr.io/rocm6.3.4:vllm-0.8.5-numa-patch-ubuntu-22.04" + +SHELL ["/bin/bash", "-ceuxo", "pipefail"] + +ENV MAX_JOBS=512 + +ENV PATH="/usr/local/python3.12/bin:$PATH" +RUN ln -sf /usr/bin/python3.12 /usr/bin/python && \ + ln -sf /usr/bin/pip3.12 /usr/bin/pip + +############################################ +############################################ +RUN apt-get update +RUN apt-get install -y pkg-config liblzma-dev +############################################ +############################################ + + +########################################### +##########Install TransformerEngine######## +########################################### +WORKDIR /workspace/ +# transformer-engine install +# https://github.com/ROCm/TransformerEngine + +RUN rm -rf TransformerEngine +RUN git clone --recursive https://github.com/ROCm/TransformerEngine.git +WORKDIR /workspace/TransformerEngine +RUN git checkout 236178e5 +# git checkout bb061ade +# git checkout 864405c + +ENV NVTE_FRAMEWORK=pytorch +ENV NVTE_ROCM_ARCH=gfx942 +ENV NVTE_USE_HIPBLASLT=1 +ENV NVTE_USE_ROCM=1 + +# export CMAKE_PREFIX_PATH="/opt/rocm:/opt/rocm/hip:/usr/local:/usr:${CMAKE_PREFIX_PATH:-}" +ENV CMAKE_PREFIX_PATH="/opt/rocm:/opt/rocm/hip:/usr/local:/usr" + + +# ENV NVTE_BUILD_MAX_JOBS=$(MAX_JOBS) + +RUN MAX_JOBS=$(MAX_JOBS) pip install . -vvv + +WORKDIR /workspace/ +########################################### +########################################### +########################################### + + + + + +#################################################################################### +################Install vllm - sglang require vllm 0.6.7 dependency################# +#################################################################################### +#### Require vllm 0.6.7 - checkout 113274a0 +WORKDIR /workspace/ +RUN rm -rf vllm +RUN pip uninstall -y vllm +# Refer to here (down-grade vllm to 0.6.3): https://docs.vllm.ai/en/v0.6.3/getting_started/amd-installation.html +RUN git clone https://github.com/ROCm/vllm.git +# git clone https://github.com/vllm-project/vllm.git +WORKDIR /workspace/vllm +RUN git checkout 113274a0 +ENV PYTORCH_ROCM_ARCH="gfx90a;gfx942" +#ENV MAX_JOBS=512 +ENV MAX_JOBS=${MAX_JOBS} +RUN pip install "boto3>=1.26.0" +RUN pip install setuptools_scm +# will add src into py. You can delete the repo +RUN python3 setup.py install +WORKDIR /workspace/ +#################################################################################### +#################################################################################### +#################################################################################### + + + +########################################### +############For hack docker################ +########################################### +RUN pip install setuptools==75.8.0 +########################################### +########################################### +########################################### + + + +########################################### +############build sgalng################### +########################################### +# Set environment variables +ENV BASE_DIR=/sgl-workspace +ENV BUILD_TYPE=all +ENV SGL_REPO=https://github.com/sgl-project/sglang +ENV SGL_BRANCH=v0.4.6.post5 +ENV TRITON_REPO=https://github.com/ROCm/triton.git +ENV TRITON_COMMIT=improve_fa_decode_3.0.0 +ENV AITER_REPO=https://github.com/ROCm/aiter.git +ENV AITER_COMMIT=v0.1.2 +# v0.1.2 version - commit id: 9d11f47 +# ENV AITER_COMMIT=9d11f47 + +ENV HIP_FORCE_DEV_KERNARG=1 +ENV HSA_NO_SCRATCH_RECLAIM=1 +ENV SGLANG_SET_CPU_AFFINITY=1 +ENV SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1 +ENV NCCL_MIN_NCHANNELS=112 +ENV MOE_PADDING=1 +ENV VLLM_FP8_PADDING=1 +ENV VLLM_FP8_ACT_PADDING=1 +ENV VLLM_FP8_WEIGHT_PADDING=1 +ENV VLLM_FP8_REDUCE_CONV=1 +ENV TORCHINDUCTOR_MAX_AUTOTUNE=1 +ENV TORCHINDUCTOR_MAX_AUTOTUNE_POINTWISE=1 +ENV HIPCC_COMPILE_FLAGS_APPEND="--offload-arch=gfx942" +ENV AMDGPU_TARGETS=gfx942 +ENV ROCM_ARCH=gfx942 +ENV PYTORCH_ROCM_ARCH="gfx90a;gfx942" + +# Switch to working directory +WORKDIR /sgl-workspace + +# Clean and create directory +RUN rm -rf /sgl-workspace && mkdir -p /sgl-workspace + +# Clone and build sglang +RUN git clone ${SGL_REPO} \ + && cd sglang \ + && git checkout ${SGL_BRANCH} || echo "Using default branch" \ + && cd sgl-kernel \ + && rm -f pyproject.toml \ + && mv pyproject_rocm.toml pyproject.toml \ + && python setup_rocm.py install \ + && cd .. \ + && if [ "$BUILD_TYPE" = "srt" ]; then \ + python -m pip --no-cache-dir install -e "python[srt_hip]"; \ + else \ + python -m pip --no-cache-dir install -e "python[all_hip]"; \ + fi \ + && cd /sgl-workspace \ + && cp -r /sgl-workspace/sglang /sglang \ + && python -m pip cache purge + +# Install common Python packages +RUN pip install IPython orjson python-multipart torchao pybind11 + +# Rebuild Triton +RUN pip uninstall -y triton || true \ + && git clone ${TRITON_REPO} \ + && cd triton \ + && git checkout ${TRITON_COMMIT} \ + && cd python \ + && python3 setup.py install \ + && cd /sgl-workspace + + +# ENV HIPCC_COMPILE_FLAGS_APPEND="--offload-arch=gfx942 --amdgpu-lower-module-lds-strategy=1" +# ENV HIPCC_COMPILE_FLAGS_APPEND="--offload-arch=gfx942" + +# Build aiter +#version: Commit 9d11f47 + # && git checkout ${AITER_COMMIT} \ +RUN pip uninstall -y aiter || true +RUN git clone ${AITER_REPO} \ + && cd aiter \ + && git checkout ${AITER_COMMIT} \ + && git submodule sync \ + && git submodule update --init --recursive \ + && PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py install \ + && cd /sgl-workspace + # && PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py develop \ + # && PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py develop \ + +# Copy MI300X config +RUN find /sgl-workspace/sglang/python/sglang/srt/layers/quantization/configs/ \ + /sgl-workspace/sglang/python/sglang/srt/layers/moe/fused_moe_triton/configs/ \ + -type f -name '*MI300X*' | \ + xargs -I {} sh -c 'vf_config=$(echo "$1" | sed "s/MI300X/MI300X_VF/"); cp "$1" "$vf_config"' -- {} + +# Environment setup complete. +RUN echo "Environment setup complete." + +WORKDIR /workspace/ +########################################### +########################################### +########################################### + + + + + + +########################################### +###############vllm v0.8.5################# +########################################### +# ENV GITHUB_USERNAME=yushengsu-thu +# ENV GITHUB_MAIL=yushengsu@gmail.com + +# RUN git config --global user.name "${GITHUB_USERNAME}" \ +# && git config --global user.email "${GITHUB_MAIL}" + +WORKDIR /workspace/ + +ENV VLLM_TARGET_DEVICE=rocm +ENV ROCM_PATH=/opt/rocm +ENV SETUPTOOLS_SCM_PRETEND_VERSION=0.8.5.dev + +# Find the repo path in: DockerFile/Dockerfile.rocm_yang +# RUN git clone https://github.com/RLFoundation/vllm-patch.git +RUN pip uninstall -y vllm || true +RUN rm -rf vllm-patch +RUN git clone https://github.com/RLFoundation/vllm-patch.git \ + && cd vllm-patch \ + && git checkout v0.8.5-sleep-numa \ + && rm -rf build/ dist/ *.egg-info \ + && ln -sf /opt/rocm/lib/libamdhip64.so /usr/lib/libamdhip64.so \ + && SETUPTOOLS_SCM_PRETEND_VERSION=0.8.5.dev PYTORCH_ROCM_ARCH="gfx90a;gfx942" MAX_JOBS=${MAX_JOBS} python3 setup.py install + # RUN SETUPTOOLS_SCM_PRETEND_VERSION=0.8.5.dev PYTORCH_ROCM_ARCH="gfx90a;gfx942" MAX_JOBS=${MAX_JOBS} python3 setup.py develop + +WORKDIR /workspace/ +########################################### +########################################### +########################################### + + + + +######################################### +#### Install megatron-core############### +######################################### +RUN pip uninstall -y megatron-core && \ + git clone https://github.com/yushengsu-thu/Megatron-LM-amd_version.git && \ + cd Megatron-LM-amd_version && \ + pip install -vvv -e . && \ + cd /workspace/ +######################################### +######################################### +######################################### + + + + +####################################### +################apex################### +####################################### +WORKDIR /workspace/ +RUN pip uninstall -y apex && \ + git clone https://github.com/ROCm/apex.git && \ + cd apex && \ + python setup.py install && \ + cd /workspace/ +####################################### +####################################### +####################################### + + + + +################################################################################ +###########################Add torch_memory_saver############################### +################################################################################ +# Set environment variables +ENV HIPCC_COMPILE_FLAGS_APPEND="--amdgpu-target=gfx90a;gfx942 -D__HIP_PLATFORM_AMD__" +ENV CFLAGS="-D__HIP_PLATFORM_AMD__" +ENV CXXFLAGS="-D__HIP_PLATFORM_AMD__" +RUN pip install "git+https://github.com/YangWang92/torch_memory_saver_numa.git@numa" +################################################################################ +################################################################################ +################################################################################ + + + +######################################## +######Install ray####################### +######################################## +# need to add this patch: https://github.com/ray-project/ray/pull/53531/files +RUN pip uninstall ray -y +RUN pip install "ray[data,train,tune,serve]>=2.47.0" +######################################## +######################################## +######################################## + + + +########################################## +#######Install other dependencies######### +########################################## +RUN pip install "tensordict==0.6.2" --no-deps && \ + pip install accelerate \ + codetiming \ + datasets \ + dill \ + hydra-core \ + liger-kernel \ + numpy \ + pandas \ + peft \ + "pyarrow>=15.0.0" \ + pylatexenc \ + torchdata \ + wandb \ + orjson \ + pybind11 + +WORKDIR /workspace/ +RUN git clone https://github.com/volcengine/verl.git && \ + cd verl && \ + pip install -e . +########################################## +########################################## +########################################## + + + +WORKDIR /workspace/ + +CMD ["/usr/bin/bash"] +CMD ["/usr/bin/bash"] diff --git a/docker/Dockerfile.vllm.sglang.megatron b/docker/Dockerfile.vllm.sglang.megatron.deepseek similarity index 85% rename from docker/Dockerfile.vllm.sglang.megatron rename to docker/Dockerfile.vllm.sglang.megatron.deepseek index 892199017..784537180 100644 --- a/docker/Dockerfile.vllm.sglang.megatron +++ b/docker/Dockerfile.vllm.sglang.megatron.deepseek @@ -61,11 +61,11 @@ RUN aria2c --always-resume=true --max-tries=99999 https://developer.download.nvi # torch-2.6.0+cu126: cxx11abi=True # see https://github.com/flashinfer-ai/flashinfer/issues/911 # Install sglang-0.4.6.post1 and torch-memory-saver -RUN pip install "sglang[all]==0.4.6.post5" --no-cache-dir --find-links https://flashinfer.ai/whl/cu124/torch2.6/flashinfer-python && pip install torch-memory-saver --no-cache-dir +RUN pip install --resume-retries 999 "sglang[all]==0.4.6.post5" --no-cache-dir --find-links https://flashinfer.ai/whl/cu124/torch2.6/flashinfer-python && pip install --resume-retries 999 torch-memory-saver --no-cache-dir -RUN pip install --no-cache-dir "vllm==0.8.5.post1" "torch==2.6.0" "torchvision==0.21.0" "torchaudio==2.6.0" "tensordict==0.6.2" torchdata +RUN pip install --resume-retries 999 --no-cache-dir "vllm==0.8.5.post1" "torch==2.6.0" "torchvision==0.21.0" "torchaudio==2.6.0" "tensordict==0.6.2" torchdata -RUN pip install --no-cache-dir "transformers[hf_xet]>=4.51.0" accelerate datasets peft hf-transfer \ +RUN pip install --resume-retries 999 --no-cache-dir "transformers[hf_xet]>=4.51.0" accelerate datasets peft hf-transfer \ "numpy<2.0.0" "pyarrow>=15.0.0" pandas \ ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler blobfile \ pytest py-spy pyext pre-commit ruff @@ -76,7 +76,7 @@ RUN wget -nv https://github.com/Dao-AILab/flash-attention/releases/download/v2.7 # Fix packages RUN pip uninstall -y pynvml nvidia-ml-py && \ - pip install --no-cache-dir --upgrade "nvidia-ml-py>=12.560.30" "fastapi[standard]>=0.115.0" "optree>=0.13.0" "pydantic>=2.9" "grpcio>=1.62.1" + pip install --resume-retries 999 --no-cache-dir --upgrade "nvidia-ml-py>=12.560.30" "fastapi[standard]>=0.115.0" "optree>=0.13.0" "pydantic>=2.9" "grpcio>=1.62.1" # Install cudnn RUN aria2c --max-tries=9999 https://developer.download.nvidia.com/compute/cudnn/9.8.0/local_installers/cudnn-local-repo-ubuntu2204-9.8.0_1.0-1_amd64.deb && \ @@ -86,7 +86,7 @@ RUN aria2c --max-tries=9999 https://developer.download.nvidia.com/compute/cudnn/ apt-get -y install cudnn-cuda-12 && \ rm cudnn-local-repo-ubuntu2204-9.8.0_1.0-1_amd64.deb -RUN pip install --no-cache-dir nvidia-cudnn-cu12==9.8.0.87 +RUN pip install --resume-retries 999 --no-cache-dir nvidia-cudnn-cu12==9.8.0.87 # Install Apex RUN git clone https://github.com/NVIDIA/apex.git && \ @@ -97,7 +97,7 @@ RUN git clone https://github.com/NVIDIA/apex.git && \ RUN export NVTE_FRAMEWORK=pytorch && pip3 install --no-deps --no-cache-dir git+https://github.com/NVIDIA/TransformerEngine.git@v2.3 # Install Megatron-LM -RUN pip3 install --no-deps --no-cache-dir git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.12.0 +RUN pip3 install --no-deps --no-cache-dir git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.12.1 # Fix opencv RUN pip install opencv-python diff --git a/docker/README.md b/docker/README.md new file mode 100644 index 000000000..1d19e8341 --- /dev/null +++ b/docker/README.md @@ -0,0 +1,79 @@ +# Dockerfiles of verl + +We provide pre-built Docker images for quick setup. And from this version, we utilize a new image release hierarchy for productivity and stability. + +The image types are divided into three large categories: + +- **Base Image**: Without inference and training frameworks, only basic dependencies are installed. Can directly install vllm or SGLang on top of it, without need of reinstall torch or CUDA. +- **Application Image**: Stable version with inference and training frameworks installed. +- **Preview Image**: Unstable version with the latest frameworks and features. + +The first two types of images are hosted on dockerhub [verlai/verl](https://hub.docker.com/r/verlai/verl) repository, while the preview images are hosted on community repository. + +> The image versions are mapped with verl releases, for example, image with tag ``verl0.4`` is built for verl release ``v0.4.x``. + +## Base Image + +The stable base image is ``verlai/verl:base-verl0.4-cu124-cudnn9.8-torch2.6-fa2.7.4``. The installed package versions can be found from tags, and the Dockerfile can be found in ``verl[version]-[packages]/Dockerfile.base``. + +The base images for preview are ``verlai/verl:base-verl0.5-cu126-cudnn9.8-torch2.7.1-fa2.8.0`` and ``verlai/verl:base-verl0.5-preview-cu128-cudnn9.8-torch2.7.1-fa2.8.0`` with different CUDA versions. + +The update of base image is not frequent, and the app image can be built on top of it without reinstalling base packages. + +## Application Image + +From this version, we divide images built for vLLM and SGLang as the divergence of dependent packages like FlashInfer. + +There are four types of application images available: + +- **vLLM with FSDP and Megatron**: ``verlai/verl:app-verl0.4-vllm0.8.5-mcore0.12.1`` +- **SGLang with FSDP and Megatron**: ``verlai/verl:app-verl0.4-sglang0.4.6.post5-vllm0.8.5-mcore0.12.1`` +- **Preview version of SGLang with FSDP and Megatron, CUDA 12.6**: ``verlai/verl:app-verl0.5-sglang0.4.8-mcore0.12.1`` +- **Preview version of SGLang with FSDP and Megatron, CUDA 12.8**: ``verlai/verl:app-preview-verl0.5-sglang0.4.8-mcore0.12.1`` + +For Megatron 0.13.0, we offer preview images, to use latest codes, just replace ``mcore0.12.1`` with ``mcore0.13.0-preview`` in the above image tag. + +The latest vLLM support is coming soon. + +Docker images with Megatron backends are runnable with large language model like ``Qwen/Qwen3-235B-A22B``, ``deepseek-ai/DeepSeek-V3-0324`` post-training. Refer to the :doc:`Large Language Model Post-Training documentation<../perf/dpsk>` for more details. + +Application images can be updated frequently, and the Dockerfile can be found in ``docker/verl[version]-[packages]/Dockerfile.app.[frameworks]``. Based on the base image, it is easy to build your own application image with the desired inference and training frameworks. + +## Community Image + +For vLLM with FSDP, please refer to [hiyouga/verl](https://hub.docker.com/r/hiyouga/verl) repository and the latest version is ``hiyouga/verl:ngc-th2.6.0-cu126-vllm0.8.4-flashinfer0.2.2-cxx11abi0``. + +For SGLang with FSDP, please refer to [ocss884/verl-sglang](https://hub.docker.com/r/ocss884/verl-sglang) repository and the latest version is ``ocss884/verl-sglang:ngc-th2.6.0-cu126-sglang0.4.6.post5`` which is provided by SGLang RL Group. + +See files under ``docker/`` for NGC-based image or if you want to build your own. + +Note that For aws instances with EFA net interface (Sagemaker AI Pod), you need to install EFA driver as shown in ``docker/Dockerfile.extenstion.awsefa`` + +## Installation from Docker + +After pulling the desired Docker image and installing desired inference and training frameworks, you can run it with the following steps: + +1. Launch the desired Docker image and attach into it: + +```sh +docker create --runtime=nvidia --gpus all --net=host --shm-size="10g" --cap-add=SYS_ADMIN -v .:/workspace/verl --name verl sleep infinity +docker start verl +docker exec -it verl bash +``` + +2. If you use the images provided, you only need to install verl itself without dependencies: + +```sh +# install the nightly version (recommended) +git clone https://github.com/volcengine/verl && cd verl +pip3 install --no-deps -e . +``` + +[Optional] If you hope to switch between different frameworks, you can install verl with the following command: + +```sh +# install the nightly version (recommended) +git clone https://github.com/volcengine/verl && cd verl +pip3 install -e .[vllm] +pip3 install -e .[sglang] +``` diff --git a/docker/verl0.4-cu124-torch2.6-fa2.7.4/Dockerfile.app.sglang.vllm.mcore0.12 b/docker/verl0.4-cu124-torch2.6-fa2.7.4/Dockerfile.app.sglang.vllm.mcore0.12 new file mode 100644 index 000000000..eaa12611e --- /dev/null +++ b/docker/verl0.4-cu124-torch2.6-fa2.7.4/Dockerfile.app.sglang.vllm.mcore0.12 @@ -0,0 +1,41 @@ +# Start from the verl base image +# Dockerfile.base +FROM verlai/verl:base-verl0.4-cu124-cudnn9.8-torch2.6-fa2.7.4 + +# Define environments +ENV MAX_JOBS=32 +ENV VLLM_WORKER_MULTIPROC_METHOD=spawn +ENV DEBIAN_FRONTEND=noninteractive +ENV NODE_OPTIONS="" +ENV PIP_ROOT_USER_ACTION=ignore +ENV HF_HUB_ENABLE_HF_TRANSFER="1" + +# Install sglang-0.4.6.post5 and torch-memory-saver +RUN pip install --resume-retries 999 "sglang[all]==0.4.6.post5" --no-cache-dir --find-links https://flashinfer.ai/whl/cu124/torch2.6/flashinfer-python && pip install torch-memory-saver --no-cache-dir + +# Some sglang operations in 0.4.6.post5 require vllm +# [Warning] vllm can have some packages not compatible with sglang, for example, flashinfer +RUN pip install --resume-retries 999 --no-cache-dir vllm==0.8.5.post1 + +# Fix packages +RUN pip install --no-cache-dir "tensordict==0.6.2" "transformers[hf_xet]>=4.51.0" accelerate datasets peft hf-transfer \ + "numpy<2.0.0" "pyarrow>=19.0.1" pandas \ + ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler blobfile xgrammar \ + pytest py-spy pyext pre-commit ruff + +RUN pip uninstall -y pynvml nvidia-ml-py && \ + pip install --resume-retries 999 --no-cache-dir --upgrade "nvidia-ml-py>=12.560.30" "fastapi[standard]>=0.115.0" "optree>=0.13.0" "pydantic>=2.9" "grpcio>=1.62.1" + +RUN pip install --resume-retries 999 --no-cache-dir nvidia-cudnn-cu12==9.8.0.87 + +# Install TransformerEngine +RUN export NVTE_FRAMEWORK=pytorch && pip3 install --resume-retries 999 --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@v2.3 + +# Install Megatron-LM +RUN pip3 install --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.12.1 + +# Fix for transformers 4.53.0 +RUN pip3 install --no-cache-dir "transformers[hf_xet]<4.52.0" + +# Install mbridge +RUN pip3 install --no-cache-dir mbridge \ No newline at end of file diff --git a/docker/verl0.4-cu124-torch2.6-fa2.7.4/Dockerfile.app.sglang.vllm.mcore0.12.deepep b/docker/verl0.4-cu124-torch2.6-fa2.7.4/Dockerfile.app.sglang.vllm.mcore0.12.deepep new file mode 100644 index 000000000..dc6907610 --- /dev/null +++ b/docker/verl0.4-cu124-torch2.6-fa2.7.4/Dockerfile.app.sglang.vllm.mcore0.12.deepep @@ -0,0 +1,82 @@ +# Start from the verl base image +# Dockerfile.base +FROM verlai/verl:base-verl0.4-cu124-cudnn9.8-torch2.6-fa2.7.4 + +# Define environments +ENV MAX_JOBS=32 +ENV VLLM_WORKER_MULTIPROC_METHOD=spawn +ENV DEBIAN_FRONTEND=noninteractive +ENV NODE_OPTIONS="" +ENV PIP_ROOT_USER_ACTION=ignore +ENV HF_HUB_ENABLE_HF_TRANSFER="1" + +# Install sglang-0.4.6.post5 and torch-memory-saver +RUN pip install --resume-retries 999 "sglang[all]==0.4.6.post5" --no-cache-dir --find-links https://flashinfer.ai/whl/cu124/torch2.6/flashinfer-python && pip install torch-memory-saver --no-cache-dir + +# Some sglang operations in 0.4.6.post5 require vllm +# [Warning] vllm can have some packages not compatible with sglang, for example, flashinfer +RUN pip install --resume-retries 999 --no-cache-dir vllm==0.8.5.post1 + +# Fix packages +RUN pip install --no-cache-dir "tensordict==0.6.2" "transformers[hf_xet]>=4.51.0" accelerate datasets peft hf-transfer \ + "numpy<2.0.0" "pyarrow>=19.0.1" pandas \ + ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler blobfile xgrammar \ + pytest py-spy pyext pre-commit ruff + +RUN pip uninstall -y pynvml nvidia-ml-py && \ + pip install --resume-retries 999 --no-cache-dir --upgrade "nvidia-ml-py>=12.560.30" "fastapi[standard]>=0.115.0" "optree>=0.13.0" "pydantic>=2.9" "grpcio>=1.62.1" + +RUN pip install --resume-retries 999 --no-cache-dir nvidia-cudnn-cu12==9.8.0.87 + +# Install TransformerEngine +RUN export NVTE_FRAMEWORK=pytorch && pip3 install --resume-retries 999 --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@v2.3 + +# Install Megatron-LM +RUN pip3 install --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.12.1 + +# Fix for transformers 4.53.0 +RUN pip3 install --no-cache-dir "transformers[hf_xet]<4.52.0" + +# Install mbridge +RUN pip3 install --no-cache-dir mbridge + +# Install DeepEP +## the dependency of IBGDA +RUN ln -s /usr/lib/x86_64-linux-gnu/libmlx5.so.1 /usr/lib/x86_64-linux-gnu/libmlx5.so + +## Clone and build deepep and deepep-nvshmem +RUN git clone -b v2.3.1 https://github.com/NVIDIA/gdrcopy.git && \ + git clone https://github.com/deepseek-ai/DeepEP.git && \ + cd DeepEP && git checkout a84a248 + +# Prepare nvshmem +RUN wget https://developer.nvidia.com/downloads/assets/secure/nvshmem/nvshmem_src_3.2.5-1.txz && \ + tar -xvf nvshmem_src_3.2.5-1.txz && mv nvshmem_src deepep-nvshmem && \ + cd deepep-nvshmem && git apply ../DeepEP/third-party/nvshmem.patch + +ENV CUDA_HOME=/usr/local/cuda +### Set MPI environment variables. Having errors when not set. +ENV CPATH=/usr/local/mpi/include:$CPATH +ENV LD_LIBRARY_PATH=/usr/local/mpi/lib:$LD_LIBRARY_PATH +ENV LD_LIBRARY_PATH=/usr/local/x86_64-linux-gnu:$LD_LIBRARY_PATH +ENV GDRCOPY_HOME=/workspace/gdrcopy + +## Build deepep-nvshmem +RUN cd deepep-nvshmem && \ + NVSHMEM_SHMEM_SUPPORT=0 \ + NVSHMEM_UCX_SUPPORT=0 \ + NVSHMEM_USE_NCCL=0 \ + NVSHMEM_MPI_SUPPORT=0 \ + NVSHMEM_IBGDA_SUPPORT=1 \ + NVSHMEM_PMIX_SUPPORT=0 \ + NVSHMEM_TIMEOUT_DEVICE_POLLING=0 \ + NVSHMEM_USE_GDRCOPY=1 \ + cmake -G Ninja -S . -B build/ -DCMAKE_INSTALL_PREFIX=/workspace/deepep-nvshmem/install && cmake --build build/ --target install + +ENV NVSHMEM_DIR=/workspace/deepep-nvshmem/install +ENV LD_LIBRARY_PATH=$NVSHMEM_DIR/lib:$LD_LIBRARY_PATH +ENV PATH=$NVSHMEM_DIR/bin:$PATH + +## Build deepep +RUN cd DeepEP && \ + python setup.py install \ No newline at end of file diff --git a/docker/verl0.4-cu124-torch2.6-fa2.7.4/Dockerfile.app.sglang.vllm.mcore0.13.preview b/docker/verl0.4-cu124-torch2.6-fa2.7.4/Dockerfile.app.sglang.vllm.mcore0.13.preview new file mode 100644 index 000000000..0e0bdd43f --- /dev/null +++ b/docker/verl0.4-cu124-torch2.6-fa2.7.4/Dockerfile.app.sglang.vllm.mcore0.13.preview @@ -0,0 +1,82 @@ +# Start from the verl base image +# Dockerfile.base +FROM verlai/verl:base-verl0.4-cu124-cudnn9.8-torch2.6-fa2.7.4 + +# Define environments +ENV MAX_JOBS=32 +ENV VLLM_WORKER_MULTIPROC_METHOD=spawn +ENV DEBIAN_FRONTEND=noninteractive +ENV NODE_OPTIONS="" +ENV PIP_ROOT_USER_ACTION=ignore +ENV HF_HUB_ENABLE_HF_TRANSFER="1" + +# Install sglang-0.4.6.post5 and torch-memory-saver +RUN pip install --resume-retries 999 "sglang[all]==0.4.6.post5" --no-cache-dir --find-links https://flashinfer.ai/whl/cu124/torch2.6/flashinfer-python && pip install torch-memory-saver --no-cache-dir + +# Some sglang operations in 0.4.6.post5 require vllm +# [Warning] vllm can have some packages not compatible with sglang, for example, flashinfer +RUN pip install --resume-retries 999 --no-cache-dir vllm==0.8.5.post1 + +# Fix packages +RUN pip install --no-cache-dir "tensordict==0.6.2" "transformers[hf_xet]>=4.51.0" accelerate datasets peft hf-transfer \ + "numpy<2.0.0" "pyarrow>=19.0.1" pandas \ + ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler blobfile xgrammar \ + pytest py-spy pyext pre-commit ruff + +RUN pip uninstall -y pynvml nvidia-ml-py && \ + pip install --resume-retries 999 --no-cache-dir --upgrade "nvidia-ml-py>=12.560.30" "fastapi[standard]>=0.115.0" "optree>=0.13.0" "pydantic>=2.9" "grpcio>=1.62.1" + +RUN pip install --resume-retries 999 --no-cache-dir nvidia-cudnn-cu12==9.8.0.87 + +# Install TransformerEngine +RUN export NVTE_FRAMEWORK=pytorch && pip3 install --resume-retries 999 --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@release_v2.5 + +# Install Megatron-LM +RUN pip3 install --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/Megatron-LM.git@core_r0.13.0 + +# Fix for transformers 4.53.0 +RUN pip3 install --no-cache-dir "transformers[hf_xet]<4.52.0" + +# Install mbridge +RUN pip3 install --no-cache-dir mbridge + +# Install DeepEP +## the dependency of IBGDA +RUN ln -s /usr/lib/x86_64-linux-gnu/libmlx5.so.1 /usr/lib/x86_64-linux-gnu/libmlx5.so + +## Clone and build deepep and deepep-nvshmem +RUN git clone -b v2.3.1 https://github.com/NVIDIA/gdrcopy.git && \ + git clone https://github.com/deepseek-ai/DeepEP.git && \ + cd DeepEP && git checkout a84a248 + +# Prepare nvshmem +RUN wget https://developer.nvidia.com/downloads/assets/secure/nvshmem/nvshmem_src_3.2.5-1.txz && \ + tar -xvf nvshmem_src_3.2.5-1.txz && mv nvshmem_src deepep-nvshmem && \ + cd deepep-nvshmem && git apply ../DeepEP/third-party/nvshmem.patch + +ENV CUDA_HOME=/usr/local/cuda +### Set MPI environment variables. Having errors when not set. +ENV CPATH=/usr/local/mpi/include:$CPATH +ENV LD_LIBRARY_PATH=/usr/local/mpi/lib:$LD_LIBRARY_PATH +ENV LD_LIBRARY_PATH=/usr/local/x86_64-linux-gnu:$LD_LIBRARY_PATH +ENV GDRCOPY_HOME=/workspace/gdrcopy + +## Build deepep-nvshmem +RUN cd deepep-nvshmem && \ + NVSHMEM_SHMEM_SUPPORT=0 \ + NVSHMEM_UCX_SUPPORT=0 \ + NVSHMEM_USE_NCCL=0 \ + NVSHMEM_MPI_SUPPORT=0 \ + NVSHMEM_IBGDA_SUPPORT=1 \ + NVSHMEM_PMIX_SUPPORT=0 \ + NVSHMEM_TIMEOUT_DEVICE_POLLING=0 \ + NVSHMEM_USE_GDRCOPY=1 \ + cmake -G Ninja -S . -B build/ -DCMAKE_INSTALL_PREFIX=/workspace/deepep-nvshmem/install && cmake --build build/ --target install + +ENV NVSHMEM_DIR=/workspace/deepep-nvshmem/install +ENV LD_LIBRARY_PATH=$NVSHMEM_DIR/lib:$LD_LIBRARY_PATH +ENV PATH=$NVSHMEM_DIR/bin:$PATH + +## Build deepep +RUN cd DeepEP && \ + python setup.py install \ No newline at end of file diff --git a/docker/verl0.4-cu124-torch2.6-fa2.7.4/Dockerfile.app.vllm.mcore0.12 b/docker/verl0.4-cu124-torch2.6-fa2.7.4/Dockerfile.app.vllm.mcore0.12 new file mode 100644 index 000000000..fcf066eda --- /dev/null +++ b/docker/verl0.4-cu124-torch2.6-fa2.7.4/Dockerfile.app.vllm.mcore0.12 @@ -0,0 +1,47 @@ +# Start from the verl base image +# Dockerfile.base +FROM verlai/verl:base-verl0.4-cu124-cudnn9.8-torch2.6-fa2.7.4 + +# Define environments +ENV MAX_JOBS=32 +ENV VLLM_WORKER_MULTIPROC_METHOD=spawn +ENV DEBIAN_FRONTEND=noninteractive +ENV NODE_OPTIONS="" +ENV PIP_ROOT_USER_ACTION=ignore +ENV HF_HUB_ENABLE_HF_TRANSFER="1" + +# Install torch-2.6.0+cu126 + vllm-0.8.5.post1 +# torch-2.6.0+cu124: cxx11abi=False +# torch-2.6.0+cu126: cxx11abi=True +# see https://github.com/flashinfer-ai/flashinfer/issues/911 +RUN pip install --resume-retries 999 --no-cache-dir vllm==0.8.5.post1 + +# Install flashinfer-0.2.2.post1+cu126 (cxx11abi=True) +# vllm-0.8.3 does not support flashinfer>=0.2.3 +# see https://github.com/vllm-project/vllm/pull/15777 +RUN aria2c --max-tries=9999 https://github.com/flashinfer-ai/flashinfer/releases/download/v0.2.2.post1/flashinfer_python-0.2.2.post1+cu124torch2.6-cp38-abi3-linux_x86_64.whl && \ + pip install --no-cache-dir flashinfer_python-0.2.2.post1+cu124torch2.6-cp38-abi3-linux_x86_64.whl && \ + rm flashinfer_python-0.2.2.post1+cu124torch2.6-cp38-abi3-linux_x86_64.whl + +# Fix packages +RUN pip install --no-cache-dir "tensordict==0.6.2" "transformers[hf_xet]>=4.51.0" accelerate datasets peft hf-transfer \ + "numpy<2.0.0" "pyarrow>=19.0.1" pandas \ + ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler blobfile xgrammar \ + pytest py-spy pyext pre-commit ruff + +RUN pip uninstall -y pynvml nvidia-ml-py && \ + pip install --resume-retries 999 --no-cache-dir --upgrade "nvidia-ml-py>=12.560.30" "fastapi[standard]>=0.115.0" "optree>=0.13.0" "pydantic>=2.9" "grpcio>=1.62.1" + +RUN pip install --resume-retries 999 --no-cache-dir nvidia-cudnn-cu12==9.8.0.87 + +# Install TransformerEngine +RUN export NVTE_FRAMEWORK=pytorch && pip3 install --resume-retries 999 --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@v2.3 + +# Install Megatron-LM +RUN pip3 install --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.12.1 + +# Fix for transformers 4.53.0 +RUN pip3 install --no-cache-dir "transformers[hf_xet]<4.52.0" + +# Install mbridge +RUN pip3 install --no-cache-dir mbridge \ No newline at end of file diff --git a/docker/verl0.4-cu124-torch2.6-fa2.7.4/Dockerfile.app.vllm.mcore0.12.deepep b/docker/verl0.4-cu124-torch2.6-fa2.7.4/Dockerfile.app.vllm.mcore0.12.deepep new file mode 100644 index 000000000..61b4fdc6a --- /dev/null +++ b/docker/verl0.4-cu124-torch2.6-fa2.7.4/Dockerfile.app.vllm.mcore0.12.deepep @@ -0,0 +1,88 @@ +# Start from the verl base image +# Dockerfile.base +FROM verlai/verl:base-verl0.4-cu124-cudnn9.8-torch2.6-fa2.7.4 + +# Define environments +ENV MAX_JOBS=32 +ENV VLLM_WORKER_MULTIPROC_METHOD=spawn +ENV DEBIAN_FRONTEND=noninteractive +ENV NODE_OPTIONS="" +ENV PIP_ROOT_USER_ACTION=ignore +ENV HF_HUB_ENABLE_HF_TRANSFER="1" + +# Install torch-2.6.0+cu126 + vllm-0.8.5.post1 +# torch-2.6.0+cu124: cxx11abi=False +# torch-2.6.0+cu126: cxx11abi=True +# see https://github.com/flashinfer-ai/flashinfer/issues/911 +RUN pip install --resume-retries 999 --no-cache-dir vllm==0.8.5.post1 + +# Install flashinfer-0.2.2.post1+cu126 (cxx11abi=True) +# vllm-0.8.3 does not support flashinfer>=0.2.3 +# see https://github.com/vllm-project/vllm/pull/15777 +RUN aria2c --max-tries=9999 https://github.com/flashinfer-ai/flashinfer/releases/download/v0.2.2.post1/flashinfer_python-0.2.2.post1+cu124torch2.6-cp38-abi3-linux_x86_64.whl && \ + pip install --no-cache-dir flashinfer_python-0.2.2.post1+cu124torch2.6-cp38-abi3-linux_x86_64.whl && \ + rm flashinfer_python-0.2.2.post1+cu124torch2.6-cp38-abi3-linux_x86_64.whl + +# Fix packages +RUN pip install --no-cache-dir "tensordict==0.6.2" "transformers[hf_xet]>=4.51.0" accelerate datasets peft hf-transfer \ + "numpy<2.0.0" "pyarrow>=19.0.1" pandas \ + ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler blobfile xgrammar \ + pytest py-spy pyext pre-commit ruff + +RUN pip uninstall -y pynvml nvidia-ml-py && \ + pip install --resume-retries 999 --no-cache-dir --upgrade "nvidia-ml-py>=12.560.30" "fastapi[standard]>=0.115.0" "optree>=0.13.0" "pydantic>=2.9" "grpcio>=1.62.1" + +RUN pip install --resume-retries 999 --no-cache-dir nvidia-cudnn-cu12==9.8.0.87 + +# Install TransformerEngine +RUN export NVTE_FRAMEWORK=pytorch && pip3 install --resume-retries 999 --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@v2.3 + +# Install Megatron-LM +RUN pip3 install --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.12.1 + +# Fix for transformers 4.53.0 +RUN pip3 install --no-cache-dir "transformers[hf_xet]<4.52.0" + +# Install mbridge +RUN pip3 install --no-cache-dir mbridge + +# Install DeepEP +## the dependency of IBGDA +RUN ln -s /usr/lib/x86_64-linux-gnu/libmlx5.so.1 /usr/lib/x86_64-linux-gnu/libmlx5.so + +## Clone and build deepep and deepep-nvshmem +RUN git clone -b v2.3.1 https://github.com/NVIDIA/gdrcopy.git && \ + git clone https://github.com/deepseek-ai/DeepEP.git && \ + cd DeepEP && git checkout a84a248 + +# Prepare nvshmem +RUN wget https://developer.nvidia.com/downloads/assets/secure/nvshmem/nvshmem_src_3.2.5-1.txz && \ + tar -xvf nvshmem_src_3.2.5-1.txz && mv nvshmem_src deepep-nvshmem && \ + cd deepep-nvshmem && git apply ../DeepEP/third-party/nvshmem.patch + +ENV CUDA_HOME=/usr/local/cuda +### Set MPI environment variables. Having errors when not set. +ENV CPATH=/usr/local/mpi/include:$CPATH +ENV LD_LIBRARY_PATH=/usr/local/mpi/lib:$LD_LIBRARY_PATH +ENV LD_LIBRARY_PATH=/usr/local/x86_64-linux-gnu:$LD_LIBRARY_PATH +ENV GDRCOPY_HOME=/workspace/gdrcopy + +## Build deepep-nvshmem +RUN cd deepep-nvshmem && \ + NVSHMEM_SHMEM_SUPPORT=0 \ + NVSHMEM_UCX_SUPPORT=0 \ + NVSHMEM_USE_NCCL=0 \ + NVSHMEM_MPI_SUPPORT=0 \ + NVSHMEM_IBGDA_SUPPORT=1 \ + NVSHMEM_PMIX_SUPPORT=0 \ + NVSHMEM_TIMEOUT_DEVICE_POLLING=0 \ + NVSHMEM_USE_GDRCOPY=1 \ + cmake -G Ninja -S . -B build/ -DCMAKE_INSTALL_PREFIX=/workspace/deepep-nvshmem/install && cmake --build build/ --target install + +ENV NVSHMEM_DIR=/workspace/deepep-nvshmem/install +ENV LD_LIBRARY_PATH=$NVSHMEM_DIR/lib:$LD_LIBRARY_PATH +ENV PATH=$NVSHMEM_DIR/bin:$PATH + +## Build deepep +RUN cd DeepEP && \ + python setup.py install \ No newline at end of file diff --git a/docker/verl0.4-cu124-torch2.6-fa2.7.4/Dockerfile.app.vllm.mcore0.13.preview b/docker/verl0.4-cu124-torch2.6-fa2.7.4/Dockerfile.app.vllm.mcore0.13.preview new file mode 100644 index 000000000..1fba3fa86 --- /dev/null +++ b/docker/verl0.4-cu124-torch2.6-fa2.7.4/Dockerfile.app.vllm.mcore0.13.preview @@ -0,0 +1,85 @@ +# Start from the verl base image +# Dockerfile.base +FROM verlai/verl:base-verl0.4-cu124-cudnn9.8-torch2.6-fa2.7.4 + +# Define environments +ENV MAX_JOBS=32 +ENV VLLM_WORKER_MULTIPROC_METHOD=spawn +ENV DEBIAN_FRONTEND=noninteractive +ENV NODE_OPTIONS="" +ENV PIP_ROOT_USER_ACTION=ignore +ENV HF_HUB_ENABLE_HF_TRANSFER="1" + +# Install torch-2.6.0+cu126 + vllm-0.8.5.post1 +# torch-2.6.0+cu124: cxx11abi=False +# torch-2.6.0+cu126: cxx11abi=True +# see https://github.com/flashinfer-ai/flashinfer/issues/911 +RUN pip install --resume-retries 999 --no-cache-dir vllm==0.8.5.post1 + +# Install flashinfer-0.2.2.post1+cu126 (cxx11abi=True) +# vllm-0.8.3 does not support flashinfer>=0.2.3 +# see https://github.com/vllm-project/vllm/pull/15777 +RUN aria2c --max-tries=9999 https://github.com/flashinfer-ai/flashinfer/releases/download/v0.2.2.post1/flashinfer_python-0.2.2.post1+cu124torch2.6-cp38-abi3-linux_x86_64.whl && \ + pip install --no-cache-dir flashinfer_python-0.2.2.post1+cu124torch2.6-cp38-abi3-linux_x86_64.whl && \ + rm flashinfer_python-0.2.2.post1+cu124torch2.6-cp38-abi3-linux_x86_64.whl + +# Fix packages +RUN pip install --no-cache-dir "tensordict==0.6.2" "transformers[hf_xet]>=4.51.0" accelerate datasets peft hf-transfer \ + "numpy<2.0.0" "pyarrow>=19.0.1" pandas \ + ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler blobfile xgrammar \ + pytest py-spy pyext pre-commit ruff + +RUN pip uninstall -y pynvml nvidia-ml-py && \ + pip install --resume-retries 999 --no-cache-dir --upgrade "nvidia-ml-py>=12.560.30" "fastapi[standard]>=0.115.0" "optree>=0.13.0" "pydantic>=2.9" "grpcio>=1.62.1" + +RUN pip install --resume-retries 999 --no-cache-dir nvidia-cudnn-cu12==9.8.0.87 + +# Install TransformerEngine +RUN export NVTE_FRAMEWORK=pytorch && pip3 install --resume-retries 999 --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@release_v2.5 + +# Install Megatron-LM +RUN pip3 install --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.12.1 + +# Install mbridge +RUN pip3 install --no-cache-dir mbridge + +# Install DeepEP +## the dependency of IBGDA +RUN ln -s /usr/lib/x86_64-linux-gnu/libmlx5.so.1 /usr/lib/x86_64-linux-gnu/libmlx5.so + +## Clone and build deepep and deepep-nvshmem +RUN git clone -b v2.3.1 https://github.com/NVIDIA/gdrcopy.git && \ + git clone https://github.com/deepseek-ai/DeepEP.git && \ + cd DeepEP && git checkout a84a248 + +# Prepare nvshmem +RUN wget https://developer.nvidia.com/downloads/assets/secure/nvshmem/nvshmem_src_3.2.5-1.txz && \ + tar -xvf nvshmem_src_3.2.5-1.txz && mv nvshmem_src deepep-nvshmem && \ + cd deepep-nvshmem && git apply ../DeepEP/third-party/nvshmem.patch + +ENV CUDA_HOME=/usr/local/cuda +### Set MPI environment variables. Having errors when not set. +ENV CPATH=/usr/local/mpi/include:$CPATH +ENV LD_LIBRARY_PATH=/usr/local/mpi/lib:$LD_LIBRARY_PATH +ENV LD_LIBRARY_PATH=/usr/local/x86_64-linux-gnu:$LD_LIBRARY_PATH +ENV GDRCOPY_HOME=/workspace/gdrcopy + +## Build deepep-nvshmem +RUN cd deepep-nvshmem && \ + NVSHMEM_SHMEM_SUPPORT=0 \ + NVSHMEM_UCX_SUPPORT=0 \ + NVSHMEM_USE_NCCL=0 \ + NVSHMEM_MPI_SUPPORT=0 \ + NVSHMEM_IBGDA_SUPPORT=1 \ + NVSHMEM_PMIX_SUPPORT=0 \ + NVSHMEM_TIMEOUT_DEVICE_POLLING=0 \ + NVSHMEM_USE_GDRCOPY=1 \ + cmake -G Ninja -S . -B build/ -DCMAKE_INSTALL_PREFIX=/workspace/deepep-nvshmem/install && cmake --build build/ --target install + +ENV NVSHMEM_DIR=/workspace/deepep-nvshmem/install +ENV LD_LIBRARY_PATH=$NVSHMEM_DIR/lib:$LD_LIBRARY_PATH +ENV PATH=$NVSHMEM_DIR/bin:$PATH + +## Build deepep +RUN cd DeepEP && \ + python setup.py install \ No newline at end of file diff --git a/docker/verl0.4-cu124-torch2.6-fa2.7.4/Dockerfile.base b/docker/verl0.4-cu124-torch2.6-fa2.7.4/Dockerfile.base new file mode 100644 index 000000000..25b1d9431 --- /dev/null +++ b/docker/verl0.4-cu124-torch2.6-fa2.7.4/Dockerfile.base @@ -0,0 +1,113 @@ +# Base Docker Image of verl, with CUDA/Torch/FlashAttn/Apex/TransformerEngine, without other frameworks +# Target: verlai/verl:base-v2-cu124-cudnn9.8-torch2.6-fa2.8.0-te2.3 +# Start from the NVIDIA official image (ubuntu-22.04 + cuda-12.6 + python-3.10) +# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-08.html +FROM nvcr.io/nvidia/pytorch:24.08-py3 + +# Define environments +ENV MAX_JOBS=16 +ENV VLLM_WORKER_MULTIPROC_METHOD=spawn +ENV DEBIAN_FRONTEND=noninteractive +ENV NODE_OPTIONS="" +ENV PIP_ROOT_USER_ACTION=ignore +ENV HF_HUB_ENABLE_HF_TRANSFER="1" + +# Define installation arguments +ARG APT_SOURCE=https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ +ARG PIP_INDEX=https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple + +# Set apt source +RUN cp /etc/apt/sources.list /etc/apt/sources.list.bak && \ + { \ + echo "deb ${APT_SOURCE} jammy main restricted universe multiverse"; \ + echo "deb ${APT_SOURCE} jammy-updates main restricted universe multiverse"; \ + echo "deb ${APT_SOURCE} jammy-backports main restricted universe multiverse"; \ + echo "deb ${APT_SOURCE} jammy-security main restricted universe multiverse"; \ + } > /etc/apt/sources.list + +# Install systemctl +RUN apt-get update && \ + apt-get install -y -o Dpkg::Options::="--force-confdef" systemd && \ + apt-get clean + +# Install tini +RUN apt-get update && \ + apt-get install -y tini aria2 && \ + apt-get clean + +# Change pip source +RUN pip config set global.index-url "${PIP_INDEX}" && \ + pip config set global.extra-index-url "${PIP_INDEX}" && \ + python -m pip install --upgrade pip + +# Uninstall nv-pytorch fork +RUN pip uninstall -y torch torchvision torchaudio \ + pytorch-quantization pytorch-triton torch-tensorrt \ + xgboost transformer_engine flash_attn apex megatron-core grpcio + +# Reinstall CUDA 12.4 +RUN aria2c https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-ubuntu2204.pin && \ + mv cuda-ubuntu2204.pin /etc/apt/preferences.d/cuda-repository-pin-600 + +RUN aria2c --always-resume=true --max-tries=99999 https://developer.download.nvidia.com/compute/cuda/12.4.1/local_installers/cuda-repo-ubuntu2204-12-4-local_12.4.1-550.54.15-1_amd64.deb && \ + dpkg -i cuda-repo-ubuntu2204-12-4-local_12.4.1-550.54.15-1_amd64.deb && \ + cp /var/cuda-repo-ubuntu2204-12-4-local/cuda-*-keyring.gpg /usr/share/keyrings/ && \ + apt-get update && \ + apt-get -y install cuda-toolkit-12-4 && \ + rm cuda-repo-ubuntu2204-12-4-local_12.4.1-550.54.15-1_amd64.deb && \ + update-alternatives --set cuda /usr/local/cuda-12.4 && \ + rm -rf /usr/local/cuda-12.6 + +RUN pip install --resume-retries 999 --no-cache-dir torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 + +RUN pip install --resume-retries 999 --no-cache-dir "tensordict==0.6.2" torchdata "transformers[hf_xet]>=4.51.0" accelerate datasets peft hf-transfer \ + "numpy<2.0.0" "pyarrow>=19.0.1" pandas \ + ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler blobfile xgrammar \ + pytest py-spy pyext pre-commit ruff + +# Install flash-attn-2.7.4.post1 (cxx11abi=False) +RUN wget -nv https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl && \ + pip install --no-cache-dir flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl + +# Fix packages +RUN pip uninstall -y pynvml nvidia-ml-py && \ + pip install --no-cache-dir --upgrade "nvidia-ml-py>=12.560.30" "fastapi[standard]>=0.115.0" "optree>=0.13.0" "pydantic>=2.9" "grpcio>=1.62.1" + +# Install cudnn +RUN aria2c --max-tries=9999 https://developer.download.nvidia.com/compute/cudnn/9.8.0/local_installers/cudnn-local-repo-ubuntu2204-9.8.0_1.0-1_amd64.deb && \ + dpkg -i cudnn-local-repo-ubuntu2204-9.8.0_1.0-1_amd64.deb && \ + cp /var/cudnn-local-repo-ubuntu2204-9.8.0/cudnn-*-keyring.gpg /usr/share/keyrings/ && \ + apt-get update && \ + apt-get -y install cudnn-cuda-12 && \ + rm cudnn-local-repo-ubuntu2204-9.8.0_1.0-1_amd64.deb + +# Install Apex +RUN git clone https://github.com/NVIDIA/apex.git && \ + cd apex && \ + pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./ + +# Profiling tools +RUN aria2c --always-resume=true --max-tries=99999 https://developer.nvidia.com/downloads/assets/tools/secure/nsight-systems/2025_3/nsight-systems-2025.3.1_2025.3.1.90-1_amd64.deb && \ + apt-get update && apt-get install -y libxcb-cursor0 && \ + dpkg -i ./nsight-systems-2025.3.1_2025.3.1.90-1_amd64.deb && \ + rm -rf /usr/local/cuda/bin/nsys && \ + ln -s /opt/nvidia/nsight-systems/2025.3.1/target-linux-x64/nsys /usr/local/cuda/bin/nsys && \ + rm -rf /usr/local/cuda/bin/nsys-ui && \ + ln -s /opt/nvidia/nsight-systems/2025.3.1/target-linux-x64/nsys-ui /usr/local/cuda/bin/nsys-ui && \ + rm nsight-systems-2025.3.1_2025.3.1.90-1_amd64.deb + +# Fix opencv +RUN pip install --resume-retries 999 --no-cache-dir opencv-python + +RUN pip install --resume-retries 999 --no-cache-dir opencv-fixer && \ + python -c "from opencv_fixer import AutoFix; AutoFix()" + +RUN pip install --resume-retries 999 --no-cache-dir cuda-bindings + +# Reset pip config +RUN pip config unset global.index-url && \ + pip config unset global.extra-index-url + +RUN apt-get update && \ + apt-get install -y libfreeimage3 libfreeimage-dev zlib1g htop + diff --git a/docker/verl0.4-cu124-torch2.6-fa2.7.4/README.md b/docker/verl0.4-cu124-torch2.6-fa2.7.4/README.md new file mode 100644 index 000000000..6f77fee6a --- /dev/null +++ b/docker/verl0.4-cu124-torch2.6-fa2.7.4/README.md @@ -0,0 +1,31 @@ +# verl image with verl v0.4.x + +## Important packages version + +```txt +cuda==12.4 +cudnn==9.8.0 +torch==2.6.0 +flash_attn=2.7.4 +sglang==0.4.6.post5 +vllm==0.8.5.post1 +vidia-cudnn-cu12==9.8.0.87 +transformer_engine==2.3 +megatron.core==core_v0.12.1 +# Preview +transformer_engine==2.5 +megatron.core==core_r0.13.0 +``` + +## Target + +- Base image: + - `verlai/verl:base-verl0.4-cu124-cudnn9.8-torch2.6-fa2.7.4` +- App image: + - `verlai/verl:app-verl0.4-sglang0.4.6.post5-vllm0.8.5-mcore0.12.1`: SGLang requires vLLM in 0.4.6.post5 version, vLLM can have some package conflicts with SGLang + - `verlai/verl:app-verl0.4-sglang0.4.6.post5-vllm0.8.5-mcore0.12.1-deepep`: Built with deepep + - `verlai/verl:app-verl0.4-vllm0.8.5-mcore0.12.1` + - `verlai/verl:app-verl0.4-vllm0.8.5-mcore0.12.1-deepep`: Built with deepep +- Preview image: + - `verlai/verl:app-verl0.4-sglang0.4.6.post5-vllm0.8.5-mcore0.13.0-preview` + - `verlai/verl:app-verl0.4-vllm0.8.5-mcore0.13.0-preview` \ No newline at end of file diff --git a/docker/verl0.5-cu126-torch2.7.1-fa2.8.0/Dockerfile.app.sglang.mcore0.12 b/docker/verl0.5-cu126-torch2.7.1-fa2.8.0/Dockerfile.app.sglang.mcore0.12 new file mode 100644 index 000000000..07435b31c --- /dev/null +++ b/docker/verl0.5-cu126-torch2.7.1-fa2.8.0/Dockerfile.app.sglang.mcore0.12 @@ -0,0 +1,37 @@ +# Start from the verl base image +# Dockerfile.base +FROM verlai/verl:base-verl0.5-cu126-cudnn9.8-torch2.7.1-fa2.8.0 + +# Define environments +ENV MAX_JOBS=8 +ENV VLLM_WORKER_MULTIPROC_METHOD=spawn +ENV DEBIAN_FRONTEND=noninteractive +ENV NODE_OPTIONS="" +ENV PIP_ROOT_USER_ACTION=ignore +ENV HF_HUB_ENABLE_HF_TRANSFER="1" + +# Install sglang-0.4.8 and torch-memory-saver +# Install FlashInfer Python package +RUN pip install --upgrade pip setuptools packaging +RUN pip install --resume-retries 999 --no-cache-dir --no-build-isolation flashinfer-python==0.2.6.post1 +RUN pip install --resume-retries 999 --no-cache-dir "sglang[all]==0.4.8" && pip install torch-memory-saver --no-cache-dir + +# Fix packages +RUN pip install --no-cache-dir "tensordict==0.6.2" "transformers[hf_xet]>=4.51.0" accelerate datasets peft hf-transfer \ + "numpy<2.0.0" "pyarrow>=19.0.1" pandas \ + ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler blobfile xgrammar \ + pytest py-spy pyext pre-commit ruff + +RUN pip uninstall -y pynvml nvidia-ml-py && \ + pip install --resume-retries 999 --no-cache-dir --upgrade "nvidia-ml-py>=12.560.30" "fastapi[standard]>=0.115.0" "optree>=0.13.0" "pydantic>=2.9" "grpcio>=1.62.1" + +RUN pip install --resume-retries 999 --no-cache-dir nvidia-cudnn-cu12==9.8.0.87 + +# Install TransformerEngine +RUN export NVTE_FRAMEWORK=pytorch && pip3 install --resume-retries 999 --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@v2.3 + +# Install Megatron-LM +RUN pip3 install --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.12.1 + +# Install mbridge +RUN pip3 install --no-cache-dir mbridge \ No newline at end of file diff --git a/docker/verl0.5-cu126-torch2.7.1-fa2.8.0/Dockerfile.app.sglang.mcore0.13.preview b/docker/verl0.5-cu126-torch2.7.1-fa2.8.0/Dockerfile.app.sglang.mcore0.13.preview new file mode 100644 index 000000000..24b831508 --- /dev/null +++ b/docker/verl0.5-cu126-torch2.7.1-fa2.8.0/Dockerfile.app.sglang.mcore0.13.preview @@ -0,0 +1,37 @@ +# Start from the verl base image +# Dockerfile.base +FROM verlai/verl:base-verl0.5-cu126-cudnn9.8-torch2.7.1-fa2.8.0 + +# Define environments +ENV MAX_JOBS=8 +ENV VLLM_WORKER_MULTIPROC_METHOD=spawn +ENV DEBIAN_FRONTEND=noninteractive +ENV NODE_OPTIONS="" +ENV PIP_ROOT_USER_ACTION=ignore +ENV HF_HUB_ENABLE_HF_TRANSFER="1" + +# Install sglang-0.4.8 and torch-memory-saver +# Install FlashInfer Python package +RUN pip install --upgrade pip setuptools packaging +RUN pip install --resume-retries 999 --no-cache-dir --no-build-isolation flashinfer-python==0.2.6.post1 +RUN pip install --resume-retries 999 --no-cache-dir "sglang[all]==0.4.8" && pip install torch-memory-saver --no-cache-dir + +# Fix packages +RUN pip install --no-cache-dir "tensordict==0.6.2" "transformers[hf_xet]>=4.51.0" accelerate datasets peft hf-transfer \ + "numpy<2.0.0" "pyarrow>=19.0.1" pandas \ + ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler blobfile xgrammar \ + pytest py-spy pyext pre-commit ruff + +RUN pip uninstall -y pynvml nvidia-ml-py && \ + pip install --resume-retries 999 --no-cache-dir --upgrade "nvidia-ml-py>=12.560.30" "fastapi[standard]>=0.115.0" "optree>=0.13.0" "pydantic>=2.9" "grpcio>=1.62.1" + +RUN pip install --resume-retries 999 --no-cache-dir nvidia-cudnn-cu12==9.8.0.87 + +# Install TransformerEngine +RUN export NVTE_FRAMEWORK=pytorch && pip3 install --resume-retries 999 --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@release_v2.5 + +# Install Megatron-LM +RUN pip3 install --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.12.1 + +# Install mbridge +RUN pip3 install --no-cache-dir mbridge \ No newline at end of file diff --git a/docker/verl0.5-cu126-torch2.7.1-fa2.8.0/Dockerfile.base b/docker/verl0.5-cu126-torch2.7.1-fa2.8.0/Dockerfile.base new file mode 100644 index 000000000..915834a0d --- /dev/null +++ b/docker/verl0.5-cu126-torch2.7.1-fa2.8.0/Dockerfile.base @@ -0,0 +1,132 @@ +# Base Docker Image of verl, with CUDA/Torch/FlashAttn/Apex/TransformerEngine, without other frameworks +# Target: verlai/verl:base-verl0.5-cu126-cudnn9.8-torch2.7.1-fa2.8.0-fi0.2.6 +# Start from the NVIDIA official image (ubuntu-22.04 + cuda-12.6 + python-3.10) +# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-08.html +FROM nvcr.io/nvidia/pytorch:24.08-py3 + +# Define environments +ENV MAX_JOBS=16 +ENV VLLM_WORKER_MULTIPROC_METHOD=spawn +ENV DEBIAN_FRONTEND=noninteractive +ENV NODE_OPTIONS="" +ENV PIP_ROOT_USER_ACTION=ignore +ENV HF_HUB_ENABLE_HF_TRANSFER="1" + +# Define installation arguments +ARG APT_SOURCE=https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ +ARG PIP_INDEX=https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple + +# Set apt source +RUN cp /etc/apt/sources.list /etc/apt/sources.list.bak && \ + { \ + echo "deb ${APT_SOURCE} jammy main restricted universe multiverse"; \ + echo "deb ${APT_SOURCE} jammy-updates main restricted universe multiverse"; \ + echo "deb ${APT_SOURCE} jammy-backports main restricted universe multiverse"; \ + echo "deb ${APT_SOURCE} jammy-security main restricted universe multiverse"; \ + } > /etc/apt/sources.list + +# Install systemctl +RUN apt-get update && \ + apt-get install -y -o Dpkg::Options::="--force-confdef" systemd && \ + apt-get clean + +# Install tini +RUN apt-get update && \ + apt-get install -y tini aria2 libfreeimage3 libfreeimage-dev zlib1g htop && \ + apt-get clean + +# Change pip source +RUN pip config set global.index-url "${PIP_INDEX}" && \ + pip config set global.extra-index-url "${PIP_INDEX}" && \ + python -m pip install --upgrade pip + +# Uninstall nv-pytorch fork +RUN pip uninstall -y torch torchvision torchaudio \ + pytorch-quantization pytorch-triton torch-tensorrt \ + xgboost transformer_engine flash_attn apex megatron-core grpcio + +RUN pip install --resume-retries 999 --no-cache-dir torch==2.7.1 torchvision==0.22.1 torchaudio==2.7.1 + +# Install flash-attn-2.8.0.post2 (cxx11abi=True) +RUN ABI_FLAG=$(python -c "import torch; print('TRUE' if torch._C._GLIBCXX_USE_CXX11_ABI else 'FALSE')") && \ + URL="https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.0.post2/flash_attn-2.8.0.post2+cu12torch2.7cxx11abi${ABI_FLAG}-cp310-cp310-linux_x86_64.whl" && \ + FILE="flash_attn-2.8.0.post2+cu12torch2.7cxx11abi${ABI_FLAG}-cp310-cp310-linux_x86_64.whl" && \ + wget -nv "${URL}" && \ + pip install --no-cache-dir "${FILE}" + +# Fix packages +RUN pip uninstall -y pynvml nvidia-ml-py && \ + pip install --no-cache-dir --upgrade "nvidia-ml-py>=12.560.30" "fastapi[standard]>=0.115.0" "optree>=0.13.0" "pydantic>=2.9" "grpcio>=1.62.1" + +# Install cudnn +RUN aria2c --max-tries=9999 https://developer.download.nvidia.com/compute/cudnn/9.8.0/local_installers/cudnn-local-repo-ubuntu2204-9.8.0_1.0-1_amd64.deb && \ + dpkg -i cudnn-local-repo-ubuntu2204-9.8.0_1.0-1_amd64.deb && \ + cp /var/cudnn-local-repo-ubuntu2204-9.8.0/cudnn-*-keyring.gpg /usr/share/keyrings/ && \ + apt-get update && \ + apt-get -y install cudnn-cuda-12 && \ + rm cudnn-local-repo-ubuntu2204-9.8.0_1.0-1_amd64.deb + +# Install Apex +RUN pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" --resume-retries 999 git+https://github.com/NVIDIA/apex.git + +# Profiling tools +RUN aria2c --always-resume=true --max-tries=99999 https://developer.nvidia.com/downloads/assets/tools/secure/nsight-systems/2025_3/nsight-systems-2025.3.1_2025.3.1.90-1_amd64.deb && \ + apt-get update && apt-get install -y libxcb-cursor0 + +RUN apt-get install -y ./nsight-systems-2025.3.1_2025.3.1.90-1_amd64.deb && \ + rm -rf /usr/local/cuda/bin/nsys && \ + ln -s /opt/nvidia/nsight-systems/2025.3.1/target-linux-x64/nsys /usr/local/cuda/bin/nsys && \ + rm -rf /usr/local/cuda/bin/nsys-ui && \ + ln -s /opt/nvidia/nsight-systems/2025.3.1/target-linux-x64/nsys-ui /usr/local/cuda/bin/nsys-ui && \ + rm nsight-systems-2025.3.1_2025.3.1.90-1_amd64.deb + +RUN pip install --resume-retries 999 --no-cache-dir "tensordict==0.6.2" torchdata "transformers[hf_xet]>=4.51.0" accelerate datasets peft hf-transfer \ + "numpy<2.0.0" "pyarrow>=19.0.1" pandas cuda-bindings \ + ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler blobfile xgrammar \ + pytest py-spy pyext pre-commit ruff + +# Install DeepEP +## the dependency of IBGDA +RUN ln -s /usr/lib/x86_64-linux-gnu/libmlx5.so.1 /usr/lib/x86_64-linux-gnu/libmlx5.so + +## Clone and build deepep and deepep-nvshmem +RUN git clone -b v2.3.1 https://github.com/NVIDIA/gdrcopy.git && \ + git clone https://github.com/deepseek-ai/DeepEP.git && \ + cd DeepEP && git checkout a84a248 + +# Prepare nvshmem +RUN wget https://developer.nvidia.com/downloads/assets/secure/nvshmem/nvshmem_src_3.2.5-1.txz && \ + tar -xvf nvshmem_src_3.2.5-1.txz && mv nvshmem_src deepep-nvshmem && \ + cd deepep-nvshmem && git apply ../DeepEP/third-party/nvshmem.patch + +ENV CUDA_HOME=/usr/local/cuda +### Set MPI environment variables. Having errors when not set. +ENV CPATH=/usr/local/mpi/include:$CPATH +ENV LD_LIBRARY_PATH=/usr/local/mpi/lib:$LD_LIBRARY_PATH +ENV LD_LIBRARY_PATH=/usr/local/x86_64-linux-gnu:$LD_LIBRARY_PATH +ENV GDRCOPY_HOME=/workspace/gdrcopy + +## Build deepep-nvshmem +RUN cd deepep-nvshmem && \ + NVSHMEM_SHMEM_SUPPORT=0 \ + NVSHMEM_UCX_SUPPORT=0 \ + NVSHMEM_USE_NCCL=0 \ + NVSHMEM_MPI_SUPPORT=0 \ + NVSHMEM_IBGDA_SUPPORT=1 \ + NVSHMEM_PMIX_SUPPORT=0 \ + NVSHMEM_TIMEOUT_DEVICE_POLLING=0 \ + NVSHMEM_USE_GDRCOPY=1 \ + cmake -G Ninja -S . -B build/ -DCMAKE_INSTALL_PREFIX=/workspace/deepep-nvshmem/install && cmake --build build/ --target install + +ENV NVSHMEM_DIR=/workspace/deepep-nvshmem/install +ENV LD_LIBRARY_PATH=$NVSHMEM_DIR/lib:$LD_LIBRARY_PATH +ENV PATH=$NVSHMEM_DIR/bin:$PATH + +## Build deepep +RUN cd DeepEP && \ + python setup.py install + +# Reset pip config +RUN pip config unset global.index-url && \ + pip config unset global.extra-index-url + diff --git a/docker/verl0.5-cu126-torch2.7.1-fa2.8.0/README.md b/docker/verl0.5-cu126-torch2.7.1-fa2.8.0/README.md new file mode 100644 index 000000000..c29a7f1f7 --- /dev/null +++ b/docker/verl0.5-cu126-torch2.7.1-fa2.8.0/README.md @@ -0,0 +1,27 @@ +# verl image with verl v0.5 + +## Important packages version + +```txt +cuda==12.6 +cudnn==9.8.0 +torch==2.7.1 +flash_attn=2.8.0 ## +sglang==0.4.8 +vllm==0.8.5.post1 +vidia-cudnn-cu12==9.8.0.87 +transformer_engine==2.3 +megatron.core==core_v0.12.1 +# Preview +transformer_engine==2.5 +megatron.core==core_r0.13.0 +``` + +## Target + +- Base image: + - `verlai/verl:base-verl0.5-cu126-cudnn9.8-torch2.7.1-fa2.8.0`: We offer a base image with deep ep built in +- App image: + - `verlai/verl:app-verl0.5-sglang0.4.9-mcore0.12.1` + - `verlai/verl:app-verl0.5-sglang0.4.9-mcore0.13.0-preview` +- vllm temporarily not support latest version \ No newline at end of file diff --git a/docker/verl0.5-preview-cu128-torch2.7.1-fa2.8.0/Dockerfile.app.sglang.megatron b/docker/verl0.5-preview-cu128-torch2.7.1-fa2.8.0/Dockerfile.app.sglang.megatron new file mode 100644 index 000000000..d41ea19d6 --- /dev/null +++ b/docker/verl0.5-preview-cu128-torch2.7.1-fa2.8.0/Dockerfile.app.sglang.megatron @@ -0,0 +1,36 @@ +# Start from the verl base image +# Dockerfile.base +FROM verlai/verl:base-verl0.5-preview-cu128-cudnn9.8-torch2.7.1-fa2.8.0-fi0.2.6 + +# Define environments +ENV MAX_JOBS=8 +ENV VLLM_WORKER_MULTIPROC_METHOD=spawn +ENV DEBIAN_FRONTEND=noninteractive +ENV NODE_OPTIONS="" +ENV PIP_ROOT_USER_ACTION=ignore +ENV HF_HUB_ENABLE_HF_TRANSFER="1" + +# Install sglang-0.4.8 and torch-memory-saver +# Install FlashInfer Python package +RUN pip install --resume-retries 999 --no-cache-dir --no-build-isolation flashinfer-python==0.2.6.post1 +RUN pip install --resume-retries 999 --no-cache-dir "sglang[all]==0.4.8" && pip install torch-memory-saver --no-cache-dir + +# Fix packages +RUN pip install --no-cache-dir "tensordict==0.6.2" "transformers[hf_xet]>=4.51.0" accelerate datasets peft hf-transfer \ + "numpy<2.0.0" "pyarrow>=19.0.1" pandas \ + ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler blobfile xgrammar \ + pytest py-spy pre-commit ruff + +RUN pip uninstall -y pynvml nvidia-ml-py && \ + pip install --resume-retries 999 --no-cache-dir --upgrade "nvidia-ml-py>=12.560.30" "fastapi[standard]>=0.115.0" "optree>=0.13.0" "pydantic>=2.9" "grpcio>=1.62.1" + +RUN pip install --resume-retries 999 --no-cache-dir nvidia-cudnn-cu12==9.8.0.87 + +# Install TransformerEngine +RUN export NVTE_FRAMEWORK=pytorch && pip3 install --resume-retries 999 --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@release_v2.5 + +# Install Megatron-LM +RUN pip3 install --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/Megatron-LM.git@core_r0.13.0 + +# Install mbridge +RUN pip3 install --no-cache-dir mbridge \ No newline at end of file diff --git a/docker/verl0.5-preview-cu128-torch2.7.1-fa2.8.0/Dockerfile.base b/docker/verl0.5-preview-cu128-torch2.7.1-fa2.8.0/Dockerfile.base new file mode 100644 index 000000000..29c49faa8 --- /dev/null +++ b/docker/verl0.5-preview-cu128-torch2.7.1-fa2.8.0/Dockerfile.base @@ -0,0 +1,91 @@ +# Base Docker Image of verl, with CUDA/Torch/FlashAttn/Apex/TransformerEngine, without other frameworks +# Target: verlai/verl:base-verl0.5-preview-cu128-cudnn9.8-torch2.7.1-fa2.8.0-fi0.2.6 +# Start from the NVIDIA official image (ubuntu-22.04 + cuda-12.6 + python-3.10) +# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-08.html +FROM nvcr.io/nvidia/pytorch:25.02-py3 + +# Define environments +ENV MAX_JOBS=16 +ENV VLLM_WORKER_MULTIPROC_METHOD=spawn +ENV DEBIAN_FRONTEND=noninteractive +ENV NODE_OPTIONS="" +ENV PIP_ROOT_USER_ACTION=ignore +ENV HF_HUB_ENABLE_HF_TRANSFER="1" + +# Define installation arguments +ARG APT_SOURCE=https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ +ARG PIP_INDEX=https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple + +# Set apt source +RUN cp /etc/apt/sources.list /etc/apt/sources.list.bak && \ + { \ + echo "deb ${APT_SOURCE} jammy main restricted universe multiverse"; \ + echo "deb ${APT_SOURCE} jammy-updates main restricted universe multiverse"; \ + echo "deb ${APT_SOURCE} jammy-backports main restricted universe multiverse"; \ + echo "deb ${APT_SOURCE} jammy-security main restricted universe multiverse"; \ + } > /etc/apt/sources.list + +# Install systemctl +RUN apt-get update && \ + apt-get install -y -o Dpkg::Options::="--force-confdef" systemd && \ + apt-get clean + +# Install tini +RUN apt-get update && \ + apt-get install -y tini aria2 libfreeimage3 libfreeimage-dev zlib1g htop && \ + apt-get clean + +# Change pip source +RUN pip config set global.index-url "${PIP_INDEX}" && \ + pip config set global.extra-index-url "${PIP_INDEX}" && \ + python -m pip install --upgrade pip + +# Uninstall nv-pytorch fork +RUN pip uninstall -y torch torchvision torchaudio \ + pytorch-quantization pytorch-triton torch-tensorrt \ + xgboost transformer_engine flash_attn apex megatron-core grpcio + +RUN pip install --resume-retries 999 --no-cache-dir torch==2.7.1 torchvision==0.22.1 torchaudio==2.7.1 --index-url https://download.pytorch.org/whl/cu128 + +# Install flash-attn-2.8.0.post2 (cxx11abi=True) +RUN ABI_FLAG=$(python -c "import torch; print('TRUE' if torch._C._GLIBCXX_USE_CXX11_ABI else 'FALSE')") && \ + URL="https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.0.post2/flash_attn-2.8.0.post2+cu12torch2.7cxx11abi${ABI_FLAG}-cp312-cp312-linux_x86_64.whl" && \ + FILE="flash_attn-2.8.0.post2+cu12torch2.7cxx11abi${ABI_FLAG}-cp312-cp312-linux_x86_64.whl" && \ + wget -nv "${URL}" && \ + pip install --no-cache-dir "${FILE}" + +# Fix packages +RUN pip uninstall -y pynvml nvidia-ml-py && \ + pip install --no-cache-dir --upgrade "nvidia-ml-py>=12.560.30" "fastapi[standard]>=0.115.0" "optree>=0.13.0" "pydantic>=2.9" "grpcio>=1.62.1" + +# Install cudnn +RUN aria2c --max-tries=9999 https://developer.download.nvidia.com/compute/cudnn/9.8.0/local_installers/cudnn-local-repo-ubuntu2204-9.8.0_1.0-1_amd64.deb && \ + dpkg -i cudnn-local-repo-ubuntu2204-9.8.0_1.0-1_amd64.deb && \ + cp /var/cudnn-local-repo-ubuntu2204-9.8.0/cudnn-*-keyring.gpg /usr/share/keyrings/ && \ + apt-get update && \ + apt-get -y install cudnn-cuda-12 && \ + rm cudnn-local-repo-ubuntu2204-9.8.0_1.0-1_amd64.deb + +# Install Apex +RUN pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" --resume-retries 999 git+https://github.com/NVIDIA/apex.git + +# Profiling tools +RUN aria2c --always-resume=true --max-tries=99999 https://developer.nvidia.com/downloads/assets/tools/secure/nsight-systems/2025_3/nsight-systems-2025.3.1_2025.3.1.90-1_amd64.deb && \ + apt-get update && apt-get install -y libxcb-cursor0 + +RUN apt-get install -y ./nsight-systems-2025.3.1_2025.3.1.90-1_amd64.deb && \ + rm -rf /usr/local/cuda/bin/nsys && \ + ln -s /opt/nvidia/nsight-systems/2025.3.1/target-linux-x64/nsys /usr/local/cuda/bin/nsys && \ + rm -rf /usr/local/cuda/bin/nsys-ui && \ + ln -s /opt/nvidia/nsight-systems/2025.3.1/target-linux-x64/nsys-ui /usr/local/cuda/bin/nsys-ui && \ + rm nsight-systems-2025.3.1_2025.3.1.90-1_amd64.deb + +RUN pip install --resume-retries 999 --no-cache-dir "tensordict==0.6.2" torchdata "transformers[hf_xet]>=4.51.0" accelerate datasets peft hf-transfer \ + "numpy<2.0.0" "pyarrow>=19.0.1" pandas cuda-bindings \ + ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler blobfile xgrammar \ + pytest py-spy pre-commit ruff + +# Reset pip config +RUN pip config unset global.index-url && \ + pip config unset global.extra-index-url + diff --git a/docker/verl0.5-preview-cu128-torch2.7.1-fa2.8.0/README.md b/docker/verl0.5-preview-cu128-torch2.7.1-fa2.8.0/README.md new file mode 100644 index 000000000..07d68977f --- /dev/null +++ b/docker/verl0.5-preview-cu128-torch2.7.1-fa2.8.0/README.md @@ -0,0 +1,26 @@ +# verl image with verl v0.5 + +## Important packages version + +```txt +cuda==12.8 +cudnn==9.8.0 +torch==2.7.1 +flash_attn=2.8.0 ## +sglang==0.4.8 +transformer_engine==2.5 +megatron.core==core_r0.13.0 +vidia-cudnn-cu12==9.8.0.87 +``` + +## Target + +- Base image: + - `verlai/verl:base-verl0.5-preview-cu128-cudnn9.8-torch2.7.1-fa2.8.0`: We offer a base image with flash infer 0.2.6.post1 built in +- App image: + - `verlai/verl:app-verl0.5-preview-sglang0.4.8-mcore0.13.0-preview` +- vllm temporarily not support latest version + +## !!!Notice!!! + +- pyext is lack of maintainace and cannot work with python 3.12, consider using replacement and deprecating this package. \ No newline at end of file diff --git a/docs/README.md b/docs/README.md index 2acaa9b12..8c5db0487 100644 --- a/docs/README.md +++ b/docs/README.md @@ -1,9 +1,12 @@ -# verl documents +# verl documentations ## Build the docs ```bash -# Install dependencies. +# If you want to view auto-generated API docstring, please make sure verl is available in python path. For instance, install verl via: +# pip install .. -e[test] + +# Install dependencies needed for building docs. pip install -r requirements-docs.txt # Build the docs. @@ -16,4 +19,4 @@ make html ```bash python -m http.server -d _build/html/ ``` -Launch your browser and navigate to http://localhost:8000 to view the documentation. \ No newline at end of file +Launch your browser and navigate to http://localhost:8000 to view the documentation. Alternatively you could drag the file `_build/html/index.html` to your local browser and view directly. diff --git a/docs/README_vllm0.7.md b/docs/README_vllm0.7.md index 4359bd0fe..e84feddd7 100644 --- a/docs/README_vllm0.7.md +++ b/docs/README_vllm0.7.md @@ -49,11 +49,11 @@ After installation, examples using FSDP as training backends can be used. By def ``` actor_rollout_ref.rollout.enforce_eager=False \ -actor_rollout_ref.rollout.free_cache_engine=False \ +actor_rollout_ref.rollout.free_cache_engine=True \ ``` -For a typical job like examples/ppo_trainer/run_qwen2-7b_seq_balance.sh, the rollout generation time is 115 seconds with vLLM0.6.3, while it is 85 seconds with vLLM0.7.0. By enabling the cudagraph, the generation duration is further reduced to 62 seconds. +For a typical job like examples/ppo_trainer/run_qwen2-7b_seq_balance.sh, the rollout generation time is 85 seconds with vLLM0.7.0. By enabling the cudagraph, the generation duration is further reduced to 62 seconds. **Note:** Currently, if the `n` is greater than 1 in `SamplingParams` in vLLM>=0.7, there is a potential performance issue on the stability of rollout generation time (Some iterations would see generation time bursts) using vLLM's V0 Engine. diff --git a/docs/README_vllm0.8.md b/docs/README_vllm0.8.md index dd40bbeff..d4f509f19 100644 --- a/docs/README_vllm0.8.md +++ b/docs/README_vllm0.8.md @@ -1,5 +1,7 @@ # Upgrading to vLLM >= 0.8 +Last updated: 05/04/2025. + ## Installation Note: This version of verl+vLLM 0.8+ supports **FSDP** for training and **vLLM** for rollout. @@ -34,16 +36,11 @@ vLLM 0.8+ supports cuda graph and V1 engine by default in verl. To enable these ```bash actor_rollout_ref.rollout.enforce_eager=False \ -actor_rollout_ref.rollout.free_cache_engine=False \ +actor_rollout_ref.rollout.free_cache_engine=True \ ``` and also **remove** the environment variable if it exists: -```bash -# If you are using vllm<=0.6.3, you might need to set the following environment variable to avoid bugs: -# export VLLM_ATTENTION_BACKEND=XFORMERS -``` - ## Notes When you just directly upgrade vllm>=0.8, some dependency packages may undergo version changes. If you encounter the following problems: diff --git a/docs/advance/checkpoint.rst b/docs/advance/checkpoint.rst index b9bebcf57..56bec4a75 100644 --- a/docs/advance/checkpoint.rst +++ b/docs/advance/checkpoint.rst @@ -1,6 +1,10 @@ +.. _checkpoint-page: + Using Checkpoints to Support Fault Tolerance Training ===================================================== +Last updated: 06/25/2025. + There could be training errors or machine failure during the whole RLHF training process, so it is recommended to enable checkpoints to minimize your loss. @@ -26,18 +30,20 @@ So the inner checkpoint structure of **FSDP** is like: checkpoints/${trainer.project_name}/${trainer.experiment_name} ├── global_steps_${i} │ ├── actor + │ │ ├── huggingface # default save config and tokenizer, save huggingface model if include ``hf_model`` in checkpoint.contents + │ │ └── fsdp_config.json # FSDP config file, including world_size and fsdp version │ │ ├── model_world_size_{self.world_size}_rank_{self.rank}.pt │ │ ├── optim_world_size_{self.world_size}_rank_{self.rank}.pt │ │ └── extra_state_world_size_{self.world_size}_rank_{self.rank}.pt - │ ├── actor_huggingface │ ├── critic + │ │ ├── huggingface + │ │ └── fsdp_config.json │ │ ├── model_world_size_{self.world_size}_rank_{self.rank}.pt │ │ ├── optim_world_size_{self.world_size}_rank_{self.rank}.pt │ │ └── extra_state_world_size_{self.world_size}_rank_{self.rank}.pt - │ └── critic_huggingface └── latest_checkpointed_iteration.txt -All model shards, optimizers and extra states are stored togather, in a sharded and distributed way. +All model shards, optimizers and extra states are stored together, in a sharded and distributed way. While **Megatron** current checkpoint structure is: @@ -46,35 +52,26 @@ While **Megatron** current checkpoint structure is: checkpoints/${trainer.project_name}/${trainer.experiment_name} ├── global_steps_${i} │ ├── actor - │ │ ├── huggingface # default save tokenizer, save huggingface model if include ``hf_mode`` in checkpoint.contents - │ │ ├── model # save sharded model, naming the same as Megatron - │ │ │ ├── mp_rank_xx_yyy # xx is tp_rank in 2 digits, yyy is pp_rank in 3 digits - │ │ │ │ └── model_states.pt - │ │ │ └── mp_rank_xx_xxx - │ │ ├── optim - │ │ │ └── distrib_optim_pp{a}_tp{b}_cp{c}_dp{d}.pt - │ │ └── rng_states + │ │ ├── huggingface # default save config and tokenizer, save huggingface model if include ``hf_mode`` in checkpoint.contents + │ │ └── dist_ckpt # save sharded model/optimizer/rng_states, naming the same as Megatron │ └── critic │ │ ├── huggingface - │ │ ├── model - │ │ ├── optim - │ │ └── rng_states + │ │ └── dist_ckpt └── latest_checkpointed_iteration.txt Convert FSDP and Megatron Checkpoints to HuggingFace Format Model ----------------------------------------------------------------- We provide a tool to convert the FSDP and Megatron checkpoints to HuggingFace format model. -The tool is located in ``scripts/model_merger.py``. +The tool is located in ``verl/model_merger``. For older versions of verl that don't include fsdp_config.json in checkpoints, you can use the legacy model merger located at ``verl/scripts/legacy_model_merger.py``. The script supports two main sub-commands: `merge` (to convert and save checkpoints) and `test` (to validate merged checkpoints against a reference model). The arguments for the `merge` sub-command are as follows: .. code:: bash - usage: model_merger.py merge [-h] --backend {fsdp,megatron} --local_dir LOCAL_DIR [--hf_model_path HF_MODEL_PATH] - [--tie-word-embedding] [--is-value-model] [--target_dir TARGET_DIR] - [--hf_upload_path HF_UPLOAD_PATH] [--private] + usage: python -m verl.model_merger merge [-h] --backend {fsdp,megatron} [--local_dir LOCAL_DIR] [--tie-word-embedding] [--is-value-model] [--use_cpu_initialization] [--target_dir TARGET_DIR] + [--hf_upload_path HF_UPLOAD_PATH] [--private] options: -h, --help show this help message and exit @@ -82,10 +79,10 @@ The arguments for the `merge` sub-command are as follows: The backend of the model --local_dir LOCAL_DIR Path to the saved model checkpoints - --hf_model_path HF_MODEL_PATH - (Deprecated) Path to the original Hugging Face model for config. --tie-word-embedding Whether to tie word embedding weights (currently only Megatron supported) --is-value-model Whether the model is a value model (currently only Megatron supported) + --use_cpu_initialization + Whether to use CPU initialization for the model. This is useful for large models that cannot fit into GPU memory during initialization. --target_dir TARGET_DIR Directory to save the merged huggingface model --hf_upload_path HF_UPLOAD_PATH @@ -96,7 +93,17 @@ Example usage for merging Megatron checkpoints: .. code:: bash - python scripts/model_merger.py merge \ + python -m verl.model_merger merge \ + --backend megatron \ + --tie-word-embedding \ + --local_dir checkpoints/verl_megatron_gsm8k_examples/qwen2_5_0b5_megatron_saveload/global_step_1/actor \ + --target_dir /path/to/merged_hf_model + +Example usage for distributed merging Megatron checkpoints: + +.. code:: bash + + torchrun --nproc_per_node 1 --nnodes 8 --node_rank ${RANK} -m verl.model_merger merge \ --backend megatron \ --tie-word-embedding \ --local_dir checkpoints/verl_megatron_gsm8k_examples/qwen2_5_0b5_megatron_saveload/global_step_1/actor \ @@ -106,7 +113,7 @@ Example usage for merging FSDP checkpoints: .. code:: bash - python scripts/model_merger.py merge \ + python -m verl.model_merger merge \ --backend fsdp \ --local_dir checkpoints/verl_fsdp_gsm8k_examples/qwen2_5_0b5_fsdp_saveload/global_step_1/actor \ --target_dir /path/to/merged_hf_model @@ -148,6 +155,15 @@ Example command to convert the model is as follows: --use_cpu_initialization # Only work for MoE models +Example command to distributed convert the huge model like deepseekv3 671B is as follows: + +.. code:: bash + + torchrun --nproc_per_node 1 --nnodes 8 --node_rank ${RANK} scripts/converter_hf_to_mcore.py \ + --hf_model_path deepseek-ai/DeepSeek-V3 \ + --output_path /mnt/disk/deepseek-ai/DeepSeek-V3 \ + --use_cpu_initialization # Only work for MoE models + Original Checkpoint Utils ------------------------- diff --git a/docs/advance/dpo_extension.rst b/docs/advance/dpo_extension.rst index 24833b69a..ee9ac619d 100644 --- a/docs/advance/dpo_extension.rst +++ b/docs/advance/dpo_extension.rst @@ -1,6 +1,8 @@ Extend to other RL(HF) algorithms ================================= +Last updated: 02/25/2025. + We already implemented the complete training pipeline of the PPO algorithms. To extend to other algorithms, we analyze the high-level principle to use verl and provide a tutorial to implement the DPO diff --git a/docs/advance/fsdp_extension.rst b/docs/advance/fsdp_extension.rst index bb77283fb..181e10908 100644 --- a/docs/advance/fsdp_extension.rst +++ b/docs/advance/fsdp_extension.rst @@ -2,6 +2,8 @@ Add models with the FSDP backend ================================== +Last updated: 02/09/2025. + Model -------------------------- diff --git a/docs/advance/megatron_extension.rst b/docs/advance/megatron_extension.rst index 9a9ea1ab4..9a52e6017 100644 --- a/docs/advance/megatron_extension.rst +++ b/docs/advance/megatron_extension.rst @@ -1,6 +1,8 @@ Add models with the Megatron-LM backend ========================================= +Last updated: 04/25/2025. + Model ----------- @@ -16,29 +18,3 @@ We list the steps here: 3. Use the right ``LayerSpec`` , ``TransformerConfig`` and ``HuggingfaceConfig`` as arguments to initialize the GPTModel. 4. Return the model at last. - - -Add Models with old version of verl ------------------------------------ - - -The most challenging aspect to use the Megatron-LM backend is implementing -the models for training. Currently, we implement Llama model that -support data parallelism, tensor parallelism, pipeline parallelism (also -vPP) and sequence parallelism. We also implement remove padding (sequence packing) on Llama -model, which can be found in `modeling_llama_megatron.py `_. - -To support other model, users are required to implement: - -1. Implemnt a model similar to ``modeling_llama_megatron.py`` that satisfy the - parallelism requirements of Megatron-LM. Then register your model in - the `registry.py `_. -2. Checkpoint utils that can load full checkpoint (e.g. huggingface - checkpoint) to partitioned models during the runtime. Then register - your loader to ``weight_loader_registry`` in `weight_loader_registry.py `_. -3. Weight loader that synchronize the weight from Megatron to rollout - (vLLM) model. Note that both the actor model and rollout model are - partitioned during runtime. So, it's advisable to map the model name - in actor model implementation. Otherwise, you may need an additional - name mapping and even weight transformation. The weight loader implementation - is in `megatron_weight_loaders.py `_. \ No newline at end of file diff --git a/docs/advance/placement.rst b/docs/advance/placement.rst index a98caa116..43ba761f7 100644 --- a/docs/advance/placement.rst +++ b/docs/advance/placement.rst @@ -1,6 +1,8 @@ Ray API Design Tutorial ======================================= +Last updated: 10/30/2024. + We provide a tutorial for our Ray API design, including: - Ray basic concepts diff --git a/docs/advance/ppo_lora.rst b/docs/advance/ppo_lora.rst index b95d0dace..baf3ab90a 100644 --- a/docs/advance/ppo_lora.rst +++ b/docs/advance/ppo_lora.rst @@ -1,6 +1,8 @@ RL(HF) algorithms with LoRA Support =========================================== +Last updated: 06/05/2025. + We support LoRA (Low-Rank Adaptation) for reinforcement learning algorithms such as PPO, GRPO, and others. LoRA is a parameter-efficient fine-tuning technique that injects trainable low-rank matrices into pre-trained weights (typically linear layers). This reduces memory footprint and compute cost, making it possible to fine-tune large models with limited hardware. diff --git a/docs/advance/rollout_trace.rst b/docs/advance/rollout_trace.rst new file mode 100644 index 000000000..ea203bbc0 --- /dev/null +++ b/docs/advance/rollout_trace.rst @@ -0,0 +1,125 @@ +Trace Function Usage Instructions +======================================== + +Last updated: 07/10/2025. + +Applicable Scenarios +-------------------- + +Agentic RL involves multiple turns of conversations, tool invocations, and user interactions during the rollout process. During the Model Training process, it is necessary to track function calls, inputs, and outputs to understand the flow path of data within the application. The Trace feature helps, in complex multi-round conversations, to view the transformation of data during each interaction and the entire process leading to the final output by recording the inputs, outputs, and corresponding timestamps of functions, which is conducive to understanding the details of how the model processes data and optimizing the training results. + +The Trace feature integrates commonly used Agent trace tools, including wandb weave and mlflow, which are already supported. Users can choose the appropriate trace tool according to their own needs and preferences. Here, we introduce the usage of each tool. + + +Trace Parameter Configuration +----------------------------- + +- ``actor_rollout_ref.rollout.trace.backend=mlflow|weave`` # the trace backend type +- ``actor_rollout_ref.rollout.trace.token2text=True`` # To show decoded text in trace view + + +Glossary +-------- + ++----------------+------------------------------------------------------------------------------------------------------+ +| Object | Explaination | ++================+======================================================================================================+ +| trajectory | A complete multi-turn conversation includes: | +| | 1. LLM output at least once | +| | 2. Tool Call | ++----------------+------------------------------------------------------------------------------------------------------+ +| step | The training step corresponds to the global_steps variable in the trainer | ++----------------+------------------------------------------------------------------------------------------------------+ +| sample_index | The identifier of the sample, defined in the extra_info.index of the dataset. It is usually a number,| +| | but may also be a uuid in some cases. | ++----------------+------------------------------------------------------------------------------------------------------+ +| rollout_n | In the GROP algorithm, each sample is rolled out n times. rollout_n represents the serial number of | +| | the rollout. | ++----------------+------------------------------------------------------------------------------------------------------+ +| validate | Whether the test dataset is used for evaluation? | ++----------------+------------------------------------------------------------------------------------------------------+ + +Rollout trace functions +----------------------- + +There are 2 functions used for tracing: + +1. ``rollout_trace_op``: This is a decorator function used to mark the functions to trace. In default, only few method has it, you can add it to more functions to trace more infor. +2. ``rollout_trace_attr``: This function is used to mark the entry of a trajectory and input some info to trace. If you add new type of agent, you may need to add it to enable trace. + + +Usage of wandb weave +-------------------- + +1.1 Basic Configuration +~~~~~~~~~~~~~~~~~~~~~~~ + +1. Set the ``WANDB_API_KEY`` environment variable +2. Configuration Parameters + + 1. ``actor_rollout_ref.rollout.trace.backend=weave`` + 2. ``trainer.logger=['console', 'wandb']``: This item is optional. Trace and logger are independent functions. When using Weave, it is recommended to also enable the wandb logger to implement both functions in one system. + 3. ``trainer.project_name=$project_name`` + 4. ``trainer.experiment_name=$experiment_name`` + 5. ``actor_rollout_ref.rollout.mode=async``: Since trace is mainly used for agentic RL, need to enable agent toop using async mode for either vllm or sglang. + +Note: +The Weave Free Plan comes with a default monthly network traffic allowance of 1GB. During the training process, the amount of trace data generated is substantial, reaching dozens of gigabytes per day, so it is necessary to select an appropriate wandb plan. + + +1.2 View Trace Logs +~~~~~~~~~~~~~~~~~~~ + +After executing the training, on the project page, you can see the WEAVE sidebar. Click Traces to view it. + +Each Trace project corresponds to a trajectory. You can filter and select the trajectories you need to view by step, sample_index, rollout_n, and experiment_name. + +After enabling token2text, prompt_text and response_text will be automatically added to the output of ToolAgentLoop.run, making it convenient to view the input and output content. + +.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/weave_trace_list.png?raw=true + +1.3 Compare Trace Logs +~~~~~~~~~~~~~~~~~~~~~~ + +Weave can select multiple trace items and then compare the differences among them. + +.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/weave_trace_compare.png?raw=true + +Usage of mlflow +--------------- + +1. Basic Configuration +~~~~~~~~~~~~~~~~~~~~~~ + +1. Set the ``MLFLOW_TRACKING_URI`` environment variable, which can be: + + 1. Http and https URLs corresponding to online services + 2. Local files or directories, such as ``sqlite:////tmp/mlruns.db``, indicate that data is stored in ``/tmp/mlruns.db``. When using local files, it is necessary to initialize the file first (e.g., start the UI: ``mlflow ui --backend-store-uri sqlite:////tmp/mlruns.db``) to avoid conflicts when multiple workers create files simultaneously. + +2. Configuration Parameters + + 1. ``actor_rollout_ref.rollout.trace.backend=mlflow`` + 2. ``trainer.logger=['console', 'mlflow']``. This item is optional. Trace and logger are independent functions. When using mlflow, it is recommended to also enable the mlflow logger to implement both functions in one system. + 3. ``trainer.project_name=$project_name`` + 4. ``trainer.experiment_name=$experiment_name`` + + +2. View Log +~~~~~~~~~~~ + +Since ``trainer.project_name`` corresponds to Experiments in mlflow, in the mlflow view, you need to select the corresponding project name, then click the "Traces" tab to view traces. Among them, ``trainer.experiment_name`` corresponds to the experiment_name of tags, and tags corresponding to step, sample_index, rollout_n, etc., are used for filtering and viewing. + +For example, searching for ``"tags.step = '1'"`` can display all trajectories of step 1. + +.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/mlflow_trace_list.png?raw=true + +Opening one of the trajectories allows you to view each function call process within it. + +After enabling token2text, prompt_text and response_text will be automatically added to the output of ToolAgentLoop.run, making it convenient to view the content. + +.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/mlflow_trace_view.png?raw=true + +Note: + +1. mlflow does not support comparing multiple traces +2. rollout_trace can not associate the mlflow trace with the run, so the trace content cannot be seen in the mlflow run logs. diff --git a/docs/advance/rope.rst b/docs/advance/rope.rst index 8382b8f4a..9463549e4 100644 --- a/docs/advance/rope.rst +++ b/docs/advance/rope.rst @@ -1,6 +1,8 @@ RoPE Scaling override ======================================= +Last updated: 05/14/2025. + Some models such as `Qwen/Qwen2.5-7B-Instruct `_ support RoPE Scaling but don't have it defined in their config.json file. For example, this model supports this configuration: diff --git a/docs/algo/baseline.md b/docs/algo/baseline.md index 4d23a9c15..ce74c367d 100644 --- a/docs/algo/baseline.md +++ b/docs/algo/baseline.md @@ -1,7 +1,11 @@ # Algorithm Baselines +Last updated: 06/18/2025. + ## Math related datasets +### GSM8k + Assuming GSM8k/math dataset is preprocessed via: ```bash @@ -19,17 +23,41 @@ Refer to the table below to reproduce RL training from different pre-trained che | NVIDIA GPU | Qwen/Qwen2.5-0.5B-Instruct | hf checkpoint | 36.4 | [Qwen blog](https://qwenlm.github.io/blog/qwen2.5-llm/) | | NVIDIA GPU | Qwen/Qwen2.5-0.5B-Instruct | PPO | 56.7 | [command and log](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/Qwen2.5-0.5B-bsz256_2-prompt1024-resp512-0.567.log) | | NVIDIA GPU | Qwen/Qwen2.5-0.5B-Instruct | PRIME | 58.7 | [script](https://github.com/volcengine/verl/blob/main/recipe/prime/run_prime_qwen.sh), [wandb](https://api.wandb.ai/links/zefan-wang-thu-tsinghua-university/rxd1btvb) | +| NVIDIA GPU | Qwen/Qwen2.5-0.5B-Instruct | GRPO-LoRA | 54.3 | [command and logs](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/Qwen2.5-0.5B-bsz64_2-prompt512-resp1024-lorarank32-score0.543.log)| +| NVIDIA GPU | Qwen/Qwen2.5-1.5B-Instruct | GRPO-LoRA | 77.9 | [command and logs](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/Qwen2.5-1.5B-bsz64_2-prompt512-resp1024-lorarank32-score0.779.log)| +| NVIDIA GPU | Qwen/Qwen2.5-3B-Instruct | GRPO-LoRA | 86.1 | [command and logs](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/Qwen2.5-3B-bsz64_2-prompt512-resp1024-lorarank32-score0.861.log)| | NVIDIA GPU | deepseek-ai/deepseek-llm-7b-chat | PPO (Megatron) | 69.5 [1] | [log](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/deepseek-llm-7b-chat-megatron-bsz256_4-prompt512-resp512-0.695.log), [wandb](https://wandb.ai/verl-team/verl_megatron_gsm8k_examples/runs/10fetyr3) | | NVIDIA GPU | Qwen/Qwen2-7B-Instruct | GRPO | 89 | [script](https://github.com/volcengine/verl/blob/a65c9157bc0b85b64cd753de19f94e80a11bd871/examples/grpo_trainer/run_qwen2-7b_seq_balance.sh) | | NVIDIA GPU | Qwen/Qwen2-7B-Instruct | GRPO (FSDP2) | 89.8 | [log](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/qwen2-7b-fsdp2.log) | | NVIDIA GPU | Qwen/Qwen2-7B-Instruct | GRPO (Megatron) | 89.6 | [log](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/qwen2-7b_math_megatron.log) | | NVIDIA GPU | Qwen/Qwen2.5-7B-Instruct | ReMax | 97 | [script](https://github.com/eric-haibin-lin/verl/blob/main/examples/remax_trainer/run_qwen2.5-3b_seq_balance.sh), [wandb](https://wandb.ai/liziniu1997/verl_remax_example_gsm8k/runs/vxl10pln) | | NVIDIA GPU | Qwen/Qwen2.5-7B-Instruct | SPPO | 65.6 (MATH) | [SPPO script](https://github.com/volcengine/verl/tree/main/recipe/sppo/README.md) | +| NVIDIA GPU | Qwen/Qwen2.5-7B-Instruct | GRPO-LoRA | 93.4 | [command and logs](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/Qwen2.5-7B-bsz64_8-prompt512-resp1024-lorarank32-score0.934.log)| | NVIDIA GPU | Mixtral-8x22B-Instruct-v0.1 | Instruct model | 83.7 | [Qwen Blog](https://qwenlm.github.io/blog/qwen2.5-llm/) | | NVIDIA GPU | Mixtral-8x22B-Instruct-v0.1 | RLOO (Megatron) | 92.3 | [wandb](https://api.wandb.ai/links/ppo_dev/sbuiuf2d) | | NVIDIA GPU | Qwen/Qwen2.5-7B-Instruct | SPIN | 92 | [script](https://github.com/volcengine/verl/tree/main/recipe/spin/README.md) | +| NVIDIA GPU | Qwen/Qwen2-7B-Instruct | GPG | 88 | [log](https://github.com/diqiuzhuanzhuan/verldata/blob/main/run_logs/qwen2-7b_math.log), [wandb](https://wandb.ai/diqiuzhuanzhuan/verl_gpg_example_gsm8k_math/runs/ab86c4va) | +| NVIDIA GPU | Qwen/Qwen2-7B-Instruct | GPG (Megatron) | 88 | [log](https://github.com/diqiuzhuanzhuan/verldata/blob/main/run_logs/qwen2-7b_math_megatron.log), [wandb](https://wandb.ai/diqiuzhuanzhuan/verl_gpg_example_gsm8k_math/runs/yy8bheu8) | +| NVIDIA GPU | Qwen/Qwen2.5-VL-7B-Instruct | GRPO (Megatron) | 65.4 (GEO3k) | [script](https://github.com/volcengine/verl/blob/main/examples/grpo_trainer/run_qwen2_5_vl-7b-megatron.sh), [wandb](https://api.wandb.ai/links/megatron-core-moe-dev/1yngvkek) | | AMD MI300 | deepseek-ai/deepseek-llm-7b-chat | PPO | 70.5 [1] | [log](https://github.com/yushengsu-thu/verl_training_log/blob/main/gsm8k/ppo_run_deepseek7b_llm.log) | | AMD MI300 | deepseek-ai/deepseek-llm-7b-chat | GRPO | 71.4 [1] | [log](https://github.com/yushengsu-thu/verl_training_log/blob/main/gsm8k/grpo_run_deepseek7b_llm.log) | +| NVIDIA GPU | Qwen/Qwen2.5-14B-Instruct | GRPO-LoRA | 94.6 | [command and logs](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/Qwen2.5-14B-bsz64_8-prompt512-resp1024-lorarank32-score0.946.log)| +| NVIDIA GPU | Qwen/Qwen2.5-32B-Instruct | GRPO-LoRA | 95.8 | [command and logs](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/Qwen2.5-32B-bsz64_8-prompt512-resp1024-lorarank32-score0.958.log)| +| NVIDIA GPU | Qwen/Qwen2.5-72B-Instruct | GRPO-LoRA | 96.0 | [command and logs](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/Qwen2.5-72B-bs64_8-prompt512-resp1024-lorarank32-score0.960.log)| + +### DAPO math-17k + +- Training DAPO math-17k dataset: https://huggingface.co/datasets/BytedTsinghua-SIA/DAPO-Math-17k +- Testing: AIME'24: https://huggingface.co/datasets/BytedTsinghua-SIA/AIME-2024 + +Note: +- For Qwen/Qwen2.5-Math-7B, we directly modify the max_position_embeddings to 32768 without observing performance degradation in order to train longer response length. + +| Hardware | Model | Method | Test score | Details | +|-------------|----------------------------------|-------------------|--------------|---------| +| NVIDIA GPU | Qwen/Qwen2.5-Math-7B (32k) | DAPO | 36.3 | [command](https://github.com/volcengine/verl/blob/main/recipe/dapo/test_dapo_7b_math.sh), [logs](https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/runs/ow47vvon?nw=nwusertongyuxuan361)| + + ## Coding related datasets @@ -44,4 +72,4 @@ Below is the result on leetcode if not specified otherwise. [1] During evaluation, we have only extracted answers following the format `"####"`. A more flexible answer extraction, longer response length, and better prompt engineering may lead to a higher score. -[2] The default value of `actor_rollout_ref.actor.entropy_coeff` is set to `0.0` since verl 0.3.x on 2025-05-30, which is different from previous versions. \ No newline at end of file +[2] The default value of `actor_rollout_ref.actor.entropy_coeff` is set to `0.0` since verl 0.3.x on 2025-05-30, which is different from previous versions. diff --git a/docs/algo/dapo.md b/docs/algo/dapo.md index 009600313..96f242eaa 100644 --- a/docs/algo/dapo.md +++ b/docs/algo/dapo.md @@ -1,11 +1,12 @@ # Recipe: Decoupled Clip and Dynamic Sampling Policy Optimization (DAPO) -> Open-Source Algorithm Implementation & Expriement Running: [Yuxuan Tong](https://tongyx361.github.io/), [Guangming Sheng](https://hk.linkedin.com/in/guangming-sheng-b50640211) +Last updated: 06/19/2025. -🏠 [Homepage](https://dapo-sia.github.io/) | 📝 [Paper](https://dapo-sia.github.io/static/pdf/dapo_paper.pdf) | 🤗 [Datasets&Models@HF](https://huggingface.co/collections/BytedTsinghua-SIA/dapo-67d7f1517ee33c8aed059da0) | 🐱 [Code@GitHub](https://github.com/volcengine/verl/tree/gm-tyx/puffin/main/recipe/dapo) | 🐱 [Repo@GitHub](https://github.com/BytedTsinghua-SIA/DAPO) +> Open-Source Algorithm Implementation & Expriement Running: [Yuxuan Tong](https://tongyx361.github.io/), [Guangming Sheng](https://hk.linkedin.com/in/guangming-sheng-b50640211) +🏠 [Homepage](https://dapo-sia.github.io/) | 📝 [Paper@arXiv](https://arxiv.org/abs/2503.14476) | 🤗 [Datasets&Models@HF](https://huggingface.co/collections/BytedTsinghua-SIA/dapo-67d7f1517ee33c8aed059da0) | 🐱 [Code@GitHub](https://github.com/volcengine/verl/tree/recipe/dapo/recipe/dapo) | 🐱 [Repo@GitHub](https://github.com/BytedTsinghua-SIA/DAPO) -> We propose the **D**ecoupled Clip and Dynamic s**A**mpling **P**olicy **O**ptimization (DAPO) algorithm. By making our work publicly available, we provide the broader research community and society with practical access to scalable reinforcement learning, enabling all to benefit from these advancements. Applying DAPO training to Qwen2.5-32B base model proves to outperform the previous state-of-the-art DeepSeek-R1-Zero-Qwen-32B on AIME 2024, achieving **50%** accuracy with **50%** less training steps. +> We propose the **D**ecoupled Clip and Dynamic s**A**mpling **P**olicy **O**ptimization (DAPO) algorithm. By making our work publicly available, we provide the broader research community and society with practical access to scalable reinforcement learning, enabling all to benefit from these advancements. Our system is based on the awesome [verl](https://github.com/volcengine/verl) framework. Thanks for their great work! Applying DAPO training to Qwen2.5-32B base model proves to outperform the previous state-of-the-art DeepSeek-R1-Zero-Qwen-32B on AIME 2024, achieving **50%** accuracy with **50%** less training steps. > > ![dapo-main-result](https://dapo-sia.github.io/static/images/score.png) @@ -24,22 +25,25 @@ cd verl # Repo root export RAY_ADDRESS="http://${RAY_IP:-localhost}:8265" # The Ray cluster address to connect to export WORKING_DIR="${PWD}" # The local directory to package to the Ray cluster # Set the runtime environment like env vars and pip packages for the Ray cluster in yaml -export RUNTIME_ENV="./verl/trainer/runtime_env.yaml" -bash recipe/dapo/run_dapo_qwen2.5_32b.sh +export RUNTIME_ENV="./recipe/dapo/runtime_env.yaml" # This sets environment variables for the Ray cluster +bash recipe/dapo/run_dapo_qwen2.5_32b.sh # or other scripts ``` ## Reproduction Runs -| Setup | AIME 2024 Acc. | Training Script | Training Record | -| -------------------------------------------- | -------------- | ---------------------------------------------------------------- | ----------------------------------------------------------------------------------------- | -| DAPO w/o Token-level Loss & Dynamic Sampling | 44% | [run_dapo_early_qwen2.5_32b.sh](./run_dapo_early_qwen2.5_32b.sh) | [W&B](https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/workspace?nw=wmb4qxfht0n) | -| DAPO w/o Dynamic Sampling | 50% | [run_dapo_wo_ds_qwen2.5_32b.sh](./run_dapo_wo_ds_qwen2.5_32b.sh) | [W&B](https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/workspace?nw=wmb4qxfht0n) | -| DAPO | 52% | [run_dapo_qwen2.5_32b.sh](./run_dapo_qwen2.5_32b.sh) | [W&B](https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/workspace?nw=wmb4qxfht0n) | +| Setup | AIME 2024 Acc. | Hardware | Image | Commit | Environment Variables | Training Script | Training Record | +| -------------------------------------------- | -------------- | --------- | -------------------------------------------------------------------- | -------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------- | +| DAPO | 52% | 16x8xH800 | `hiyouga/verl:ngc-th2.6.0-cu126-vllm0.8.3-flashinfer0.2.2-cxx11abi0` | [`4f80e4`](https://github.com/volcengine/verl/tree/4f80e465c2ec79ab9c3c30ec74b9745de61d0490) | [runtime_env.yaml](https://github.com/volcengine/verl/blob/4f80e465c2ec79ab9c3c30ec74b9745de61d0490/recipe/dapo/runtime_env.yaml) | [run_dapo_qwen2.5_32b.sh](https://github.com/volcengine/verl/blob/4f80e465c2ec79ab9c3c30ec74b9745de61d0490/recipe/dapo/run_dapo_qwen2.5_32b.sh) | [W&B](https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/workspace?nw=wmb4qxfht0n) | +| DAPO w/o Dynamic Sampling | 50% | 16x8xH800 | `hiyouga/verl:ngc-th2.6.0-cu126-vllm0.8.3-flashinfer0.2.2-cxx11abi0` | [`4f80e4`](https://github.com/volcengine/verl/tree/4f80e465c2ec79ab9c3c30ec74b9745de61d0490) | [runtime_env.yaml](https://github.com/volcengine/verl/blob/4f80e465c2ec79ab9c3c30ec74b9745de61d0490/recipe/dapo/runtime_env.yaml) | [run_dapo_wo_ds_qwen2.5_32b.sh](https://github.com/volcengine/verl/blob/4f80e465c2ec79ab9c3c30ec74b9745de61d0490/recipe/dapo/run_dapo_wo_ds_qwen2.5_32b.sh) | [W&B](https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/workspace?nw=wmb4qxfht0n) | +| DAPO w/o Token-level Loss & Dynamic Sampling | 44% | 16x8xH20 | `hiyouga/verl:ngc-th2.5.1-cu120-vllm0.7.4-hotfix` | [`4f80e4`](https://github.com/volcengine/verl/tree/4f80e465c2ec79ab9c3c30ec74b9745de61d0490) | [runtime_env.yaml](https://github.com/volcengine/verl/blob/4f80e465c2ec79ab9c3c30ec74b9745de61d0490/recipe/dapo/runtime_env.yaml) | [run_dapo_early_qwen2.5_32b.sh](https://github.com/volcengine/verl/blob/4f80e465c2ec79ab9c3c30ec74b9745de61d0490/recipe/dapo/run_dapo_early_qwen2.5_32b.sh) | [W&B](https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/workspace?nw=wmb4qxfht0n) | -## Configuration +> [!IMPORTANT] +> +> **📢 Call for Contribution!** +> +> Welcome to submit your reproduction runs and setups! -> [!NOTE] -> Most experiments in the paper, including the best-performant one, are run without Overlong Filtering because it's somehow overlapping with Overlong Reward Shaping in terms of properly learning from the longest outputs. So we don't implement it here. +## Configuration ### Separated Clip Epsilons (-> Clip-Higher) @@ -159,3 +163,25 @@ if self.overlong_buffer_cfg.enable: overlong_reward = min(-exceed_len / overlong_buffer_len * overlong_penalty_factor, 0) reward += overlong_reward ``` + +## FAQ + +### Where is the "Overlong Filtering" in the paper? + +Most experiments in the paper, including the best-performant one, are run without Overlong Filtering because it's somehow overlapping with Overlong Reward Shaping in terms of properly learning from the longest outputs. So we don't implement it here. + +### What's the difference between [the `recipe/dapo` directory in the `main` branch](https://github.com/volcengine/verl/tree/main/recipe/dapo) and the [`recipe/dapo` branch](https://github.com/volcengine/verl/tree/recipe/dapo/recipe/dapo)? + +[The `recipe/dapo` branch](https://github.com/volcengine/verl/tree/recipe/dapo/recipe/dapo) is for **as-is reproduction** and thus won't be updated with new features. + +[The `recipe/dapo` directory in the `main` branch](https://github.com/volcengine/verl/tree/main/recipe/dapo) works as an example of how to extend the latest `verl` to implement an algorithm recipe, which will be maintained with new features. + +### Why can't I produce similar results after modifications? + +RL infrastructures nowadays still have inherent unrobustness, on which we are still working hard to improve. + +We strongly recommend to only modify one thing at a time. + +We also list some known problems here: + +1. Enabling CUDA graph (`enforce_eager=False`) might cause model performance degradation, whose cause is still under investigation. diff --git a/docs/algo/entropy.md b/docs/algo/entropy.md new file mode 100644 index 000000000..46153b7e8 --- /dev/null +++ b/docs/algo/entropy.md @@ -0,0 +1,115 @@ +# Recipe: Entropy Mechanism + +Last updated: 06/27/2025. + + +
+ + The Entropy Mechanism of Reinforcement Learning for Large Language Model Reasoning. + +[![Paper](https://img.shields.io/badge/paper-A42C25?style=for-the-badge&logo=arxiv&logoColor=white)](https://arxiv.org/pdf/2505.22617) [![Github](https://img.shields.io/badge/PRIME-000000?style=for-the-badge&logo=github&logoColor=000&logoColor=white)](https://github.com/PRIME-RL/Entropy-Mechanism-of-RL) [![alphaXiv](https://img.shields.io/badge/discussion-A42C25?style=for-the-badge&logo=arxiv&logoColor=white&color=blue +)](https://www.alphaxiv.org/abs/2505.22617) [![Twitter](https://img.shields.io/badge/Twitter-%23000000.svg?style=for-the-badge&logo=twitter&logoColor=white)](https://x.com/stingning/status/1928088554166505667) [![Twitter](https://img.shields.io/badge/Twitter-%23000000.svg?style=for-the-badge&logo=twitter&logoColor=white)](https://x.com/charlesfornlp/status/1928089451080585283) [![Twitter-ak](https://img.shields.io/badge/Twitter-%23000000.svg?style=for-the-badge&logo=twitter&logoColor=white)](https://x.com/_akhaliq/status/1928077929105268861) + + + + +
+ + +## 🎉News + +- **[2025/05/29]** 🎉 Ranked **#1** of the day on [Huggingface Daily Papers](https://huggingface.co/papers?date=2025-05-29). +- **[2025/05/29]** Released our Paper on arXiv. See [here](https://arxiv.org/pdf/2505.22617). We provide insights into the entropy mechanism of RL for LLMs and propose two simple yet effective strategies to alleviate the entropy collapse. + + + +## ✨Getting started + +After preparing the training data, for training Qwen2.5-7B on a single node, taking the KL-Cov approach as an example, you can simply run: + +``` +cd verl +conda activate your_env +bash recipe/dapo/7b_kl_cov.sh +``` + +While for training Qwen2.5-32B on multi nodes, you can run the following commands: + +``` +cd verl +conda activate your_env +bash recipe/dapo/32b_kl_cov.sh +``` + +## 📖Introduction + +
+ issue +
+ +This paper addresses the entropy collapse issue in scaling reinforcement learning (RL) for large language models (LLMs), where policy entropy drops sharply during training, leading to overconfidence and performance saturation. We empirically establish a relationship between entropy ($H$) and performance ($R$): $R=−aexp(H)+b$, showing performance is bottlenecked by entropy exhaustion. + +
+ issue +
+ +Theoretically, we find entropy changes are driven by the covariance between action probability and logit updates, which correlates with advantage in Policy Gradient methods. High-probability, high-advantage actions reduce entropy, while rare, high-advantage actions increase it. Empirically, the covariance term remains positive, explaining entropy’s monotonic decline. To mitigate this, we propose ​​Clip-Cov​​ and ​​KL-Cov​​, which restrict updates for high-covariance tokens. These methods effectively prevent entropy collapse, and improve performance. + +## 📃Evaluation + +
+ issue +
+ + +Our method is able to maintain a considerably higher level of entropy throughout training. For example, when the baseline's entropy reaches a plateau and can no longer be consumed, the KL-Cov method still sustains an entropy level over 10 times higher. Meanwhile, the response length of the policy model steadily increases, and its performance on the test set consistently surpasses that of the baseline. This indicates that our model is able to explore more freely during training, learning better policy through RL. +| **Method** | **AIME24** | **AIME25** | **AMC** | **MATH-500** | **OMNI-MATH** | **OlympiadBench** | **Minerva** | **Avg.** | +| ----------------- | ---------: | ---------: | -------: | -----------: | ------------: | ----------------: | ----------: | -------: | +| *Qwen2.5-7B* | | | | | | | | | +| GRPO | 21.2 | 9.6 | 58.7 | 78.8 | 27.9 | 40.7 | 36.7 | 38.6 | +| w. Clip-higher | 18.1 | 11.5 | 56.6 | 79.2 | 29.8 | 43.3 | 40.4 | 38.8 | +| w. **`CLIP-Cov`** | 22.1 | **15.8** | 58.2 | 80.4 | **30.5** | **44.1** | **41.1** | 40.4 | +| w. **`KL-Cov`** | **22.6** | 12.9 | **61.4** | **80.8** | 29.1 | 42.6 | 38.2 | **40.6** | +| *Qwen2.5-32B* | | | | | | | | | +| GRPO | 21.8 | 16.2 | 69.7 | 84.2 | 35.2 | 43.6 | 45.5 | 45.8 | +| w. Clip-higher | 35.6 | 22.3 | 69.5 | 77.2 | 35.1 | 42.5 | 43.0 | 47.2 | +| w. **`CLIP-Cov`** | 32.3 | 22.7 | 67.2 | **87.0** | **42.0** | **57.2** | 46.0 | 50.3 | +| w. **`KL-Cov`** | **36.8** | **30.8** | **74.5** | 84.6 | 39.1 | 49.0 | **46.3** | **52.2** | + +Our two approaches both achieve non-trivial improvements across all benchmarks. Compared to GRPO, our method outperforms it by 2.0% on average for the 7B model and by 6.4% for the 32B model. Moreover, we observe that our method yields more substantial gains on the larger Qwen2.5-32B. Specifically, our method achieves improvements of 15.0% and 14.6% compared to GRPO on the most challenging benchmarks, AIME24 and AIME25, respectively. + + +## 🎈Citation +If you find this paper or repo helpful, please cite us. + +```bibtex +@article{cui2025entropy, + title={The Entropy Mechanism of Reinforcement Learning for Reasoning Language Models}, + author={Cui, Ganqu and Zhang, Yuchen and Chen, Jiacheng and Yuan, Lifan and Wang, Zhi and Zuo, Yuxin and Li, Haozhan and Fan, Yuchen and Chen, Huayu and Chen, Weize and others}, + journal={arXiv preprint arXiv:2505.22617}, + year={2025} +} +``` +## 🌻Acknowledgement +We implement our reinforcement learning algorithm extending from [verl](https://github.com/volcengine/verl). We utilize [vLLM](https://github.com/vllm-project/vllm) for inference. Our models are trained primarily on [Qwen2.5 family](https://github.com/QwenLM/Qwen2.5). Our training data is built from [DAPO-MATH](https://huggingface.co/datasets/BytedTsinghua-SIA/DAPO-Math-17k). Thanks for their great contributions! + +## 📬 Contact + +For questions, discussion, or collaboration opportunities, feel free to contact: +- Ganqu Cui: cuiganqu@pjlab.org.cn +- Yuchen Zhang: yuchen.zhang2003@gmail.com +- Jiacheng Chen: jackchan9345@gmail.com +- Ning Ding: ningding.cs@gmail.com + diff --git a/docs/algo/gpg.md b/docs/algo/gpg.md new file mode 100644 index 000000000..36bede8c3 --- /dev/null +++ b/docs/algo/gpg.md @@ -0,0 +1,36 @@ +# GPG: Group Policy Gradient + +Last updated: 07/03/2025. + +Group Policy Gradient (GPG) is a minimalist reinforcement learning (RL) method that enhances the reasoning ability of large language models without relying on supervised fine-tuning or complex tricks. GPG revisits traditional policy gradients and directly optimizes the RL objective—no surrogate losses, no KL penalties, no critic, and no reference model. Compared to GRPO, GPG is simpler, more efficient, and achieves better results on many tasks. For more details, please refer to the original paper [GPG: A Simple and Strong Reinforcement Learning Baseline for Model Reasoning +](https://arxiv.org/abs/2504.02546). + +## Key Components +- Use a corrected advantage function to improve policy gradient accuracy and training efficiency. +- By eliminating the critic and reference models, avoiding KL divergence constraints, significantly simplifies the training process compared to Group Relative Policy Optimization (GRPO) + +## Configuration +To configure GPG within the framework, use the following YAML settings. + +```yaml +algorithm: + adv_estimator: gpg +actor_rollout_ref: + actor: + policy_loss: + loss_mode: "gpg" +``` + +## Advanced Extensions +GPG is a simple and strong baseline for model reasoning. Although it avoids using KL loss in its original form, you can still use KL loss to further improve the performance. + +```yaml +algorithm: + adv_estimator: gpg +actor_rollout_ref: + actor: + use_kl_loss: True # enable kl regularization + kl_loss_coef: 0.01 + policy_loss: + loss_mode: "gpg" +``` \ No newline at end of file diff --git a/docs/algo/grpo.md b/docs/algo/grpo.md index 92e3bd790..ba6d8ddab 100644 --- a/docs/algo/grpo.md +++ b/docs/algo/grpo.md @@ -1,5 +1,7 @@ # Group Relative Policy Optimization (GRPO) +Last updated: 05/31/2025. + In reinforcement learning, classic algorithms like PPO rely on a "critic" model to estimate the value of actions, guiding the learning process. However, training this critic model can be resource-intensive. GRPO simplifies this process by eliminating the need for a separate critic model. Instead, it operates as follows: diff --git a/docs/algo/opo.md b/docs/algo/opo.md index 01db5cccd..338f3a762 100644 --- a/docs/algo/opo.md +++ b/docs/algo/opo.md @@ -1,5 +1,7 @@ # On-Policy RL with Optimal Reward Baseline (OPO) +Last updated: 06/02/2025. + Loose on-policy constraints and suboptimal baselines in reinforcement learning often lead to training instability such as large policy shifts and entropy collapse. OPO addresses these challenges by using exact on-policy training with the theretically optimal reward baseline for advantage estimation. It achieves lower policy shifts and higher output entropy, encouraging more diverse and less repetitive responses. OPO uses group sampling to generate multiple outputs for each input like GRPO. Unlike group-based algorithms which typically use the mean reward of a group as its baseline, OPO employs a theoretically optimal baseline: the length-weighted reward of the group. It also omits the standard deviation normalization. By adopting these two key components, OPO enables the training of a single policy model with the objective of maximizing only the expected reward. For more detailes, refer to the original paper [On-Policy RL with Optimal Reward Baseline](https://arxiv.org/pdf/2505.23585). diff --git a/docs/algo/ppo.md b/docs/algo/ppo.md index 7d4069414..d1f3046e5 100644 --- a/docs/algo/ppo.md +++ b/docs/algo/ppo.md @@ -1,5 +1,7 @@ # Proximal Policy Optimization (PPO) +Last updated: 06/19/2025. + Proximal Policy Optimization (PPO) is a family of policy gradient methods for reinforcement learning, proposed by OpenAI in 2017. PPO strikes a balance between simplicity, stability, and performance, making it one of the most widely used algorithms in modern RL applications, including large-scale language model fine-tuning. Traditional policy gradient methods like REINFORCE or Vanilla Policy Gradient suffer from: @@ -37,7 +39,7 @@ Most critic configs are similar to those of actors. Note that the critic model i - `actor_rollout_ref.actor.ppo_epochs`: Number of epochs for PPO updates on one set of sampled trajectories for actor -- `actor_rollout_ref.actor.ppo_epochs`: Number of epochs for PPO updates on one set of sampled trajectories for critic +- `critic.ppo_epochs`: Number of epochs for PPO updates on one set of sampled trajectories for critic. Defaults to `actor_rollout_ref.actor.ppo_epochs` - `algorithm.gemma`: discount factor @@ -86,7 +88,7 @@ Qwen2.5 training log and commands: [link](https://github.com/eric-haibin-lin/ver bash run_gemma.sh trainer.n_gpus_per_node=1 \ actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ - trainer.logger=['console'] \ + trainer.logger=console \ critic.model.path=Qwen/Qwen2.5-0.5B-Instruct \ actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \ data.train_batch_size=256 \ diff --git a/docs/algo/spin.md b/docs/algo/spin.md index 4b50532d0..c2a834262 100644 --- a/docs/algo/spin.md +++ b/docs/algo/spin.md @@ -1,5 +1,7 @@ # Recipe: Self-Play Fine-Tuning (SPIN) +Last updated: 05/31/2025. + `verl` provides a recipe inspired by the paper **"Self-Play Fine-Tuning Converts Weak Language Models to Strong Language Models"** (SPIN). SPIN is a language model finetuning algorithm that enables iterative self-improvement through a self-play mechanism inspired by game theory. **Core Idea:** Models learn by playing against themselves, reducing reliance on external preference datasets or stronger teacher models: diff --git a/docs/algo/sppo.md b/docs/algo/sppo.md index ff648849d..bf7c4e9e6 100644 --- a/docs/algo/sppo.md +++ b/docs/algo/sppo.md @@ -1,5 +1,7 @@ # Recipe: Self-Play Preference Optimization (SPPO) +Last updated: 05/28/2025. + verl provides a community recipe implementation for the paper [Self-Play Preference Optimization for Language Model Alignment](https://arxiv.org/abs/2405.00675). SPPO can significantly enhance the performance of an LLM without strong external signals such as responses or preferences from GPT-4. It can outperform the model trained with iterative direct preference optimization (DPO), among other methods. SPPO is theoretically grounded, ensuring that the LLM can converge to the von Neumann winner (i.e., Nash equilibrium) under general, potentially intransitive preference, and empirically validated through extensive evaluations on multiple datasets. Paper Authors: [Yue Wu](https://yuewu.us/)\*, [Zhiqing Sun](https://www.cs.cmu.edu/~zhiqings/)\*, [Huizhuo Yuan](https://scholar.google.com/citations?user=8foZzX4AAAAJ)\*, [Kaixuan Ji](https://scholar.google.com/citations?user=FOoKDukAAAAJ), [Yiming Yang](https://www.cs.cmu.edu/~yiming/), [Quanquan Gu](https://web.cs.ucla.edu/~qgu/) diff --git a/docs/amd_tutorial/amd_build_dockerfile_page.rst b/docs/amd_tutorial/amd_build_dockerfile_page.rst index 3b25df232..51efa247c 100644 --- a/docs/amd_tutorial/amd_build_dockerfile_page.rst +++ b/docs/amd_tutorial/amd_build_dockerfile_page.rst @@ -1,6 +1,8 @@ Getting started with AMD (ROCM Kernel) ===================================================== +Last updated: 07/06/2025. + Author: `Yusheng Su `_ Setup @@ -14,40 +16,267 @@ docker/Dockerfile.rocm .. code-block:: bash - # Build the docker in the repo dir: - # docker build -f docker/Dockerfile.rocm -t verl-rocm . - # docker images # you can find your built docker + FROM "rlfoundation.azurecr.io/rocm6.3.4:vllm-0.8.5-numa-patch-ubuntu-22.04" + + SHELL ["/bin/bash", "-ceuxo", "pipefail"] + + ENV MAX_JOBS=512 + + ENV PATH="/usr/local/python3.12/bin:$PATH" + RUN ln -sf /usr/bin/python3.12 /usr/bin/python && \ + ln -sf /usr/bin/pip3.12 /usr/bin/pip + + ############################################ + RUN apt-get update + RUN apt-get install -y pkg-config liblzma-dev + ############################################ + + ########################################### + ##########Install TransformerEngine######## + ########################################### + WORKDIR /workspace/ + # transformer-engine install + # https://github.com/ROCm/TransformerEngine + RUN rm -rf TransformerEngine + RUN git clone --recursive https://github.com/ROCm/TransformerEngine.git + WORKDIR /workspace/TransformerEngine + git checkout 236178e5 + # git checkout bb061ade + # git checkout 864405c + ENV NVTE_FRAMEWORK=pytorch + ENV NVTE_ROCM_ARCH=gfx942 + ENV NVTE_USE_HIPBLASLT=1 + ENV NVTE_USE_ROCM=1 + # export CMAKE_PREFIX_PATH="/opt/rocm:/opt/rocm/hip:/usr/local:/usr:${CMAKE_PREFIX_PATH:-}" + ENV CMAKE_PREFIX_PATH="/opt/rocm:/opt/rocm/hip:/usr/local:/usr" + RUN MAX_JOBS=$(MAX_JOBS) pip install . -vvv + WORKDIR /workspace/ + ########################################### + ########################################### + ########################################### + + + + + + #################################################################################### + ################Install vllm - sglang require vllm 0.6.7 dependency################# + #################################################################################### + #### Require vllm 0.6.7 - checkout 113274a0 + WORKDIR /workspace/ + RUN rm -rf vllm + RUN pip uninstall -y vllm + # Refer to here (down-grade vllm to 0.6.3): https://docs.vllm.ai/en/v0.6.3/getting_started/amd-installation.html + RUN git clone https://github.com/ROCm/vllm.git + # git clone https://github.com/vllm-project/vllm.git + WORKDIR /workspace/vllm + RUN git checkout 113274a0 + ENV PYTORCH_ROCM_ARCH="gfx90a;gfx942" + #ENV MAX_JOBS=512 + ENV MAX_JOBS=${MAX_JOBS} + RUN pip install "boto3>=1.26.0" + RUN pip install setuptools_scm + # will add src into py. You can delete the repo + RUN python3 setup.py install + WORKDIR /workspace/ + #################################################################################### + #################################################################################### + #################################################################################### + - # Support - Traing: fsdp; Inference: vllm - # FROM rocm/vllm:rocm6.2_mi300_ubuntu20.04_py3.9_vllm_0.6.4 - # Support - Traing: fsdp; Inference: vllm, sglang - FROM lmsysorg/sglang:v0.4.6.post5-rocm630 + ########################################### + ############For hack docker################ + ########################################### + RUN pip install setuptools==75.8.0 + ########################################### + ########################################### + ########################################### - # Set working directory - # WORKDIR $PWD/app + + ########################################### + ############build sgalng################### + ########################################### # Set environment variables + ENV BASE_DIR=/sgl-workspace + ENV BUILD_TYPE=all + ENV SGL_REPO=https://github.com/sgl-project/sglang + ENV SGL_BRANCH=v0.4.6.post5 + ENV TRITON_REPO=https://github.com/ROCm/triton.git + ENV TRITON_COMMIT=improve_fa_decode_3.0.0 + ENV AITER_REPO=https://github.com/ROCm/aiter.git + ENV AITER_COMMIT=v0.1.2 + # v0.1.2 version - commit id: 9d11f47 + # ENV AITER_COMMIT=9d11f47 + ENV HIP_FORCE_DEV_KERNARG=1 + ENV HSA_NO_SCRATCH_RECLAIM=1 + ENV SGLANG_SET_CPU_AFFINITY=1 + ENV SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1 + ENV NCCL_MIN_NCHANNELS=112 + ENV MOE_PADDING=1 + ENV VLLM_FP8_PADDING=1 + ENV VLLM_FP8_ACT_PADDING=1 + ENV VLLM_FP8_WEIGHT_PADDING=1 + ENV VLLM_FP8_REDUCE_CONV=1 + ENV TORCHINDUCTOR_MAX_AUTOTUNE=1 + ENV TORCHINDUCTOR_MAX_AUTOTUNE_POINTWISE=1 + ENV HIPCC_COMPILE_FLAGS_APPEND="--offload-arch=gfx942" + ENV AMDGPU_TARGETS=gfx942 + ENV ROCM_ARCH=gfx942 ENV PYTORCH_ROCM_ARCH="gfx90a;gfx942" - + # Switch to working directory + WORKDIR /sgl-workspace + # Clean and create directory + RUN rm -rf /sgl-workspace && mkdir -p /sgl-workspace + + # Clone and build sglang + RUN git clone ${SGL_REPO} \ + && cd sglang \ + && git checkout ${SGL_BRANCH} || echo "Using default branch" \ + && cd sgl-kernel \ + && rm -f pyproject.toml \ + && mv pyproject_rocm.toml pyproject.toml \ + && python setup_rocm.py install \ + && cd .. \ + && if [ "$BUILD_TYPE" = "srt" ]; then \ + python -m pip --no-cache-dir install -e "python[srt_hip]"; \ + else \ + python -m pip --no-cache-dir install -e "python[all_hip]"; \ + fi \ + && cd /sgl-workspace \ + && cp -r /sgl-workspace/sglang /sglang \ + && python -m pip cache purge + + # Install common Python packages + RUN pip install IPython orjson python-multipart torchao pybind11 + # Rebuild Triton + RUN pip uninstall -y triton || true \ + && git clone ${TRITON_REPO} \ + && cd triton \ + && git checkout ${TRITON_COMMIT} \ + && cd python \ + && python3 setup.py install \ + && cd /sgl-workspace + # ENV HIPCC_COMPILE_FLAGS_APPEND="--offload-arch=gfx942 --amdgpu-lower-module-lds-strategy=1" + # ENV HIPCC_COMPILE_FLAGS_APPEND="--offload-arch=gfx942" + + # Build aiter + #version: Commit 9d11f47 + # && git checkout ${AITER_COMMIT} \ + RUN pip uninstall -y aiter || true + RUN git clone ${AITER_REPO} \ + && cd aiter \ + && git checkout ${AITER_COMMIT} \ + && git submodule sync \ + && git submodule update --init --recursive \ + && PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py install \ + && cd /sgl-workspace + + # Copy MI300X config + RUN find /sgl-workspace/sglang/python/sglang/srt/layers/quantization/configs/ \ + /sgl-workspace/sglang/python/sglang/srt/layers/moe/fused_moe_triton/configs/ \ + -type f -name '*MI300X*' | \ + xargs -I {} sh -c 'vf_config=$(echo "$1" | sed "s/MI300X/MI300X_VF/"); cp "$1" "$vf_config"' -- {} + + # Environment setup complete. + RUN echo "Environment setup complete." + + WORKDIR /workspace/ + ########################################### + ########################################### + ########################################### + + + + + + + ########################################### + ###############vllm v0.8.5################# + ########################################### + WORKDIR /workspace/ + + ENV VLLM_TARGET_DEVICE=rocm + ENV ROCM_PATH=/opt/rocm + ENV SETUPTOOLS_SCM_PRETEND_VERSION=0.8.5.dev + # Find the repo path in: DockerFile/Dockerfile.rocm_yang + # RUN git clone https://github.com/RLFoundation/vllm-patch.git + RUN pip uninstall -y vllm || true + RUN rm -rf vllm-patch + RUN git clone https://github.com/RLFoundation/vllm-patch.git \ + && cd vllm-patch \ + && git checkout v0.8.5-sleep-numa \ + && rm -rf build/ dist/ *.egg-info \ + && ln -sf /opt/rocm/lib/libamdhip64.so /usr/lib/libamdhip64.so \ + && SETUPTOOLS_SCM_PRETEND_VERSION=0.8.5.dev PYTORCH_ROCM_ARCH="gfx90a;gfx942" MAX_JOBS=${MAX_JOBS} python3 setup.py install + # RUN SETUPTOOLS_SCM_PRETEND_VERSION=0.8.5.dev PYTORCH_ROCM_ARCH="gfx90a;gfx942" MAX_JOBS=${MAX_JOBS} python3 setup.py develop + WORKDIR /workspace/ + ########################################### + ########################################### + ########################################### + + + + + ######################################### + #### Install megatron-core############### + ######################################### + RUN pip uninstall -y megatron-core && \ + git clone https://github.com/yushengsu-thu/Megatron-LM-amd_version.git && \ + cd Megatron-LM-amd_version && \ + pip install -vvv -e . && \ + cd /workspace/ + ######################################### + ######################################### + ######################################### + + + + + ####################################### + ################apex################### + ####################################### + WORKDIR /workspace/ + RUN pip uninstall -y apex && \ + git clone git@github.com:ROCm/apex.git && \ + cd apex && \ + python setup.py install && \ + cd /workspace/ + ####################################### + ####################################### + ####################################### + + + ################################################################################ + ###########################Add torch_memory_saver############################### + ################################################################################ + # Set environment variables ENV HIPCC_COMPILE_FLAGS_APPEND="--amdgpu-target=gfx90a;gfx942 -D__HIP_PLATFORM_AMD__" ENV CFLAGS="-D__HIP_PLATFORM_AMD__" ENV CXXFLAGS="-D__HIP_PLATFORM_AMD__" + RUN pip install "git+https://github.com/YangWang92/torch_memory_saver_numa.git@numa" + ################################################################################ + ################################################################################ + ################################################################################ - # Install vllm - RUN pip uninstall -y vllm && \ - rm -rf vllm && \ - git clone -b v0.6.3 https://github.com/vllm-project/vllm.git && \ - cd vllm && \ - MAX_JOBS=$(nproc) python3 setup.py install && \ - cd .. && \ - rm -rf vllm - # Copy the entire project directory - COPY . . - # Install dependencies - RUN pip install "tensordict<0.6" --no-deps && \ + ######################################## + ######Install ray####################### + ######################################## + # need to add this patch: https://github.com/ray-project/ray/pull/53531/files + RUN pip uninstall ray -y + RUN pip install "ray[data,train,tune,serve]>=2.47.0" + ######################################## + ######################################## + ######################################## + + + ########################################## + #######Install other dependencies######### + ########################################## + RUN pip install "tensordict==0.6.2" --no-deps && \ pip install accelerate \ codetiming \ datasets \ @@ -59,16 +288,21 @@ docker/Dockerfile.rocm peft \ "pyarrow>=15.0.0" \ pylatexenc \ - "ray[data,train,tune,serve]>=2.45.0" \ torchdata \ - transformers \ wandb \ orjson \ - pybind11 && \ - pip install -e . --no-deps + pybind11 + + WORKDIR /workspace/ + RUN git clone https://github.com/volcengine/verl.git && \ + cd verl && \ + pip install -e . + ########################################## + ########################################## + ########################################## - # Install torch_memory_saver - RUN pip install git+https://github.com/ExtremeViscent/torch_memory_saver.git --no-deps + WORKDIR /workspace/ + CMD ["/usr/bin/bash"] Build the image: @@ -76,7 +310,20 @@ Build the image: .. code-block:: bash - docker build -t verl-rocm . + docker docker/build -t verl-rocm . + +Run the container +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Note: You can pull the docker from this DockerHub: [RLSys Foundation](https://hub.docker.com/u/yushengsuthu) +Pull the image: +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + docker pull yushengsuthu/verl:verl-0.4.1_ubuntu-22.04_rocm6.3.4-numa-patch_vllm0.8.5_sglang0.4.6.post4 + + docker tag yushengsuthu/verl:verl-0.4.1_ubuntu-22.04_rocm6.3.4-numa-patch_vllm0.8.5_sglang0.4.6.post4 verl-rocm:latest Run the container ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -109,7 +356,7 @@ Example ------- Due to to special setting in AMD (ROCM) torch, -1. If your ``ray>=2.45.0`` (default), you need to set ``RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES`` when starting ray in verl's RLHF training. +1. If your ``ray>=2.45.0`` (default), you need to set ``RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES`` when starting ray in verl's RLHF training and add this [patch](https://github.com/ray-project/ray/pull/53531/files). 2. If your ``ray<2.45.0``, you need to set ``RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES`` when starting ray in verl's RLHF training. Inference ``$ENGINE`` can be ``vllm`` or ``sglang``. We choose ``vllm`` as default in the following examples. @@ -124,6 +371,8 @@ PPO YOUR_RUN_NAME=r1-training_ppo-upstream # export HYDRA_FULL_ERROR=1 + export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 + # [ray] < 2.45.0 #export RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1 @@ -156,11 +405,10 @@ PPO critic.model.path=$MODEL_PATH \ critic.ppo_micro_batch_size_per_gpu=4 \ algorithm.kl_ctrl.kl_coef=0.001 \ - trainer.logger=['console'] \ + trainer.logger=console \ trainer.project_name=$YOUR_PROJECT_NAME \ trainer.experiment_name=$YOUR_RUN_NAME \ trainer.val_before_train=False \ - trainer.default_hdfs_dir=null \ trainer.n_gpus_per_node=$GPUS_PER_NODE \ trainer.nnodes=1 \ trainer.save_freq=10 \ @@ -177,6 +425,8 @@ GRPO # export HYDRA_FULL_ERROR=1 # export FSDP_VERBOSE=1 + #export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 + # [ray] < 2.45.0 #export RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1 @@ -217,7 +467,7 @@ GRPO actor_rollout_ref.ref.fsdp_config.param_offload=False \ algorithm.kl_ctrl.kl_coef=0.001 \ trainer.critic_warmup=0 \ - trainer.logger=['console'] \ + trainer.logger=console \ trainer.project_name=$YOUR_PROJECT_NAME \ trainer.experiment_name=$YOUR_RUN_NAME \ trainer.n_gpus_per_node=$GPUS_PER_NODE \ @@ -303,6 +553,9 @@ slurm_script.sh export HSA_NO_SCRATCH_RECLAIM=1 ########################################################################## + ## Assign using GPUs + export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 + ### For rocm and training script # [ray] < 2.45.0 #export RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1 @@ -405,8 +658,6 @@ slurm_script.sh echo "IP Head: $ip_head" # make sure we set environment variables before Ray initialization - # If you are using vllm<=0.6.3, you might need to set the following environment variable to avoid bugs: - # export VLLM_ATTENTION_BACKEND=XFORMERS # Print out all env variables printenv @@ -524,7 +775,7 @@ slurm_script.sh critic.model.fsdp_config.optimizer_offload=False \ algorithm.kl_ctrl.kl_coef=0.0001 \ trainer.critic_warmup=0 \ - trainer.logger=['console','wandb'] \ + trainer.logger='["console","wandb"]' \ trainer.project_name='verl_example' \ trainer.experiment_name='Qwen2.5-32B-Instruct_function_rm' \ trainer.n_gpus_per_node=${SLURM_GPUS_PER_NODE} \ diff --git a/docs/amd_tutorial/amd_vllm_page.rst b/docs/amd_tutorial/amd_vllm_page.rst index bc4db4f73..9c64755cb 100644 --- a/docs/amd_tutorial/amd_vllm_page.rst +++ b/docs/amd_tutorial/amd_vllm_page.rst @@ -1,6 +1,8 @@ verl performance tuning for AMD (ROCm Kernel) ===================================================== +Last updated: 04/25/2025. + Author: `Yang Wang `_ Patch vLLM to Enable Sleep Mode for AMD GPUs @@ -96,9 +98,8 @@ Our investigation shows that ROCm may trigger an unexpected crash when attemptin seed=config.get('seed', 0), ) -Then, you can enable CUDA graph by setting the following environment variables (see `this page `_): +Then, you can choose to enable CUDA graph by setting the following environment variables (see `this page `_): .. code-block:: bash actor_rollout_ref.rollout.enforce_eager=False \ - actor_rollout_ref.rollout.free_cache_engine=False \ diff --git a/docs/api/data.rst b/docs/api/data.rst index e7efca203..1f6018bc9 100644 --- a/docs/api/data.rst +++ b/docs/api/data.rst @@ -1,6 +1,8 @@ Data interface ========================= +Last updated: 05/19/2025 (API docstrings are auto-generated). + DataProto is the interface for data exchange. The :class:`verl.DataProto` class contains two key members: diff --git a/docs/api/single_controller.rst b/docs/api/single_controller.rst index 369e59776..44ea366ff 100644 --- a/docs/api/single_controller.rst +++ b/docs/api/single_controller.rst @@ -1,6 +1,8 @@ Single Controller interface ============================ +Last updated: 05/27/2025 (API docstrings are auto-generated). + The Single Controller provides a unified interface for managing distributed workers using Ray or other backends and executing functions across them. It simplifies the process of dispatching tasks and collecting results, particularly diff --git a/docs/api/trainer.rst b/docs/api/trainer.rst index cd308c44d..abfa51f01 100644 --- a/docs/api/trainer.rst +++ b/docs/api/trainer.rst @@ -1,6 +1,8 @@ Trainer Interface ================================ +Last updated: 06/08/2025 (API docstrings are auto-generated). + Trainers drive the training loop. Introducing new trainer classes in case of new training paradiam is encouraged. .. autosummary:: @@ -12,17 +14,18 @@ Trainers drive the training loop. Introducing new trainer classes in case of new Core APIs ~~~~~~~~~~~~~~~~~ -.. autoclass:: verl.trainer.ppo.ray_trainer.RayPPOTrainer +.. autoclass:: verl.trainer.ppo.ray_trainer.RayPPOTrainer :members: __init__, init_workers, fit - .. automodule:: verl.utils.tokenizer :members: hf_tokenizer - .. automodule:: verl.trainer.ppo.core_algos :members: agg_loss, kl_penalty, compute_policy_loss, kl_penalty - .. automodule:: verl.trainer.ppo.reward :members: load_reward_manager, compute_reward, compute_reward_async + +.. autoclass:: verl.workers.reward_manager.NaiveRewardManager + +.. autoclass:: verl.workers.reward_manager.DAPORewardManager diff --git a/docs/api/utils.rst b/docs/api/utils.rst index 3ac4380b0..e15e3a5a3 100644 --- a/docs/api/utils.rst +++ b/docs/api/utils.rst @@ -1,6 +1,8 @@ Utilities ============ +Last updated: 05/19/2025 (API docstrings are auto-generated). + This section documents the utility functions and classes in the VERL library. Python Functional Utilities @@ -58,7 +60,7 @@ Ulysses Utilities -------------------- .. automodule:: verl.utils.ulysses - :members: gather_outpus_and_unpad, ulysses_pad_and_slice_inputs + :members: gather_outputs_and_unpad, ulysses_pad_and_slice_inputs FSDP Utilities ------------------ @@ -69,6 +71,6 @@ FSDP Utilities Debug Utilities ------------------- -.. automodule:: verl.utils.debug +.. automodule:: verl.utils.profiler :members: log_gpu_memory_usage, GPUMemoryLogger diff --git a/docs/ascend_tutorial/ascend_profiling.rst b/docs/ascend_tutorial/ascend_profiling.rst new file mode 100644 index 000000000..db2972d78 --- /dev/null +++ b/docs/ascend_tutorial/ascend_profiling.rst @@ -0,0 +1,100 @@ +在昇腾设备上基于FSDP后端进行数据采集 +==================================== + +Last updated: 07/14/2025. + +这是一份在昇腾设备上基于FSDP后端使用GRPO或DAPO算法进行数据采集的教程。 + +配置 +---- + +复用verl/trainer/config/ppo_trainer.yaml中的配置项控制采集的模式和步数, +通过verl/trainer/config/npu_profile/npu_profile.yaml中的配置项控制例如采集等级等参数。 + +全局采集控制 +~~~~~~~~~~~~ + +通过 ppo_trainer.yaml 中的参数控制采集步数和模式: + +- trainer.profile_steps: + 该参数可以设置为一个包含采集步数的列表,例如[2, + 4], 意味着将会采集第二步和第四步。如果该参数为null,则代表不进行采集 +- actor_rollout_ref.profiler: + 控制采集的ranks和模式 + + - all_ranks:设为True代表对所有rank进行采集 + - ranks:当all_ranks不为True时, + 通过ranks参数控制需要采集的rank,该参数设置为一个包含采集rank的列表, 例如[0, + 1] + - discrete: + 控制采集的模式。当该参数设置为False,代表采集端到端的数据;当该参数设置为True,代表采用离散模式分训练阶段采集数据 + +通过 npu_profile.yaml 中的参数控制具体采集行为: + +- save_path:采集数据的存放路径 +- level:采集等级,可选项为level_none、level0、level1和level2 + + - level_none:不采集所有Level层级控制的数据,即关闭profiler_level + - level0:采集上层应用数据、底层NPU数据以及NPU上执行的算子信息 + - level1:在level0的基础上多采集CANN层AscendCL数据和NPU上执行的AI + Core性能指标信息 + - level2:在level1的基础上多采集CANN层Runtime数据以及AI CPU + +- record_shapes:是否记录张量形状 +- with_memory:是否启用内存分析 +- with_npu:是否采集device侧性能数据 +- with_cpu:是否采集host侧性能数据 +- with_module:是否记录框架层python调用栈信息 +- with_stack:是否记录算子调用栈信息 +- analysis:是否自动解析数据 + +示例 +---- + +禁用采集 +~~~~~~~~ + +.. code:: yaml + + trainer: + profile_steps: null # disable profile + +端到端采集 +~~~~~~~~~~ + +.. code:: yaml + + trainer: + profile_steps: [1, 2, 5] + actor_rollout_ref: + profiler: + discrete: False + all_ranks: True + + +离散模式采集 +~~~~~~~~~~~~ + +.. code:: yaml + + trainer: + profile_steps: [1, 2, 5] + actor_rollout_ref: + profiler: + discrete: True + all_ranks: False + ranks: [0, 1] + + +可视化 +------ + +采集后的数据存放在用户设置的save_path下,可通过 `MindStudio Insight `_ 工具进行可视化。 + +如果analysis参数设置为False,采集之后需要进行离线解析: + +.. code:: python + + import torch_npu + # profiler_path请设置为"localhost.localdomain___ascend_pt"目录的上一级目录 + torch_npu.profiler.profiler.analyse(profiler_path=profiler_path) \ No newline at end of file diff --git a/docs/ascend_tutorial/ascend_profiling_en.rst b/docs/ascend_tutorial/ascend_profiling_en.rst new file mode 100644 index 000000000..3ab067ae2 --- /dev/null +++ b/docs/ascend_tutorial/ascend_profiling_en.rst @@ -0,0 +1,109 @@ +Data collection based on FSDP (Fully Sharded Data Parallel) backend on Ascend devices(NPU) +========================================================================================== + +Last updated: 07/14/2025. + +This is a tutorial for data collection using the GRPO or DAPO algorithm +based on FSDP on Ascend devices. + +Configuration +------------- + +Reuse the configuration items in +verl/trainer/config/ppo_trainer.yaml to control the collection mode +and steps, you can also manage the collection behaviors such as +collection level via verl/trainer/config/npu_profile/npu_profile.yaml. + +Global collection control +~~~~~~~~~~~~~~~~~~~~~~~~~ + +Use parameters in ppo_trainer.yaml to control the collection mode +and steps. + +- trainer.profile_steps: This parameter can be set as a list that has + collection steps, such as [2, 4], which means it will collect steps 2 + and 4. If set to null, no collection occurs. +- actor_rollout_ref.profiler: Control the ranks and mode of profiling + + - all_ranks: Collects data from all ranks when set to true. + - ranks: This parameter specifies which ranks to collect (e.g., [0, + 1]) when all_ranks is False. + - discrete: Controls the collection mode. If False, end-to-end data + is collected; if True, data is collected in discrete phases during + training. + +Use parameters in npu_profile.yaml to control collection behavior: + +- save_path: Storage path for collected data. +- level: Collection level—options are level_none, level0, level1, and + level2 + + - level_none: Disables all level-based data collection (turns off + profiler_level). + - level0: Collect high-level application data, underlying NPU data, + and operator execution details on NPU. + - level1: Extends level0 by adding CANN-layer AscendCL data and AI + Core performance metrics on NPU. + - level2: Extends level1 by adding CANN-layer Runtime data and AI + CPU metrics. + +- record_shapes: Whether to record tensor shapes. +- with_memory: Whether to enable memory analysis. +- with_npu: Whether to collect device-side performance data. +- with_cpu: Whether to collect host-side performance data. +- with_module: Whether to record framework-layer Python call stack + information. +- with_stack: Whether to record operator call stack information. +- analysis: Enables automatic data parsing. + +Examples +-------- + +Disabling collection +~~~~~~~~~~~~~~~~~~~~ + +.. code:: yaml + + trainer: + profile_steps: null # disable profile + +End-to-End collection +~~~~~~~~~~~~~~~~~~~~~ + +.. code:: yaml + + trainer: + profile_steps: [1, 2, 5] + actor_rollout_ref: + profiler: + discrete: False + all_ranks: True + + +Discrete Mode Collection +~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code:: yaml + + trainer: + profile_steps: [1, 2, 5] + actor_rollout_ref: + profiler: + discrete: True + all_ranks: False + ranks: [0, 1] + + +Visualization +------------- + +Collected data is stored in the user-defined save_path and can be +visualized by using the `MindStudio Insight `_ tool. + +If the analysis parameter is set to False, offline parsing is required after data collection: + +.. code:: python + + import torch_npu + # Set profiler_path to the parent directory of the "localhost.localdomain___ascend_pt" folder + torch_npu.profiler.profiler.analyse(profiler_path=profiler_path) \ No newline at end of file diff --git a/docs/ascend_tutorial/ascend_quick_start.rst b/docs/ascend_tutorial/ascend_quick_start.rst index 0cd349bbf..589964328 100644 --- a/docs/ascend_tutorial/ascend_quick_start.rst +++ b/docs/ascend_tutorial/ascend_quick_start.rst @@ -1,6 +1,7 @@ verl x Ascend =================================== +Last updated: 06/17/2025. 我们在 verl 上增加对华为昇腾设备的支持。 @@ -9,7 +10,7 @@ verl x Ascend Atlas 200T A2 Box16 -Atlas 800T A2 +Atlas 900 A2 PODc 安装 @@ -46,13 +47,13 @@ vllm & vllm-ascend # for Atlas 200T A2 Box16 VLLM_TARGET_DEVICE=empty pip install -e . --extra-index https://download.pytorch.org/whl/cpu/ - # for Atlas 800T A2 + # for Atlas 900 A2 PODc VLLM_TARGET_DEVICE=empty pip install -e . .. code-block:: bash # vllm-ascend - git clone -b v0.7.3 --depth 1 https://github.com/vllm-project/vllm-ascend.git + git clone -b v0.7.3.post1 --depth 1 https://github.com/vllm-project/vllm-ascend.git cd vllm-ascend export COMPILE_CUSTOM_KERNELS=1 python setup.py install @@ -73,16 +74,24 @@ vllm & vllm-ascend +--------------+---------------+ | software | description | +--------------+---------------+ -| transformers | >= v4.52.0 | +| transformers | v4.52.4 | +--------------+---------------+ | flash_attn | not supported | +--------------+---------------+ | liger-kernel | not supported | +--------------+---------------+ +| tensordict | 0.8.3 (ARM) | ++--------------+---------------+ 1. 支持通过 transformers 使能 --flash_attention_2, transformers 需大于等于 4.52.0版本。 2. 不支持通过 flash_attn 使能 flash attention 加速。 3. 不支持 liger-kernel 使能。 +4. 针对 ARM 服务器,tensordict 要求 0.8.3,可在依赖安装完成后再手动安装 tensordict。 +5. 针对 x86 服务器,需要安装 cpu 版本的 torchvision。 + +.. code-block:: bash + + pip install torchvision==0.20.1+cpu --index-url https://download.pytorch.org/whl/cpu 快速开始 @@ -134,7 +143,7 @@ vllm & vllm-ascend actor_rollout_ref.ref.fsdp_config.param_offload=True \ algorithm.kl_ctrl.kl_coef=0.001 \ trainer.critic_warmup=0 \ - trainer.logger=['console'] \ + trainer.logger=console \ trainer.project_name='verl_grpo_example_gsm8k' \ trainer.experiment_name='qwen2_7b_function_rm' \ trainer.n_gpus_per_node=8 \ @@ -148,20 +157,23 @@ vllm & vllm-ascend 支持现状 ----------------------------------- -+-----------+----------------------+-------------+-------------------+----------------------+ -| algorithm | model | rewards mae | throughput ratio | hardware | -+-----------+----------------------+-------------+-------------------+----------------------+ -| GRPO | Qwen2.5-7B-instruct | 0.38% | 0.588 | Atlas 200T A2 Box16 | -+-----------+----------------------+-------------+-------------------+----------------------+ -| GRPO | Qwen2.5-32B-instruct | 0.30% | 0.685 | Atlas 200T A2 Box16 | -+-----------+----------------------+-------------+-------------------+----------------------+ - -目前支持 Qwen2.5 的 GRPO 训练,Qwen2.5-VL GRPO 训练在 vllm-ascend 的修复后支持,涉及到的issue为: - -1. `issues#809 `_ - -2. `issues#825 `_ - ++-----------+-------------------------+-------------+-------------------+----------------------+ +| algorithm | model | rewards mae | throughput ratio | hardware | ++-----------+-------------------------+-------------+-------------------+----------------------+ +| GRPO | Qwen2.5-7B-instruct | 0.38% | 0.588 | Atlas 200T A2 Box16 | ++-----------+-------------------------+-------------+-------------------+----------------------+ +| GRPO | Qwen2.5-32B-instruct | 0.30% | 0.685 | Atlas 200T A2 Box16 | ++-----------+-------------------------+-------------+-------------------+----------------------+ +| GRPO | Qwen2.5-VL-3B-instruct | 3.14% | 0.470 | Atlas 200T A2 Box16 | ++-----------+-------------------------+-------------+-------------------+----------------------+ +| GRPO | Qwen2.5-VL-7B-instruct | 3.30% | 0.380 | Atlas 200T A2 Box16 | ++-----------+-------------------------+-------------+-------------------+----------------------+ +| GRPO | Qwen2.5-VL-32B-instruct | 0.79% | 0.568 | Atlas 200T A2 Box16 | ++-----------+-------------------------+-------------+-------------------+----------------------+ +| DAPO | Qwen2.5-7B-instruct | 3.83% | pending | Atlas 200T A2 Box16 | ++-----------+-------------------------+-------------+-------------------+----------------------+ +| SFT-PEFT | Qwen2.5-0.5B-instruct | 0.06% | 0.305 | Atlas 900 A2 PODc | ++-----------+-------------------------+-------------+-------------------+----------------------+ 精度对比说明 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -189,4 +201,4 @@ Ascend npu 和 A100 分别取日志中前4个 step 的 "perf/throughput" 做平 声明 ----------------------------------- -verl中提供的ascend支持代码皆为参考样例,商业使用请通过官方正式途径沟通,谢谢。 \ No newline at end of file +verl中提供的ascend支持代码皆为参考样例,商业使用请通过官方正式途径沟通,谢谢。 diff --git a/docs/conf.py b/docs/conf.py index d2604dd4b..d405288ff 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -49,6 +49,7 @@ "sphinx.ext.autosummary", "sphinx.ext.autosectionlabel", "sphinx.ext.napoleon", + "sphinx.ext.viewcode", ] # Use Google style docstrings instead of NumPy docstrings. napoleon_google_docstring = True @@ -57,8 +58,8 @@ # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: source_suffix = { - '.rst': 'restructuredtext', - '.md': 'markdown', + ".rst": "restructuredtext", + ".md": "markdown", } # Add any paths that contain templates here, relative to this directory. @@ -93,3 +94,7 @@ html_js_files = [ "js/runllm-widget.js", ] + +exclude_patterns += ["README.md", "README_vllm0.7.md"] + +suppress_warnings = ["ref.duplicate", "ref.myst"] diff --git a/docs/examples/config.rst b/docs/examples/config.rst index e9cb89bb1..0f05c181b 100644 --- a/docs/examples/config.rst +++ b/docs/examples/config.rst @@ -3,6 +3,8 @@ Config Explanation =================== +Last updated: 06/18/2025. + ppo_trainer.yaml for RL FSDP Backend ------------------------------------- @@ -65,7 +67,9 @@ Data - ``data.truncation``: Truncate the input_ids or prompt length if they exceed max_prompt_length. Default is 'error', not allow exceed the max_prompt_length. The users should increase the max_prompt_length if - throwing the error. You can also set ``left`` and ``right``. + throwing the error. You can also set ``left``, ``right`` and ``middle``. + When ``middle`` is selected, the logic splits the allowed max length roughly in half + and keeps the head and tail of the sequence, effectively discarding the middle section. - ``data.image_key``: The field in the multi-modal dataset where the image is located. Default is 'images'. - ``data.trust_remote_code``: If the remote tokenizer has python file, we can use this field to allow @@ -137,7 +141,11 @@ Actor/Rollout/Reference Policy optimizer_offload: False fsdp_size: -1 checkpoint: - contents: ['model', 'optimizer', 'extra'] + # What to include in saved checkpoints + # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space + save_contents: ['model', 'optimizer', 'extra'] + # For more flexibility, you can specify the contents to load from the checkpoint. + load_contents: ${actor_rollout_ref.actor.checkpoint.save_contents} ref: fsdp_config: param_offload: False @@ -175,6 +183,7 @@ Actor/Rollout/Reference Policy engine_kwargs: # inference engine parameters vllm: swap_space: null # null means "use the engine default value" (usually 4 GB), setting it to, e.g., 32 means 32 GB + disable_mm_preprocessor_cache: False # disable preprocessor cache for multimodel models sglang: attention_backend: null # null means use the engine default value, available options: flashinfer, triton, flashmla @@ -187,6 +196,11 @@ Actor/Rollout/Reference Policy n: 1 do_sample: False # default eager for validation + agent: + custom_async_server: # Use custom async server implementation for rollout + path: null + name: null + **Common config for actor, rollout and reference model** - ``actor_rollout_ref.hybrid_engine``: Whether it's a hybrid engine, @@ -205,6 +219,17 @@ Actor/Rollout/Reference Policy activation offloading for the actor - ``actor_rollout_ref.model.trust_remote_code``: Whether to enable loading a remote code model +- ``actor_rollout_ref.model.use_fused_kernels``: Whether to use fused + kernels in the model. If set to True, the following parameters will be + used. + - ``actor_rollout_ref.model.fused_kernel_options.impl_backend``: The + implementation backend for fused kernels. Options: "triton" or + "torch". Default is "torch". + While in megatron, we only support "triton" as the + implementation backend, so there is no need for this option. +- ``actor_rollout_ref.model.use_remove_padding``: Whether to use remove + padding in the model. If set to True, the model will remove padding + tokens in the input_ids and response_ids. This helps a lot in improving model running efficiency. **Actor model** @@ -267,9 +292,11 @@ Actor/Rollout/Reference Policy - ``actor_rollout_ref.actor.checkpoint``: The configurations of checkpoint function in actor - - ``contents``: The contents to save in the checkpoint. By default, we save model, optimizer and extra information in the checkpoint. + - ``save_contents``: The contents to save in the checkpoint. By default, we save model, optimizer and extra information in the checkpoint. The extra information includes Rng states currently, FSDP supported lr_scheduler, and Megatron opt_param_scheduler will coming soon. - We do not store hf_model in checkpoint by default, but we provide a tool in `scripts/model_merge.py` to convert checkpoint format to hf format. + We do not store hf_model in checkpoint by default, but we provide a tool in ``scripts/model_merge.py`` to convert checkpoint format to hf format. + + - ``load_contents``: The contents to load in the checkpoint, you can specify different checkpoint loading contents. By default, it is the same with ``save_checkpoint``. **Reference Model** @@ -300,9 +327,6 @@ Reference model will be enabled when ``actor.use_kl_loss`` or/and ``algorithm.us - ``actor_rollout_ref.rollout.gpu_memory_utilization``: - - For vLLM v0.5.4 and v0.6.3: The proportion of the **remaining** GPU memory - allocated for kv cache after other models have initialized when using - vLLM. - For vLLM v0.7.0 and later: The fraction of **total** GPU memory to be used for the vLLM instance. - For SGLang: Corresponding to ``mem_fraction_static``, the fraction of the free GPU memory used for **static** memory like model weights and KV cache. @@ -331,6 +355,7 @@ Reference model will be enabled when ``actor.use_kl_loss`` or/and ``algorithm.us - ``actor_rollout_ref.rollout.engine_kwargs.vllm``: extra vllm engine args - ``swap_space``: swap space in GB used by the inference engine. Positive integer, e.g., ``32`` means 32 GB. ``null``: means not setting and using the engine default value (usually, e.g., 4 GB for vLLM) + - ``disable_mm_preprocessor_cache``: Whether to disable preprocessor cache for multimodel models. - ``actor_rollout_ref.rollout.engine_kwargs.sglang``: extra sglang engine args @@ -345,9 +370,9 @@ Reference model will be enabled when ``actor.use_kl_loss`` or/and ``algorithm.us token and continue generating tokens after the EOS token is generated. - ``actor_rollout_ref.rollout.free_cache_engine``: Offload the KVCache - after rollout generation stage. Default is True. When set to True, we - need to disable the usage of CUDAGraph (set ``enforce_eager`` to - True.) + after rollout generation stage. Default is True. When set to True, + for vllm v0.5.4 and v0.6.3, we need to disable the usage of CUDAGraph + (set ``enforce_eager`` to True.) - ``actor_rollout_ref.rollout.enforce_eager``: Whether to use CUDAGraph in vLLM generation. Default set to True to disable CUDAGraph. @@ -374,6 +399,42 @@ Reference model will be enabled when ``actor.use_kl_loss`` or/and ``algorithm.us .. note:: **NOTED**: In this config field, users only need to select from ``dummy_megatron``, ``dummy_dtensor``, ``dummy_hf`` for rollout initialization and our hybrid engine will select the corresponding weight loader (i.e., ``megatron``, ``dtensor``, ``hf``) during actor/rollout weight synchronization. + +Megatron Optimizer and Optimizer Parameter Scheduler +____________________________________________________ + +.. code:: yaml + + optim: + optimizer: adam + lr: 1e-6 + clip_grad: 1.0 + total_training_steps: -1 # must be override by program + lr_warmup_init: 0.0 # initial learning rate for warmup, default to 0.0 + lr_warmup_steps: -1 # Prioritized. Negative values mean delegating to lr_warmup_steps_ratio. + lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime + lr_decay_steps: null + lr_decay_style: constant # select from constant/linear/cosine/inverse_square_root + min_lr: 0.0 # minimum learning rate, default to 0.0 + weight_decay: 0.01 + weight_decay_incr_style: constant # select from constant/linear/cosine + lr_wsd_decay_style: exponential # select from constant/exponential/cosine + lr_wsd_decay_steps: null + use_checkpoint_opt_param_scheduler: False # use checkpoint optimizer parameter scheduler + + +Notice that there are some differences in APIs between Megatron optimizer and FSDP optimizer. + +- Megatron optimizer scheduler names the period after lr_warmup as lr_decay_steps, so the ``warmup_style`` actually means the style of lr decay after warmup. +- Megatron optimizer also support weight decay decay mechanism +- ``use_checkpoint_opt_param_scheduler`` determines whether to use the checkpoint optimizer parameter scheduler. If set to True, the optimizer parameter scheduler will be saved in the checkpoint and loaded from the checkpoint during resuming training. + +For learning rate decay, original Megatron pretrain default option of ``lr_decay_style`` is ``linear``, +meaning that the learning rate will be linearly decayed from the initial learning rate to ``min_lr`` within the +``lr_decay_steps``. However, in verl, to align with FSDP's default behavior, we set the default +``lr_decay_style`` to ``constant``, meaning that the learning rate will be kept constant after the warmup stage. + + Critic Model ~~~~~~~~~~~~ @@ -448,7 +509,7 @@ Algorithm horizon: 10000 target_kl: 0.1 -- ``gemma``: discount factor +- ``gamma``: discount factor - ``lam``: Trade-off between bias and variance in the GAE estimator - ``adv_estimator``: Support ``gae``, ``grpo``, ``reinforce_plus_plus``, ``reinforce_plus_plus_baseline``, ``rloo`` - ``use_kl_in_reward``: Whether to enable in-reward kl penalty. Default is False. @@ -477,7 +538,7 @@ Trainer val_before_train: True test_freq: 2 critic_warmup: 0 - default_hdfs_dir: ~/experiments/gsm8k/ppo/${trainer.experiment_name} # hdfs checkpoint path + default_hdfs_dir: null # hdfs checkpoint path default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} # local checkpoint path resume_mode: auto # or disable or resume_path if resume_from_path is set resume_from_path: null @@ -501,7 +562,7 @@ Trainer - ``trainer.resume_mode``: The mode of resuming training. Support ``disable``, ``auto`` and ``resume_path``. If set to ``auto`` as default, the program will automatically resume from the latest checkpoint in the - default_hdfs_dir. If set to ``resume_path``, the program will resume + ``default_local_dir``. If set to ``resume_path``, the program will resume from the path specified in ``resume_from_path``. - ``trainer.resume_from_path``: The path to resume training from. Only effective when ``resume_mode`` is set to ``resume_path``. diff --git a/docs/examples/gsm8k_example.rst b/docs/examples/gsm8k_example.rst index 17eb76aca..02d1a526c 100644 --- a/docs/examples/gsm8k_example.rst +++ b/docs/examples/gsm8k_example.rst @@ -1,6 +1,8 @@ GSM8K Example ============= +Last updated: 03/25/2025. + Introduction ------------ @@ -90,11 +92,10 @@ We also provide various training scripts for SFT on GSM8K dataset in `gsm8k sft data.response_key=answer \ data.micro_batch_size_per_gpu=8 \ model.partial_pretrain=deepseek-ai/deepseek-coder-6.7b-instruct \ - trainer.default_hdfs_dir=hdfs://user/verl/experiments/gsm8k/deepseek-coder-6.7b-instruct/ \ trainer.project_name=gsm8k-sft \ trainer.experiment_name=gsm8k-sft-deepseek-coder-6.7b-instruct \ trainer.total_epochs=4 \ - trainer.logger=['console','wandb'] + trainer.logger='["console","wandb"]' If you use AMD GPUs (ROCm kernel), you need to add the following environment variables into the run script: @@ -168,7 +169,7 @@ The script of run_deepseek7b_llm.sh critic.model.fsdp_config.optimizer_offload=False \ algorithm.kl_ctrl.kl_coef=0.001 \ trainer.critic_warmup=0 \ - trainer.logger=['console','wandb'] \ + trainer.logger='["console","wandb"]' \ trainer.project_name='verl_example_gsm8k' \ trainer.experiment_name='deepseek_llm_7b_function_rm' \ trainer.n_gpus_per_node=8 \ diff --git a/docs/examples/multi_modal_example.rst b/docs/examples/multi_modal_example.rst index 6e7348846..844005b66 100644 --- a/docs/examples/multi_modal_example.rst +++ b/docs/examples/multi_modal_example.rst @@ -1,6 +1,8 @@ Multi-Modal Example Architecture ================================= +Last updated: 04/28/2025. + Introduction ------------ diff --git a/docs/examples/ppo_code_architecture.rst b/docs/examples/ppo_code_architecture.rst index 1ad7c9658..94d62413a 100644 --- a/docs/examples/ppo_code_architecture.rst +++ b/docs/examples/ppo_code_architecture.rst @@ -1,6 +1,8 @@ PPO Example Architecture ======================== +Last updated: 02/17/2025. + Let's start with the Proximal Policy Optimization algorithm, which is most widely used algorithm in LLM post-training. @@ -46,8 +48,8 @@ Define worker classes .. code:: python - if config.actor_rollout_ref.actor.strategy == 'fsdp': # for FSDP backend - assert config.actor_rollout_ref.actor.strategy == config.critic.strategy + if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}: # for FSDP backend + assert config.critic.strategy in {"fsdp", "fsdp2"} from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker from verl.single_controller.ray import RayWorkerGroup ray_worker_group_cls = RayWorkerGroup diff --git a/docs/examples/sandbox_fusion_example.rst b/docs/examples/sandbox_fusion_example.rst index b1040f28d..f3359efda 100644 --- a/docs/examples/sandbox_fusion_example.rst +++ b/docs/examples/sandbox_fusion_example.rst @@ -1,6 +1,8 @@ Sandbox Fusion Example ============================ +Last updated: 06/27/2025. + Introduction ------------ @@ -35,6 +37,7 @@ To integrate Sandbox Fusion into your training script, configure the following p - ``reward_model.sandbox_fusion.url=''``: Enable Sandbox Fusion by specifying the API endpoint (must end with ``/run_code``). - ``reward_model.sandbox_fusion.max_concurrent=256``: Set the maximum number of concurrent API requests to the Sandbox Fusion service. +- ``reward_model.sandbox_fusion.memory_limit_mb=1024``: Set the memory limit (in MB) for each sandbox instance. Defaults to 1024MB if not specified. **Additional Optimization** diff --git a/docs/faq/faq.rst b/docs/faq/faq.rst index 4bd626491..328ad6eb7 100644 --- a/docs/faq/faq.rst +++ b/docs/faq/faq.rst @@ -1,6 +1,8 @@ Frequently Asked Questions ==================================== +Last updated: 06/25/2025. + Ray related ------------ @@ -100,19 +102,12 @@ Solution 2nd: Illegal memory access --------------------------------- -If you encounter the error message like ``CUDA error: an illegal memory access was encountered`` during rollout, most likely it is due to a known issue from vllm(<=0.6.3). -Please set the following environment variable. The env var must be set before the ``ray start`` command if any. - -.. code:: bash - - export VLLM_ATTENTION_BACKEND=XFORMERS - -If in doubt, print this env var in each rank to make sure it is properly set. +If you encounter the error message like ``CUDA error: an illegal memory access was encountered`` during rollout, please check the vLLM documentation for troubleshooting steps specific to your vLLM version. Checkpoints ------------------------ -If you want to convert the model checkpoint into huggingface safetensor format, please refer to ``scripts/model_merger.py``. +If you want to convert the model checkpoint into huggingface safetensor format, please refer to ``verl/model_merger``. Triton ``compile_module_from_src`` error @@ -155,6 +150,23 @@ https://excalidraw.com/#json=pfhkRmiLm1jnnRli9VFhb,Ut4E8peALlgAUpr7E5pPCA .. image:: https://github.com/user-attachments/assets/16aebad1-0da6-4eb3-806d-54a74e712c2d +How to generate ray timeline to analyse performance of a training job? +------------------------------------------------------------------------------------------ + +To generate the ray timeline file, you can set the config term ``ray_init.timeline_file`` to a json file path. +For example: + +.. code:: bash + + ray_init.timeline_file=/tmp/ray_timeline.json + +The file will be generated in the specified path at the end of a training job. +You can use tools like chrome://tracing or the Perfetto UI and view the ray timeline file. + +This figure shows the ray timeline file generated by from a training job on 1 node with 4 GPUs + +.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/ray_timeline.png?raw=true + How to set proxy only for wandb? ------------------------------------------------------------------------------------------ @@ -163,4 +175,4 @@ Comparing to using global https_proxy env variable, this approach won't mess up .. code:: bash - +trainer.wandb_proxy=http:// + +trainer.wandb_proxy=http:// \ No newline at end of file diff --git a/docs/hybrid_flow.rst b/docs/hybrid_flow.rst index 14694846c..3aa5a4a97 100644 --- a/docs/hybrid_flow.rst +++ b/docs/hybrid_flow.rst @@ -2,6 +2,8 @@ HybridFlow Programming Guide ========================================================= +Last updated: 06/02/2025. + .. _vermouth: https://github.com/vermouth1992 Author: `Chi Zhang `_ @@ -115,7 +117,7 @@ Each worker inside the WorkerGroup runs on a GPU. The worker group serves as a p For example, in PPO, we define 3 worker groups: -- ActorRolloutRef: manages actor, rollout and reference policy. ActorRolloutRefWorker can be instantiated as a single actor, a single rollout, a single reference policy, a combined actor/rollout or a combined actor/rollout/ref. This design is aimed for the maximum code reuse in various scenarios. The reason for colocating actor and rollout is for fast weight transfer using nccl. The reason for coloating actor and reference is to implement an efficient lora PPO as the reference policy is simply the base model of PPO in lora. +- ActorRolloutRef: manages actor, rollout and reference policy. ActorRolloutRefWorker can be instantiated as a single actor, a single rollout, a single reference policy, a combined actor/rollout or a combined actor/rollout/ref. This design is aimed for the maximum code reuse in various scenarios. The reason for colocating actor and rollout is for fast weight transfer using nccl. The reason for coloating actor and reference is to implement an efficient lora PPO as the reference policy is simply the base model of PPO in lora. The colocation is done via ``verl.single_controller.ray.base.create_colocated_worker_cls``, where it creates a single ray remote class exposing all class methods from these roles. - Critic: manages the critic model - Reward: manages the reward model @@ -252,12 +254,7 @@ Important code files in the repository are organized as below: weight_loader_registery.py # registry of weight loaders for loading hf ckpt into Megatron third_party vllm # adaptor for vllm's usage in RL - vllm_v_0_6_3 # vllm v0.6.3 adaptor - llm.py # entrypoints for generate, sync_model_weight, offload_model_weights - parallel_state.py # vllm related device mesh and process groups - dtensor_weight_loaders.py # weight loader for huggingface models with FSDP - megatron_weight_loaders.py # weight loader for Megatron models - vllm_spmd # vllm >= v0.7 adaptor (coming soon) + vllm_spmd # vllm >= v0.7 adaptor examples # example scripts tests # integration and unit tests .github # the configuration of continuous integration tests diff --git a/docs/index.rst b/docs/index.rst index 201be82b5..980066a7f 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -32,6 +32,8 @@ verl is fast with: start/quickstart start/multinode start/ray_debug_tutorial + start/more_resources + start/agentic_rl .. toctree:: :maxdepth: 2 @@ -70,10 +72,12 @@ verl is fast with: algo/dapo.md algo/spin.md algo/sppo.md + algo/entropy.md algo/opo.md algo/baseline.md + algo/gpg.md -.. toctree:: +.. toctree:: :maxdepth: 1 :caption: PPO Trainer and Workers @@ -85,10 +89,12 @@ verl is fast with: .. toctree:: :maxdepth: 1 :caption: Performance Tuning Guide - + + perf/dpsk.md perf/perf_tuning README_vllm0.8.md perf/device_tuning + perf/nsight_profiling.md .. toctree:: :maxdepth: 1 @@ -105,9 +111,21 @@ verl is fast with: advance/rope advance/ppo_lora.rst sglang_multiturn/multiturn.rst + sglang_multiturn/interaction_system.rst advance/placement advance/dpo_extension examples/sandbox_fusion_example + advance/rollout_trace.rst + +.. toctree:: + :maxdepth: 1 + :caption: Hardware Support + + amd_tutorial/amd_build_dockerfile_page.rst + amd_tutorial/amd_vllm_page.rst + ascend_tutorial/ascend_quick_start.rst + ascend_tutorial/ascend_profiling.rst + ascend_tutorial/ascend_profiling_en.rst .. toctree:: :maxdepth: 1 @@ -125,6 +143,12 @@ verl is fast with: faq/faq +.. toctree:: + :maxdepth: 1 + :caption: Development Notes + + sglang_multiturn/sandbox_fusion.rst + Contribution ------------- diff --git a/docs/perf/device_tuning.rst b/docs/perf/device_tuning.rst index 3b5933805..567683b3b 100644 --- a/docs/perf/device_tuning.rst +++ b/docs/perf/device_tuning.rst @@ -1,5 +1,7 @@ -Resource Needed for verl RL -============================== +Hardware Resource Needed for RL +=============================== + +Last updated: 06/25/2025. Since RL requires more resources compared to regular training, determining how much resources are needed to successfully run it before training @@ -8,7 +10,7 @@ resource selection when dealing with different models and tasks, this section is mainly dedicated to introducing the environmental requirements based on experiments we have conducted. -However, due to limited manpower and equipment resources, we also hope for more +However, due to limited staff and equipment resources, we also hope for more contributions from the open-source community. When submitting a PR, it is necessary to provide a script to be added to the example/tuning scripts. @@ -27,6 +29,84 @@ a PR and include a screenshot from Wandb or other verifiable evidence. ---------------------------------------- +0.5B +~~~ + +.. list-table:: + :widths: auto + :header-rows: 1 + + * - Tag + - Model + - Task + - Resource + - MaxBatch + - Train + - Infer + - Link + - Contributor + * - MIN + - Qwen2.5-0.5B + - GRPO-LoRA + - 1*H100 + - 116 + - fsdp + - vllm0.8.3 + - `qwen2-0.5b_grpo-lora_1_h100_fsdp_vllm.sh `_ + - `SimonHuang `_ + +1.5B +~~~ + +.. list-table:: + :widths: auto + :header-rows: 1 + + * - Tag + - Model + - Task + - Resource + - MaxBatch + - Train + - Infer + - Link + - Contributor + * - MIN + - Qwen2.5-1.5B + - GRPO-LoRA + - 1*H100 + - 128 + - fsdp + - vllm0.8.3 + - `qwen2-1.5b_grpo-lora_1_h100_fsdp_vllm.sh `_ + - `SimonHuang `_ + +3B +~~~ + +.. list-table:: + :widths: auto + :header-rows: 1 + + * - Tag + - Model + - Task + - Resource + - MaxBatch + - Train + - Infer + - Link + - Contributor + * - MIN + - Qwen2.5-3B + - GRPO-LoRA + - 1*H100 + - 62 + - fsdp + - vllm0.8.3 + - `qwen2-3b_grpo-lora_1_h100_fsdp_vllm.sh `_ + - `SimonHuang `_ + 7B ~~~ @@ -38,6 +118,7 @@ a PR and include a screenshot from Wandb or other verifiable evidence. - Model - Task - Resource + - MaxBatch - Train - Infer - Link @@ -46,11 +127,20 @@ a PR and include a screenshot from Wandb or other verifiable evidence. - Qwen2-7B - GRPO - 2*H800 + - \ - fsdp - vllm0.8.2 - `qwen2-7b_grpo_2_h800_fsdp_vllm `_ - `Xiangyongan `_ - + * - MIN + - Qwen2.5-7B + - GRPO-LoRA + - 1*H100 + - 16 + - fsdp + - vllm0.8.3 + - `qwen2-7b_grpo-lora_1_h100_fsdp_vllm.sh `_ + - `SimonHuang `_ 14B ~~~ @@ -63,6 +153,7 @@ a PR and include a screenshot from Wandb or other verifiable evidence. - Model - Task - Resource + - MaxBatch - Train - Infer - Link @@ -71,11 +162,20 @@ a PR and include a screenshot from Wandb or other verifiable evidence. - Qwen2-14B - GRPO - 4*H800 + - \ - fsdp - vllm0.8.2 - `qwen2-14b_grpo_4_h800_fsdp_vllm `_ - `Xiangyongan `_ - + * - MIN + - Qwen2.5-14B + - GRPO-LoRA + - 2*H100 + - 116 + - fsdp + - vllm0.8.3 + - `qwen2-14b_grpo-lora_2_h100_fsdp_vllm.sh `_ + - `SimonHuang `_ 32B ~~~ @@ -88,6 +188,7 @@ a PR and include a screenshot from Wandb or other verifiable evidence. - Model - Task - Resource + - MaxBatch - Train - Infer - Link @@ -96,10 +197,20 @@ a PR and include a screenshot from Wandb or other verifiable evidence. - Qwen2-32B - GRPO - 8*H20 + - \ - megatron - vllm0.8.2 - `qwen2-32b_grpo_8_h20_megatron_vllm `_ - `Xiangyongan `_ + * - MIN + - Qwen2.5-32B + - GRPO-LoRA + - 4*H100 + - 180 + - fsdp + - vllm0.8.3 + - `qwen2-32b_grpo-lora_4_h100_fsdp_vllm.sh `_ + - `SimonHuang `_ 70B ~~~ @@ -112,6 +223,7 @@ a PR and include a screenshot from Wandb or other verifiable evidence. - Model - Task - Resource + - MaxBatch - Train - Infer - Link @@ -120,6 +232,7 @@ a PR and include a screenshot from Wandb or other verifiable evidence. - Qwen2-70B - GRPO - 32*H20 + - \ - fsdp - vllm0.8.2 - `qwen2-70b_grpo_32_h20_fsdp_vllm `_ @@ -128,10 +241,20 @@ a PR and include a screenshot from Wandb or other verifiable evidence. - Qwen2-70B - GRPO - 32*H800 + - \ - fsdp - vllm0.8.3 - `qwen2-70b_grpo_32_h800_fsdp_vllm `_ - `Xiangyongan `_ + * - MIN + - Qwen2.5-72B + - GRPO-LoRA + - 8*H100 + - 176 + - fsdp + - vllm0.8.3 + - `qwen2-72b_grpo-lora_8_h100_fsdp_vllm.sh `_ + - `SimonHuang `_ 405B ~~~~ @@ -139,12 +262,11 @@ a PR and include a screenshot from Wandb or other verifiable evidence. .. table:: :widths: auto - ====== ====== ====== ======== ====== ====== ====== - tag model task resource train infer link - ====== ====== ====== ======== ====== ====== ====== - \ \ \ \ \ \ - ====== ====== ====== ======== ====== ====== ====== - + ====== ====== ====== ======== ======== ====== ====== ====== + tag model task resource MaxBatch train infer link + ====== ====== ====== ======== ======== ====== ====== ====== + \ \ \ \ \ \ \ + ====== ====== ====== ======== ======== ====== ====== ====== 671B ~~~~ @@ -152,8 +274,8 @@ a PR and include a screenshot from Wandb or other verifiable evidence. .. table:: :widths: auto - ====== ====== ====== ======== ====== ====== ====== - tag model task resource train infer link - ====== ====== ====== ======== ====== ====== ====== - \ \ \ \ \ \ - ====== ====== ====== ======== ====== ====== ====== + ====== ====== ====== ======== ======== ====== ====== ====== + tag model task resource MaxBatch train infer link + ====== ====== ====== ======== ======== ====== ====== ====== + \ \ \ \ \ \ \ + ====== ====== ====== ======== ======== ====== ====== ====== diff --git a/docs/perf/dpsk.md b/docs/perf/dpsk.md new file mode 100644 index 000000000..0a3b42a11 --- /dev/null +++ b/docs/perf/dpsk.md @@ -0,0 +1,51 @@ +# Training DeepSeek 671b + +Last updated: 06/13/2025. + +verl integrates Megatron to support large MoE models such as `Qwen3-235B-A22B` and `deepseek-ai/DeepSeek-V3`. This is an ongoing community effort. + +In the journey the community added the following features and optimizations that enable verl with larger models: +- per tensor weight resharding between rollout and training +- context parallelism and expert parallelism enabled via megatron +- dynamic batch size (sequence balance) for megatron +- reduced ray-related serialization overhead +- optimizer offloading, recomputation, and efficient kernels +- various debugging metrics and utils + +and the megatron backend now has a wider list of models supported: +- DeepSeek-V3 +- Moonlight +- Qwen3 +- Qwen2.5-VL (to be merged soon) +- Qwen2 +- Mixtral + +## Getting Started + +### DeepSeek 671b + +The recommended image with pre-built megatron dependency is `whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6.post5-mcore0.12.1-te2.3-deepseekv3`, built with the Dockerfile in [docker/Dockerfile.vllm.sglang.megatron.deepseek](https://github.com/volcengine/verl/blob/main/docker/Dockerfile.vllm.sglang.megatron.deepseek). + +For checkpoint loading, we rely on megatron dist-ckpt for resharding. A converted dist-ckpt for DeepSeek-V3 is available from [huggingface BearBiscuit05/dpsk-v3-671B-BF16-dist_ckpt](https://huggingface.co/BearBiscuit05/dpsk-v3-671B-BF16-dist_ckpt/tree/main). + +To run end-to-end training on the DAPO dataset, run [recipe/dapo/test_dapo_dspk_671b_megatron.sh](https://github.com/volcengine/verl/blob/main/recipe/dapo/test_dapo_dspk_671b_megatron.sh). It runs on 512 H20(96GB) GPUs with the following setup: +- vllm rollout with TP=32, bfloat16 +- megatron training with attention DP, MoE EP=32, PP=16, bfloat16 + +MTP is disabled during RL training. + +### Qwen3 236b + +For Qwen3-236b, please refer to [examples/grpo_trainer/run_qwen3-236b_megatron.sh](https://github.com/volcengine/verl/blob/main/examples/grpo_trainer/run_qwen3-236b_megatron.sh), which runs on 128 H20(96GB) GPUs. + +## Upcoming Optimizations + +The community continue to optimize large MoE models further, ongoing efforts include: +- further optimizing memory consumption, and provide recommended/tuned configurations with various machine types +- optimizing long context RL training performance +- performance improvement with SGLang x Megatron + +We invite the community to try and improve verl together. Get connected with us on [slack](https://join.slack.com/t/verlgroup/shared_invite/zt-2w5p9o4c3-yy0x2Q56s_VlGLsJ93A6vA)/[wechat](https://raw.githubusercontent.com/eric-haibin-lin/verl-community/refs/heads/main/WeChat.JPG)/[Github issues](https://github.com/volcengine/verl/issues/708)! + +## Acknowledgement +@vermouth1992 @ISEEKYAN @ETOgaosion @yzlnew @ShareLer @BearBiscuit05 @ccclyu @ann-qin-lu @SwordFaith @zzong2006 @zhaochenyang20 @ocss884 @eric-haibin-lin diff --git a/docs/perf/nsight_profiling.md b/docs/perf/nsight_profiling.md new file mode 100644 index 000000000..ed083c38e --- /dev/null +++ b/docs/perf/nsight_profiling.md @@ -0,0 +1,107 @@ +# NVIDIA Nsight Systems profiling in verl + +Last updated: 06/20/2025. + +This guide explains how to use NVIDIA Nsight Systems for profiling verl training runs. + +## Configuration + +Profiling in verl can be configured through several parameters in the trainer configuration file (ppo_trainer.yaml or other files like dapo_trainer.yaml): + +### Prerequisites + +Nsight Systems version is important, please reference `docker/Dockerfile.vllm.sglang.megatron` for the version we used. + +### Global profiling control + +verl has one single controller process and multiple worker processes. Both controller and worker processes can be profiled. Since the controller process can be executed in any nodes in the cluster, there is a message printed in the logging to indicate the controller process node hostname and process id. + +In `trainer`, three new config entries control the profiler behaviors: + +* **`trainer.profile_steps`**. List of step numbers at which profiling should be performed. For example: [1, 2, 5] will profile steps 1, 2, and 5. And ``null`` means no profiling. + + +* **`controller_nsight_options`**. This config group is for the single controller. All fields in this config group will be just sent to Nsight Systems when Ray starts the controller process. `ppo_trainer.yaml` provides a workable example. Users can reference [Nsight Systems manual](https://docs.nvidia.com/nsight-systems/UserGuide/index.html) and [Ray user guide](https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html) for more details. + +* **`worker_nsight_options`**. This config group is for the worker processes. Similarly all fields in this config group will be just sent to Nsight Systems when Ray starts the controller process. Capture range is used to control the profiler when to start and stop. So `capture-range: "cudaProfilerApi"` is fixed and does not change it. Users can change `capture-range-end` with some accurate calculation or just leave it `null`. + +### Worker process profiling + +Verl manages mulitiple RL roles, _Actor_, _Ref_, _Rollout_, _Critic_, _Reward_, which are implemented in different Worker classes. And these workers can be combined into one Ray Actor, running in a process group. Each RL role has its own profiling config group, `profiler`, which consists of three fields: + +* **`all_ranks` and `ranks`**. When `all_ranks` is set `True` then all ranks will be profiled; when set `False`, `ranks` will be profiled. By default, verl profiles the whole training process in a series ` worker_process_..nsys-rep` files for each process rank. PID is the process ID; RID is the capture range ID. + +* **`discrete`**. When set `False`, all the roles actions in one training step will be dumped in one database. When set `True`, the actions annotated by `DistProfiler.annotate` will be dumped into a discrete database. In this case, each role's action occupies one ``. + +* **`actor_rollout_ref`**. This Worker can be configured to contain at most 3 roles and executes together. So `actor_rollout_ref` has a `profiler` config and all the inside roles inherit it. + +* **Verl collocate mode**. Verl can combine two Worker sub classes to one Worker Actor. In this case, the user should take care that the combined Workers have consistent `discrete`. The Nsight Systems profiler uses a `torch.cuda.profiler.start()` and `stop()` pair to dump a `` database anyway. + +### where to find the profiling data + +By default the `*.nsys-rep` files are saved in the directory `/tmp/ray/session_latest/logs/nsight/` at each node. According to the Ray manual, this default directory is not changeable. ["however, Ray preserves the `--output` option of the default config"](https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html). + +Some users may think it is not convenient, but it is understandable that Ray may start hundreds of processes and it would be a big network file system pressure if we save the files in one central place. + +## Usage Example + +To enable profiling for specific components and steps, modify your ppo_trainer.yaml like this: + +### Disable profiler +```yaml + trainer: + profile_steps: null # disable profile +``` + +### Enable profiler and one database for one training step +```yaml + trainer: + profile_steps: [1, 2, 5] + actor_rollout_ref: + profiler: + discrete: False + all_ranks: False + ranks: [0, 1] + critic: + profiler: + discrete: False + all_ranks: False + ranks: [0, 1] + reward_model: + profiler: + discrete: False + all_ranks: False + ranks: [0, 1] +``` + +### Enable profiler and multiple databases for one training step +```yaml + trainer: + profile_steps: [1, 2, 5] + actor_rollout_ref: + profiler: + discrete: True + all_ranks: False + ranks: [0, 1] + critic: + profiler: + discrete: True + all_ranks: False + ranks: [0, 1] + reward_model: + profiler: + discrete: True + all_ranks: False + ranks: [0, 1] +``` + +## Profiling Output + +When profiling is enabled, verl will generate Nsight Systems profiles for the specified components and steps. The profiles will include: + +- CUDA kernel execution +- Memory operations +- CPU-GPU synchronization +- NVTX markers for key operations + +Nsight Systems supports multi-report view, to open multiple databases together. In this mode, different processes and steps can be aligned in one time line for better analysis. diff --git a/docs/perf/perf_tuning.rst b/docs/perf/perf_tuning.rst index bab3dc29d..58df6ce13 100644 --- a/docs/perf/perf_tuning.rst +++ b/docs/perf/perf_tuning.rst @@ -1,7 +1,9 @@ Performance Tuning Guide ============================== -Author: `Guangming Sheng `_ +Last updated: 06/23/2025. + +Author: `Guangming Sheng `_, `Jiali Zheng `_ In this section, we will discuss how to tune the performance of all the stages in verl, including: @@ -17,6 +19,10 @@ In this section, we will discuss how to tune the performance of all the stages i 6. LigerKernel for SFT performance optimization +7. Forward prefetch in FSDP training backend + +8. Memory optimization for entropy calculation from logits + Rollout Generation Tuning -------------------------- @@ -26,7 +32,6 @@ Below are key factors for tuning vLLM-based rollout. Before tuning, we recommend - Increase ``gpu_memory_utilization``. - - For vLLM v0.5.4 and v0.6.3, the vLLM pre-allocates GPU KVCache by using gpu_memory_utilization of the **remaining** memory. - For vLLM v0.7.0 and later, the vLLM instance will only use gpu_memory_utilization of the **total** memory. - For SGLang, it's the fraction of the free GPU memory used for **static** memory like model weights and KV cache. However, the remaining (1-gpu_memory_utilization) will also be used during inference. @@ -49,7 +54,7 @@ Below are key factors for tuning vLLM-based rollout. Before tuning, we recommend More tuning details such as dealing with Preemption and Chunked-prefill can be found in `vLLM official tuning guide `_ -The performance of vllm can be further increased if upgrading from v0.6.3 to v0.7. See https://github.com/volcengine/verl/blob/main/docs/README_vllm0.7.md for details on how to upgrade. +For optimal performance, we recommend using vLLM v0.8.3 or later. See https://github.com/volcengine/verl/blob/main/docs/README_vllm0.8.md for details. Enable remove padding (sequence packing) ----------------------------------------- @@ -172,3 +177,23 @@ LigerKernel is a high-performance kernel for Supervised Fine-Tuning (SFT) that c 3. LigerKernel is particularly useful for improving training performance in SFT scenarios. +Forward prefetch in FSDP training backend +---------------------- + +During the training phase, users can enable forward prefetching in FSDP by setting ``fsdp_config.forward_prefetch=True``. For example, ``actor_rollout_ref.actor.fsdp_config.forward_prefetch=True``. This configuration prefetches the next forward-pass all-gather operation before completing the current forward computation, overlapping communication with computation and improving efficiency. For further details, refer to the `FSDP forward_pefetch `_ documentation. + +.. note:: + Backward prefetch is unsupported because the ``BACKWARD_POST`` policy may prefetch incorrectly in nested-module cases. For details, see the `FSDP documentation `_ + +Memory optimization for entropy calculation from logits +---------------------- + +The ``logits`` tensor (typically of shape ``[bsz*seq_len, voc]``) can consume significant memory. When using ``compute_entropy_from_logits``, memory usage reaches approximately ``[bsz*seq_len, voc] × (4 bytes (float32) + 2 bytes (autocast for softmax+logsumexp) + 1 byte (softmax output))``. + +To reduce this memory peak, enable chunked computation by setting: +``actor_rollout_ref.ref.entropy_from_logits_with_chunking = True`` +This processes the tensor in chunks of shape ``[chunk_size, voc]`` (e.g., 2048) rather than the full sequence length, exclusively during the model's forward pass. + +Additionally, during training, standard gradient checkpointing (``enable_gradient_checkpointing=True``) does not apply to entropy calculations. To reduce memory peaks in this context, set: +``actor_rollout_ref.actor.entropy_checkpointing = True`` +This enables entropy recomputation specifically for the entropy calculation, lowering memory usage during training. diff --git a/docs/preparation/prepare_data.rst b/docs/preparation/prepare_data.rst index de88d8d91..312352826 100644 --- a/docs/preparation/prepare_data.rst +++ b/docs/preparation/prepare_data.rst @@ -1,6 +1,8 @@ Prepare Data for Post-Training ======================================== +Last updated: 02/09/2025. + Before starting the post-training job, we need to prepare the data for the policy training. The data should be stored in the parquet format. diff --git a/docs/preparation/reward_function.rst b/docs/preparation/reward_function.rst index 426340783..286e2aff4 100644 --- a/docs/preparation/reward_function.rst +++ b/docs/preparation/reward_function.rst @@ -1,6 +1,8 @@ Implement Reward Function for Dataset ====================================== +Last updated: 06/02/2025. + For each dataset, we need to implement a reward function or utilize a reward model to compute the rewards for the generated responses. We already pre-implemented some reward functions in `reward_score directory `_. You can also use customized reward functions. diff --git a/docs/requirements-docs.txt b/docs/requirements-docs.txt index d1c4d014a..55ccdb8f7 100644 --- a/docs/requirements-docs.txt +++ b/docs/requirements-docs.txt @@ -10,4 +10,4 @@ sphinx-markdown-tables sphinx-rtd-theme # pin tokenizers version to avoid env_logger version req -tokenizers==0.19.1 +tokenizers==0.21 diff --git a/docs/sglang_multiturn/interaction_system.rst b/docs/sglang_multiturn/interaction_system.rst new file mode 100644 index 000000000..26b3db91e --- /dev/null +++ b/docs/sglang_multiturn/interaction_system.rst @@ -0,0 +1,419 @@ +Interaction System for Multi-turn RL Training +============================================= + +Last updated: 06/25/2025. + +Overview +-------- + +The verl interaction system enables dynamic, multi-turn conversational feedback during reinforcement learning training. This system allows models to engage in iterative problem-solving scenarios where interaction agents can provide corrective feedback, guidance, or evaluation based on the model's responses. + +**New in Multi-Interaction Support**: The system now supports multiple named interactions within a single training session, enabling sophisticated training scenarios where different samples can use different interaction strategies. This allows for curriculum learning, domain-specific feedback, and flexible agent switching at the sample level. + +Key features: + +- **Async-based Architecture**: Non-blocking interaction processing for distributed training +- **Instance Management**: Stateful session handling with unique instance IDs for concurrent interactions +- **SGLang Integration**: Seamless integration with SGLang rollout system for multi-turn conversations +- **Configuration-driven**: Dynamic agent loading via YAML configuration files +- **Multi-Interaction Support**: Registry system enabling multiple named interactions per rollout +- **Sample-Level Selection**: Each sample can specify which interaction to use via configuration +- **Reward Integration**: Turn-level scoring mechanism integrated with verl's reward system + +Architecture +------------ + +The interaction system follows a plugin-based architecture with clear separation of concerns: + +.. code-block:: + + Interaction Registry System + ↓ + BaseInteraction (Abstract Interface) + ↓ + Multiple Named Interactions (e.g., Gsm8kInteraction, CustomInteraction) + ↓ + SGLang Rollout Integration (interaction_map) + ↓ + Sample-Level Interaction Selection + ↓ + Async Request Lifecycle Management + +Core Components +~~~~~~~~~~~~~~~ + +**Interaction Registry System** + +The interaction registry system allows loading and managing multiple named interactions: + +.. code-block:: python + + from verl.interactions.utils.interaction_registry import initialize_interactions_from_config + + # Load multiple interactions from config + interaction_map = initialize_interactions_from_config("config.yaml") + + # Access specific interaction by name + gsm8k_interaction = interaction_map["gsm8k"] + custom_interaction = interaction_map["custom_solver"] + +**BaseInteraction Interface** + +All interaction agents must implement the ``BaseInteraction`` abstract class: + +.. code-block:: python + + from verl.interactions.base import BaseInteraction + from typing import Dict, Any, List, Tuple, Optional + + class BaseInteraction: + def __init__(self, config: Dict[str, Any]): + self.config = config + self.name: str = config.get("name", "interaction_agent") + + async def start_interaction(self, instance_id: Optional[str] = None, **kwargs) -> str: + """Initialize interaction session, return instance_id""" + + async def generate_response(self, instance_id: str, messages: List[Dict[str, Any]], **kwargs) -> Tuple[bool, str, float, Dict[str, Any]]: + """Generate response, return (should_terminate, response, score, metadata)""" + + async def calculate_score(self, instance_id: str, **kwargs) -> float: + """Calculate turn-level score for RL training""" + + async def finalize_interaction(self, instance_id: str, **kwargs) -> None: + """Clean up resources""" + +**Request Lifecycle** + +The interaction system integrates with SGLang's async rollout via state management: + +1. ``PENDING`` → Initialize interaction via ``start_interaction()`` +2. ``GENERATING`` → Model generates response +3. ``INTERACTING`` → Process response via ``generate_response()`` +4. ``GENERATING`` → Continue if not terminated, otherwise ``COMPLETED`` + +Configuration +------------- + +**Basic Setup** + +Enable interaction in your rollout configuration: + +.. code-block:: yaml + + actor_rollout_ref: + rollout: + multi_turn: + enable: true + interaction_config_path: "path/to/interaction_config.yaml" + max_user_turns: 10 + max_assistant_turns: 10 + +**Interaction Configuration File** + +Create an interaction configuration file (e.g., ``interaction_config.yaml``): + +**Single Interaction (Legacy Format)** + +.. code-block:: yaml + + interaction: + - name: "gsm8k" + class_name: "verl.interactions.gsm8k_interaction.Gsm8kInteraction" + config: {} + +**Multiple Interactions (New Format)** + +.. code-block:: yaml + + interaction: + - name: "gsm8k" + class_name: "verl.interactions.gsm8k_interaction.Gsm8kInteraction" + config: {} + - name: "custom_solver" + class_name: "custom.interactions.CustomInteraction" + config: + solver_type: "advanced" + timeout: 30 + - name: "code_verifier" + class_name: "verl.interactions.base.BaseInteraction" + config: + verification_mode: "strict" + +**Automatic Name Generation** + +If no ``name`` field is provided, the system will automatically generate one from the class name: + +.. code-block:: yaml + + interaction: + - class_name: "verl.interactions.gsm8k_interaction.Gsm8kInteraction" + config: {} + # Automatically generates name: "gsm8k" + +The system will dynamically load all specified interaction classes and make them available by name. + +Implementation Example: GSM8K +----------------------------- + +The GSM8K interaction demonstrates a complete implementation for math problem-solving scenarios: + +.. code-block:: python + + from verl.interactions.base import BaseInteraction + from verl.utils.reward_score import gsm8k + from uuid import uuid4 + + class Gsm8kInteraction(BaseInteraction): + def __init__(self, config: dict): + super().__init__(config) + self._instance_dict = {} + + async def start_interaction(self, instance_id=None, ground_truth=None, **kwargs): + if instance_id is None: + instance_id = str(uuid4()) + self._instance_dict[instance_id] = { + "response": "", + "ground_truth": ground_truth, + "reward": 0.0, + } + return instance_id + + async def generate_response(self, instance_id, messages, **kwargs): + # Extract last user message content + content = "" + for item in reversed(messages): + if item.get("role") == "user": + content = item.get("content", "") + break + + # Ensure GSM8K format (#### prefix) + if content.startswith("#### "): + self._instance_dict[instance_id]["response"] = content + else: + self._instance_dict[instance_id]["response"] = "#### " + content + + reward = await self.calculate_score(instance_id) + if reward == 1.0: + return True, "Your response is correct!", 1.0, {} + else: + return False, "Your response is incorrect! You need to reflect on your answer and try again.", 0.0, {} + + async def calculate_score(self, instance_id, **kwargs): + return gsm8k.compute_score( + self._instance_dict[instance_id]["response"], + self._instance_dict[instance_id]["ground_truth"], + method="flexible", format_score=0.0, score=1.0, + ) + + async def finalize_interaction(self, instance_id, **kwargs): + del self._instance_dict[instance_id] + +Training Integration +-------------------- + +**Training Script Configuration** + +Include interaction configuration in your training command: + +.. code-block:: bash + + python3 -m verl.trainer.main_ppo \\ + --config-path="$CONFIG_PATH" \\ + --config-name='gsm8k_multiturn_grpo_w_interaction' \\ + algorithm.adv_estimator=grpo \\ + data.train_batch_size=512 \\ + data.return_raw_chat=True \\ + actor_rollout_ref.rollout.name=sglang \\ + actor_rollout_ref.rollout.multi_turn.interaction_config_path="$PROJECT_DIR/examples/sglang_multiturn/config/interaction_config/gsm8k_interaction_config.yaml" \\ + trainer.total_epochs=15 + +**Data Requirements** + +Ensure your dataset includes interaction parameters with the ``name`` field for interaction selection: + +.. code-block:: python + + # Dataset should include interaction_kwargs in non_tensor_batch + interaction_kwargs = [ + {"name": "gsm8k", "query": "What is 2+2?", "ground_truth": "4"}, + {"name": "custom_solver", "query": "Solve: x^2 + 5x + 6 = 0", "ground_truth": "x = -2, -3"}, + {"name": "gsm8k", "query": "What is 3+3?", "ground_truth": "6"}, + ] + +**Sample-Level Interaction Selection** + +Each sample can specify which interaction to use via the ``name`` field. This enables flexible training scenarios where different samples use different interaction strategies: + +.. code-block:: python + + # Example: Math problems use GSM8K interaction, code problems use code verifier + data_samples = [ + { + "prompt": "What is 15% of 200?", + "interaction_kwargs": { + "name": "gsm8k", + "query": "What is 15% of 200?", + "ground_truth": "30" + } + }, + { + "prompt": "Write a function to check if a number is prime", + "interaction_kwargs": { + "name": "code_verifier", + "code_type": "python", + "expected_behavior": "return True for prime numbers" + } + } + ] + +**Backward Compatibility** + +If no ``name`` field is provided in ``interaction_kwargs``, the system defaults to ``"gsm8k"`` for backward compatibility. + +Best Practices +-------------- + +**Resource Management** + +- Always implement proper cleanup in ``finalize_interaction()`` +- Use unique instance IDs to avoid conflicts in concurrent training +- Handle edge cases like empty messages or malformed content + +**Performance Optimization** + +- Keep interaction logic lightweight to avoid blocking training +- Use async/await properly to maintain non-blocking behavior +- Consider caching expensive computations within interaction instances + +**Testing** + +Comprehensive testing is essential for interaction systems: + +.. code-block:: python + + import pytest + from unittest.mock import patch + + @pytest.mark.asyncio + async def test_interaction_workflow(): + interaction = YourInteraction({}) + + # Test complete workflow + instance_id = await interaction.start_interaction(ground_truth="expected_answer") + + messages = [{"role": "user", "content": "user_response"}] + should_terminate, response, reward, metadata = await interaction.generate_response(instance_id, messages) + + assert should_terminate in [True, False] + assert isinstance(reward, float) + + await interaction.finalize_interaction(instance_id) + +Advanced Usage +-------------- + +**Multi-Interaction Training Strategies** + +You can design sophisticated training scenarios using multiple interactions: + +.. code-block:: python + + # Example: Progressive difficulty with different interaction agents + class MathTrainingPipeline: + def create_interaction_config(self): + return { + "interaction": [ + { + "name": "basic_math", + "class_name": "verl.interactions.gsm8k_interaction.Gsm8kInteraction", + "config": {"difficulty": "easy"} + }, + { + "name": "advanced_math", + "class_name": "custom.interactions.AdvancedMathInteraction", + "config": {"difficulty": "hard", "allow_hints": True} + }, + { + "name": "competition_math", + "class_name": "custom.interactions.CompetitionMathInteraction", + "config": {"time_limit": 300, "show_steps": False} + } + ] + } + + def create_curriculum_data(self, epoch): + if epoch < 5: + return [{"name": "basic_math", ...} for _ in samples] + elif epoch < 10: + return [{"name": "advanced_math", ...} for _ in samples] + else: + return [{"name": "competition_math", ...} for _ in samples] + +**Custom Scoring Functions** + +You can integrate custom reward functions: + +.. code-block:: python + + async def calculate_score(self, instance_id, **kwargs): + response = self._instance_dict[instance_id]["response"] + ground_truth = self._instance_dict[instance_id]["ground_truth"] + + # Custom evaluation logic + if custom_evaluation_function(response, ground_truth): + return 1.0 + else: + return 0.0 + +**Multi-step Interactions** + +For complex scenarios requiring multiple feedback rounds: + +.. code-block:: python + + async def generate_response(self, instance_id, messages, **kwargs): + instance = self._instance_dict[instance_id] + instance["attempts"] += 1 + + # Evaluate current response + reward = await self.calculate_score(instance_id) + + if reward > 0.8: + return True, "Excellent work!", reward, {} + elif instance["attempts"] < 3: + return False, "Good attempt, but try to improve...", reward, {} + else: + return True, "Maximum attempts reached.", reward, {} + +Troubleshooting +--------------- + +**Common Issues** + +1. **Instance ID Conflicts**: Ensure unique instance IDs across concurrent sessions +2. **Memory Leaks**: Always call ``finalize_interaction()`` to clean up resources +3. **Blocking Operations**: Keep interaction logic async and non-blocking +4. **Configuration Errors**: Verify interaction config path and class name are correct +5. **Interaction Name Conflicts**: Ensure all interactions have unique names in the configuration +6. **Missing Interaction**: Verify the ``name`` field in ``interaction_kwargs`` matches available interactions +7. **Backward Compatibility**: When migrating from single to multi-interaction, add ``name`` fields to existing data + +**Debugging** + +Enable debug logging to trace interaction flow: + +.. code-block:: bash + + export VERL_LOGGING_LEVEL=DEBUG + +**Performance Monitoring** + +Monitor interaction performance impact on training throughput and adjust accordingly. + +Related Documentation +-------------------- + +- :doc:`multiturn`: Basic multi-turn rollout configuration +- :doc:`sandbox_fusion`: Tool integration with SGLang +- :doc:`search_tool_example`: Search tool implementation example \ No newline at end of file diff --git a/docs/sglang_multiturn/multiturn.rst b/docs/sglang_multiturn/multiturn.rst index 970ba46c1..5a4c444cb 100644 --- a/docs/sglang_multiturn/multiturn.rst +++ b/docs/sglang_multiturn/multiturn.rst @@ -1,6 +1,8 @@ Multi-turn Rollout Support ========================== +Last updated: 06/27/2025. + Basic Configuration ~~~~~~~~~~~~~~~~~~~ @@ -24,7 +26,8 @@ For custom environment interaction tools, you can implement your own tools based tools: - class_name: "" - config: {} + config: + type: native tool_schema: You may refer to GSM8KTool_example_configuration_, which is one example of the tool configurations. Its implementation can be found in gsm8k_tool.py_. @@ -38,7 +41,269 @@ Finally, set the ``tools_config_file`` in your rollout config: tool_kwargs: tools_config_file: -This allows integration of customized tool behaviors during actor rollout steps. +This allows integration of customized tool behaviors during actor rollout steps. + +If you want rollout with simulated interaction, you can set the ``interaction_config_file`` in your rollout config: + +.. code-block:: yaml + + interaction: + - class_name: "" + config: {} + +.. code-block:: yaml + + actor_rollout_ref: + rollout: + interaction_config_file: + +If your tool creates multi-modal inputs, you should return a list of multi-modal inputs in your tool.execute() implementation. + +Image and video should be processed before returning. For example, if you are using Qwen2.5-VL, you can use the following code to get the representations: + +.. code-block:: python + + async def execute(self, ...) -> Tuple[str | Dict[str, Any], float, dict]: + ... + from verl.utils.dataset.vision_utils import process_image, process_video + + img1 = process_image(img1) + video1 = process_video(video1) + + # due to the (image | video) key is ("image" | "video") instead of ("images" | "videos") in vllm, we need to use ("image" | "video") to specify list of images/videos + # link: https://github.com/vllm-project/vllm/blob/3c545c0c3b98ee642373a308197d750d0e449403/vllm/multimodal/parse.py#L205 + return {"image": [img1, ...], "video": [video1, ...], "text": "..."}, 0, {} + +remeber to set ``return_multi_modal_inputs: False`` in your dataset config in order to process the multi-modal inputs in the rollout correctly. +Refer to the `Handling Multi-Modal Inputs in Datasets`_ section for more details. + +MCP Tool Configuration +~~~~~~~~~~~~~~~~~~~~~~ + +For MCP interaction tools, you can flexibly configure them using a YAML file. The typical setup is as follows: + +.. code-block:: yaml + + tools: + - class_name: "" + config: + type: mcp + mcp: + mcp_servers_config_path: ./mcp_server.json + tool_selected_list: {} + +The ``tool_selected_list`` field is optional and specifies which tools to use from the servers. If you want to enable all available tools, simply omit this attribute. Besides, ``mcp_servers_config_path`` points to a JSON file containing the MCP server configurations. For example: + +.. code-block:: json + + { + "mcpServers": { + "SSE Server": { + "url": "your_server_url", + "auth_token": "your_server_api_token" + }, + "STDIO Server": { + "command": "npx", + "args": ["-y", "server-mcp@0.2.1"], + "env": { + "SERVER_API_KEY": "your_server_api_token" + } + } + } + } + +Since the content formats returned by the MCP server may vary, users can inherit from ``MCPBaseTool`` and override the ``_parse_tool_result`` method to implement custom parsing logic. + +.. code-block:: python + + class MCPYourTool(MCPBaseTool): + def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema): + super().__init__(config, tool_schema) + + def _parse_tool_result(self, content: list) -> Tuple[str, dict]: + ... + +Overall, you may refer to mcp_search_tool.py_ and mcp_tool_config.yaml_ for custom implementation and configuration. + +Multi-turn Tokenization +~~~~~~~~~~~~~~~~~~~~~~~ + +Tokenizing multi-turn rollouts poses a challenge: after applying the chat template and tokenizing the full message list, it's hard to identify which tokens belong to assistant messages. Since the token list is flat, it lacks direct alignment with the message roles. + +To address this, we adopt a **delta-based tokenization** strategy. Each time the LLM generates a new message, we: + +1. Apply the chat template to all prior messages (`messages[:i]`). +2. Apply the chat template again including the latest message (`messages[:i+1]`). +3. Tokenize only the *delta* between these two serialized message strings. + +This ensures that only tokens generated by the assistant are included in the loss mask. + +.. code-block:: python + + # When using tokenizer + # Exclude the assistant prompt (e.g., "<|im_start|>assistant") from the loss by setting add_generation_prompt=True + prev = tokenizer.apply_chat_template(messages[:i], add_generation_prompt=True, tokenize=False) + curr = tokenizer.apply_chat_template(messages[:i+1], add_generation_prompt=False, tokenize=False) + token_ids += tokenizer.encode(curr[len(prev):], add_special_tokens=False) + loss_mask += [1] * len(token_ids) # Mask only the new assistant tokens + +.. code-block:: python + + # When using processor + # Exclude the assistant prompt (e.g., "<|im_start|>assistant") from the loss by setting add_generation_prompt=True + prev = processor.apply_chat_template(messages[:i], add_generation_prompt=True, tokenize=False) + prev_model_inputs = processor(text=prev, images=images, videos=videos, return_tensors="pt")[0].tolist() + curr = processor.apply_chat_template(messages[:i+1], add_generation_prompt=False, tokenize=False) + curr_model_inputs = processor(text=curr, images=images, videos=videos, return_tensors="pt")[0].tolist() + token_ids += curr_model_inputs["input_ids"][len(prev_model_inputs["input_ids"]):] + loss_mask += [1] * len(token_ids) # Mask only the new assistant tokens + +While we've validated this produces consistent results with full message tokenization, future models' chat template could break compatibility. To guard against silent inconsistencies, we compare the delta-based tokenization with full-tokenization results by default at the end of each rollout. + +If you see the following warning, you can check the mismatched substring in the log: + +.. code-block:: + + Inconsistent training and inference tokenization detected. This may lead to unexpected behavior during training. Please review your chat template to determine if this is intentional. For more information, refer to the multiturn README.md. + +The tokenization sanity check mode can be configured using the ``actor_rollout_ref.rollout.multi_turn.tokenization_sanity_check_mode`` parameter, which accepts the following values: + +- ``strict`` (default): Performs strict comparison between delta-based and full tokenization results, raising warnings for any differences. + +- ``ignore_strippable``: Ignores differences in whitespace characters (``\n``, ``\t``, ``\r``, spaces) while still checking for meaningful text mismatches. This is useful when debugging chat template issues where whitespace variations are expected and acceptable. + +- ``disable``: Completely disables the tokenization sanity check. Only use this if you have thoroughly validated that tokenization discrepancies are expected and won't impact training. + +Example configuration: + +.. code-block:: yaml + + actor_rollout_ref: + rollout: + multi_turn: + tokenization_sanity_check_mode: "ignore_strippable" # Choose from: "disable", "ignore_strippable", "strict" + +Handling Multi-Modal Inputs in Datasets +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +If your dataset includes multi-modal inputs (such as images or videos), you can control whether these are pre-processed and included in each sample by setting the return_multi_modal_inputs flag in your dataset config (used by RLHFDataset). + +- ``return_multi_modal_inputs: True`` (default): The dataset will pre-process and include a multi_modal_inputs dictionary for each sample. This dict contains the model-ready representations (e.g., image tensors, video tensors, etc.) as produced by your processor. This is useful for single-turn or SFT-style training, where the model expects all modalities to be present in the batch. + +- ``return_multi_modal_inputs: False``: The dataset will not include the multi_modal_inputs field. This is recommended for multi-turn RL or tool-augmented rollouts, where the model may generate new multi-modal inputs dynamically during rollout, and you want to avoid conflicts or redundant data in the batch. + + +Special Cases +^^^^^^^^^^^^^ + +Some models (e.g., Qwen/QwQ-32B and Qwen3 series) remove internal reasoning content during chat template rendering. As a result, the message content can vary across turns, making the delta-based tokenization inaccurate. + +For example, for the following conversation: + +.. code-block:: python + + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is 2 + 2?"}, + {"role": "assistant", "content": "user asked about a simple math question. 2 + 2 = 4."}, + {"role": "user", "content": "Explain why."}, + {"role": "assistant", "content": "user wants to know the reasoning behind the answer. Search for a good explanation", + "tool_calls": [{"id": "tool1", "type": "search", "arguments": {"query": "Why is 2 + 2 = 4?"}}]}, + {"role": "tool", "content": "The sum of two and two is four because it is a basic arithmetic operation."}, + {"role": "assistant", "content": "The tool provided a good explanation.The sum of two and two is four because it is a basic arithmetic operation."} + ] + +1. Qwen/QwQ-32B will remove all reasoning content except the last assistant message after applying the chat template. + +.. code-block:: text + + <|im_start|>system + You are a helpful assistant.<|im_end|> + <|im_start|>user + What is 2 + 2?<|im_end|> + <|im_start|>assistant + 2 + 2 = 4.<|im_end|> + <|im_start|>user + Explain why.<|im_end|> + <|im_start|>assistant + + {"name": "", "arguments": {"query": "Why is 2 + 2 = 4?"}} + <|im_end|> + <|im_start|>user + + The sum of two and two is four because it is a basic arithmetic operation. + <|im_end|> + <|im_start|>assistant + The tool provided a good explanation. The sum of two and two is four because it is a basic arithmetic operation.<|im_end|> + +2. Qwen3 series will remove all reasoning content before the last user message. + +.. code-block:: text + + <|im_start|>system + You are a helpful assistant.<|im_end|> + <|im_start|>user + What is 2 + 2?<|im_end|> + <|im_start|>assistant + 2 + 2 = 4.<|im_end|> + <|im_start|>user + Explain why.<|im_end|> + <|im_start|>assistant + + user wants to know the reasoning behind the answer. Search for a good explanation + + + + {"name": "", "arguments": {"query": "Why is 2 + 2 = 4?"}} + <|im_end|> + <|im_start|>user + + The sum of two and two is four because it is a basic arithmetic operation. + <|im_end|> + <|im_start|>assistant + + The tool provided a good explanation. + + + The sum of two and two is four because it is a basic arithmetic operation.<|im_end|> + +To handle this, we fall back to a **fixed base conversation** containing only a single system and user message. Since this base doesn't include assistant messages or reasoning content, it remains consistent across turns. + +.. code-block:: python + + BASE_CHAT_HISTORY = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "I am a user."} + ] + prev = tokenizer.apply_chat_template(BASE_CHAT_HISTORY, add_generation_prompt=True, tokenize=False) + curr = tokenizer.apply_chat_template([*BASE_CHAT_HISTORY, messages[i]], add_generation_prompt=False, tokenize=False) + token_ids += tokenizer.encode(curr[len(prev):], add_special_tokens=False) + loss_mask += [1] * len(token_ids) + +This method works well for Qwen3 series. However, Qwen/QwQ-32B currently has a bug in its chat template. A fix_ has been proposed but not yet adopted. Until then, use the following command to download the fixed model revision: + +.. code-block:: bash + + pip install huggingface_hub + huggingface-cli download Qwen/QwQ-32B --revision refs/pr/81 + +.. _fix: https://huggingface.co/Qwen/QwQ-32B/discussions/81 + +Discrepancy Between Training and Inference Templates +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Although the above approach fixes the delta mismatch issue, the removal of reasoning content in the inference-time chat template introduces a new discrepancy: training uses the full reasoning content, while inference does not. + +This mismatch can affect model performance in unpredictable ways. To avoid it, we default to using the full response (including reasoning) for both training and rollout. + +However, this approach comes with trade-offs: + +1. Long reasoning contents can easily exceed the model's context window, especially in multi-turn rollout. +2. There's a mismatch between rollout and production environment now—models will not have reasoning content from past turns if you use the default chat template in production. + +We are still evaluating the impact of these issues. If you experience context length problems or prefer rollouts that match production (i.e., exclude reasoning), you can enable: + +``actor_rollout_ref.rollout.multi_turn.use_inference_chat_template = True`` GSM8K Multi-turn Training Performance ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -51,10 +316,28 @@ See the training performance of multi-turn rollout on the GSM8K task HERE_. .. _gsm8k_tool.py: https://github.com/volcengine/verl/blob/main/verl/tools/gsm8k_tool.py +.. _mcp_search_tool.py: https://github.com/volcengine/verl/blob/main/verl/tools/mcp_search_tool.py + +.. _mcp_tool_config.yaml: https://github.com/volcengine/verl/blob/main/examples/sglang_multiturn/config/tool_config/mcp_tool_config.yaml + +Interaction System +~~~~~~~~~~~~~~~~~~ + +For dynamic conversational feedback during RL training, see: + +.. toctree:: + :maxdepth: 1 + + interaction_system + Search Tool Integration ~~~~~~~~~~~~~~~~~~~~~~~ .. toctree:: :maxdepth: 1 - search_tool_example \ No newline at end of file + search_tool_example + +Code Walkthrough +~~~~~~~~~~~~~~~~~~~~~~~ +If you want to learn more in depth about the code execution flow, please read https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/tree/main/rlhf/verl/multi-turn/code-walk-through diff --git a/docs/sglang_multiturn/sandbox_fusion.rst b/docs/sglang_multiturn/sandbox_fusion.rst index 1f9f675d3..207af5289 100644 --- a/docs/sglang_multiturn/sandbox_fusion.rst +++ b/docs/sglang_multiturn/sandbox_fusion.rst @@ -2,6 +2,8 @@ Sandbox Fusion Tool Integration =============================== +Last updated: 06/10/2025. + Motivations =========== @@ -208,79 +210,83 @@ Unit Tests e2e Tests ---------- -we provide e2e test scripts in `tests/e2e` folder, named `tests/e2e/run_gsm8k_fsdp_sgl_multiturn_sf_tool.sh` +we provide e2e test scripts in `tests/special_e2e` folder, named `tests/special_e2e/run_gsm8k_fsdp_sgl_multiturn_sf_tool.sh` by setting 'trainer.rollout_data_dir' you can dump the rollout data to local disk. here is an sample taken from the rollout data: -{ - "input": " - - system\nYou are a math expert. You are given a question and you need to solve it step by step. Reasoning step by step before any tool call. You should use the `calc_gsm8k_reward` tool after step by step solving the question, before generate final answer at least once and refine your answer if necessary. Put your final answer in the format of `#### `.\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n\n{\"type\": \"function\", \"function\": {\"name\": \"code_interpreter\", \"description\": \"A tool for executing code.\", \"parameters\": {\"type\": \"object\", \"properties\": {\"code\": {\"type\": \"string\", \"description\": \"The code to execute.\", \"enum\": null}}, \"required\": [\"code\"]}, \"strict\": false}}\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n\n - - user\nJohn gets a bonus that's the same percentage every year. Last year he made $100,000 and got a $10,000 bonus. This year he makes $200,000. How much will John make this year when adding both his total pay and bonus together? Let's think step by step and output the final answer after `####`.\n - - assistant\n", - "output": "assistant\n\nOkay, let's see. John gets a bonus that's the same percentage every year. Last year, he made $100,000 and got a $10,000 bonus. This year, he makes $200,000. The question is asking how much he'll make this year when adding both his total pay and bonus together.\n\nFirst, I need to figure out the percentage bonus he gets. Since the bonus is the same percentage each year, if I can find that percentage based on last year's details, I can apply it to this year's earnings to find this year's bonus. Then, I'll add that bonus to his total pay this year.\n\nLast year, he earned $100,000 and got a $10,000 bonus. To find the percentage, I can divide the bonus by his total earnings. So, 10,000 divided by 100,000. Let me calculate that.\n\n10,000 / 100,000 equals 0.1, which is 10% as a percentage. So, his bonus is 10% of his pay each year.\n\nThis year, he makes $200,000. So, the bonus this year would be 10% of 200,000. Let me compute that using the code interpreter to ensure accuracy.\n - - \n{\"name\": \"code_interpreter\", \"arguments\": {\"code\": \"total_pay_this_year = 200000\\nbonus_percentage = 10 / 100\\nbonus_this_year = total_pay_this_year * bonus_percentage\\ntotal_income_this_year = total_pay_this_year + bonus_this_year\\nprint(total_income_this_year)\", \"executes\": \"True\"}}\n\n - - tool\n220000.0\n\n - - assistant\n\nSo the code shows that John's total income this year would be $220,000. This is consistent with the 10% bonus on last year's pay of $100,000, which was $10,000. Therefore, this year's bonus of $20,000 added to his total pay of $200,000 results in a total income of $220,000.\n\nI need to verify that there are no other factors at play. The problem states it's a constant percentage bonus each year, so there's no mention of changing percentages or other variables. Therefore, applying a straightforward percentage of 10% on this year's partial salary amount would result in $20,000. Adding that to his total pay gives $220,000. The calculation is correct and aligns with the question's parameters.\n\n\n#### 220000.0", - "score": 0, - "step": 1 -} - -here is the readable format version: - -[system] - -You are a math expert. You are given a question and you need to solve it step by step. Reasoning step by step before any tool call. You should use the `calc_gsm8k_reward` tool after step by step solving the question, before generate final answer at least once and refine your answer if necessary. Put your final answer in the format of `#### `. - -# Tools - -You may call one or more functions to assist with the user query. - -You are provided with function signatures within XML tags: - -{"type": "function", "function": {"name": "code_interpreter", "description": "A tool for executing code.", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": "The code to execute.", "enum": null}}, "required": ["code"]}, "strict": false}} - - -For each function call, return a json object with function name and arguments within XML tags: - -{"name": , "arguments": } - - -[user] - -John gets a bonus that's the same percentage every year. Last year he made $100,000 and got a $10,000 bonus. This year he makes $200,000. How much will John make this year when adding both his total pay and bonus together? Let's think step by step and output the final answer after `####`. - -[assistant] - - -Okay, let's see. John gets a bonus that's the same percentage every year. Last year, he made $100,000 and got a $10,000 bonus. This year, he makes $200,000. The question is asking how much he'll make this year when adding both his total pay and bonus together. - -First, I need to figure out the percentage bonus he gets. Since the bonus is the same percentage each year, if I can find that percentage based on last year's details, I can apply it to this year's earnings to find this year's bonus. Then, I'll add that bonus to his total pay this year. - -Last year, he earned $100,000 and got a $10,000 bonus. To find the percentage, I can divide the bonus by his total earnings. So, 10,000 divided by 100,000. Let me calculate that. - -10,000 / 100,000 equals 0.1, which is 10% as a percentage. So, his bonus is 10% of his pay each year. - -This year, he makes $200,000. So, the bonus this year would be 10% of 200,000. Let me compute that using the code interpreter to ensure accuracy. - -{"name": "code_interpreter", "arguments": {"code": "total_pay_this_year = 200000\nbonus_percentage = 10 / 100\nbonus_this_year = total_pay_this_year * bonus_percentage\ntotal_income_this_year = total_pay_this_year + bonus_this_year\nprint(total_income_this_year)", "executes": "True"}} - - -[tool] - -220000.0 +.. code-block:: python -[assistant] + { + "input": " + + system\nYou are a math expert. You are given a question and you need to solve it step by step. Reasoning step by step before any tool call. You should use the `calc_gsm8k_reward` tool after step by step solving the question, before generate final answer at least once and refine your answer if necessary. Put your final answer in the format of `#### `.\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n\n{\"type\": \"function\", \"function\": {\"name\": \"code_interpreter\", \"description\": \"A tool for executing code.\", \"parameters\": {\"type\": \"object\", \"properties\": {\"code\": {\"type\": \"string\", \"description\": \"The code to execute.\", \"enum\": null}}, \"required\": [\"code\"]}, \"strict\": false}}\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n\n + + user\nJohn gets a bonus that's the same percentage every year. Last year he made $100,000 and got a $10,000 bonus. This year he makes $200,000. How much will John make this year when adding both his total pay and bonus together? Let's think step by step and output the final answer after `####`.\n + + assistant\n", + "output": "assistant\n\nOkay, let's see. John gets a bonus that's the same percentage every year. Last year, he made $100,000 and got a $10,000 bonus. This year, he makes $200,000. The question is asking how much he'll make this year when adding both his total pay and bonus together.\n\nFirst, I need to figure out the percentage bonus he gets. Since the bonus is the same percentage each year, if I can find that percentage based on last year's details, I can apply it to this year's earnings to find this year's bonus. Then, I'll add that bonus to his total pay this year.\n\nLast year, he earned $100,000 and got a $10,000 bonus. To find the percentage, I can divide the bonus by his total earnings. So, 10,000 divided by 100,000. Let me calculate that.\n\n10,000 / 100,000 equals 0.1, which is 10% as a percentage. So, his bonus is 10% of his pay each year.\n\nThis year, he makes $200,000. So, the bonus this year would be 10% of 200,000. Let me compute that using the code interpreter to ensure accuracy.\n + + \n{\"name\": \"code_interpreter\", \"arguments\": {\"code\": \"total_pay_this_year = 200000\\nbonus_percentage = 10 / 100\\nbonus_this_year = total_pay_this_year * bonus_percentage\\ntotal_income_this_year = total_pay_this_year + bonus_this_year\\nprint(total_income_this_year)\", \"executes\": \"True\"}}\n\n + + tool\n220000.0\n\n + + assistant\n\nSo the code shows that John's total income this year would be $220,000. This is consistent with the 10% bonus on last year's pay of $100,000, which was $10,000. Therefore, this year's bonus of $20,000 added to his total pay of $200,000 results in a total income of $220,000.\n\nI need to verify that there are no other factors at play. The problem states it's a constant percentage bonus each year, so there's no mention of changing percentages or other variables. Therefore, applying a straightforward percentage of 10% on this year's partial salary amount would result in $20,000. Adding that to his total pay gives $220,000. The calculation is correct and aligns with the question's parameters.\n\n\n#### 220000.0", + "score": 0, + "step": 1 + } - -So the code shows that John's total income this year would be $220,000. This is consistent with the 10% bonus on last year's pay of $100,000, which was $10,000. Therefore, this year's bonus of $20,000 added to his total pay of $200,000 results in a total income of $220,000. +here is the readable format version: -I need to verify that there are no other factors at play. The problem states it's a constant percentage bonus each year, so there's no mention of changing percentages or other variables. Therefore, applying a straightforward percentage of 10% on this year's partial salary amount would result in $20,000. Adding that to his total pay gives $220,000. The calculation is correct and aligns with the question's parameters. - +.. code-block:: python -#### 220000.0 \ No newline at end of file + [system] + + You are a math expert. You are given a question and you need to solve it step by step. Reasoning step by step before any tool call. You should use the `calc_gsm8k_reward` tool after step by step solving the question, before generate final answer at least once and refine your answer if necessary. Put your final answer in the format of `#### `. + + # Tools + + You may call one or more functions to assist with the user query. + + You are provided with function signatures within XML tags: + + {"type": "function", "function": {"name": "code_interpreter", "description": "A tool for executing code.", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": "The code to execute.", "enum": null}}, "required": ["code"]}, "strict": false}} + + + For each function call, return a json object with function name and arguments within XML tags: + + {"name": , "arguments": } + + + [user] + + John gets a bonus that's the same percentage every year. Last year he made $100,000 and got a $10,000 bonus. This year he makes $200,000. How much will John make this year when adding both his total pay and bonus together? Let's think step by step and output the final answer after `####`. + + [assistant] + + + Okay, let's see. John gets a bonus that's the same percentage every year. Last year, he made $100,000 and got a $10,000 bonus. This year, he makes $200,000. The question is asking how much he'll make this year when adding both his total pay and bonus together. + + First, I need to figure out the percentage bonus he gets. Since the bonus is the same percentage each year, if I can find that percentage based on last year's details, I can apply it to this year's earnings to find this year's bonus. Then, I'll add that bonus to his total pay this year. + + Last year, he earned $100,000 and got a $10,000 bonus. To find the percentage, I can divide the bonus by his total earnings. So, 10,000 divided by 100,000. Let me calculate that. + + 10,000 / 100,000 equals 0.1, which is 10% as a percentage. So, his bonus is 10% of his pay each year. + + This year, he makes $200,000. So, the bonus this year would be 10% of 200,000. Let me compute that using the code interpreter to ensure accuracy. + + {"name": "code_interpreter", "arguments": {"code": "total_pay_this_year = 200000\nbonus_percentage = 10 / 100\nbonus_this_year = total_pay_this_year * bonus_percentage\ntotal_income_this_year = total_pay_this_year + bonus_this_year\nprint(total_income_this_year)", "executes": "True"}} + + + [tool] + + 220000.0 + + [assistant] + + + So the code shows that John's total income this year would be $220,000. This is consistent with the 10% bonus on last year's pay of $100,000, which was $10,000. Therefore, this year's bonus of $20,000 added to his total pay of $200,000 results in a total income of $220,000. + + I need to verify that there are no other factors at play. The problem states it's a constant percentage bonus each year, so there's no mention of changing percentages or other variables. Therefore, applying a straightforward percentage of 10% on this year's partial salary amount would result in $20,000. Adding that to his total pay gives $220,000. The calculation is correct and aligns with the question's parameters. + + + #### 220000.0 diff --git a/docs/sglang_multiturn/search_tool_example.rst b/docs/sglang_multiturn/search_tool_example.rst index 4fac6ef3c..cbbdeb0d0 100644 --- a/docs/sglang_multiturn/search_tool_example.rst +++ b/docs/sglang_multiturn/search_tool_example.rst @@ -1,6 +1,9 @@ ======================= Search Tool Integration ======================= + +Last updated: 05/30/2025. + Introduction ------------ - We have added a search tool calling function to Multi-Turn RL, enabling the model to initiate retrieval requests during Actor rollout and directly use retrieval results for training. **We support using a local dense retriever as the retrieval tool, as well as integrating with your own local retrieval engine.** @@ -211,7 +214,7 @@ To enable multi-turn reasoning, set the following fields in your config: actor_rollout_ref: rollout: - name: "sglang_async" + name: "sglang" multi_turn: enable: True diff --git a/docs/single_controller.rst b/docs/single_controller.rst index 521158c85..d12177854 100644 --- a/docs/single_controller.rst +++ b/docs/single_controller.rst @@ -1,6 +1,8 @@ The Design of ``verl.single_controller`` ============================================== +Last updated: 05/21/2025. + **Author:**\ `Wang Zhang `__ Preface diff --git a/docs/start/agentic_rl.rst b/docs/start/agentic_rl.rst new file mode 100644 index 000000000..60af79f5f --- /dev/null +++ b/docs/start/agentic_rl.rst @@ -0,0 +1,125 @@ +Agentic RL Training +=================== + +Last updated: 07/15/2025. + +Overview +---------- +The goal of Agentic RL is to improve the performance of backend models from reinforcement learning to the Agent. During the training process, a series of features are developed: + +1. Server-based asynchronous rollout +2. Multi-turn conversations and tool calls +3. LangGraph-based Agent + + +This document explains the system principles and usage involved to help users implement Agentic RL. + + +Server-based Asynchronous Rollout +--------------------------------- + +Since Agents need to interact with the environment through various tool calls, in order to avoid GPU idling while waiting for tool call return results, an asyncio based co-routing mechanism is utilized to execute each rollout requests asynchronously, thereby improving training performance. To support asynchronous rollout, the inference engine (server) and the agent (client) are architecturally separated, implementing a server-based system with the following objectives: + +1. Enabling load balancing mechanisms to balance loads across multiple GPUs and reduce the impact of long-tail requests on performance. For this purpose, scheduling capabilities in stream mode (recipe\stream_mode) are implemented as a recipe. +2. Preventing agent specific features such as tracing from affecting the inference engine. + +System Architecture +~~~~~~~~~~~~~~~~~~~ + +.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/agent_loop.png?raw=true + +System Components +~~~~~~~~~~~~~~~~~ + ++--------------------------+----------------------------------------------------------------------------+ +| Component | Role | ++==========================+============================================================================+ +| AgentLoop | Client, implements Agent functions | ++--------------------------+----------------------------------------------------------------------------+ +| AsyncLLMServerManager | Inference gateway, provides generate interface for AgentLoop | ++--------------------------+----------------------------------------------------------------------------+ +| AsyncServer | Server, each instance is connected to one DP group of the inference engine | ++--------------------------+----------------------------------------------------------------------------+ + +**"generate" Interface** + +The "generate" function based on ray actor is used between the Client and Server instead of the standard chat completion API. This is because the conversion between tokens and text can be irreversible. For example, the token converted from "" will be different from that generated by the LLM. During the training phase, it is necessary to strictly use the tokens generated by LLM inference to avoid inaccurate in computing advantage, which may affect model performance. Having the Server provide a token-based API helps the Client maintain the relationship between the text generated by tool calls and the tokens returned by the LLM, so as to output correct tokens for training. + + +**Inference Engine Adaptation** +AsyncServer uniformly provides a generate function to the upper layer, with separate implementations for SGLang and vLLM to hide underlying differences: + +1. The SGLang AsyncServer uses the async_generate interface of the SGLang engine, which is located on the first GPU of each TP group. Therefore, AsyncServer needs to remotely call async_generate through ray actor. +2. The vLLM AsyncServer uses the generate interface of the vLLM engine, which can communicate with the GPUs in the TP group through ZMQ and can be directly called in AsyncServer. + + +Usage Example +~~~~~~~~~~~~~ + +Follow :doc:`GSM8K example<../examples/gsm8k_example>` to prepare the dataset and model checkpoints. +This example uses the sglang inference engine by default, and you can also modify rollout_name to use vllm. + +.. code-block:: bash + + bash examples/grpo_trainer/run_qwen2-7b_seq_balance.sh + + +Multi-turn Conversations and Tool Calls +--------------------------------------- + +Follow :doc:`Multi-turn Rollout Support<../sglang_multiturn/multiturn>` to prepare tool and configuration files. + +The Tool Agent Loop has an additional requirement: adding an "agent_name" field to the dataset. During rollout, it will choose to use tool_agent_loop or single_turn_agent (default) based on this field. + +Usage Example +~~~~~~~~~~~~~ + +.. code-block:: bash + + # install mlflow to view toolcall and llm trace + pip install mlflow + + # This will download and preprocess the GSM8K dataset into ~/data/gsm8k/ and add the "agent_name" field. + bash examples/data_preprocess/gsm8k_tool_agent_loop.py + + # Start training with tool calls and enabled mlflow based trace helping to debug the rollout details + bash examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_tool_agent_mlflow.sh + + # When training is done, start a mlflow server to view trace + mlflow ui -h 0.0.0.0 -p 5000 --backend-store-uri sqlite:////tmp/mlruns.db + + # then you can open http://:5000 from browser to view trace + + +Note: During training, because the model may sometimes fail to generate correct toolcall tags, an error message "Failed to decode tool call" will be output to the console, which does not indicate an abnormality in training. + + +Follow :doc:`Rollout trace<../advance/rollout_trace>` to known more about trace feature. + + + +Agent Framework +--------------- + +System Architecture +~~~~~~~~~~~~~~~~~~~ + +.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/langgraph_agent.png?raw=true + +System Components +~~~~~~~~~~~~~~~~~ + ++--------------------------+-----------------------------------------------------------------------------------------------+ +| Component | Role | ++==========================+===============================================================================================+ +| ChatModel | LLM object of LangChain, used to adapt to the “generate” api provided by AsyncLLMServerManager| ++--------------------------+-----------------------------------------------------------------------------------------------+ +| RectAgentLoop | Agent adaptation layer, which by default supports a naive LangGraph Agentic. | +| | New classes can be derived to support user-defined Agents, and the run function needs to be | +| | implemented to complete Agent calls. | ++--------------------------+-----------------------------------------------------------------------------------------------+ +| AsyncServer | Server, each instance is connected to one DP group of the inference engine. | ++--------------------------+-----------------------------------------------------------------------------------------------+ + + +Follow doc "recipe/langgraph_agent/example/README.md" for more details. \ No newline at end of file diff --git a/docs/start/install.rst b/docs/start/install.rst index d1579b7ed..12c9c3531 100644 --- a/docs/start/install.rst +++ b/docs/start/install.rst @@ -19,75 +19,108 @@ Choices of Backend Engines We recommend using **FSDP** backend to investigate, research and prototype different models, datasets and RL algorithms. The guide for using FSDP backend can be found in :doc:`FSDP Workers<../workers/fsdp_workers>`. -For users who pursue better scalability, we recommend using **Megatron-LM** backend. Currently, we support `Megatron-LM v0.11 `_. The guide for using Megatron-LM backend can be found in :doc:`Megatron-LM Workers<../workers/megatron_workers>`. +For users who pursue better scalability, we recommend using **Megatron-LM** backend. Currently, we support `Megatron-LM v0.12.1 `_. The guide for using Megatron-LM backend can be found in :doc:`Megatron-LM Workers<../workers/megatron_workers>`. -.. note:: - - verl directly supports megatron's `GPTModel` API on the main branch with mcore v0.11. For mcore v0.4 try `0.3.x branch `_ instead. 2. Inference: -For inference, vllm 0.6.3 and 0.8.2 have been tested for stability. Avoid using vllm 0.7x due to reported issues with its functionality. +For inference, vllm 0.8.3 and later versions have been tested for stability. We recommend turning on env var `VLLM_USE_V1=1` for optimal performance. -For SGLang, refer to the :doc:`SGLang Backend<../workers/sglang_worker>` for detailed installation and usage instructions. **SGLang offers better throughput and is under extensive development.** We encourage users to report any issues or provide feedback via the `SGLang Issue Tracker `_. +For SGLang, refer to the :doc:`SGLang Backend<../workers/sglang_worker>` for detailed installation and usage instructions. SGLang rollout is under extensive development and offers many advanced features and optimizations. We encourage users to report any issues or provide feedback via the `SGLang Issue Tracker `_. For huggingface TGI integration, it is usually used for debugging and single GPU exploration. Install from docker image ------------------------- -We provide pre-built Docker images for quick setup. +We provide pre-built Docker images for quick setup. And from this version, +we utilize a new image release hierarchy for productivity and stability. + +The image types are divided into three large categories: + +- **Base Image**: Without inference and training frameworks, only basic dependencies are installed. + Can directly install vllm or SGLang on top of it, without need of reinstall torch or CUDA. +- **Application Image**: Stable version with inference and training frameworks installed. +- **Community Image**: Unstable version with the latest frameworks and features. + +The first two types of images are hosted on dockerhub `verlai/verl `_ repository, while the preview images are hosted on community repository. + +.. note:: + + The image versions are mapped with verl releases, for example, image with tag ``verl0.4`` is built for verl release ``v0.4.x``. + +Base Image +:::::::::: + +The stable base image is ``verlai/verl:base-verl0.4-cu124-cudnn9.8-torch2.6-fa2.7.4``. The installed package versions can be found from tags, and the Dockerfile can be found in ``docker/verl[version]-[packages]/Dockerfile.base``. + +The base images for preview are ``verlai/verl:base-verl0.5-cu126-cudnn9.8-torch2.7.1-fa2.8.0` and ``verlai/verl:base-verl0.5-preview-cu128-cudnn9.8-torch2.7.1-fa2.8.0`` with different CUDA versions. From verl0.5, images are built with `Deep-EP `_ for efficient EP communication. + +The update of base image is not frequent, and the app image can be built on top of it without reinstalling base packages. + +Application Image +::::::::::::::::: + +From this version, we divide images built for vLLM and SGLang as the divergence of dependent packages like FlashInfer. + +There are four types of application images available: + +- **vLLM with FSDP and Megatron**: ``verlai/verl:app-verl0.4-vllm0.8.5-mcore0.12.1``, with Deep-EP support: ``verlai/verl:app-verl0.4-vllm0.8.5-mcore0.12.1-deepep``. +- **SGLang with FSDP and Megatron**: ``verlai/verl:app-verl0.4-sglang0.4.6.post5-vllm0.8.5-mcore0.12.1`` (need vLLM support, but can have some package conflicts), with Deep-EP support: ``verlai/verl:app-verl0.4-sglang0.4.6.post5-vllm0.8.5-mcore0.12.1-deepep``. +- **Preview version of SGLang with FSDP and Megatron, CUDA 12.6**: ``verlai/verl:app-verl0.5-sglang0.4.8-mcore0.12.1`` +- **Preview version of SGLang with FSDP and Megatron, CUDA 12.8**: ``verlai/verl:app-preview-verl0.5-sglang0.4.8-mcore0.12.1`` + +The latest vLLM support is coming soon. + +Docker images with Megatron backends are runnable with large language model like ``Qwen/Qwen3-235B-A22B``, ``deepseek-ai/DeepSeek-V3-0324`` post-training. Refer to the :doc:`Large Language Model Post-Training documentation<../perf/dpsk>` for more details. -For vLLM with Megatron or FSDP, please use the stable version of image ``whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6-mcore0.12.0-te2.3``. +Application images can be updated frequently, and the Dockerfile can be found in ``docker/verl[version]-[packages]/Dockerfile.app.[frameworks]``. Based on the base image, it is easy to build your own application image with the desired inference and training frameworks. -For latest vLLM with FSDP, please refer to ``hiyouga/verl:ngc-th2.6.0-cu126-vllm0.8.4-flashinfer0.2.2-cxx11abi0``. +Community Image +::::::::::::::: -For SGLang with FSDP, please use ``ocss884/verl-sglang:ngc-th2.6.0-cu126-sglang0.4.6.post5`` which is provided by SGLang RL Group. +Community images are provided by the community, including the latest versions of vLLM and SGLang, and may include experimental features or configurations. And also works for other hardwares or platforms like AMD GPUs with ROCM or AWS EFA and Sagemaker. + +For latest vLLM with FSDP, please refer to `hiyouga/verl `_ repository and the latest version is ``hiyouga/verl:ngc-th2.6.0-cu126-vllm0.8.4-flashinfer0.2.2-cxx11abi0``. + +For latest SGLang with FSDP, please refer to `ocss884/verl-sglang `_ repository and the latest version is ``ocss884/verl-sglang:ngc-th2.6.0-cu126-sglang0.4.6.post5`` which is provided by SGLang RL Group. See files under ``docker/`` for NGC-based image or if you want to build your own. +Note that For aws instances with EFA net interface (Sagemaker AI Pod), +you need to install EFA driver as shown in ``docker/Dockerfile.extenstion.awsefa`` + +Installation from Docker +:::::::::::::::::::::::: + +After pulling the desired Docker image and installing desired inference and training frameworks, you can run it with the following steps: + 1. Launch the desired Docker image and attach into it: .. code:: bash - docker create --runtime=nvidia --gpus all --net=host --shm-size="10g" --cap-add=SYS_ADMIN -v .:/workspace/verl --name verl + docker create --runtime=nvidia --gpus all --net=host --shm-size="10g" --cap-add=SYS_ADMIN -v .:/workspace/verl --name verl sleep infinity docker start verl docker exec -it verl bash -2. Inside the container, install latest verl: +2. If you use the images provided, you only need to install verl itself without dependencies: -.. code:: bash +.. note:: # install the nightly version (recommended) git clone https://github.com/volcengine/verl && cd verl - # pick your choice of inference engine: vllm or sglang - # pip3 install -e .[vllm] - # pip3 install -e .[sglang] - # or install from pypi instead of git via: - # pip3 install verl[vllm] - # pip3 install verl[sglang] + pip3 install --no-deps -e . -.. note:: - - The Docker image ``whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6-mcore0.12.0-te2.3`` is built with the following configurations: - - - **PyTorch**: 2.6.0+cu124 - - **CUDA**: 12.4 - - **cuDNN**: 9.8.0 - - **nvidia-cudnn-cu12**: 9.8.0.87, **important for the usage of Megatron FusedAttention with MLA Support** - - **Flash Attenttion**: 2.7.4.post1 - - **Flash Infer**: 0.2.2.post1 - - **vLLM**: 0.8.5 - - **SGLang**: 0.4.6.post5 - - **Megatron-LM**: core_v0.12.0 - - **TransformerEngine**: 2.3 - - **Ray**: 2.44.1 +[Optional] If you hope to switch between different frameworks, you can install verl with the following command: .. note:: - For aws instances with EFA net interface (Sagemaker AI Pod), - you need to install EFA driver as shown in ``docker/Dockerfile.awsefa`` + # install the nightly version (recommended) + git clone https://github.com/volcengine/verl && cd verl + pip3 install -e .[vllm] + pip3 install -e .[sglang] + Install from custom environment --------------------------------------------- @@ -102,6 +135,11 @@ For training and inference engines to utilize better and faster hardware support and some of the dependencies are easy to be overridden when installing other packages, so we put them in the :ref:`Post-installation` step. +.. note:: + + The installation steps below are recommended configurations for the latest version of verl. + If you are trying to customize your own environment, please ignore the strict constraints. + We need to install the following pre-requisites: - **CUDA**: Version >= 12.4 @@ -217,7 +255,7 @@ If you encounter issues about package versions during running verl, please updat Install with AMD GPUs - ROCM kernel support ------------------------------------------------------------------ -When you run on AMD GPUs (MI300) with ROCM platform, you cannot use the previous quickstart to run verl. You should follow the following steps to build a docker and run it. +When you run on AMD GPUs (MI300) with ROCM platform, you cannot use the previous quickstart to run verl. You should follow the following steps to build a docker and run it. If you encounter any issues in using AMD GPUs running verl, feel free to contact me - `Yusheng Su `_. Find the docker for AMD ROCm: `docker/Dockerfile.rocm `_ @@ -297,7 +335,7 @@ Launch the container verl-rocm \ /bin/bash -(Optional): If you do not want to root mode and require assign yuorself as the user -Please add ``-e HOST_UID=$(id -u)`` and ``-e HOST_GID=$(id -g)`` into the above docker launch script. +If you do not want to root mode and require assign yourself as the user, +Please add ``-e HOST_UID=$(id -u)`` and ``-e HOST_GID=$(id -g)`` into the above docker launch script. -(Currently Support): Training Engine: FSDP; Inference Engine: vLLM and SGLang - We will support Megatron in the future. +verl with AMD GPUs currently supports FSDP as the training engine, vLLM and SGLang as the inference engine. We will support Megatron in the future. diff --git a/docs/start/more_resources.rst b/docs/start/more_resources.rst new file mode 100644 index 000000000..aa8cb2a62 --- /dev/null +++ b/docs/start/more_resources.rst @@ -0,0 +1,7 @@ +More Resources +============== + +Last updated: 06/30/2025. + +- Introduction to verl (`Slides `_) +- verl Code Walkthrough (`Slides `_, `Talk in Chinese `_) diff --git a/docs/start/multinode.rst b/docs/start/multinode.rst index 844a2c9dc..9e058055d 100644 --- a/docs/start/multinode.rst +++ b/docs/start/multinode.rst @@ -1,6 +1,8 @@ Multinode Training ================== +Last updated: 06/10/2025. + .. _wuxibin89: https://github.com/wuxibin89 Author: `Xibin Wu `_, `Yusheng Su `_. @@ -62,10 +64,6 @@ Submit job to ray cluster .. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/ray/job.png?raw=true .. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/ray/job_detail.png?raw=true -.. note:: - - From Ray 2.20, ``ray job submit`` or ``client = JobSubmissionClient("http://127.0.0.1:8265")`` is deprecated in current environment, and Ray version less than 2.40 is not compatible with current version of verl. We recommend you upgrade to Ray latest version and directly execute the training scripts. - Slurm ----- @@ -177,7 +175,6 @@ Now you can submit the training job to the Ray cluster which is available at ``l trainer.project_name=ppo_training \ trainer.experiment_name=qwen-2.5-7B \ trainer.val_before_train=False \ - trainer.default_hdfs_dir=null \ trainer.n_gpus_per_node=8 \ trainer.nnodes=2 \ trainer.default_local_dir=/checkpoints \ @@ -456,8 +453,6 @@ slurm_script.sh echo "IP Head: $ip_head" # make sure we set environment variables before Ray initialization - # If you are using vllm<=0.6.3, you might need to set the following environment variable to avoid bugs: - # export VLLM_ATTENTION_BACKEND=XFORMERS # Print out all env variables printenv @@ -575,7 +570,7 @@ slurm_script.sh critic.model.fsdp_config.optimizer_offload=False \ algorithm.kl_ctrl.kl_coef=0.0001 \ trainer.critic_warmup=0 \ - trainer.logger=['console','wandb'] \ + trainer.logger='["console","wandb"]' \ trainer.project_name='verl_example' \ trainer.experiment_name='Qwen2.5-32B-Instruct_function_rm' \ trainer.n_gpus_per_node=${SLURM_GPUS_PER_NODE} \ diff --git a/docs/start/quickstart.rst b/docs/start/quickstart.rst index 99db020a9..22b8388a2 100644 --- a/docs/start/quickstart.rst +++ b/docs/start/quickstart.rst @@ -70,7 +70,7 @@ answer from both the solution and model's output using regular expression matching. We assign a reward of 1 to correct answer, 0.0 to incorrect answer and 0 to no answer. -For more details, please refer to `verl/utils/reward_score/gsm8k.py `_. +For more details, please refer to `verl/utils/reward_score/gsm8k.py `_. **Training Script** @@ -78,7 +78,7 @@ Now let's run PPO training with the dataset and model above. [2]_ Set the ``data.train_files`` ,\ ``data.val_files``, ``actor_rollout_ref.model.path`` and ``critic.model.path`` based on your dataset and model names or paths. -You may set ``VERL_USE_MODELSCOPE=True`` to download models from modelscope instead of huggingface. +You may set ``VERL_USE_MODELSCOPE=True`` to download models from `modelscope `_ instead of `huggingface `_. .. code-block:: bash @@ -100,9 +100,8 @@ You may set ``VERL_USE_MODELSCOPE=True`` to download models from modelscope inst critic.model.path=Qwen/Qwen2.5-0.5B-Instruct \ critic.ppo_micro_batch_size_per_gpu=4 \ algorithm.kl_ctrl.kl_coef=0.001 \ - trainer.logger=['console'] \ + trainer.logger=console \ trainer.val_before_train=False \ - trainer.default_hdfs_dir=null \ trainer.n_gpus_per_node=1 \ trainer.nnodes=1 \ trainer.save_freq=10 \ @@ -116,15 +115,24 @@ You are expected to see the following logs, indicating training in progress. The step:0 - timing/gen:21.470 - timing/ref:4.360 - timing/values:5.800 - actor/reward_kl_penalty:0.000 - actor/reward_kl_penalty_coeff:0.001 - timing/adv:0.109 - timing/update_critic:15.664 - critic/vf_loss:14.947 - critic/vf_clipfrac:0.000 - critic/vpred_mean:-2.056 - critic/grad_norm:1023.278 - critic/lr(1e-4):0.100 - timing/update_actor:20.314 - actor/entropy_loss:0.433 - actor/pg_loss:-0.005 - actor/pg_clipfrac:0.000 - actor/ppo_kl:0.000 - actor/grad_norm:1.992 - actor/lr(1e-4):0.010 - critic/score/mean:0.004 - critic/score/max:1.000 - critic/score/min:0.000 - critic/rewards/mean:0.004 - critic/rewards/max:1.000 - critic/rewards/min:0.000 - critic/advantages/mean:-0.000 - critic/advantages/max:2.360 - critic/advantages/min:-2.280 - critic/returns/mean:0.003 - critic/returns/max:0.000 - critic/returns/min:0.000 - critic/values/mean:-2.045 - critic/values/max:9.500 - critic/values/min:-14.000 - response_length/mean:239.133 - response_length/max:256.000 - response_length/min:77.000 - prompt_length/mean:104.883 - prompt_length/max:175.000 - prompt_length/min:68.000 step:1 - timing/gen:23.020 - timing/ref:4.322 - timing/values:5.953 - actor/reward_kl_penalty:0.000 - actor/reward_kl_penalty:0.001 - timing/adv:0.118 - timing/update_critic:15.646 - critic/vf_loss:18.472 - critic/vf_clipfrac:0.384 - critic/vpred_mean:1.038 - critic/grad_norm:942.924 - critic/lr(1e-4):0.100 - timing/update_actor:20.526 - actor/entropy_loss:0.440 - actor/pg_loss:0.000 - actor/pg_clipfrac:0.002 - actor/ppo_kl:0.000 - actor/grad_norm:2.060 - actor/lr(1e-4):0.010 - critic/score/mean:0.000 - critic/score/max:0.000 - critic/score/min:0.000 - critic/rewards/mean:0.000 - critic/rewards/max:0.000 - critic/rewards/min:0.000 - critic/advantages/mean:0.000 - critic/advantages/max:2.702 - critic/advantages/min:-2.616 - critic/returns/mean:0.000 - critic/returns/max:0.000 - critic/returns/min:0.000 - critic/values/mean:-2.280 - critic/values/max:11.000 - critic/values/min:-16.000 - response_length/mean:232.242 - response_length/max:256.000 - response_length/min:91.000 - prompt_length/mean:102.398 - prompt_length/max:185.000 - prompt_length/min:70.000 -Checkout :ref:`algo-baseline-page` for full training and validation logs for reference. +Checkout ``Algorithm Baselines`` page for full training and validation logs for reference. -The checkpoint is saved at the following dir by default: ``checkpoints/${trainer.project_name}/${trainer.experiment_name}`` +The checkpoint is saved at the following dir by default: ``checkpoints/${trainer.project_name}/${trainer.experiment_name}``. You can merge the saved checkpoints to huggingface model using ``verl.model_merger`` module, for example: + +.. code-block:: bash + + python3 -m verl.model_merger merge \ + --backend fsdp \ + --local_dir checkpoints/${trainer.project_name}/${trainer.experiment_name}/global_step_1/actor \ + --target_dir checkpoints/${trainer.project_name}/${trainer.experiment_name}/global_step_1/actor/huggingface + +For more details about checkpoint and model merging, please refer to :ref:`checkpoint-page`. To enable ``wandb`` for experiment tracking, set the following configs: .. code-block:: bash - trainer.logger=['console','wandb'] \ + trainer.logger='["console","wandb"]' \ trainer.project_name=$YOUR_PROJECT_NAME \ trainer.experiment_name=$YOUR_RUN_NAME \ diff --git a/docs/start/ray_debug_tutorial.rst b/docs/start/ray_debug_tutorial.rst index 334bd72d5..9e7c87dfa 100644 --- a/docs/start/ray_debug_tutorial.rst +++ b/docs/start/ray_debug_tutorial.rst @@ -1,6 +1,9 @@ Ray Debug Tutorial ================== +Last updated: 04/23/2025 + + .. _wuxibin89: https://github.com/wuxibin89 Author: `Ao Shen `_. diff --git a/docs/workers/fsdp_workers.rst b/docs/workers/fsdp_workers.rst index b1c78e7c8..b158fb265 100644 --- a/docs/workers/fsdp_workers.rst +++ b/docs/workers/fsdp_workers.rst @@ -1,6 +1,8 @@ PyTorch FSDP Backend ====================== +Last updated: 02/12/2025. + We support PyTorch FSDP Backend by implementing various workers for actor, critic, reference, rollout and reward models. We also implement the ``FSDPVLLMShardingManager`` that reshard weight between FSDP and diff --git a/docs/workers/megatron_workers.rst b/docs/workers/megatron_workers.rst index c7ae31740..b93bd033c 100644 --- a/docs/workers/megatron_workers.rst +++ b/docs/workers/megatron_workers.rst @@ -1,6 +1,8 @@ Megatron-LM Backend =================== +Last updated: 06/24/2025. + We support Megatron Backend by implementing various workers for actor, critic, reference, rollout and reward models. We also implement the ``3DHybridEngine`` using Megatron-LM and vLLM/SGLang in @@ -77,7 +79,7 @@ MegatronWorker ``MegatronWorker`` is the base class of different megatron worker classes. In this class, ``get_megatron_global_info`` and -``get_megatron_rank_info`` function to retrive the 3D parallel world +``get_megatron_rank_info`` function to retrieve the 3D parallel world size and rank of each ``Worker`` running on specific GPU. These information will be used in transfer protocol for Megatron Backend. @@ -113,27 +115,19 @@ initialization process. The initialization details of HybridEngine, Actor and Rollout are highlighted below: -1. ``AllGatherPPModel`` holds memory buffer for both Actor and Rollout - and support weight resharding between actor and rollout. -2. ``MegatronPPOActor`` implements the simple PPO computation logics +1. ``MegatronPPOActor`` implements the simple PPO computation logics when the model is built with Megatron, including compute log prob, model update. -3. ``vLLMRollout`` support generation with vLLM. We modify the vLLM +2. ``vLLMRollout`` support generation with vLLM. We modify the vLLM Engine and make it executed under SPMD to fit into our ``WorkerGroup`` design. -4. ``MegatronVLLMShardingManager`` a context manager to perform actual +3. ``MegatronVLLMShardingManager`` a context manager to perform actual resharding between actor and rollout. See `source code `_ for more information. .. code:: python - # Initialize the 3D HybridEngine - hybrid_engine = AllGatherPPModel(model_provider=megatron_actor_model_provider) - # Fetch the model at current rank - actor_module = hybrid_engine.this_rank_models - ... - # build actor model self.actor = MegatronPPOActor(config=self.config.actor, model_config=self.actor_model_config, @@ -156,7 +150,7 @@ See `source code `_. \ No newline at end of file +kinds of models, please refer to `MCore Document `_. diff --git a/docs/workers/ray_trainer.rst b/docs/workers/ray_trainer.rst index 29873ce22..9c482d39a 100644 --- a/docs/workers/ray_trainer.rst +++ b/docs/workers/ray_trainer.rst @@ -1,6 +1,8 @@ PPO Ray Trainer =============== +Last updated: 02/12/2025. + We implement the RayPPOTrainer, which is a trainer runs on the driver process on a single CPU/GPU node (default is CPU). diff --git a/docs/workers/sglang_worker.rst b/docs/workers/sglang_worker.rst index df208a45a..1ef93823c 100644 --- a/docs/workers/sglang_worker.rst +++ b/docs/workers/sglang_worker.rst @@ -1,5 +1,8 @@ SGLang Backend ============== + +Last updated: 05/31/2025. + **Authored By SGLang RL Team and listed alphabetically by last name** `Jingyi Chen `_, `Yitong Guan `_, `Zhuobin Huang `_, `Jiajun Li `_, `Ji Li `_, `Shenggui Li `_, `Junrong Lin `_, `Xiang Long `_, `Rui Lu `_, `Jin Pan `_, `Shuai Shi `_, `Yushen Su `_, `Xinyuan Tong `_, `Chendong Wang `_, `Hanchen Zhang `_, `Haoran Wang `_, `Yongan Xiang `_, `Chengxing Xie `_, `Yuhao Yang `_, `Jinwei Yao `_, `Qiaolin Yu `_, `Yuzhen Zhou `_, `Chenyang Zhao `_ @@ -73,9 +76,8 @@ We use Qwen/Qwen2-7B-Instruct on the gsm8k dataset for a simple test. critic.model.fsdp_config.param_offload=True \ critic.model.fsdp_config.optimizer_offload=True \ algorithm.kl_ctrl.kl_coef=0.001 \ - trainer.logger=['console'] \ + trainer.logger=console \ trainer.val_before_train=False \ - trainer.default_hdfs_dir=null \ trainer.n_gpus_per_node=4 \ trainer.nnodes=1 \ trainer.save_freq=-1 \ @@ -226,9 +228,8 @@ You can see that the cluster has two nodes with 16 GPUs: critic.model.fsdp_config.optimizer_offload=True \ algorithm.kl_ctrl.kl_coef=0.001 \ trainer.critic_warmup=0 \ - trainer.logger=['console'] \ + trainer.logger=console \ trainer.val_before_train=True \ - trainer.default_hdfs_dir=null \ trainer.n_gpus_per_node=8 \ trainer.nnodes=2 \ trainer.save_freq=-1 \ diff --git a/examples/data_preprocess/hellaswag.py b/examples/data_preprocess/hellaswag.py new file mode 100644 index 000000000..1b3f20080 --- /dev/null +++ b/examples/data_preprocess/hellaswag.py @@ -0,0 +1,96 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Preprocess Hellaswag dataset. + +""" + +import argparse +import os +import re + +import datasets + +from verl.utils.hdfs_io import copy, makedirs + + +def preprocess(text): + text = text.strip() + # NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag. + text = text.replace(" [title]", ". ") + text = re.sub("\\[.*?\\]", "", text) + text = text.replace(" ", " ") + return text + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--local_dir", default="/opt/tiger/hellaswag") + parser.add_argument("--hdfs_dir", default=None) + + args = parser.parse_args() + + data_source = "Rowan/hellaswag" + + dataset = datasets.load_dataset(data_source, trust_remote_code=True) + + train_dataset = dataset["train"] + val_dataset = dataset["validation"] + test_dataset = dataset["test"] + + instruction = "Please complete the following sentence.\n" + + def make_map_fn(split): + def process_fn(doc, idx): + ctx = doc["ctx_a"] + " " + doc["ctx_b"].capitalize() + query = preprocess(doc["activity_label"] + ": " + ctx) + choices = [preprocess(ending) for ending in doc["endings"]] + gold = int(doc["label"]) + + data = { + "data_source": data_source, + "prompt": [{"role": "user", "content": query}], + "ability": "nlp", + "reward_model": { + "style": "model", + "eval": "multiple_choice", # using loglikelihood + "ground_truth": gold, + "choices": choices, + }, + "extra_info": {"split": split, "index": idx}, + } + return data + + return process_fn + + # filter data that doesn't have a label + train_dataset = train_dataset.filter(lambda x: len(x["label"]) > 0) + val_dataset = val_dataset.filter(lambda x: len(x["label"]) > 0) + test_dataset = test_dataset.filter(lambda x: len(x["label"]) > 0) + + train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True) + val_dataset = val_dataset.map(function=make_map_fn("validation"), with_indices=True) + test_dataset = test_dataset.map(function=make_map_fn("test"), with_indices=True) + + local_dir = args.local_dir + hdfs_dir = args.hdfs_dir + + train_dataset.to_parquet(os.path.join(local_dir, "train.parquet")) + val_dataset.to_parquet(os.path.join(local_dir, "validation.parquet")) + test_dataset.to_parquet(os.path.join(local_dir, "test.parquet")) + + if hdfs_dir is not None: + makedirs(hdfs_dir) + + copy(src=local_dir, dst=hdfs_dir) diff --git a/examples/data_preprocess/math_dataset.py b/examples/data_preprocess/math_dataset.py new file mode 100644 index 000000000..e2e5d3524 --- /dev/null +++ b/examples/data_preprocess/math_dataset.py @@ -0,0 +1,81 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Preprocess the MATH-lighteval dataset to parquet format +""" + +import argparse +import os + +import datasets + +from verl.utils.hdfs_io import copy, makedirs +from verl.utils.reward_score.math import last_boxed_only_string, remove_boxed + + +def extract_solution(solution_str): + return remove_boxed(last_boxed_only_string(solution_str)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--local_dir", default="~/data/math") + parser.add_argument("--hdfs_dir", default=None) + + args = parser.parse_args() + + # 'lighteval/MATH' is no longer available on huggingface. + # Use mirror repo: DigitalLearningGmbH/MATH-lighteval + data_source = "DigitalLearningGmbH/MATH-lighteval" + print(f"Loading the {data_source} dataset from huggingface...", flush=True) + dataset = datasets.load_dataset(data_source, trust_remote_code=True) + + train_dataset = dataset["train"] + test_dataset = dataset["test"] + + instruction_following = "Let's think step by step and output the final answer within \\boxed{}." + + # add a row to each data item that represents a unique id + def make_map_fn(split): + def process_fn(example, idx): + question = example.pop("problem") + + question = question + " " + instruction_following + + answer = example.pop("solution") + solution = extract_solution(answer) + data = { + "data_source": data_source, + "prompt": [{"role": "user", "content": question}], + "ability": "math", + "reward_model": {"style": "rule", "ground_truth": solution}, + "extra_info": {"split": split, "index": idx}, + } + return data + + return process_fn + + train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True) + test_dataset = test_dataset.map(function=make_map_fn("test"), with_indices=True) + + local_dir = args.local_dir + hdfs_dir = args.hdfs_dir + + train_dataset.to_parquet(os.path.join(local_dir, "train.parquet")) + test_dataset.to_parquet(os.path.join(local_dir, "test.parquet")) + + if hdfs_dir is not None: + makedirs(hdfs_dir) + + copy(src=local_dir, dst=hdfs_dir) diff --git a/examples/generation/run_deepseek_v2_lite_math.sh b/examples/generation/run_deepseek_v2_lite_math.sh index 6f13fef53..0c5a74b1f 100644 --- a/examples/generation/run_deepseek_v2_lite_math.sh +++ b/examples/generation/run_deepseek_v2_lite_math.sh @@ -1,7 +1,7 @@ set -x -data_path=$HOME/data/rlhf/gsm8k/test.parquet -save_path=$HOME/data/rlhf/math/deepseek_v2_lite_gen_test.parquet +data_path=$HOME/data/gsm8k/test.parquet +save_path=$HOME/data/gsm8k/deepseek_v2_lite_gen_test.parquet model_path=deepseek-ai/deepseek-llm-7b-chat python3 -m verl.trainer.main_generation \ diff --git a/examples/gpg_trainer/gpg.md b/examples/gpg_trainer/gpg.md new file mode 100644 index 000000000..b40cc83bc --- /dev/null +++ b/examples/gpg_trainer/gpg.md @@ -0,0 +1,34 @@ +# GPG: Group Policy Gradient + +Group Policy Gradient (GPG) is a minimalist reinforcement learning (RL) method that enhances the reasoning ability of large language models without relying on supervised fine-tuning or complex tricks. GPG revisits traditional policy gradients and directly optimizes the RL objective—no surrogate losses, no KL penalties, no critic, and no reference model. Compared to GRPO, GPG is simpler, more efficient, and achieves better results on many tasks. For more details, please refer to the original paper [GPG: A Simple and Strong Reinforcement Learning Baseline for Model Reasoning +](https://arxiv.org/abs/2504.02546). + +## Key Components +- Use a corrected advantage function to improve policy gradient accuracy and training efficiency. +- By eliminating the critic and reference models, avoiding KL divergence constraints, significantly simplifies the training process compared to Group Relative Policy Optimization (GRPO) + +## Configuration +To configure GPG within the framework, use the following YAML settings. + +```yaml +algorithm: + adv_estimator: gpg +actor_rollout_ref: + actor: + policy_loss: + loss_mode: "gpg" +``` + +## Advanced Extensions +GPG is a simple and strong baseline for model reasoning. Although it avoids using KL loss in its original form, you can still use KL loss to further improve the performance. + +```yaml +algorithm: + adv_estimator: gpg +actor_rollout_ref: + actor: + use_kl_loss: True # enable kl regularization + kl_loss_coef: 0.01 + policy_loss: + loss_mode: "gpg" +``` \ No newline at end of file diff --git a/examples/gpg_trainer/run_qwen2-7b_math.sh b/examples/gpg_trainer/run_qwen2-7b_math.sh new file mode 100755 index 000000000..1454bf294 --- /dev/null +++ b/examples/gpg_trainer/run_qwen2-7b_math.sh @@ -0,0 +1,52 @@ +set -x + +# If you are using vllm<=0.6.3, you might need to set the following environment variable to avoid bugs: +# export VLLM_ATTENTION_BACKEND=XFORMERS + +gsm8k_train_path=$HOME/data/gsm8k/train.parquet +gsm8k_test_path=$HOME/data/gsm8k/test.parquet +math_train_path=$HOME/data/math/train.parquet +math_test_path=$HOME/data/math/test.parquet + +train_files="['$gsm8k_train_path', '$math_train_path']" +test_files="['$gsm8k_test_path', '$math_test_path']" + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=gpg \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.train_batch_size=1024 \ + data.max_prompt_length=1024 \ + data.max_response_length=1024 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=256 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.actor.use_kl_loss=False \ + actor_rollout_ref.actor.policy_loss.loss_mode=gpg \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='verl_gpg_example_gsm8k_math' \ + trainer.experiment_name='qwen2_7b_function_rm' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=20 \ + trainer.test_freq=5 \ + trainer.total_epochs=15 $@ \ No newline at end of file diff --git a/tests/e2e/run_qwen2-7b_math_megatron.sh b/examples/gpg_trainer/run_qwen2-7b_math_megatron.sh old mode 100644 new mode 100755 similarity index 90% rename from tests/e2e/run_qwen2-7b_math_megatron.sh rename to examples/gpg_trainer/run_qwen2-7b_math_megatron.sh index f5f135530..2317fa07d --- a/tests/e2e/run_qwen2-7b_math_megatron.sh +++ b/examples/gpg_trainer/run_qwen2-7b_math_megatron.sh @@ -14,7 +14,7 @@ test_files="['$gsm8k_test_path', '$math_test_path']" python3 -m verl.trainer.main_ppo --config-path=config \ --config-name='ppo_megatron_trainer.yaml'\ - algorithm.adv_estimator=grpo \ + algorithm.adv_estimator=gpg \ data.train_files="$train_files" \ data.val_files="$test_files" \ data.train_batch_size=1024 \ @@ -28,7 +28,8 @@ python3 -m verl.trainer.main_ppo --config-path=config \ actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \ actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2 \ - actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.policy_loss.loss_mode=gpg \ + actor_rollout_ref.actor.use_kl_loss=False \ actor_rollout_ref.actor.kl_loss_coef=0.001 \ actor_rollout_ref.actor.kl_loss_type=low_var_kl \ actor_rollout_ref.actor.entropy_coeff=0 \ @@ -43,8 +44,8 @@ python3 -m verl.trainer.main_ppo --config-path=config \ actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ - trainer.logger=['console','wandb'] \ - trainer.project_name='verl_grpo_example_gsm8k_math' \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='verl_gpg_example_gsm8k_math' \ trainer.experiment_name='qwen2_7b_megatron' \ trainer.n_gpus_per_node=8 \ trainer.nnodes=1 \ diff --git a/examples/grpo_trainer/run_deepseek671b_math_megatron.sh b/examples/grpo_trainer/run_deepseek671b_math_megatron.sh index 567b7ca4d..2087570ed 100644 --- a/examples/grpo_trainer/run_deepseek671b_math_megatron.sh +++ b/examples/grpo_trainer/run_deepseek671b_math_megatron.sh @@ -1,9 +1,10 @@ set -x # 0. download the config -# only need to download the configuration_deepseek.py and config.json +# only need to download the `configuration_deepseek.py`, `config.json`, `tokenizer_config.json`, `tokenizer.json` and `generation_config.json` # remove the `quantization_config` in the `config.json` # set `num_nextn_predict_layers=0` to disable MTP, which is not currently supported + huggingface-cli download deepseek-ai/DeepSeek-V3-0324 configuration_deepseek.py config.json # 1. download the dist_ckpt format model from https://huggingface.co/BearBiscuit05/dpsk-v3-671B-BF16-dist_ckpt/tree/main @@ -73,7 +74,7 @@ python3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megat actor_rollout_ref.rollout.top_k=-1 \ actor_rollout_ref.rollout.tensor_model_parallel_size=$INFER_TP \ algorithm.use_kl_in_reward=False \ - trainer.logger=['console','tensorboard'] \ + trainer.logger='["console","tensorboard"]' \ trainer.project_name='verl_megatron_gsm8k_examples' \ trainer.experiment_name='dsv3-32nodes' \ trainer.n_gpus_per_node=8 \ @@ -82,8 +83,6 @@ python3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megat trainer.test_freq=5 \ +actor_rollout_ref.actor.megatron.override_transformer_config.num_layers_in_first_pipeline_stage=3 \ +actor_rollout_ref.actor.megatron.override_transformer_config.num_layers_in_last_pipeline_stage=2 \ - +actor_rollout_ref.actor.megatron.override_transformer_config.moe_router_bias_update_rate=0.0 \ - +actor_rollout_ref.actor.megatron.override_transformer_config.moe_aux_loss_coeff=0.0001 \ actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=$PP \ actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=$PP \ actor_rollout_ref.actor.megatron.tensor_model_parallel_size=$TP \ @@ -100,7 +99,6 @@ python3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megat actor_rollout_ref.ref.megatron.use_dist_checkpointing=True \ actor_rollout_ref.actor.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \ actor_rollout_ref.ref.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \ - trainer.default_hdfs_dir=null \ trainer.default_local_dir=$CKPT_DIR \ trainer.val_before_train=False \ trainer.total_epochs=100 $@ diff --git a/tests/e2e/run_deepseek_grpo.sh b/examples/grpo_trainer/run_deepseek7b_llm.sh similarity index 92% rename from tests/e2e/run_deepseek_grpo.sh rename to examples/grpo_trainer/run_deepseek7b_llm.sh index d97206819..af9204ab1 100644 --- a/tests/e2e/run_deepseek_grpo.sh +++ b/examples/grpo_trainer/run_deepseek7b_llm.sh @@ -1,7 +1,5 @@ set -x -export VLLM_ATTENTION_BACKEND=XFORMERS - python3 -m verl.trainer.main_ppo \ algorithm.adv_estimator=grpo \ data.train_files=$HOME/data/gsm8k/train.parquet \ @@ -32,12 +30,11 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.ref.fsdp_config.param_offload=True \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ - trainer.logger=['console'] \ + trainer.logger=console \ trainer.project_name='verl_grpo_example_gsm8k' \ trainer.experiment_name='deepseek_llm_7b_function_rm' \ trainer.n_gpus_per_node=8 \ trainer.nnodes=1 \ trainer.save_freq=20 \ trainer.test_freq=5 \ - trainer.total_epochs=15 \ - trainer.total_training_steps=2 $@ \ No newline at end of file + trainer.total_epochs=15 $@ \ No newline at end of file diff --git a/tests/e2e/run_deepseek7b_llm_math.sh b/examples/grpo_trainer/run_deepseek7b_llm_math.sh similarity index 90% rename from tests/e2e/run_deepseek7b_llm_math.sh rename to examples/grpo_trainer/run_deepseek7b_llm_math.sh index a2273915b..198e6f4ae 100644 --- a/tests/e2e/run_deepseek7b_llm_math.sh +++ b/examples/grpo_trainer/run_deepseek7b_llm_math.sh @@ -1,7 +1,5 @@ set -x -# If you are using vllm<=0.6.3, you might need to set the following environment variable to avoid bugs: -# export VLLM_ATTENTION_BACKEND=XFORMERS gsm8k_train_path=$HOME/data/gsm8k/train.parquet gsm8k_test_path=$HOME/data/gsm8k/test.parquet @@ -41,11 +39,11 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.ref.fsdp_config.param_offload=True \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ - trainer.logger=['console','wandb'] \ + trainer.logger='["console","wandb"]' \ trainer.project_name='verl_grpo_example_gsm8k_math' \ trainer.experiment_name='deepseek_llm_7b_function_rm_math' \ trainer.n_gpus_per_node=8 \ trainer.nnodes=1 \ trainer.save_freq=20 \ trainer.test_freq=5 \ - trainer.total_epochs=15 $@ \ No newline at end of file + trainer.total_epochs=15 $@ diff --git a/tests/e2e/run_deepseek7b_llm_math_megatron.sh b/examples/grpo_trainer/run_deepseek7b_llm_math_megatron.sh similarity index 90% rename from tests/e2e/run_deepseek7b_llm_math_megatron.sh rename to examples/grpo_trainer/run_deepseek7b_llm_math_megatron.sh index f2c8a9ed4..84d59e2ee 100644 --- a/tests/e2e/run_deepseek7b_llm_math_megatron.sh +++ b/examples/grpo_trainer/run_deepseek7b_llm_math_megatron.sh @@ -1,7 +1,5 @@ set -x -# If you are using vllm<=0.6.3, you might need to set the following environment variable to avoid bugs: -# export VLLM_ATTENTION_BACKEND=XFORMERS export CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping gsm8k_train_path=$HOME/data/gsm8k/train.parquet @@ -43,11 +41,11 @@ python3 -m verl.trainer.main_ppo --config-path=config \ actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ - trainer.logger=['console','wandb'] \ + trainer.logger='["console","wandb"]' \ trainer.project_name='verl_grpo_example_gsm8k_math' \ trainer.experiment_name='deepseek_llm_7b_math_megatron' \ trainer.n_gpus_per_node=16 \ trainer.nnodes=1 \ trainer.save_freq=20 \ trainer.test_freq=5 \ - trainer.total_epochs=15 $@ \ No newline at end of file + trainer.total_epochs=15 $@ diff --git a/examples/grpo_trainer/run_deepseek7b_llm_seq_balance.sh b/examples/grpo_trainer/run_deepseek7b_llm_seq_balance.sh index 1870c82d7..72cd4445a 100644 --- a/examples/grpo_trainer/run_deepseek7b_llm_seq_balance.sh +++ b/examples/grpo_trainer/run_deepseek7b_llm_seq_balance.sh @@ -29,7 +29,7 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.ref.fsdp_config.param_offload=True \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ - trainer.logger=['console','wandb'] \ + trainer.logger='["console","wandb"]' \ trainer.project_name='verl_grpo_example_gsm8k' \ trainer.experiment_name='deepseek_llm_7b_function_rm_seq_packing' \ trainer.n_gpus_per_node=8 \ diff --git a/tests/e2e/run_qwen2vl_geo3k_function_rm.sh b/examples/grpo_trainer/run_minicpmo2_6.sh similarity index 56% rename from tests/e2e/run_qwen2vl_geo3k_function_rm.sh rename to examples/grpo_trainer/run_minicpmo2_6.sh index 396c53c41..d95622e1a 100644 --- a/tests/e2e/run_qwen2vl_geo3k_function_rm.sh +++ b/examples/grpo_trainer/run_minicpmo2_6.sh @@ -1,41 +1,49 @@ set -x -export VLLM_ATTENTION_BACKEND=XFORMERS - python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ data.train_files=$HOME/data/geo3k/train.parquet \ data.val_files=$HOME/data/geo3k/test.parquet \ data.train_batch_size=128 \ - data.max_prompt_length=1536 \ - data.max_response_length=1536 \ + data.max_prompt_length=1024 \ + data.max_response_length=2048 \ + data.filter_overlong_prompts=False \ + data.truncation='error' \ data.image_key=images \ - actor_rollout_ref.model.path=Qwen/Qwen2-VL-2B-Instruct \ - actor_rollout_ref.actor.optim.lr=1e-6 \ + data.trust_remote_code=True \ + data.custom_cls.path=recipe/minicpmo/rl_dataset.py \ + data.custom_cls.name=RLHFDataset \ + actor_rollout_ref.model.path=openbmb/MiniCPM-o-2_6 \ + actor_rollout_ref.model.trust_remote_code=True \ actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=128 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.ppo_mini_batch_size=32 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \ actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.use_dynamic_bsz=False \ actor_rollout_ref.actor.kl_loss_coef=0.001 \ actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ actor_rollout_ref.actor.fsdp_config.param_offload=False \ actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + +actor_rollout_ref.actor.fsdp_config.use_orig_params=True \ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \ actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ actor_rollout_ref.rollout.enable_chunked_prefill=False \ - actor_rollout_ref.rollout.enforce_eager=True \ + actor_rollout_ref.rollout.enforce_eager=False \ actor_rollout_ref.rollout.free_cache_engine=False \ + actor_rollout_ref.rollout.n=8 \ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \ actor_rollout_ref.ref.fsdp_config.param_offload=True \ algorithm.kl_ctrl.kl_coef=0.001 \ - algorithm.adv_estimator=grpo \ trainer.critic_warmup=0 \ - trainer.logger=['console'] \ - trainer.project_name='verl_example_geo3k' \ - trainer.experiment_name='qwen2vl_e2e_ci_function_rm' \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='verl_grpo_example_geo3k' \ + trainer.experiment_name='minicpmo2_6_function_rm' \ trainer.n_gpus_per_node=8 \ trainer.nnodes=1 \ trainer.save_freq=-1 \ - trainer.total_training_steps=1 $@ + trainer.test_freq=5 \ + trainer.total_epochs=15 $@ diff --git a/tests/e2e/run_moonlight16b_math_megatron.sh b/examples/grpo_trainer/run_moonlight16b_math_megatron.sh similarity index 98% rename from tests/e2e/run_moonlight16b_math_megatron.sh rename to examples/grpo_trainer/run_moonlight16b_math_megatron.sh index b67833809..aebac5c18 100644 --- a/tests/e2e/run_moonlight16b_math_megatron.sh +++ b/examples/grpo_trainer/run_moonlight16b_math_megatron.sh @@ -43,7 +43,7 @@ python3 -m verl.trainer.main_ppo --config-path=config \ actor_rollout_ref.ref.megatron.dist_checkpointing_path=/mnt/hdfs/gaoziyuan/dist_ckpt/moonshotai/Moonlight-16B-A3B \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ - trainer.logger=['console','wandb'] \ + trainer.logger='["console","wandb"]' \ trainer.project_name='verl_grpo_example_gsm8k_math' \ trainer.experiment_name='moonlight_megatron_ep' \ trainer.n_gpus_per_node=8 \ diff --git a/tests/e2e/run_qwen_grpo.sh b/examples/grpo_trainer/run_qwen2-7b.sh similarity index 87% rename from tests/e2e/run_qwen_grpo.sh rename to examples/grpo_trainer/run_qwen2-7b.sh index f9162896f..c32087e8c 100644 --- a/tests/e2e/run_qwen_grpo.sh +++ b/examples/grpo_trainer/run_qwen2-7b.sh @@ -1,7 +1,5 @@ set -x -# If you are using vllm<=0.6.3, you might need to set the following environment variable to avoid bugs: -# export VLLM_ATTENTION_BACKEND=XFORMERS python3 -m verl.trainer.main_ppo \ algorithm.adv_estimator=grpo \ @@ -33,12 +31,11 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.ref.fsdp_config.param_offload=True \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ - trainer.logger=['console'] \ + trainer.logger='["console","wandb"]' \ trainer.project_name='verl_grpo_example_gsm8k' \ trainer.experiment_name='qwen2_7b_function_rm' \ trainer.n_gpus_per_node=8 \ trainer.nnodes=1 \ trainer.save_freq=20 \ trainer.test_freq=5 \ - trainer.total_epochs=15 \ - trainer.total_training_steps=2 $@ \ No newline at end of file + trainer.total_epochs=15 $@ \ No newline at end of file diff --git a/tests/e2e/run_qwen2-7b_math.sh b/examples/grpo_trainer/run_qwen2-7b_math.sh similarity index 90% rename from tests/e2e/run_qwen2-7b_math.sh rename to examples/grpo_trainer/run_qwen2-7b_math.sh index 09c723ea4..f4e6ec408 100644 --- a/tests/e2e/run_qwen2-7b_math.sh +++ b/examples/grpo_trainer/run_qwen2-7b_math.sh @@ -1,7 +1,5 @@ set -x -# If you are using vllm<=0.6.3, you might need to set the following environment variable to avoid bugs: -# export VLLM_ATTENTION_BACKEND=XFORMERS gsm8k_train_path=$HOME/data/gsm8k/train.parquet gsm8k_test_path=$HOME/data/gsm8k/test.parquet @@ -41,11 +39,11 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.ref.fsdp_config.param_offload=True \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ - trainer.logger=['console','wandb'] \ + trainer.logger='["console","wandb"]' \ trainer.project_name='verl_grpo_example_gsm8k_math' \ trainer.experiment_name='qwen2_7b_function_rm' \ trainer.n_gpus_per_node=8 \ trainer.nnodes=1 \ trainer.save_freq=20 \ trainer.test_freq=5 \ - trainer.total_epochs=15 $@ \ No newline at end of file + trainer.total_epochs=15 $@ diff --git a/tests/e2e/run_deepseek_grpo_megatron.sh b/examples/grpo_trainer/run_qwen2-7b_math_megatron.sh similarity index 50% rename from tests/e2e/run_deepseek_grpo_megatron.sh rename to examples/grpo_trainer/run_qwen2-7b_math_megatron.sh index d4eb929a7..0a23bab8f 100644 --- a/tests/e2e/run_deepseek_grpo_megatron.sh +++ b/examples/grpo_trainer/run_qwen2-7b_math_megatron.sh @@ -1,42 +1,62 @@ set -x -export VLLM_ATTENTION_BACKEND=XFORMERS +export CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping + +rollout_mode="sync" +if [ "$rollout_mode" = "async" ]; then + export VLLM_USE_V1=1 + return_raw_chat="True" +fi + +gsm8k_train_path=$HOME/data/gsm8k/train.parquet +gsm8k_test_path=$HOME/data/gsm8k/test.parquet +math_train_path=$HOME/data/math/train.parquet +math_test_path=$HOME/data/math/test.parquet + +train_files="['$gsm8k_train_path', '$math_train_path']" +test_files="['$gsm8k_test_path', '$math_test_path']" + +USE_FUSED_KERNELS=True python3 -m verl.trainer.main_ppo --config-path=config \ --config-name='ppo_megatron_trainer.yaml'\ algorithm.adv_estimator=grpo \ - data.train_files=$HOME/data/gsm8k/train.parquet \ - data.val_files=$HOME/data/gsm8k/test.parquet \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.return_raw_chat=$return_raw_chat \ data.train_batch_size=1024 \ data.max_prompt_length=1024 \ data.max_response_length=1024 \ data.filter_overlong_prompts=True \ data.truncation='error' \ - actor_rollout_ref.model.path=deepseek-ai/deepseek-coder-1.3b-instruct \ + actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \ + actor_rollout_ref.model.use_fused_kernels=$USE_FUSED_KERNELS \ actor_rollout_ref.actor.optim.lr=1e-6 \ actor_rollout_ref.actor.ppo_mini_batch_size=256 \ actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \ - actor_rollout_ref.actor.megatron.virtual_pipeline_model_parallel_size=2 \ actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2 \ actor_rollout_ref.actor.use_kl_loss=True \ actor_rollout_ref.actor.kl_loss_coef=0.001 \ actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ actor_rollout_ref.model.enable_gradient_checkpointing=True \ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.mode=$rollout_mode \ actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ actor_rollout_ref.rollout.n=5 \ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ - algorithm.kl_ctrl.kl_coef=0.001 \ + actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=2 \ + actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \ + algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ - trainer.logger=['console'] \ - trainer.project_name='verl_grpo_example_gsm8k' \ - trainer.experiment_name='deepseek_llm_7b_function_rm_math_megatron' \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='verl_grpo_example_gsm8k_math' \ + trainer.experiment_name='qwen2_7b_megatron' \ trainer.n_gpus_per_node=8 \ trainer.nnodes=1 \ - trainer.save_freq=-1 \ + trainer.save_freq=20 \ trainer.test_freq=5 \ - trainer.total_epochs=15 \ - trainer.total_training_steps=2 $@ \ No newline at end of file + trainer.total_epochs=15 $@ diff --git a/tests/e2e/run_qwen_grpo_megatron.sh b/examples/grpo_trainer/run_qwen2-7b_seq_balance.sh similarity index 62% rename from tests/e2e/run_qwen_grpo_megatron.sh rename to examples/grpo_trainer/run_qwen2-7b_seq_balance.sh index 1a1c27d75..79881a1e0 100644 --- a/tests/e2e/run_qwen_grpo_megatron.sh +++ b/examples/grpo_trainer/run_qwen2-7b_seq_balance.sh @@ -1,17 +1,15 @@ set -x -# If you are using vllm<=0.6.3, you might need to set the following environment variable to avoid bugs: -# export VLLM_ATTENTION_BACKEND=XFORMERS # For async rollout mode, dataset should return raw chat. -rollout_mode="sync" +rollout_mode="async" +rollout_name="sglang" # sglang or vllm if [ "$rollout_mode" = "async" ]; then + export VLLM_USE_V1=1 return_raw_chat="True" - chat_scheduler=examples.ppo_trainer.naive_chat_scheduler.NaiveChatCompletionScheduler fi -python3 -m verl.trainer.main_ppo --config-path=config \ - --config-name='ppo_megatron_trainer.yaml'\ +python3 -m verl.trainer.main_ppo \ algorithm.adv_estimator=grpo \ data.train_files=$HOME/data/gsm8k/train.parquet \ data.val_files=$HOME/data/gsm8k/test.parquet \ @@ -23,27 +21,27 @@ python3 -m verl.trainer.main_ppo --config-path=config \ data.truncation='error' \ actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \ actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \ - actor_rollout_ref.actor.megatron.virtual_pipeline_model_parallel_size=2 \ - actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2 \ + actor_rollout_ref.actor.use_dynamic_bsz=True \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \ actor_rollout_ref.actor.use_kl_loss=True \ actor_rollout_ref.actor.kl_loss_coef=0.001 \ actor_rollout_ref.actor.kl_loss_type=low_var_kl \ actor_rollout_ref.actor.entropy_coeff=0 \ actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.name=$rollout_name \ actor_rollout_ref.rollout.mode=$rollout_mode \ - actor_rollout_ref.rollout.chat_scheduler=$chat_scheduler \ + actor_rollout_ref.rollout.multi_turn.format=hermes \ actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ actor_rollout_ref.rollout.n=5 \ actor_rollout_ref.ref.fsdp_config.param_offload=True \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ - trainer.logger=['console'] \ + trainer.logger='["console","wandb"]' \ trainer.project_name='verl_grpo_example_gsm8k' \ trainer.experiment_name='qwen2_7b_function_rm_kl1e-3' \ trainer.val_before_train=False \ @@ -51,5 +49,4 @@ python3 -m verl.trainer.main_ppo --config-path=config \ trainer.nnodes=1 \ trainer.save_freq=20 \ trainer.test_freq=5 \ - trainer.total_epochs=15 \ - trainer.total_training_steps=2 $@ \ No newline at end of file + trainer.total_epochs=15 $@ \ No newline at end of file diff --git a/tests/e2e/run_qwen2-7b_seq_balance_math_megatron.sh b/examples/grpo_trainer/run_qwen2-7b_seq_balance_math_megatron.sh similarity index 91% rename from tests/e2e/run_qwen2-7b_seq_balance_math_megatron.sh rename to examples/grpo_trainer/run_qwen2-7b_seq_balance_math_megatron.sh index 07f5319d6..54572c02d 100644 --- a/tests/e2e/run_qwen2-7b_seq_balance_math_megatron.sh +++ b/examples/grpo_trainer/run_qwen2-7b_seq_balance_math_megatron.sh @@ -1,7 +1,5 @@ set -x -# If you are using vllm<=0.6.3, you might need to set the following environment variable to avoid bugs: -# export VLLM_ATTENTION_BACKEND=XFORMERS export CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping gsm8k_train_path=$HOME/data/gsm8k/train.parquet @@ -44,11 +42,11 @@ python3 -m verl.trainer.main_ppo --config-path=config \ actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ - trainer.logger=['console','wandb'] \ + trainer.logger='["console","wandb"]' \ trainer.project_name='verl_grpo_example_gsm8k_math' \ trainer.experiment_name='qwen2_7b_megatron' \ trainer.n_gpus_per_node=8 \ trainer.nnodes=1 \ trainer.save_freq=20 \ trainer.test_freq=5 \ - trainer.total_epochs=15 $@ \ No newline at end of file + trainer.total_epochs=15 $@ diff --git a/tests/e2e/run_qwen2-7b_sgl_megatron.sh b/examples/grpo_trainer/run_qwen2-7b_sgl_megatron.sh similarity index 97% rename from tests/e2e/run_qwen2-7b_sgl_megatron.sh rename to examples/grpo_trainer/run_qwen2-7b_sgl_megatron.sh index df66fdc3f..eeac388b4 100644 --- a/tests/e2e/run_qwen2-7b_sgl_megatron.sh +++ b/examples/grpo_trainer/run_qwen2-7b_sgl_megatron.sh @@ -38,7 +38,7 @@ python3 -m verl.trainer.main_ppo --config-path=config \ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ - trainer.logger=['console','wandb'] \ + trainer.logger='["console","wandb"]' \ trainer.project_name='verl_grpo_example_gsm8k' \ trainer.experiment_name='qwen2_7b_function_rm_megatron' \ trainer.n_gpus_per_node=8 \ diff --git a/tests/e2e/run_qwen2_5-3b_gsm8k_grpo_lora.sh b/examples/grpo_trainer/run_qwen2_5-3b_gsm8k_grpo_lora.sh similarity index 90% rename from tests/e2e/run_qwen2_5-3b_gsm8k_grpo_lora.sh rename to examples/grpo_trainer/run_qwen2_5-3b_gsm8k_grpo_lora.sh index eafac3d42..81236f621 100644 --- a/tests/e2e/run_qwen2_5-3b_gsm8k_grpo_lora.sh +++ b/examples/grpo_trainer/run_qwen2_5-3b_gsm8k_grpo_lora.sh @@ -1,7 +1,5 @@ set -x -# If you are using vllm<=0.6.3, you might need to set the following environment variable to avoid bugs: -# export VLLM_ATTENTION_BACKEND=XFORMERS python3 -m verl.trainer.main_ppo \ algorithm.adv_estimator=grpo \ @@ -39,11 +37,11 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.ref.fsdp_config.param_offload=True \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ - trainer.logger=['console','wandb'] \ + trainer.logger='["console","wandb"]' \ trainer.project_name='verl_grpo_example_gsm8k' \ trainer.experiment_name='qwen2.5_3b_grpo_lora' \ trainer.n_gpus_per_node=8 \ trainer.nnodes=1 \ trainer.save_freq=20 \ trainer.test_freq=5 \ - trainer.total_epochs=15 $@ \ No newline at end of file + trainer.total_epochs=15 $@ diff --git a/tests/e2e/run_qwen2_5-7b_math_megatron_diff_tp.sh b/examples/grpo_trainer/run_qwen2_5-7b_math_megatron_diff_tp.sh similarity index 90% rename from tests/e2e/run_qwen2_5-7b_math_megatron_diff_tp.sh rename to examples/grpo_trainer/run_qwen2_5-7b_math_megatron_diff_tp.sh index 185da0c6b..d4a1a3fcd 100644 --- a/tests/e2e/run_qwen2_5-7b_math_megatron_diff_tp.sh +++ b/examples/grpo_trainer/run_qwen2_5-7b_math_megatron_diff_tp.sh @@ -1,7 +1,5 @@ set -x -# If you are using vllm<=0.6.3, you might need to set the following environment variable to avoid bugs: -# export VLLM_ATTENTION_BACKEND=XFORMERS export CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping gsm8k_train_path=$HOME/data/gsm8k/train.parquet @@ -43,11 +41,11 @@ python3 -m verl.trainer.main_ppo --config-path=config \ actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ - trainer.logger=['console','wandb'] \ + trainer.logger='["console","wandb"]' \ trainer.project_name='verl_grpo_example_gsm8k_math' \ trainer.experiment_name='qwen2_7b_megatron' \ trainer.n_gpus_per_node=8 \ trainer.nnodes=1 \ trainer.save_freq=20 \ trainer.test_freq=5 \ - trainer.total_epochs=15 $@ \ No newline at end of file + trainer.total_epochs=15 $@ diff --git a/tests/npu/run_qwen2_5_32b_grpo.sh b/examples/grpo_trainer/run_qwen2_5_32b_grpo_npu.sh similarity index 90% rename from tests/npu/run_qwen2_5_32b_grpo.sh rename to examples/grpo_trainer/run_qwen2_5_32b_grpo_npu.sh index 461b27b80..6d0d4fe4e 100644 --- a/tests/npu/run_qwen2_5_32b_grpo.sh +++ b/examples/grpo_trainer/run_qwen2_5_32b_grpo_npu.sh @@ -1,8 +1,5 @@ set -x -# If you are using vllm<=0.6.3, you might need to set the following environment variable to avoid bugs: -# export VLLM_ATTENTION_BACKEND=XFORMERS - python3 -m verl.trainer.main_ppo \ algorithm.adv_estimator=grpo \ data.train_files=$HOME/data/gsm8k/train.parquet \ @@ -33,7 +30,7 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.ref.fsdp_config.param_offload=True \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ - trainer.logger=['console'] \ + trainer.logger=console \ trainer.project_name='verl_grpo_example_gsm8k' \ trainer.experiment_name='qwen2_5_32b_function_rm' \ trainer.n_gpus_per_node=16 \ diff --git a/examples/grpo_trainer/run_qwen2_5_7b_grpo_discrete_prof_npu.sh b/examples/grpo_trainer/run_qwen2_5_7b_grpo_discrete_prof_npu.sh new file mode 100644 index 000000000..44e94cd07 --- /dev/null +++ b/examples/grpo_trainer/run_qwen2_5_7b_grpo_discrete_prof_npu.sh @@ -0,0 +1,72 @@ +set -x + +# profiling configuration +PROFILE_STEPS="[2,4]" +PROFILE_RANKS_ALL=False +DISCRETE=True +PROFILE_RANKS="[1,2]" + +# profiling NPU options +SAVE_PATH="$HOME/profile_data" +LEVEL="level1" +WITH_MEMORY=False +RECORD_SHAPES=False +WITH_NPU=True +WITH_CPU=True +WITH_MODULE=False +WITH_STACK=False +ANALYSIS=True + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=$HOME/data/gsm8k/train.parquet \ + data.val_files=$HOME/data/gsm8k/test.parquet \ + data.train_batch_size=1024 \ + data.max_prompt_length=1024 \ + data.max_response_length=1024 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-7B-Instruct \ + actor_rollout_ref.actor.optim.lr=5e-8 \ + actor_rollout_ref.model.use_remove_padding=False \ + actor_rollout_ref.profiler.ranks=$PROFILE_RANKS \ + actor_rollout_ref.profiler.all_ranks=$PROFILE_RANKS_ALL \ + actor_rollout_ref.profiler.discrete=$DISCRETE \ + actor_rollout_ref.actor.ppo_mini_batch_size=32 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.3 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.rollout.enable_chunked_prefill=False \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.npu_profile.options.save_path=$SAVE_PATH \ + trainer.npu_profile.options.level=$LEVEL \ + trainer.npu_profile.options.with_memory=$WITH_MEMORY \ + trainer.npu_profile.options.record_shapes=$RECORD_SHAPES \ + trainer.npu_profile.options.with_npu=$WITH_NPU \ + trainer.npu_profile.options.with_cpu=$WITH_CPU \ + trainer.npu_profile.options.with_module=$WITH_MODULE \ + trainer.npu_profile.options.with_stack=$WITH_STACK \ + trainer.npu_profile.options.analysis=$ANALYSIS \ + trainer.critic_warmup=0 \ + trainer.logger=console \ + trainer.project_name='verl_grpo_example_gsm8k' \ + trainer.experiment_name='qwen2_5_7b_function_rm' \ + trainer.n_gpus_per_node=16 \ + trainer.nnodes=1 \ + trainer.save_freq=-1 \ + trainer.test_freq=5 \ + trainer.total_epochs=5 \ + trainer.profile_steps=$PROFILE_STEPS \ + trainer.device=npu $@ \ No newline at end of file diff --git a/examples/grpo_trainer/run_qwen2_5_7b_grpo_e2e_prof_npu.sh b/examples/grpo_trainer/run_qwen2_5_7b_grpo_e2e_prof_npu.sh new file mode 100644 index 000000000..70491c235 --- /dev/null +++ b/examples/grpo_trainer/run_qwen2_5_7b_grpo_e2e_prof_npu.sh @@ -0,0 +1,70 @@ +set -x + +# profiling configuration +PROFILE_STEPS="[2,4]" +PROFILE_RANKS_ALL=True +DISCRETE=False + +# profiling NPU options +SAVE_PATH="$HOME/profile_data" +LEVEL="level1" +WITH_MEMORY=False +RECORD_SHAPES=False +WITH_NPU=True +WITH_CPU=True +WITH_MODULE=False +WITH_STACK=False +ANALYSIS=True + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=$HOME/data/gsm8k/train.parquet \ + data.val_files=$HOME/data/gsm8k/test.parquet \ + data.train_batch_size=1024 \ + data.max_prompt_length=1024 \ + data.max_response_length=1024 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-7B-Instruct \ + actor_rollout_ref.actor.optim.lr=5e-8 \ + actor_rollout_ref.model.use_remove_padding=False \ + actor_rollout_ref.profiler.all_ranks=$PROFILE_RANKS_ALL \ + actor_rollout_ref.profiler.discrete=$DISCRETE \ + actor_rollout_ref.actor.ppo_mini_batch_size=32 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.3 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.rollout.enable_chunked_prefill=False \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.npu_profile.options.save_path=$SAVE_PATH \ + trainer.npu_profile.options.level=$LEVEL \ + trainer.npu_profile.options.with_memory=$WITH_MEMORY \ + trainer.npu_profile.options.record_shapes=$RECORD_SHAPES \ + trainer.npu_profile.options.with_npu=$WITH_NPU \ + trainer.npu_profile.options.with_cpu=$WITH_CPU \ + trainer.npu_profile.options.with_module=$WITH_MODULE \ + trainer.npu_profile.options.with_stack=$WITH_STACK \ + trainer.npu_profile.options.analysis=$ANALYSIS \ + trainer.critic_warmup=0 \ + trainer.logger=console \ + trainer.project_name='verl_grpo_example_gsm8k' \ + trainer.experiment_name='qwen2_5_7b_function_rm' \ + trainer.n_gpus_per_node=16 \ + trainer.nnodes=1 \ + trainer.save_freq=-1 \ + trainer.test_freq=5 \ + trainer.total_epochs=5 \ + trainer.profile_steps=$PROFILE_STEPS \ + trainer.device=npu $@ \ No newline at end of file diff --git a/tests/npu/run_qwen2_5_7b_grpo.sh b/examples/grpo_trainer/run_qwen2_5_7b_grpo_npu.sh similarity index 90% rename from tests/npu/run_qwen2_5_7b_grpo.sh rename to examples/grpo_trainer/run_qwen2_5_7b_grpo_npu.sh index ff173e2b5..07dda340c 100644 --- a/tests/npu/run_qwen2_5_7b_grpo.sh +++ b/examples/grpo_trainer/run_qwen2_5_7b_grpo_npu.sh @@ -1,8 +1,5 @@ set -x -# If you are using vllm<=0.6.3, you might need to set the following environment variable to avoid bugs: -# export VLLM_ATTENTION_BACKEND=XFORMERS - python3 -m verl.trainer.main_ppo \ algorithm.adv_estimator=grpo \ data.train_files=$HOME/data/gsm8k/train.parquet \ @@ -34,7 +31,7 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.ref.fsdp_config.param_offload=True \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ - trainer.logger=['console'] \ + trainer.logger=console \ trainer.project_name='verl_grpo_example_gsm8k' \ trainer.experiment_name='qwen2_5_7b_function_rm' \ trainer.n_gpus_per_node=16 \ diff --git a/examples/grpo_trainer/run_qwen2_5_vl-7b-megatron.sh b/examples/grpo_trainer/run_qwen2_5_vl-7b-megatron.sh new file mode 100644 index 000000000..d0de1aac5 --- /dev/null +++ b/examples/grpo_trainer/run_qwen2_5_vl-7b-megatron.sh @@ -0,0 +1,88 @@ +set -x +ENGINE=${1:-vllm} +export CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping + +HF_MODEL_PATH=Qwen/Qwen2.5-VL-7B-Instruct +DIST_CKPT_PATH=${DIST_CKPT_PATH} + +# convert HF model to meagatron format offlinely +# python scripts/converter_hf_to_mcore.py --hf_model_path $HF_MODEL_PATH --output_path $DIST_CKPT_PATH + + +# megatron tuning guide: +# 1. recommend to offload all states by setting ALL_OFFLOAD=True +# 2. enable dynamic batch size by setting actor_rollout_ref.actor.use_dynamic_bsz=True ref.log_prob_use_dynamic_bsz=True rollout.log_prob_use_dynamic_bsz=True +# 3. set ppo_max_token_len_per_gpu and log_prob_max_token_len_per_gpu as large as possible for better MFU (limited by GPU memory). assure ppo_max_token_len_per_gpu > max_prompt_length+max_response_length, if sequence length is too long, you can increase the TP/PP size +# 4. if memory is very limited, enable full recompute, but the mfu will be 30% lower +# full recompute settings: +# +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform \ +# +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full \ +# +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1 \ + +ALL_OFFLOAD=${ALL_OFFLOAD:-True} +COMMON_PARAM_OFFLOAD=${COMMON_PARAM_OFFLOAD:-$ALL_OFFLOAD} +COMMON_GRAD_OFFLOAD=${COMMON_GRAD_OFFLOAD:-$ALL_OFFLOAD} +COMMON_OPTIMIZER_OFFLOAD=${COMMON_OPTIMIZER_OFFLOAD:-$ALL_OFFLOAD} + +ACTOR_PARAM_OFFLOAD=${ACTOR_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD} +ACTOR_GRAD_OFFLOAD=${ACTOR_GRAD_OFFLOAD:-$COMMON_GRAD_OFFLOAD} +ACTOR_OPTIMIZER_OFFLOAD=${ACTOR_OPTIMIZER_OFFLOAD:-$COMMON_OPTIMIZER_OFFLOAD} +REF_PARAM_OFFLOAD=${REF_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD} + + +train_path=$HOME/data/geo3k/train.parquet +test_path=$HOME/data/geo3k/test.parquet + +python3 -m verl.trainer.main_ppo --config-path=config \ + --config-name='ppo_megatron_trainer.yaml'\ + algorithm.adv_estimator=grpo \ + data.train_files="$train_path" \ + data.val_files="$test_path" \ + data.train_batch_size=512 \ + data.max_prompt_length=1024 \ + data.max_response_length=2048 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.model.path=$HF_MODEL_PATH \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.ppo_mini_batch_size=128 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=1 \ + actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.01 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=20 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.actor.use_dynamic_bsz=True \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=5120 \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=True \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=20480 \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=True \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=20480 \ + actor_rollout_ref.rollout.name=$ENGINE \ + actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=20 \ + actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=1 \ + actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \ + actor_rollout_ref.actor.megatron.use_dist_checkpointing=True \ + actor_rollout_ref.ref.megatron.use_dist_checkpointing=True \ + actor_rollout_ref.actor.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \ + actor_rollout_ref.ref.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \ + actor_rollout_ref.actor.megatron.param_offload=${ACTOR_PARAM_OFFLOAD} \ + actor_rollout_ref.actor.megatron.optimizer_offload=${ACTOR_OPTIMIZER_OFFLOAD} \ + actor_rollout_ref.actor.megatron.grad_offload=${ACTOR_GRAD_OFFLOAD} \ + actor_rollout_ref.ref.megatron.param_offload=${REF_PARAM_OFFLOAD} \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='verl_grpo_example_geo3k' \ + trainer.experiment_name='qwen2_5_vl_7b_megatron' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=20 \ + trainer.test_freq=5 \ + trainer.total_epochs=15 $@ \ No newline at end of file diff --git a/tests/e2e/run_qwen2_5_vl-7b.sh b/examples/grpo_trainer/run_qwen2_5_vl-7b.sh similarity index 87% rename from tests/e2e/run_qwen2_5_vl-7b.sh rename to examples/grpo_trainer/run_qwen2_5_vl-7b.sh index 947314d7c..450390e25 100644 --- a/tests/e2e/run_qwen2_5_vl-7b.sh +++ b/examples/grpo_trainer/run_qwen2_5_vl-7b.sh @@ -1,7 +1,5 @@ set -x ENGINE=${1:-vllm} -# If you are using vllm<=0.6.3, you might need to set the following environment variable to avoid bugs: -# export VLLM_ATTENTION_BACKEND=XFORMERS python3 -m verl.trainer.main_ppo \ algorithm.adv_estimator=grpo \ @@ -28,20 +26,21 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=20 \ actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ actor_rollout_ref.rollout.name=$ENGINE \ + actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \ actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ actor_rollout_ref.rollout.enable_chunked_prefill=False \ actor_rollout_ref.rollout.enforce_eager=False \ - actor_rollout_ref.rollout.free_cache_engine=False \ + actor_rollout_ref.rollout.free_cache_engine=True \ actor_rollout_ref.rollout.n=5 \ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=20 \ actor_rollout_ref.ref.fsdp_config.param_offload=True \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ - trainer.logger=['console','wandb'] \ + trainer.logger='["console","wandb"]' \ trainer.project_name='verl_grpo_example_geo3k' \ trainer.experiment_name='qwen2_5_vl_7b_function_rm' \ trainer.n_gpus_per_node=8 \ trainer.nnodes=1 \ trainer.save_freq=20 \ trainer.test_freq=5 \ - trainer.total_epochs=15 $@ \ No newline at end of file + trainer.total_epochs=15 $@ diff --git a/examples/grpo_trainer/run_qwen2_5_vl-7b_lora.sh b/examples/grpo_trainer/run_qwen2_5_vl-7b_lora.sh new file mode 100644 index 000000000..b00ad8087 --- /dev/null +++ b/examples/grpo_trainer/run_qwen2_5_vl-7b_lora.sh @@ -0,0 +1,52 @@ +set -x +ENGINE=${1:-vllm} +# If you are using vllm<=0.6.3, you might need to set the following environment variable to avoid bugs: +# export VLLM_ATTENTION_BACKEND=XFORMERS + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=$HOME/data/geo3k/train.parquet \ + data.val_files=$HOME/data/geo3k/test.parquet \ + data.train_batch_size=512 \ + data.max_prompt_length=1024 \ + data.max_response_length=2048 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.image_key=images \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-7B-Instruct \ + actor_rollout_ref.actor.optim.lr=3e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=128 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=10 \ + actor_rollout_ref.model.lora_rank=64 \ + actor_rollout_ref.model.lora_alpha=32 \ + actor_rollout_ref.model.target_modules=all-linear \ + actor_rollout_ref.model.exclude_modules='.*visual.*' \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.01 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=20 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=$ENGINE \ + actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.enable_chunked_prefill=False \ + actor_rollout_ref.rollout.enforce_eager=False \ + actor_rollout_ref.rollout.free_cache_engine=False \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=20 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='verl_grpo_example_geo3k' \ + trainer.experiment_name='qwen2_5_vl_7b_function_rm' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=20 \ + trainer.test_freq=5 \ + trainer.total_epochs=15 $@ \ No newline at end of file diff --git a/examples/grpo_trainer/run_qwen2_5_vl-7b_seq_balance.sh b/examples/grpo_trainer/run_qwen2_5_vl-7b_seq_balance.sh new file mode 100644 index 000000000..e9933b106 --- /dev/null +++ b/examples/grpo_trainer/run_qwen2_5_vl-7b_seq_balance.sh @@ -0,0 +1,45 @@ +set -x +ENGINE=${1:-vllm} + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=$HOME/data/geo3k/train.parquet \ + data.val_files=$HOME/data/geo3k/test.parquet \ + data.train_batch_size=512 \ + data.max_prompt_length=1024 \ + data.max_response_length=2048 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.image_key=images \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-7B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=128 \ + actor_rollout_ref.actor.use_dynamic_bsz=True \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=6144 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.01 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=$ENGINE \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.enable_chunked_prefill=False \ + actor_rollout_ref.rollout.enforce_eager=False \ + actor_rollout_ref.rollout.free_cache_engine=False \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=6144 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='verl_grpo_example_geo3k' \ + trainer.experiment_name='qwen2_5_vl_7b_function_rm' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=20 \ + trainer.test_freq=5 \ + trainer.total_epochs=15 $@ diff --git a/examples/grpo_trainer/run_qwen2_5_vl_32b_npu.sh b/examples/grpo_trainer/run_qwen2_5_vl_32b_npu.sh new file mode 100644 index 000000000..ef1301126 --- /dev/null +++ b/examples/grpo_trainer/run_qwen2_5_vl_32b_npu.sh @@ -0,0 +1,52 @@ +set -x +ENGINE=${1:-vllm} + +# Some models are optimized by vllm ascend. While in some case, e.g. rlhf training, +# the optimized model may not be suitable. In this case, set this value to 0 to disable the optimized model. +export USE_OPTIMIZED_MODEL=0 + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=$HOME/data/geo3k/train.parquet \ + data.val_files=$HOME/data/geo3k/test.parquet \ + data.train_batch_size=512 \ + data.max_prompt_length=1024 \ + data.max_response_length=2048 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.image_key=images \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-32B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=32 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.01 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.use_torch_compile=False \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=8 \ + actor_rollout_ref.rollout.name=$ENGINE \ + actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.3 \ + actor_rollout_ref.rollout.enable_chunked_prefill=False \ + actor_rollout_ref.rollout.enforce_eager=True \ + actor_rollout_ref.rollout.free_cache_engine=True \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger=console \ + trainer.project_name='verl_grpo_example_geo3k' \ + trainer.experiment_name='qwen2_5_vl_32b_function_rm' \ + trainer.n_gpus_per_node=16 \ + trainer.nnodes=2 \ + trainer.save_freq=-1 \ + trainer.test_freq=-1 \ + trainer.total_epochs=15 \ + trainer.device=npu $@ \ No newline at end of file diff --git a/examples/grpo_trainer/run_qwen2_5_vl_3b_npu.sh b/examples/grpo_trainer/run_qwen2_5_vl_3b_npu.sh new file mode 100644 index 000000000..b319dee99 --- /dev/null +++ b/examples/grpo_trainer/run_qwen2_5_vl_3b_npu.sh @@ -0,0 +1,52 @@ +set -x +ENGINE=${1:-vllm} + +# Some models are optimized by vllm ascend. While in some case, e.g. rlhf training, +# the optimized model may not be suitable. In this case, set this value to 0 to disable the optimized model. +export USE_OPTIMIZED_MODEL=0 + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=$HOME/data/geo3k/train.parquet \ + data.val_files=$HOME/data/geo3k/test.parquet \ + data.train_batch_size=512 \ + data.max_prompt_length=1024 \ + data.max_response_length=2048 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.image_key=images \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-3B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=16 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.01 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.use_torch_compile=False \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=$ENGINE \ + actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.enable_chunked_prefill=False \ + actor_rollout_ref.rollout.enforce_eager=True \ + actor_rollout_ref.rollout.free_cache_engine=True \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger=console \ + trainer.project_name='verl_grpo_example_geo3k' \ + trainer.experiment_name='qwen2_5_vl_3b_function_rm' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=-1 \ + trainer.test_freq=-1 \ + trainer.total_epochs=15 \ + trainer.device=npu $@ \ No newline at end of file diff --git a/examples/grpo_trainer/run_qwen2_5_vl_7b_npu.sh b/examples/grpo_trainer/run_qwen2_5_vl_7b_npu.sh new file mode 100644 index 000000000..913da5424 --- /dev/null +++ b/examples/grpo_trainer/run_qwen2_5_vl_7b_npu.sh @@ -0,0 +1,52 @@ +set -x +ENGINE=${1:-vllm} + +# Some models are optimized by vllm ascend. While in some case, e.g. rlhf training, +# the optimized model may not be suitable. In this case, set this value to 0 to disable the optimized model. +export USE_OPTIMIZED_MODEL=0 + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=$HOME/data/geo3k/train.parquet \ + data.val_files=$HOME/data/geo3k/test.parquet \ + data.train_batch_size=512 \ + data.max_prompt_length=1024 \ + data.max_response_length=2048 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.image_key=images \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-7B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=32 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.01 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.use_torch_compile=False \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ + actor_rollout_ref.rollout.name=$ENGINE \ + actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ + actor_rollout_ref.rollout.enable_chunked_prefill=False \ + actor_rollout_ref.rollout.enforce_eager=True \ + actor_rollout_ref.rollout.free_cache_engine=True \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger=console \ + trainer.project_name='verl_grpo_example_geo3k' \ + trainer.experiment_name='qwen2_5_vl_7b_function_rm' \ + trainer.n_gpus_per_node=16 \ + trainer.nnodes=1 \ + trainer.save_freq=-1 \ + trainer.test_freq=-1 \ + trainer.total_epochs=15 \ + trainer.device=npu $@ \ No newline at end of file diff --git a/tests/e2e/run_qwen3-236b_megatron.sh b/examples/grpo_trainer/run_qwen3-236b_megatron.sh similarity index 94% rename from tests/e2e/run_qwen3-236b_megatron.sh rename to examples/grpo_trainer/run_qwen3-236b_megatron.sh index dd064b8d6..7c3f741db 100644 --- a/tests/e2e/run_qwen3-236b_megatron.sh +++ b/examples/grpo_trainer/run_qwen3-236b_megatron.sh @@ -1,6 +1,10 @@ #!/usr/bin/env bash set -xeuo pipefail +# Note that we set the response length to 4k. This results in many truncations at the beginning. +# So the training dynamic acts as using RL to compress the math capabilities of QWen3 236b into 4k response instead of verbose thinking. +# We can achieve 0.5 on AIME'24 after 30 steps. + project_name='DAPO' exp_name='DAPO-Qwen3-236b-megatron-0531a1' @@ -120,7 +124,7 @@ python3 -m verl.trainer.main_ppo \ +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \ +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \ +reward_model.reward_kwargs.max_resp_len=${max_response_length} \ - trainer.logger=['console','wandb'] \ + trainer.logger='["console","wandb"]' \ trainer.project_name="${project_name}" \ trainer.experiment_name="${exp_name}" \ trainer.n_gpus_per_node=8 \ diff --git a/tests/e2e/run_qwen3-8b.sh b/examples/grpo_trainer/run_qwen3-8b.sh similarity index 97% rename from tests/e2e/run_qwen3-8b.sh rename to examples/grpo_trainer/run_qwen3-8b.sh index 4acacb2a9..a99b432d6 100644 --- a/tests/e2e/run_qwen3-8b.sh +++ b/examples/grpo_trainer/run_qwen3-8b.sh @@ -33,7 +33,7 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.ref.fsdp_config.param_offload=True \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ - trainer.logger=['console','wandb'] \ + trainer.logger='["console","wandb"]' \ trainer.project_name='verl_grpo_example_gsm8k' \ trainer.experiment_name='qwen3_8b_function_rm' \ trainer.n_gpus_per_node=8 \ diff --git a/tests/e2e/run_qwen3moe-30b_megatron.sh b/examples/grpo_trainer/run_qwen3moe-30b_megatron.sh similarity index 91% rename from tests/e2e/run_qwen3moe-30b_megatron.sh rename to examples/grpo_trainer/run_qwen3moe-30b_megatron.sh index b33aab142..49d5eb999 100644 --- a/tests/e2e/run_qwen3moe-30b_megatron.sh +++ b/examples/grpo_trainer/run_qwen3moe-30b_megatron.sh @@ -5,8 +5,6 @@ DIST_CKPT_PATH=${DIST_CKPT_PATH} python scripts/converter_hf_to_mcore.py --hf_model_path $HF_MODEL_PATH --output_path $DIST_CKPT_PATH -# If you are using vllm<=0.6.3, you might need to set the following environment variable to avoid bugs: -# export VLLM_ATTENTION_BACKEND=XFORMERS export CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping python3 -m verl.trainer.main_ppo --config-path=config \ @@ -46,11 +44,11 @@ python3 -m verl.trainer.main_ppo --config-path=config \ actor_rollout_ref.ref.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ - trainer.logger=['console','wandb'] \ + trainer.logger='["console","wandb"]' \ trainer.project_name='verl_grpo_example_gsm8k_math' \ trainer.experiment_name='qwen3_30b_moe_megatron' \ trainer.n_gpus_per_node=8 \ trainer.nnodes=4 \ trainer.save_freq=20 \ trainer.test_freq=5 \ - trainer.total_epochs=15 $@ \ No newline at end of file + trainer.total_epochs=15 $@ diff --git a/examples/ppo_trainer/README.md b/examples/ppo_trainer/README.md index 7d4069414..f4df70f9a 100644 --- a/examples/ppo_trainer/README.md +++ b/examples/ppo_trainer/README.md @@ -37,9 +37,9 @@ Most critic configs are similar to those of actors. Note that the critic model i - `actor_rollout_ref.actor.ppo_epochs`: Number of epochs for PPO updates on one set of sampled trajectories for actor -- `actor_rollout_ref.actor.ppo_epochs`: Number of epochs for PPO updates on one set of sampled trajectories for critic +- `critic.ppo_epochs`: Number of epochs for PPO updates on one set of sampled trajectories for critic. Defaults to `actor_rollout_ref.actor.ppo_epochs` -- `algorithm.gemma`: discount factor +- `algorithm.gamma`: discount factor - `algorithm.lam`: The lambda term that trades off between bias and variance in the GAE estimator @@ -86,7 +86,7 @@ Qwen2.5 training log and commands: [link](https://github.com/eric-haibin-lin/ver bash run_gemma.sh trainer.n_gpus_per_node=1 \ actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ - trainer.logger=['console'] \ + trainer.logger=console \ critic.model.path=Qwen/Qwen2.5-0.5B-Instruct \ actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \ data.train_batch_size=256 \ diff --git a/examples/ppo_trainer/run_deepseek7b_llm.sh b/examples/ppo_trainer/run_deepseek7b_llm.sh index 26a510c32..01e4a24a1 100644 --- a/examples/ppo_trainer/run_deepseek7b_llm.sh +++ b/examples/ppo_trainer/run_deepseek7b_llm.sh @@ -31,7 +31,7 @@ python3 -m verl.trainer.main_ppo \ critic.model.fsdp_config.optimizer_offload=False \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ - trainer.logger=['console','wandb'] \ + trainer.logger='["console","wandb"]' \ trainer.project_name='verl_example_gsm8k' \ trainer.experiment_name='deepseek_llm_7b_function_rm' \ trainer.n_gpus_per_node=8 \ diff --git a/tests/e2e/run_deepseek7b_llm_modelscope.sh b/examples/ppo_trainer/run_deepseek7b_llm_modelscope.sh similarity index 97% rename from tests/e2e/run_deepseek7b_llm_modelscope.sh rename to examples/ppo_trainer/run_deepseek7b_llm_modelscope.sh index 85d77caec..eb6dc7923 100644 --- a/tests/e2e/run_deepseek7b_llm_modelscope.sh +++ b/examples/ppo_trainer/run_deepseek7b_llm_modelscope.sh @@ -32,7 +32,7 @@ python3 -m verl.trainer.main_ppo \ critic.model.fsdp_config.optimizer_offload=False \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ - trainer.logger=['console','wandb'] \ + trainer.logger='["console","wandb"]' \ trainer.project_name='verl_example_gsm8k' \ trainer.experiment_name='deepseek_llm_7b_function_rm' \ trainer.n_gpus_per_node=8 \ diff --git a/tests/e2e/run_deepseek7b_llm_pfppo.sh b/examples/ppo_trainer/run_deepseek7b_llm_pfppo.sh similarity index 97% rename from tests/e2e/run_deepseek7b_llm_pfppo.sh rename to examples/ppo_trainer/run_deepseek7b_llm_pfppo.sh index 336884cef..312c6b50b 100644 --- a/tests/e2e/run_deepseek7b_llm_pfppo.sh +++ b/examples/ppo_trainer/run_deepseek7b_llm_pfppo.sh @@ -35,7 +35,7 @@ python3 -m verl.trainer.main_ppo \ critic.model.fsdp_config.optimizer_offload=False \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ - trainer.logger=['console','wandb'] \ + trainer.logger='["console","wandb"]' \ trainer.project_name='verl_example_gsm8k' \ trainer.experiment_name='deepseek_llm_7b_function_rm' \ trainer.n_gpus_per_node=8 \ diff --git a/tests/e2e/run_deepseek7b_llm_sandbox_fusion.sh b/examples/ppo_trainer/run_deepseek7b_llm_sandbox_fusion.sh similarity index 97% rename from tests/e2e/run_deepseek7b_llm_sandbox_fusion.sh rename to examples/ppo_trainer/run_deepseek7b_llm_sandbox_fusion.sh index 46488a23d..69ee7b8bd 100644 --- a/tests/e2e/run_deepseek7b_llm_sandbox_fusion.sh +++ b/examples/ppo_trainer/run_deepseek7b_llm_sandbox_fusion.sh @@ -34,7 +34,7 @@ python3 -m verl.trainer.main_ppo \ critic.model.fsdp_config.optimizer_offload=False \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ - trainer.logger=['console','wandb'] \ + trainer.logger='["console","wandb"]' \ trainer.project_name='verl_example_sandbox_fusion' \ trainer.experiment_name='deepseek_llm_7b_function_sandbox_fusion' \ trainer.n_gpus_per_node=8 \ diff --git a/examples/ppo_trainer/run_deepseek7b_llm_sp2.sh b/examples/ppo_trainer/run_deepseek7b_llm_sp2.sh index cee6fee29..3cb8a852b 100644 --- a/examples/ppo_trainer/run_deepseek7b_llm_sp2.sh +++ b/examples/ppo_trainer/run_deepseek7b_llm_sp2.sh @@ -33,7 +33,7 @@ python3 -m verl.trainer.main_ppo \ critic.model.fsdp_config.optimizer_offload=False \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ - trainer.logger=['console','wandb'] \ + trainer.logger='["console","wandb"]' \ trainer.project_name='verl_example_gsm8k' \ trainer.experiment_name='deepseek_llm_7b_function_rm_sp2' \ trainer.n_gpus_per_node=8 \ diff --git a/examples/ppo_trainer/run_deepseek_full_hh_rlhf.sh b/examples/ppo_trainer/run_deepseek_full_hh_rlhf.sh index a2147cc57..976641f13 100644 --- a/examples/ppo_trainer/run_deepseek_full_hh_rlhf.sh +++ b/examples/ppo_trainer/run_deepseek_full_hh_rlhf.sh @@ -32,7 +32,7 @@ python3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megat reward_model.param_offload=False \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ - trainer.logger=['console','wandb'] \ + trainer.logger='["console","wandb"]' \ trainer.project_name='verl_megatron_full_hh_rlhf_examples' \ trainer.experiment_name='deepseek_llm_7b_model_rm' \ trainer.n_gpus_per_node=8 \ diff --git a/examples/ppo_trainer/run_deepseek_math_gsm8k_megatron.sh b/examples/ppo_trainer/run_deepseek_math_gsm8k_megatron.sh index a33367b46..c747b573f 100644 --- a/examples/ppo_trainer/run_deepseek_math_gsm8k_megatron.sh +++ b/examples/ppo_trainer/run_deepseek_math_gsm8k_megatron.sh @@ -2,8 +2,6 @@ set -x # Example runnable on H20 * 8 -# If you are using vllm<=0.6.3, you might need to set the following environment variable to avoid bugs: -# export VLLM_ATTENTION_BACKEND=XFORMERS export CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping gsm8k_train_path=$HOME/data/gsm8k/train.parquet @@ -42,7 +40,7 @@ python3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megat critic.ppo_micro_batch_size_per_gpu=4 \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ - trainer.logger=['console','wandb'] \ + trainer.logger='["console","wandb"]' \ trainer.project_name='verl_ppo_gsm8k_math_examples' \ trainer.experiment_name='deepseek_llm_7b_megatron' \ trainer.n_gpus_per_node=8 \ diff --git a/examples/ppo_trainer/run_deepseek_math_gsm8k_megatron_nsys.sh b/examples/ppo_trainer/run_deepseek_math_gsm8k_megatron_nsys.sh new file mode 100644 index 000000000..9cbbade33 --- /dev/null +++ b/examples/ppo_trainer/run_deepseek_math_gsm8k_megatron_nsys.sh @@ -0,0 +1,64 @@ +set -x + +# Example runnable on H20 * 8 + +export CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping + +gsm8k_train_path=$HOME/data/gsm8k/train.parquet +gsm8k_test_path=$HOME/data/gsm8k/test.parquet +math_train_path=$HOME/data/math/train.parquet +math_test_path=$HOME/data/math/test.parquet + +train_files=${train_files:-"$gsm8k_train_path"} +test_files=${test_files:-"$gsm8k_test_path"} + +# Nsight profiling configuration +PROFILE_STEPS="[1,2,5]" # or [] or null +PROFILE_RANKS_ALL=False # or True +PROFILE_RANKS=[0,4,8,12] +DISCRETE=True # or True + +python3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megatron_trainer'\ + algorithm.adv_estimator=gae \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.train_batch_size=256 \ + data.max_prompt_length=1024 \ + data.max_response_length=512 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.ppo_mini_batch_size=64 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \ + actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2 \ + actor_rollout_ref.actor.use_kl_loss=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \ + actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=2 \ + actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \ + actor_rollout_ref.profiler.ranks=$PROFILE_RANKS \ + actor_rollout_ref.profiler.all_ranks=$PROFILE_RANKS_ALL \ + actor_rollout_ref.profiler.discrete=$DISCRETE \ + critic.optim.lr=1e-5 \ + critic.model.path=deepseek-ai/deepseek-llm-7b-chat \ + critic.model.enable_gradient_checkpointing=False \ + critic.ppo_micro_batch_size_per_gpu=4 \ + critic.profiler.ranks=$PROFILE_RANKS \ + critic.profiler.all_ranks=$PROFILE_RANKS_ALL \ + critic.profiler.discrete=$DISCRETE \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='verl_ppo_gsm8k_math_examples' \ + trainer.experiment_name='deepseek_llm_7b_megatron' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=2 \ + trainer.save_freq=-1 \ + trainer.test_freq=-1 \ + trainer.total_epochs=100 \ + trainer.total_training_steps=6 \ + trainer.profile_steps=$PROFILE_STEPS $@ diff --git a/examples/ppo_trainer/run_gemma.sh b/examples/ppo_trainer/run_gemma.sh index 0b0b83df6..b015275c1 100644 --- a/examples/ppo_trainer/run_gemma.sh +++ b/examples/ppo_trainer/run_gemma.sh @@ -30,7 +30,7 @@ python3 -m verl.trainer.main_ppo \ critic.model.fsdp_config.optimizer_offload=False \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ - trainer.logger=['console','wandb'] \ + trainer.logger='["console","wandb"]' \ trainer.project_name='verl_example' \ trainer.experiment_name='gemma2b_function_rm' \ trainer.n_gpus_per_node=2 \ diff --git a/tests/e2e/run_moonlight16b_a3b_gsm8k_megatron.sh b/examples/ppo_trainer/run_moonlight16b_a3b_gsm8k_megatron.sh similarity index 96% rename from tests/e2e/run_moonlight16b_a3b_gsm8k_megatron.sh rename to examples/ppo_trainer/run_moonlight16b_a3b_gsm8k_megatron.sh index c9ab3bde8..64bdbb727 100644 --- a/tests/e2e/run_moonlight16b_a3b_gsm8k_megatron.sh +++ b/examples/ppo_trainer/run_moonlight16b_a3b_gsm8k_megatron.sh @@ -1,7 +1,5 @@ set -x -# If you are using vllm<=0.6.3, you might need to set the following environment variable to avoid bugs: -# export VLLM_ATTENTION_BACKEND=XFORMERS export CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping @@ -68,7 +66,7 @@ python3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megat critic.ppo_micro_batch_size_per_gpu=4 \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ - trainer.logger=['console','wandb'] \ + trainer.logger='["console","wandb"]' \ trainer.project_name='verl_megatron_gsm8k_examples' \ trainer.experiment_name='moonlight_16b_a3b_instruct_1node' \ trainer.n_gpus_per_node=8 \ diff --git a/tests/e2e/run_qwen1.5_moe_a2.7b-gsm8k_megatron.sh b/examples/ppo_trainer/run_qwen1.5_moe_a2.7b-gsm8k_megatron.sh similarity index 94% rename from tests/e2e/run_qwen1.5_moe_a2.7b-gsm8k_megatron.sh rename to examples/ppo_trainer/run_qwen1.5_moe_a2.7b-gsm8k_megatron.sh index 409e099ce..accdd7f65 100644 --- a/tests/e2e/run_qwen1.5_moe_a2.7b-gsm8k_megatron.sh +++ b/examples/ppo_trainer/run_qwen1.5_moe_a2.7b-gsm8k_megatron.sh @@ -1,7 +1,5 @@ set -x -# If you are using vllm<=0.6.3, you might need to set the following environment variable to avoid bugs: -# export VLLM_ATTENTION_BACKEND=XFORMERS export CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping # 0. download the model @@ -65,7 +63,7 @@ python3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megat critic.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ - trainer.logger=['console','wandb'] \ + trainer.logger='["console","wandb"]' \ trainer.project_name='verl_megatron_gsm8k_examples' \ trainer.experiment_name='qwen1.5_moe_nochat' \ trainer.n_gpus_per_node=8 \ diff --git a/examples/ppo_trainer/run_qwen2-7b_math_gsm8k_megatron.sh b/examples/ppo_trainer/run_qwen2-7b_math_gsm8k_megatron.sh index 9a30fa80b..22558c62b 100644 --- a/examples/ppo_trainer/run_qwen2-7b_math_gsm8k_megatron.sh +++ b/examples/ppo_trainer/run_qwen2-7b_math_gsm8k_megatron.sh @@ -1,7 +1,5 @@ set -x -# If you are using vllm<=0.6.3, you might need to set the following environment variable to avoid bugs: -# export VLLM_ATTENTION_BACKEND=XFORMERS export CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping gsm8k_train_path=$HOME/data/gsm8k/train.parquet @@ -40,7 +38,7 @@ python3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megat critic.ppo_micro_batch_size_per_gpu=4 \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ - trainer.logger=['console','wandb'] \ + trainer.logger='["console","wandb"]' \ trainer.project_name='verl_ppo_gsm8k_math_examples' \ trainer.experiment_name='qwen2_7b_megatron' \ trainer.n_gpus_per_node=8 \ diff --git a/examples/ppo_trainer/run_qwen2-7b_rm.sh b/examples/ppo_trainer/run_qwen2-7b_rm.sh index 4ca5517cb..98b305844 100644 --- a/examples/ppo_trainer/run_qwen2-7b_rm.sh +++ b/examples/ppo_trainer/run_qwen2-7b_rm.sh @@ -15,7 +15,6 @@ math_test_path=$HOME/data/math/test.parquet train_files="['$gsm8k_train_path', '$math_train_path']" test_files="['$gsm8k_test_path', '$math_test_path']" -export VLLM_ATTENTION_BACKEND=XFORMERS # vllm + qwen2-7b with flash_attn has some issues # prepare model ckpt huggingface-cli download Qwen/Qwen2-7B-Instruct --local-dir $HOME/models/Qwen2-7B-Instruct & @@ -61,7 +60,7 @@ python3 -m verl.trainer.main_ppo \ reward_model.micro_batch_size_per_gpu=32 \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ - trainer.logger=['console','wandb'] \ + trainer.logger='["console","wandb"]' \ trainer.project_name='verl_example' \ trainer.val_before_train=False \ trainer.experiment_name='Qwen2-7B-Instruct_hybrid_rm' \ diff --git a/examples/ppo_trainer/run_qwen2-7b_rm_seq_balance.sh b/examples/ppo_trainer/run_qwen2-7b_rm_seq_balance.sh index cf243c00f..e0ddc01e7 100644 --- a/examples/ppo_trainer/run_qwen2-7b_rm_seq_balance.sh +++ b/examples/ppo_trainer/run_qwen2-7b_rm_seq_balance.sh @@ -49,7 +49,7 @@ python3 -m verl.trainer.main_ppo \ reward_model.forward_max_token_len_per_gpu=98304 \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ - trainer.logger=['console','wandb'] \ + trainer.logger='["console","wandb"]' \ trainer.project_name='verl_example_gsm8k' \ trainer.experiment_name='qwen2-7b_hybrid_rm_bsz8k_p4k_r4k_seq_packing' \ trainer.n_gpus_per_node=8 \ diff --git a/examples/ppo_trainer/run_qwen2-7b_rm_seq_balance_fused_kernels.sh b/examples/ppo_trainer/run_qwen2-7b_rm_seq_balance_fused_kernels.sh new file mode 100644 index 000000000..7e0a335ef --- /dev/null +++ b/examples/ppo_trainer/run_qwen2-7b_rm_seq_balance_fused_kernels.sh @@ -0,0 +1,64 @@ +set -x + +gsm8k_train_path=$HOME/data/gsm8k/train.parquet +gsm8k_test_path=$HOME/data/gsm8k/test.parquet +math_train_path=$HOME/data/math/train.parquet +math_test_path=$HOME/data/math/test.parquet + +train_files="['$gsm8k_train_path', '$math_train_path']" +test_files="['$gsm8k_test_path', '$math_test_path']" + +FUSED_KERNEL_BACKEND=triton # or 'torch' for torch backend + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=gae \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.train_batch_size=4096 \ + data.max_prompt_length=4096 \ + data.max_response_length=4096 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.return_raw_chat=True \ + actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.model.use_fused_kernels=True \ + actor_rollout_ref.model.fused_kernel_options.impl_backend=$FUSED_KERNEL_BACKEND \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=512 \ + actor_rollout_ref.actor.use_dynamic_bsz=True \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.actor.use_kl_loss=False \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=24000 \ + critic.optim.lr=1e-5 \ + critic.model.use_remove_padding=True \ + critic.model.path=Qwen/Qwen2-7B-Instruct \ + critic.model.enable_gradient_checkpointing=True \ + critic.use_dynamic_bsz=True \ + critic.ppo_max_token_len_per_gpu=98304 \ + critic.model.fsdp_config.param_offload=False \ + critic.model.fsdp_config.optimizer_offload=False \ + reward_model.enable=True \ + reward_model.model.path=sfairXC/FsfairX-LLaMA3-RM-v0.1\ + reward_model.model.use_remove_padding=True \ + reward_model.model.fsdp_config.param_offload=True \ + reward_model.micro_batch_size_per_gpu=32 \ + reward_model.use_dynamic_bsz=True \ + reward_model.forward_max_token_len_per_gpu=98304 \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='verl_example_gsm8k' \ + trainer.experiment_name='qwen2-7b_hybrid_rm_bsz8k_p4k_r4k_seq_packing_fused_kernel' \ + trainer.n_gpus_per_node=8 \ + trainer.val_before_train=False \ + trainer.nnodes=1 \ + trainer.save_freq=20 \ + trainer.test_freq=5 \ + trainer.total_epochs=15 $@ diff --git a/examples/ppo_trainer/run_qwen2-7b_rm_seq_balance_nsys.sh b/examples/ppo_trainer/run_qwen2-7b_rm_seq_balance_nsys.sh new file mode 100644 index 000000000..4173d02ea --- /dev/null +++ b/examples/ppo_trainer/run_qwen2-7b_rm_seq_balance_nsys.sh @@ -0,0 +1,78 @@ +set -x + +gsm8k_train_path=$HOME/data/gsm8k/train.parquet +gsm8k_test_path=$HOME/data/gsm8k/test.parquet +math_train_path=$HOME/data/math/train.parquet +math_test_path=$HOME/data/math/test.parquet + +train_files=${train_files:-"$gsm8k_train_path"} +test_files=${test_files:-"$gsm8k_test_path"} + +PROFILE_STEPS="[1,2,5]" # or [] or null +PROFILE_RANKS_ALL=False # or True +PROFILE_RANKS=[0,4,8,12] +DISCRETE=True # or True + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=gae \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.train_batch_size=4096 \ + data.max_prompt_length=4096 \ + data.max_response_length=4096 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.return_raw_chat=True \ + actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=512 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.actor.use_dynamic_bsz=True \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.actor.use_kl_loss=False \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=24000 \ + actor_rollout_ref.profiler.ranks=$PROFILE_RANKS \ + actor_rollout_ref.profiler.all_ranks=$PROFILE_RANKS_ALL \ + actor_rollout_ref.profiler.discrete=$DISCRETE \ + critic.optim.lr=1e-5 \ + critic.model.use_remove_padding=True \ + critic.model.path=Qwen/Qwen2-7B-Instruct \ + critic.model.enable_gradient_checkpointing=True \ + critic.ppo_micro_batch_size_per_gpu=2 \ + critic.use_dynamic_bsz=True \ + critic.ppo_max_token_len_per_gpu=98304 \ + critic.model.fsdp_config.param_offload=False \ + critic.model.fsdp_config.optimizer_offload=False \ + critic.profiler.ranks=$PROFILE_RANKS \ + critic.profiler.all_ranks=$PROFILE_RANKS_ALL \ + critic.profiler.discrete=$DISCRETE \ + reward_model.enable=True \ + reward_model.model.path=sfairXC/FsfairX-LLaMA3-RM-v0.1\ + reward_model.model.use_remove_padding=True \ + reward_model.model.fsdp_config.param_offload=True \ + reward_model.micro_batch_size_per_gpu=32 \ + reward_model.use_dynamic_bsz=True \ + reward_model.forward_max_token_len_per_gpu=98304 \ + reward_model.profiler.ranks=$PROFILE_RANKS \ + reward_model.profiler.all_ranks=$PROFILE_RANKS_ALL \ + reward_model.profiler.discrete=$DISCRETE \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='verl_example_gsm8k' \ + trainer.experiment_name='qwen2-7b_hybrid_rm_bsz8k_p4k_r4k_seq_packing' \ + trainer.n_gpus_per_node=8 \ + trainer.val_before_train=False \ + trainer.nnodes=2 \ + trainer.save_freq=-1 \ + trainer.test_freq=-1 \ + trainer.total_epochs=15 \ + trainer.total_training_steps=6 \ + trainer.profile_steps=$PROFILE_STEPS $@ diff --git a/examples/ppo_trainer/run_qwen2-7b_seq_balance.sh b/examples/ppo_trainer/run_qwen2-7b_seq_balance.sh index abc490acd..9717e5f94 100644 --- a/examples/ppo_trainer/run_qwen2-7b_seq_balance.sh +++ b/examples/ppo_trainer/run_qwen2-7b_seq_balance.sh @@ -12,7 +12,6 @@ test_files="['$gsm8k_test_path', '$math_test_path']" rollout_mode="sync" if [ "$rollout_mode" = "async" ]; then return_raw_chat="True" - chat_scheduler=examples.ppo_trainer.naive_chat_scheduler.NaiveChatCompletionScheduler fi python3 -m verl.trainer.main_ppo \ @@ -38,7 +37,7 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ actor_rollout_ref.rollout.name=vllm \ actor_rollout_ref.rollout.mode=$rollout_mode \ - actor_rollout_ref.rollout.chat_scheduler=$chat_scheduler \ + actor_rollout_ref.rollout.multi_turn.format=hermes \ actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=24000 \ critic.optim.lr=1e-5 \ @@ -50,7 +49,7 @@ python3 -m verl.trainer.main_ppo \ critic.model.fsdp_config.optimizer_offload=False \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ - trainer.logger=['console','wandb'] \ + trainer.logger='["console","wandb"]' \ trainer.project_name='verl_example_gsm8k' \ trainer.experiment_name='qwen2-7b_function_rm_bsz8k_p4k_r4k_seq_packing' \ trainer.n_gpus_per_node=8 \ diff --git a/tests/e2e/run_qwen2-7b_sglang_seq_balance.sh b/examples/ppo_trainer/run_qwen2-7b_sglang_seq_balance.sh similarity index 98% rename from tests/e2e/run_qwen2-7b_sglang_seq_balance.sh rename to examples/ppo_trainer/run_qwen2-7b_sglang_seq_balance.sh index 42b2bd9e9..5108e8b5d 100644 --- a/tests/e2e/run_qwen2-7b_sglang_seq_balance.sh +++ b/examples/ppo_trainer/run_qwen2-7b_sglang_seq_balance.sh @@ -40,7 +40,7 @@ python3 -m verl.trainer.main_ppo \ critic.model.fsdp_config.optimizer_offload=False \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ - trainer.logger=['console','wandb'] \ + trainer.logger='["console","wandb"]' \ trainer.project_name='verl_example_gsm8k' \ trainer.experiment_name='qwen2-7b_function_rm_bsz8k_p4k_r4k_seq_packing' \ trainer.n_gpus_per_node=8 \ diff --git a/examples/ppo_trainer/run_qwen2.5-32b.sh b/examples/ppo_trainer/run_qwen2.5-32b.sh index 7c0b29d17..580376585 100644 --- a/examples/ppo_trainer/run_qwen2.5-32b.sh +++ b/examples/ppo_trainer/run_qwen2.5-32b.sh @@ -40,7 +40,7 @@ python3 -m verl.trainer.main_ppo \ critic.model.fsdp_config.optimizer_offload=False \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ - trainer.logger=['console','wandb'] \ + trainer.logger='["console","wandb"]' \ trainer.project_name='verl_example' \ trainer.experiment_name='Qwen2.5-32B-Instruct_function_rm' \ trainer.n_gpus_per_node=8 \ diff --git a/examples/ray/tutorial.ipynb b/examples/ray/tutorial.ipynb index 17686b1df..ca176af0f 100644 --- a/examples/ray/tutorial.ipynb +++ b/examples/ray/tutorial.ipynb @@ -919,7 +919,9 @@ } ], "source": [ - "output = layer_worker_group.run_layer([x]) # This must be a list of size 1, ensuring that the input equals the data parallel (dp).\n", + "output = layer_worker_group.run_layer(\n", + " [x]\n", + ") # This must be a list of size 1, ensuring that the input equals the data parallel (dp).\n", "print(output[0].shape)" ] }, diff --git a/examples/reinforce_plus_plus_trainer/run_qwen2-7b_math_rf.sh b/examples/reinforce_plus_plus_trainer/run_qwen2-7b_math_rf.sh index 98848c6ef..c2bf6d05b 100644 --- a/examples/reinforce_plus_plus_trainer/run_qwen2-7b_math_rf.sh +++ b/examples/reinforce_plus_plus_trainer/run_qwen2-7b_math_rf.sh @@ -1,7 +1,5 @@ set -x -# If you are using vllm<=0.6.3, you might need to set the following environment variable to avoid bugs: -# export VLLM_ATTENTION_BACKEND=XFORMERS gsm8k_train_path=$HOME/data/gsm8k/train.parquet gsm8k_test_path=$HOME/data/gsm8k/test.parquet @@ -41,11 +39,11 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.ref.fsdp_config.param_offload=True \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ - trainer.logger=['console','wandb'] \ + trainer.logger='["console","wandb"]' \ trainer.project_name='verl_grpo_example_gsm8k' \ trainer.experiment_name='qwen2_7b_function_rm' \ trainer.n_gpus_per_node=16 \ trainer.nnodes=1 \ trainer.save_freq=-1 \ trainer.test_freq=5 \ - trainer.total_epochs=15 $@ \ No newline at end of file + trainer.total_epochs=15 $@ diff --git a/examples/reinforce_plus_plus_trainer/run_qwen2-7b_math_rf_baseline.sh b/examples/reinforce_plus_plus_trainer/run_qwen2-7b_math_rf_baseline.sh index afad18b3c..b134ee5d1 100644 --- a/examples/reinforce_plus_plus_trainer/run_qwen2-7b_math_rf_baseline.sh +++ b/examples/reinforce_plus_plus_trainer/run_qwen2-7b_math_rf_baseline.sh @@ -1,7 +1,5 @@ set -x -# If you are using vllm<=0.6.3, you might need to set the following environment variable to avoid bugs: -# export VLLM_ATTENTION_BACKEND=XFORMERS gsm8k_train_path=$HOME/data/gsm8k/train.parquet gsm8k_test_path=$HOME/data/gsm8k/test.parquet @@ -41,11 +39,11 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.ref.fsdp_config.param_offload=True \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ - trainer.logger=['console','wandb'] \ + trainer.logger='["console","wandb"]' \ trainer.project_name='verl_grpo_example_gsm8k' \ trainer.experiment_name='qwen2_7b_function_rm' \ trainer.n_gpus_per_node=16 \ trainer.nnodes=1 \ trainer.save_freq=-1 \ trainer.test_freq=5 \ - trainer.total_epochs=15 $@ \ No newline at end of file + trainer.total_epochs=15 $@ diff --git a/examples/remax_trainer/run_qwen2.5-3b_seq_balance.sh b/examples/remax_trainer/run_qwen2.5-3b_seq_balance.sh index 48eb06e70..feebe8a84 100644 --- a/examples/remax_trainer/run_qwen2.5-3b_seq_balance.sh +++ b/examples/remax_trainer/run_qwen2.5-3b_seq_balance.sh @@ -3,8 +3,6 @@ set -x export HF_DATASETS_OFFLINE=1 export TRANSFORMERS_OFFLINE=1 -# If you are using vllm<=0.6.3, you might need to set the following environment variable to avoid bugs: -# export VLLM_ATTENTION_BACKEND=XFORMERS python3 -m verl.trainer.main_ppo \ algorithm.adv_estimator=remax \ @@ -34,7 +32,7 @@ python3 -m verl.trainer.main_ppo \ algorithm.kl_penalty=kl \ algorithm.kl_ctrl.kl_coef=0.001 \ trainer.critic_warmup=0 \ - trainer.logger=['console','wandb'] \ + trainer.logger='["console","wandb"]' \ trainer.project_name='verl_remax_example_gsm8k' \ trainer.experiment_name='qwen2.5_3b_function_rm_kl1e-3' \ trainer.val_before_train=False \ diff --git a/examples/remax_trainer/run_qwen2.5-7b_seq_balance.sh b/examples/remax_trainer/run_qwen2.5-7b_seq_balance.sh index 342eb8a67..8734eb351 100644 --- a/examples/remax_trainer/run_qwen2.5-7b_seq_balance.sh +++ b/examples/remax_trainer/run_qwen2.5-7b_seq_balance.sh @@ -3,8 +3,6 @@ set -x export HF_DATASETS_OFFLINE=1 export TRANSFORMERS_OFFLINE=1 -# If you are using vllm<=0.6.3, you might need to set the following environment variable to avoid bugs: -# export VLLM_ATTENTION_BACKEND=XFORMERS python3 -m verl.trainer.main_ppo \ algorithm.adv_estimator=remax \ @@ -34,7 +32,7 @@ python3 -m verl.trainer.main_ppo \ algorithm.kl_penalty=kl \ algorithm.kl_ctrl.kl_coef=0.001 \ trainer.critic_warmup=0 \ - trainer.logger=['console','wandb'] \ + trainer.logger='["console","wandb"]' \ trainer.project_name='verl_remax_example_gsm8k' \ trainer.experiment_name='qwen2.5_7b_function_rm_kl1e-3' \ trainer.val_before_train=False \ diff --git a/examples/rloo_trainer/run_qwen2-7b.sh b/examples/rloo_trainer/run_qwen2-7b.sh index ec8fceeda..fc9b6e29f 100644 --- a/examples/rloo_trainer/run_qwen2-7b.sh +++ b/examples/rloo_trainer/run_qwen2-7b.sh @@ -1,7 +1,5 @@ set -x -# If you are using vllm<=0.6.3, you might need to set the following environment variable to avoid bugs: -# export VLLM_ATTENTION_BACKEND=XFORMERS python3 -m verl.trainer.main_ppo \ algorithm.adv_estimator=rloo \ @@ -32,11 +30,11 @@ python3 -m verl.trainer.main_ppo \ algorithm.kl_penalty=kl \ algorithm.kl_ctrl.kl_coef=0.001 \ trainer.critic_warmup=0 \ - trainer.logger=['console','wandb'] \ + trainer.logger='["console","wandb"]' \ trainer.project_name='verl_rloo_example_gsm8k' \ trainer.experiment_name='qwen2_7b_function_rm' \ trainer.n_gpus_per_node=8 \ trainer.nnodes=1 \ trainer.save_freq=-1 \ trainer.test_freq=5 \ - trainer.total_epochs=15 $@ \ No newline at end of file + trainer.total_epochs=15 $@ diff --git a/examples/sft/gsm8k/run_deepseek_6b7.sh b/examples/sft/gsm8k/run_deepseek_6b7.sh index 4aa33dbfd..8a067f05d 100644 --- a/examples/sft/gsm8k/run_deepseek_6b7.sh +++ b/examples/sft/gsm8k/run_deepseek_6b7.sh @@ -25,5 +25,4 @@ torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \ trainer.project_name=gsm8k-sft \ trainer.experiment_name=gsm8k-sft-deepseek-coder-6.7b-instruct \ trainer.total_epochs=4 \ - trainer.logger=['console','wandb'] \ - trainer.default_hdfs_dir=null $@ \ No newline at end of file + trainer.logger='["console","wandb"]' $@ \ No newline at end of file diff --git a/examples/sft/gsm8k/run_gemma_2b.sh b/examples/sft/gsm8k/run_gemma_2b.sh index ae3e537b6..5b59893d2 100644 --- a/examples/sft/gsm8k/run_gemma_2b.sh +++ b/examples/sft/gsm8k/run_gemma_2b.sh @@ -27,5 +27,4 @@ torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \ trainer.project_name=gsm8k-sft \ trainer.experiment_name=gsm8k-sft-gemma-2b-it \ trainer.total_epochs=2 \ - trainer.logger=['console','wandb'] \ - trainer.default_hdfs_dir=null $@ \ No newline at end of file + trainer.logger='["console","wandb"]' $@ \ No newline at end of file diff --git a/examples/sft/gsm8k/run_gemma_7b.sh b/examples/sft/gsm8k/run_gemma_7b.sh new file mode 100644 index 000000000..aee7603d7 --- /dev/null +++ b/examples/sft/gsm8k/run_gemma_7b.sh @@ -0,0 +1,26 @@ +set -x + +if [ "$#" -lt 2 ]; then + echo "Usage: run_gemma_7b.sh [other_configs...]" + exit 1 +fi + +nproc_per_node=$1 +save_path=$2 + +# Shift the arguments so $@ refers to the rest +shift 2 + +torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \ + -m verl.trainer.fsdp_sft_trainer \ + data.train_files=$HOME/data/gsm8k/train.parquet \ + data.val_files=$HOME/data/gsm8k/test.parquet \ + data.prompt_key=prompt \ + data.response_key=answer \ + data.micro_batch_size_per_gpu=4 \ + model.partial_pretrain=google/gemma-1.1-7b-it \ + trainer.default_local_dir=$save_path \ + trainer.project_name=gsm8k-sft \ + trainer.experiment_name=gsm8k-sft-gemma-1.1-7b-it \ + trainer.total_epochs=4 \ + trainer.logger='["console","wandb"]' $@ \ No newline at end of file diff --git a/examples/sft/gsm8k/run_qwen2_5_05b_sft_peft_sp2_npu.sh b/examples/sft/gsm8k/run_qwen2_5_05b_sft_peft_sp2_npu.sh new file mode 100644 index 000000000..45e427f39 --- /dev/null +++ b/examples/sft/gsm8k/run_qwen2_5_05b_sft_peft_sp2_npu.sh @@ -0,0 +1,35 @@ +set -x + +if [ "$#" -lt 2 ]; then + echo "Usage: run_qwen2_5_05b_sft_peft_sp2_npu.sh [other_configs...]" + exit 1 +fi + +nproc_per_node=$1 +save_path=$2 + +# Shift the arguments so $@ refers to the rest +shift 2 + +torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \ + -m verl.trainer.fsdp_sft_trainer \ + data.train_files=$HOME/data/gsm8k/train.parquet \ + data.val_files=$HOME/data/gsm8k/test.parquet \ + data.prompt_key=extra_info \ + data.response_key=extra_info \ + optim.lr=1e-4 \ + data.prompt_dict_keys=['question'] \ + +data.response_dict_keys=['answer'] \ + data.micro_batch_size_per_gpu=64 \ + model.partial_pretrain=Qwen/Qwen2.5-0.5B-Instruct \ + trainer.default_local_dir=$save_path \ + trainer.project_name=gsm8k-sft \ + trainer.experiment_name=gsm8k-sft-qwen-2.5-0.5b-instruct \ + trainer.logger=console \ + trainer.total_epochs=2 $@ \ + model.lora_rank=32 \ + model.lora_alpha=16 \ + model.target_modules=all-linear \ + model.strategy=fsdp \ + ulysses_sequence_parallel_size=2 \ + use_remove_padding=true diff --git a/examples/sft/gsm8k/run_qwen_05_peft.sh b/examples/sft/gsm8k/run_qwen_05_peft.sh index 37711d857..3a7d44558 100644 --- a/examples/sft/gsm8k/run_qwen_05_peft.sh +++ b/examples/sft/gsm8k/run_qwen_05_peft.sh @@ -27,9 +27,8 @@ torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \ trainer.default_local_dir=$save_path \ trainer.project_name=gsm8k-sft \ trainer.experiment_name=gsm8k-sft-qwen-2.5-0.5b-instruct \ - trainer.logger=['console'] \ - trainer.total_epochs=1 \ - trainer.default_hdfs_dir=null $@ \ + trainer.logger=console \ + trainer.total_epochs=1 $@ \ model.lora_rank=32\ model.lora_alpha=16 \ model.target_modules=all-linear diff --git a/examples/sft/gsm8k/run_qwen_05_sp2.sh b/examples/sft/gsm8k/run_qwen_05_sp2.sh index e3fac993f..7210a5a40 100644 --- a/examples/sft/gsm8k/run_qwen_05_sp2.sh +++ b/examples/sft/gsm8k/run_qwen_05_sp2.sh @@ -25,8 +25,7 @@ torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \ trainer.default_local_dir=$save_path \ trainer.project_name=gsm8k-sft \ trainer.experiment_name=gsm8k-sft-qwen-2.5-0.5b-instruct-sp2 \ - trainer.logger=['console'] \ - trainer.total_training_steps=1 \ - trainer.default_hdfs_dir=null $@ \ + trainer.logger=console \ + trainer.total_training_steps=1 $@ \ ulysses_sequence_parallel_size=2 \ use_remove_padding=true diff --git a/examples/sft/gsm8k/run_qwen_05_sp2_liger.sh b/examples/sft/gsm8k/run_qwen_05_sp2_liger.sh index 280e866c9..1c5cd591f 100644 --- a/examples/sft/gsm8k/run_qwen_05_sp2_liger.sh +++ b/examples/sft/gsm8k/run_qwen_05_sp2_liger.sh @@ -26,7 +26,6 @@ torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \ trainer.default_local_dir=$save_path \ trainer.project_name=gsm8k-sft \ trainer.experiment_name=gsm8k-sft-qwen-2.5-0.5b-instruct-sp2-liger \ - trainer.logger=['console'] \ - trainer.default_hdfs_dir=null $@ \ + trainer.logger=console $@ \ ulysses_sequence_parallel_size=2 \ use_remove_padding=true diff --git a/examples/sft/multiturn/run_qwen_05_sp2.sh b/examples/sft/multiturn/run_qwen_05_sp2.sh index 1da72070b..5e1fc47e9 100644 --- a/examples/sft/multiturn/run_qwen_05_sp2.sh +++ b/examples/sft/multiturn/run_qwen_05_sp2.sh @@ -12,7 +12,7 @@ save_path=$2 # Shift the arguments so $@ refers to the rest shift 2 -torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \ +torchrun --nnodes=1 --nproc_per_node=$nproc_per_node \ -m verl.trainer.fsdp_sft_trainer \ data.train_files=$HOME/data/multiturn/train.parquet \ data.val_files=$HOME/data/multiturn/test.parquet \ @@ -23,8 +23,7 @@ torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \ trainer.default_local_dir=$save_path \ trainer.project_name=multiturn-sft \ trainer.experiment_name=multiturn-sft-qwen-2.5-0.5b-instruct-sp2 \ - trainer.logger=['console'] \ - trainer.total_training_steps=1 \ - trainer.default_hdfs_dir=null $@ \ + trainer.logger=console \ + trainer.total_training_steps=1 $@ \ ulysses_sequence_parallel_size=2 \ use_remove_padding=true \ No newline at end of file diff --git a/examples/sglang_multiturn/config/geo3k_multiturn_grpo.yaml b/examples/sglang_multiturn/config/geo3k_multiturn_grpo.yaml new file mode 100644 index 000000000..a9523f196 --- /dev/null +++ b/examples/sglang_multiturn/config/geo3k_multiturn_grpo.yaml @@ -0,0 +1,25 @@ +hydra: + searchpath: + - file://verl/trainer/config + +defaults: + - ppo_trainer + - _self_ + +data: + max_prompt_length: 2048 + max_response_length: 2048 + train_batch_size: 256 + return_raw_chat: True + return_multi_modal_inputs: False + +actor_rollout_ref: + hybrid_engine: True + model: + custom_chat_template: "{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{%- if tools %}{{- '<|im_start|>system\\n' }}{%- if messages[0]['role'] == 'system' %}{{- messages[0]['content'] }}{%- else %}{{- 'You are a helpful assistant.' }}{%- endif %}{{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}{%- for tool in tools %}{{- \"\\n\" }}{{- tool | tojson }}{%- endfor %}{{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n\" }}{% for message in messages %}{% if message['role'] != 'system' or loop.first == false %}{%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{%- elif message.role == \"assistant\" %}{{- '<|im_start|>' + message.role }}{%- if message.content %}{{- '\\n' + message.content }}{%- endif %}{%- for tool_call in message.tool_calls %}{%- if tool_call.function is defined %}{%- set tool_call = tool_call.function %}{%- endif %}{{- '\\n\\n{\"name\": \"' }}{{- tool_call.name }}{{- '\", \"arguments\": ' }}{{- tool_call.arguments | tojson }}{{- '}\\n' }}{%- endfor %}{{- '<|im_end|>\\n' }}{%- elif message.role == \"tool\" %}{%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}{{- '<|im_start|>user' }}{%- endif %}{{- '\\n\\n' }}{% if message['content'] is string %}{{ message.content }}{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif content['type'] == 'text' or 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}{% endif %}{{- '\\n' }}{%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}{{- '<|im_end|>\\n' }}{%- endif %}{%- endif %}{% endif %}{% endfor %}{%- else %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n{% endif %}{%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{%- elif message.role == \"assistant\" %}{{- '<|im_start|>' + message.role }}{%- if message.content %}{{- '\\n' + message.content }}{%- endif %}{%- for tool_call in message.tool_calls %}{%- if tool_call.function is defined %}{%- set tool_call = tool_call.function %}{%- endif %}{{- '\\n\\n{\"name\": \"' }}{{- tool_call.name }}{{- '\", \"arguments\": ' }}{{- tool_call.arguments | tojson }}{{- '}\\n' }}{%- endfor %}{{- '<|im_end|>\\n' }}{%- elif message.role == \"tool\" %}{%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}{{- '<|im_start|>user' }}{%- endif %}{{- '\\n\\n' }}{% if message['content'] is string %}{{ message.content }}{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif content['type'] == 'text' or 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}{% endif %}{{- '\\n' }}{%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}{{- '<|im_end|>\\n' }}{%- endif %}{%- endif %}{% endfor %}{%- endif %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}" + rollout: + name: sglang + multi_turn: + enable: True + max_assistant_turns: 5 + # tool_config_path: "./config/tool_config/gsm8k_tool_config.yaml" diff --git a/examples/sglang_multiturn/config/geo3k_multiturn_megatron_grpo.yaml b/examples/sglang_multiturn/config/geo3k_multiturn_megatron_grpo.yaml new file mode 100644 index 000000000..5e208f333 --- /dev/null +++ b/examples/sglang_multiturn/config/geo3k_multiturn_megatron_grpo.yaml @@ -0,0 +1,25 @@ +hydra: + searchpath: + - file://verl/trainer/config + +defaults: + - ppo_megatron_trainer + - _self_ + +data: + max_prompt_length: 2048 + max_response_length: 2048 + train_batch_size: 256 + return_raw_chat: True + return_multi_modal_inputs: False + +actor_rollout_ref: + hybrid_engine: True + model: + custom_chat_template: "{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{%- if tools %}{{- '<|im_start|>system\\n' }}{%- if messages[0]['role'] == 'system' %}{{- messages[0]['content'] }}{%- else %}{{- 'You are a helpful assistant.' }}{%- endif %}{{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}{%- for tool in tools %}{{- \"\\n\" }}{{- tool | tojson }}{%- endfor %}{{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n\" }}{% for message in messages %}{% if message['role'] != 'system' or loop.first == false %}{%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{%- elif message.role == \"assistant\" %}{{- '<|im_start|>' + message.role }}{%- if message.content %}{{- '\\n' + message.content }}{%- endif %}{%- for tool_call in message.tool_calls %}{%- if tool_call.function is defined %}{%- set tool_call = tool_call.function %}{%- endif %}{{- '\\n\\n{\"name\": \"' }}{{- tool_call.name }}{{- '\", \"arguments\": ' }}{{- tool_call.arguments | tojson }}{{- '}\\n' }}{%- endfor %}{{- '<|im_end|>\\n' }}{%- elif message.role == \"tool\" %}{%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}{{- '<|im_start|>user' }}{%- endif %}{{- '\\n\\n' }}{% if message['content'] is string %}{{ message.content }}{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif content['type'] == 'text' or 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}{% endif %}{{- '\\n' }}{%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}{{- '<|im_end|>\\n' }}{%- endif %}{%- endif %}{% endif %}{% endfor %}{%- else %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n{% endif %}{%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{%- elif message.role == \"assistant\" %}{{- '<|im_start|>' + message.role }}{%- if message.content %}{{- '\\n' + message.content }}{%- endif %}{%- for tool_call in message.tool_calls %}{%- if tool_call.function is defined %}{%- set tool_call = tool_call.function %}{%- endif %}{{- '\\n\\n{\"name\": \"' }}{{- tool_call.name }}{{- '\", \"arguments\": ' }}{{- tool_call.arguments | tojson }}{{- '}\\n' }}{%- endfor %}{{- '<|im_end|>\\n' }}{%- elif message.role == \"tool\" %}{%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}{{- '<|im_start|>user' }}{%- endif %}{{- '\\n\\n' }}{% if message['content'] is string %}{{ message.content }}{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif content['type'] == 'text' or 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}{% endif %}{{- '\\n' }}{%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}{{- '<|im_end|>\\n' }}{%- endif %}{%- endif %}{% endfor %}{%- endif %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}" + rollout: + name: sglang + multi_turn: + enable: True + max_assistant_turns: 5 + # tool_config_path: "./config/tool_config/gsm8k_tool_config.yaml" diff --git a/examples/sglang_multiturn/config/gsm8k_multiturn_grpo.yaml b/examples/sglang_multiturn/config/gsm8k_multiturn_grpo.yaml index db133f8af..e9109232a 100644 --- a/examples/sglang_multiturn/config/gsm8k_multiturn_grpo.yaml +++ b/examples/sglang_multiturn/config/gsm8k_multiturn_grpo.yaml @@ -18,6 +18,4 @@ actor_rollout_ref: name: sglang multi_turn: enable: True - max_turns: 5 - format: qwen - # tool_config_path: "./config/tool_config/gsm8k_tool_config.yaml" + max_assistant_turns: 5 diff --git a/examples/sglang_multiturn/config/gsm8k_multiturn_grpo_w_interaction.yaml b/examples/sglang_multiturn/config/gsm8k_multiturn_grpo_w_interaction.yaml new file mode 100644 index 000000000..122f7e50f --- /dev/null +++ b/examples/sglang_multiturn/config/gsm8k_multiturn_grpo_w_interaction.yaml @@ -0,0 +1,21 @@ +hydra: + searchpath: + - file://verl/trainer/config + +defaults: + - ppo_trainer + - _self_ + +data: + max_prompt_length: 1024 + max_response_length: 1024 + train_batch_size: 256 + return_raw_chat: True + +actor_rollout_ref: + hybrid_engine: True + rollout: + name: sglang + multi_turn: + enable: True + max_user_turns: 5 diff --git a/examples/sglang_multiturn/config/gsm8k_multiturn_megatron_grpo.yaml b/examples/sglang_multiturn/config/gsm8k_multiturn_megatron_grpo.yaml index 8609d8901..8aff859cc 100644 --- a/examples/sglang_multiturn/config/gsm8k_multiturn_megatron_grpo.yaml +++ b/examples/sglang_multiturn/config/gsm8k_multiturn_megatron_grpo.yaml @@ -18,7 +18,5 @@ actor_rollout_ref: name: sglang multi_turn: enable: True - max_turns: 5 - format: qwen - # tool_config_path: "./config/tool_config/gsm8k_tool_config.yaml" - \ No newline at end of file + max_assistant_turns: 5 + diff --git a/examples/sglang_multiturn/config/interaction_config/gsm8k_interaction_config.yaml b/examples/sglang_multiturn/config/interaction_config/gsm8k_interaction_config.yaml new file mode 100644 index 000000000..78faf386e --- /dev/null +++ b/examples/sglang_multiturn/config/interaction_config/gsm8k_interaction_config.yaml @@ -0,0 +1,4 @@ +interaction: + - name: "gsm8k" + class_name: "verl.interactions.gsm8k_interaction.Gsm8kInteraction" + config: {} \ No newline at end of file diff --git a/examples/sglang_multiturn/config/retool_multiturn_grpo.yaml b/examples/sglang_multiturn/config/retool_multiturn_grpo.yaml index df6677dc4..d1cfaccce 100644 --- a/examples/sglang_multiturn/config/retool_multiturn_grpo.yaml +++ b/examples/sglang_multiturn/config/retool_multiturn_grpo.yaml @@ -15,8 +15,8 @@ data: actor_rollout_ref: hybrid_engine: True rollout: - name: sglang_async + name: sglang multi_turn: enable: True - max_turns: 5 + max_assistant_turns: 5 tool_config_path: "./config/tool_config/sandbox_fusion_tool_config.yaml" diff --git a/examples/sglang_multiturn/config/search_multiturn_grpo.yaml b/examples/sglang_multiturn/config/search_multiturn_grpo.yaml index 0c18ecf51..0e24f62b7 100644 --- a/examples/sglang_multiturn/config/search_multiturn_grpo.yaml +++ b/examples/sglang_multiturn/config/search_multiturn_grpo.yaml @@ -16,8 +16,8 @@ data: actor_rollout_ref: hybrid_engine: True rollout: - name: sglang_async + name: sglang multi_turn: enable: True - max_turns: 2 + max_assistant_turns: 2 format: qwen diff --git a/examples/sglang_multiturn/config/tool_config/geo3k_tool_config.yaml b/examples/sglang_multiturn/config/tool_config/geo3k_tool_config.yaml new file mode 100644 index 000000000..675a342e6 --- /dev/null +++ b/examples/sglang_multiturn/config/tool_config/geo3k_tool_config.yaml @@ -0,0 +1,16 @@ +tools: + - class_name: "verl.tools.geo3k_tool.Geo3kTool" + config: + type: native + tool_schema: + type: "function" + function: + name: "calc_geo3k_reward" + description: "A tool for calculating the reward of geo3k. (1.0 if parsed answer is correct, 0.0 if parsed answer is incorrect or not correctly parsed)" + parameters: + type: "object" + properties: + answer: + type: "string" + description: "The model's answer to the geo3k problem, must be a digits" + required: ["answer"] \ No newline at end of file diff --git a/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml b/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml index 4caedc1da..a4197baab 100644 --- a/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml +++ b/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml @@ -1,6 +1,7 @@ tools: - class_name: "verl.tools.gsm8k_tool.Gsm8kTool" - config: {} + config: + type: native tool_schema: type: "function" function: diff --git a/examples/sglang_multiturn/config/tool_config/mcp_server.json b/examples/sglang_multiturn/config/tool_config/mcp_server.json new file mode 100644 index 000000000..29424f71e --- /dev/null +++ b/examples/sglang_multiturn/config/tool_config/mcp_server.json @@ -0,0 +1,8 @@ +{ + "mcpServers": { + "Tavily Expert": { + "url": "your_tavily_expert_url", + "auth_token": "your_tavily_api_token" + } + } +} \ No newline at end of file diff --git a/examples/sglang_multiturn/config/tool_config/mcp_tool_config.yaml b/examples/sglang_multiturn/config/tool_config/mcp_tool_config.yaml new file mode 100644 index 000000000..40abf7c67 --- /dev/null +++ b/examples/sglang_multiturn/config/tool_config/mcp_tool_config.yaml @@ -0,0 +1,11 @@ +tools: + - class_name: verl.tools.mcp_search_tool.MCPSearchTool + config: + rate_limit: 120 + timeout: 120 + type: mcp + mcp: + mcp_servers_config_path: ./mcp_server.json + # optional + tool_selected_list: + - tavily_search_tool \ No newline at end of file diff --git a/examples/sglang_multiturn/config/tool_config/sandbox_fusion_tool_config.yaml b/examples/sglang_multiturn/config/tool_config/sandbox_fusion_tool_config.yaml index 8ffb041dd..516acf569 100644 --- a/examples/sglang_multiturn/config/tool_config/sandbox_fusion_tool_config.yaml +++ b/examples/sglang_multiturn/config/tool_config/sandbox_fusion_tool_config.yaml @@ -7,6 +7,8 @@ tools: rate_limit: 10 default_timeout: 30 default_language: "python" + memory_limit_mb: 1024 + type: native tool_schema: type: "function" diff --git a/examples/sglang_multiturn/config/tool_config/search_tool_config.yaml b/examples/sglang_multiturn/config/tool_config/search_tool_config.yaml index 79b647e62..926b6b832 100644 --- a/examples/sglang_multiturn/config/tool_config/search_tool_config.yaml +++ b/examples/sglang_multiturn/config/tool_config/search_tool_config.yaml @@ -5,6 +5,7 @@ tools: num_workers: 120 rate_limit: 120 timeout: 30 + type: native tool_schema: type: function function: diff --git a/examples/sglang_multiturn/geo3k/run_qwen2.5-3b_geo3k_multiturn.sh b/examples/sglang_multiturn/geo3k/run_qwen2.5-3b_geo3k_multiturn.sh new file mode 100644 index 000000000..d9306e9df --- /dev/null +++ b/examples/sglang_multiturn/geo3k/run_qwen2.5-3b_geo3k_multiturn.sh @@ -0,0 +1,54 @@ +# run on 8xH100 +# make sure your current working directory is the root of the project + +set -x + +ulimit -n 65535 + +PROJECT_DIR="$(pwd)" +CONFIG_PATH="$PROJECT_DIR/examples/sglang_multiturn/config" + + +python3 -m verl.trainer.main_ppo \ + --config-path="$CONFIG_PATH" \ + --config-name='geo3k_multiturn_grpo' \ + algorithm.adv_estimator=grpo \ + data.train_batch_size=256 \ + data.max_prompt_length=2048 \ + data.max_response_length=2048 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.return_raw_chat=True \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-3B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=256 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=sglang \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ + actor_rollout_ref.rollout.n=16 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='geo3k_async_rl' \ + trainer.experiment_name='qwen2.5-3b_function_rm-geo3k-sgl-multi-w-tool-verify-n16' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=-1 \ + trainer.test_freq=20 \ + data.train_files=$HOME/data/geo3k_multiturn_w_tool/train.parquet \ + data.val_files=$HOME/data/geo3k_multiturn_w_tool/test.parquet \ + actor_rollout_ref.rollout.multi_turn.tool_config_path="$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/geo3k_tool_config.yaml" \ + trainer.total_epochs=15 $@ + diff --git a/examples/sglang_multiturn/geo3k/run_qwen2.5-3b_geo3k_multiturn_4xgpu.sh b/examples/sglang_multiturn/geo3k/run_qwen2.5-3b_geo3k_multiturn_4xgpu.sh new file mode 100644 index 000000000..66f12a5e5 --- /dev/null +++ b/examples/sglang_multiturn/geo3k/run_qwen2.5-3b_geo3k_multiturn_4xgpu.sh @@ -0,0 +1,58 @@ +# run on 4xH100 +# make sure your current working directory is the root of the project + +set -x +export HYDRA_FULL_ERROR=1 +ulimit -n 65535 + +PROJECT_DIR="$(pwd)" +CONFIG_PATH="$PROJECT_DIR/examples/sglang_multiturn/config" + +python3 -m verl.trainer.main_ppo \ + --config-path="$CONFIG_PATH" \ + --config-name='geo3k_multiturn_grpo' \ + algorithm.adv_estimator=grpo \ + data.train_batch_size=256 \ + data.max_prompt_length=2048 \ + data.max_response_length=2048 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.return_raw_chat=True \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-3B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=256 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=sglang \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ + actor_rollout_ref.rollout.n=16 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='geo3k_async_rl' \ + trainer.experiment_name='qwen2.5-3b_function_rm-geo3k-async-sgl-multi-w-tool-verify-n16-4cards' \ + trainer.n_gpus_per_node=4 \ + trainer.nnodes=1 \ + trainer.save_freq=-1 \ + trainer.test_freq=20 \ + trainer.total_epochs=15 \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=8192 \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=8192 \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=8192 \ + critic.ppo_max_token_len_per_gpu=8192 \ + critic.forward_max_token_len_per_gpu=8192 \ + data.train_files=$HOME/data/geo3k/train.parquet \ + data.val_files=$HOME/data/geo3k/test.parquet \ + actor_rollout_ref.rollout.multi_turn.tool_config_path="$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/geo3k_tool_config.yaml" \ + $@ \ No newline at end of file diff --git a/examples/sglang_multiturn/geo3k/run_qwen2.5-3b_megatron_geo3k_multiturn.sh b/examples/sglang_multiturn/geo3k/run_qwen2.5-3b_megatron_geo3k_multiturn.sh new file mode 100644 index 000000000..547b34d43 --- /dev/null +++ b/examples/sglang_multiturn/geo3k/run_qwen2.5-3b_megatron_geo3k_multiturn.sh @@ -0,0 +1,65 @@ +# run on 8xH100 +# make sure your current working directory is the root of the project +# this is a verification training script, the parallel setting should be tuned to your model + +set -x + +export PYTHONUNBUFFERED=1 +export RAY_DEDUP_LOGS=0 +export RUST_BACKTRACE=1 +export HYDRA_FULL_ERROR=1 +export CUDA_DEVICE_MAX_CONNECTIONS=1 + +ulimit -n 65535 + +PROJECT_DIR="$(pwd)" +CONFIG_PATH="$PROJECT_DIR/examples/sglang_multiturn/config" + +python3 -m verl.trainer.main_ppo \ + --config-path="$CONFIG_PATH" \ + --config-name='geo3k_multiturn_megatron_grpo' \ + algorithm.adv_estimator=grpo \ + data.train_batch_size=256 \ + data.max_prompt_length=2048 \ + data.max_response_length=2048 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.return_raw_chat=True \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-3B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.ppo_mini_batch_size=256 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \ + actor_rollout_ref.actor.megatron.virtual_pipeline_model_parallel_size=2 \ + actor_rollout_ref.actor.megatron.context_parallel_size=2 \ + actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.megatron.seed=42 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=2 \ + actor_rollout_ref.ref.megatron.virtual_pipeline_model_parallel_size=2 \ + actor_rollout_ref.ref.megatron.context_parallel_size=2 \ + actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=sglang \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ + actor_rollout_ref.rollout.n=8 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='geo3k_async_rl' \ + trainer.experiment_name='qwen2.5-3b_function_rm-geo3k-sgl-multi-w-tool-n8-mcore-v2505201745_seed42' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=-1 \ + trainer.test_freq=20 \ + data.train_files=$HOME/data/geo3k_multiturn_w_tool/train.parquet \ + data.val_files=$HOME/data/geo3k_multiturn_w_tool/test.parquet \ + actor_rollout_ref.rollout.multi_turn.tool_config_path="$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/geo3k_tool_config.yaml" \ + trainer.total_epochs=15 $@ + diff --git a/examples/sglang_multiturn/run_qwen0.5b_gsm8k_multiturn_curriculum.sh b/examples/sglang_multiturn/run_qwen0.5b_gsm8k_multiturn_curriculum.sh new file mode 100755 index 000000000..d67a76e48 --- /dev/null +++ b/examples/sglang_multiturn/run_qwen0.5b_gsm8k_multiturn_curriculum.sh @@ -0,0 +1,56 @@ +# run on 8xH100 +# make sure your current working directory is the root of the project + +set -x + +ulimit -n 65535 + +PROJECT_DIR="$(pwd)" +CONFIG_PATH="$PROJECT_DIR/examples/sglang_multiturn/config" + +python3 -m verl.trainer.main_ppo \ + --config-path="$CONFIG_PATH" \ + --config-name='gsm8k_multiturn_grpo' \ + algorithm.adv_estimator=grpo \ + data.sampler.class_name="RandomCurriculumSampler" \ + data.sampler.class_path="pkg://tests.utils.dataset.test_create_rl_sampler_on_cpu" \ + data.dataloader_num_workers=0 \ + data.max_prompt_length=1024 \ + data.max_response_length=1024 \ + data.train_batch_size=256 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.return_raw_chat=True \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=256 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=sglang \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ + actor_rollout_ref.rollout.n=16 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='gsm8k_async_rl' \ + trainer.experiment_name='qwen3-4b_function_rm-gsm8k-sgl-multi-w-tool-verify-n16' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=-1 \ + trainer.test_freq=20 \ + data.train_files=$HOME/data/gsm8k/train.parquet \ + data.val_files=$HOME/data/gsm8k/test.parquet \ + actor_rollout_ref.rollout.multi_turn.tool_config_path="$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml" \ + trainer.total_epochs=15 $@ + diff --git a/examples/sglang_multiturn/run_qwen2.5-0.5b_gsm8k_multiturn_w_interaction.sh b/examples/sglang_multiturn/run_qwen2.5-0.5b_gsm8k_multiturn_w_interaction.sh new file mode 100644 index 000000000..2667664c9 --- /dev/null +++ b/examples/sglang_multiturn/run_qwen2.5-0.5b_gsm8k_multiturn_w_interaction.sh @@ -0,0 +1,58 @@ +# run on 8xH100 +# make sure your current working directory is the root of the project + +set -x + +ulimit -n 65535 + +PROJECT_DIR="$(pwd)" +CONFIG_PATH="$PROJECT_DIR/examples/sglang_multiturn/config" +TRAIN_BATCH_SIZE=${TRAIN_BATCH_SIZE:-512} +MICRO_BATCH_SIZE=${MICRO_BATCH_SIZE:-8} +OFFLOAD=${OFFLOAD:-False} + +python3 -m verl.trainer.main_ppo \ + --config-path="$CONFIG_PATH" \ + --config-name='gsm8k_multiturn_grpo_w_interaction' \ + algorithm.adv_estimator=grpo \ + data.train_batch_size=$TRAIN_BATCH_SIZE \ + data.max_prompt_length=1024 \ + data.max_response_length=$((1024 * 3)) \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.return_raw_chat=True \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + +actor_rollout_ref.model.enable_activation_offloading=True \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.ppo_mini_batch_size=$TRAIN_BATCH_SIZE \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=$MICRO_BATCH_SIZE \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.fsdp_config.param_offload=$OFFLOAD \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=$OFFLOAD \ + +actor_rollout_ref.actor.fsdp_config.model_dtype=bfloat16 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=$MICRO_BATCH_SIZE \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=sglang \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \ + actor_rollout_ref.rollout.n=8 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=$MICRO_BATCH_SIZE \ + actor_rollout_ref.ref.fsdp_config.param_offload=$OFFLOAD \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='gsm8k_async_rl' \ + trainer.experiment_name='qwen2.5-0.5b_function_rm-gsm8k-sgl-multi-w-interaction-n8' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=-1 \ + trainer.test_freq=20 \ + data.train_files=$HOME/data/gsm8k_verl_sgl_multi_turn_w_interaction/train.parquet \ + data.val_files=$HOME/data/gsm8k_verl_sgl_multi_turn_w_interaction/test.parquet \ + actor_rollout_ref.rollout.multi_turn.interaction_config_path="$PROJECT_DIR/examples/sglang_multiturn/config/interaction_config/gsm8k_interaction_config.yaml" \ + trainer.total_epochs=15 $@ + diff --git a/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn.sh b/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn.sh index a3bcde50c..662723df4 100644 --- a/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn.sh +++ b/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn.sh @@ -39,7 +39,7 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.ref.fsdp_config.param_offload=True \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ - trainer.logger=['console','wandb'] \ + trainer.logger='["console","wandb"]' \ trainer.project_name='gsm8k_async_rl' \ trainer.experiment_name='qwen2.5-3b_function_rm-gsm8k-sgl-multi-w-tool-verify-n16' \ trainer.n_gpus_per_node=8 \ @@ -49,5 +49,6 @@ python3 -m verl.trainer.main_ppo \ data.train_files=$HOME/data/gsm8k/train.parquet \ data.val_files=$HOME/data/gsm8k/test.parquet \ actor_rollout_ref.rollout.multi_turn.tool_config_path="$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml" \ - trainer.total_epochs=15 $@ + trainer.total_epochs=15 \ + actor_rollout_ref.rollout.update_weights_bucket_megabytes=512 $@ diff --git a/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn_4xgpu.sh b/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn_4xgpu.sh index ee17a18b9..9e61893b0 100644 --- a/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn_4xgpu.sh +++ b/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn_4xgpu.sh @@ -39,7 +39,7 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.ref.fsdp_config.param_offload=True \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ - trainer.logger=['console','wandb'] \ + trainer.logger='["console","wandb"]' \ trainer.project_name='gsm8k_async_rl' \ trainer.experiment_name='qwen2.5-3b_function_rm-gsm8k-async-sgl-multi-w-tool-verify-n16-4cards' \ trainer.n_gpus_per_node=4 \ @@ -55,4 +55,6 @@ python3 -m verl.trainer.main_ppo \ data.train_files=$HOME/data/gsm8k/train.parquet \ data.val_files=$HOME/data/gsm8k/test.parquet \ actor_rollout_ref.rollout.multi_turn.tool_config_path="$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml" \ + actor_rollout_ref.rollout.multi_turn.interaction_config_path="$PROJECT_DIR/examples/sglang_multiturn/config/interaction_config/gsm8k_interaction_config.yaml" \ + actor_rollout_ref.rollout.multi_turn.max_user_turns=1 \ $@ \ No newline at end of file diff --git a/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_tool_agent_mlflow.sh b/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_tool_agent_mlflow.sh new file mode 100644 index 000000000..11c104fa9 --- /dev/null +++ b/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_tool_agent_mlflow.sh @@ -0,0 +1,57 @@ +# run on 8xH100 +# make sure your current working directory is the root of the project + +set -x + +ulimit -n 65535 + +PROJECT_DIR="$(pwd)" +CONFIG_PATH="$PROJECT_DIR/examples/sglang_multiturn/config" + +python3 -m verl.trainer.main_ppo \ + --config-path="$CONFIG_PATH" \ + --config-name='gsm8k_multiturn_grpo' \ + algorithm.adv_estimator=grpo \ + data.train_batch_size=256 \ + data.max_prompt_length=1024 \ + data.max_response_length=1024 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.return_raw_chat=True \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=256 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=sglang \ + actor_rollout_ref.rollout.mode=async \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ + actor_rollout_ref.rollout.n=16 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + actor_rollout_ref.rollout.trace.backend=mlflow \ + actor_rollout_ref.rollout.trace.token2text=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","mlflow"]' \ + trainer.project_name='gsm8k_tool-agent' \ + trainer.experiment_name='qwen2.5-3b_function_rm-gsm8k-sgl-tool-agent-verify-n16' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=-1 \ + trainer.test_freq=20 \ + trainer.total_training_steps=2 \ + data.train_files=$HOME/data/gsm8k/train.parquet \ + data.val_files=$HOME/data/gsm8k/test.parquet \ + actor_rollout_ref.rollout.multi_turn.tool_config_path="$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml" \ + trainer.total_epochs=15 $@ + diff --git a/examples/sglang_multiturn/run_qwen2.5-3b_megatron_gsm8k_multiturn.sh b/examples/sglang_multiturn/run_qwen2.5-3b_megatron_gsm8k_multiturn.sh index 671d58edd..a13d4f422 100644 --- a/examples/sglang_multiturn/run_qwen2.5-3b_megatron_gsm8k_multiturn.sh +++ b/examples/sglang_multiturn/run_qwen2.5-3b_megatron_gsm8k_multiturn.sh @@ -51,7 +51,7 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ - trainer.logger=['console','wandb'] \ + trainer.logger='["console","wandb"]' \ trainer.project_name='gsm8k_async_rl' \ trainer.experiment_name='qwen2.5-3b_function_rm-gsm8k-sgl-multi-w-tool-n8-mcore-v2505201745_seed42' \ trainer.n_gpus_per_node=8 \ diff --git a/examples/sglang_multiturn/run_qwen3-4b_gsm8k_multiturn.sh b/examples/sglang_multiturn/run_qwen3-4b_gsm8k_multiturn.sh new file mode 100755 index 000000000..56228f4b5 --- /dev/null +++ b/examples/sglang_multiturn/run_qwen3-4b_gsm8k_multiturn.sh @@ -0,0 +1,53 @@ +# run on 8xH100 +# make sure your current working directory is the root of the project + +set -x + +ulimit -n 65535 + +PROJECT_DIR="$(pwd)" +CONFIG_PATH="$PROJECT_DIR/examples/sglang_multiturn/config" + +python3 -m verl.trainer.main_ppo \ + --config-path="$CONFIG_PATH" \ + --config-name='gsm8k_multiturn_grpo' \ + algorithm.adv_estimator=grpo \ + data.train_batch_size=256 \ + data.max_prompt_length=1024 \ + data.max_response_length=1024 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.return_raw_chat=True \ + actor_rollout_ref.model.path=Qwen/Qwen3-4B \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=256 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=sglang \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ + actor_rollout_ref.rollout.n=16 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='gsm8k_async_rl' \ + trainer.experiment_name='qwen3-4b_function_rm-gsm8k-sgl-multi-w-tool-verify-n16' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=-1 \ + trainer.test_freq=20 \ + data.train_files=$HOME/data/gsm8k/train.parquet \ + data.val_files=$HOME/data/gsm8k/test.parquet \ + actor_rollout_ref.rollout.multi_turn.tool_config_path="$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml" \ + trainer.total_epochs=15 $@ + diff --git a/examples/sglang_multiturn/search_r1_like/local_dense_retriever/retrieval_server.py b/examples/sglang_multiturn/search_r1_like/local_dense_retriever/retrieval_server.py index 46d4184a4..2f67c1439 100644 --- a/examples/sglang_multiturn/search_r1_like/local_dense_retriever/retrieval_server.py +++ b/examples/sglang_multiturn/search_r1_like/local_dense_retriever/retrieval_server.py @@ -18,7 +18,7 @@ import argparse import json import warnings -from typing import List, Optional +from typing import Optional import datasets import faiss @@ -75,7 +75,7 @@ def __init__(self, model_name, model_path, pooling_method, max_length, use_fp16) self.model.eval() @torch.no_grad() - def encode(self, query_list: List[str], is_query=True) -> np.ndarray: + def encode(self, query_list: list[str], is_query=True) -> np.ndarray: # processing query for different encoders if isinstance(query_list, str): query_list = [query_list] @@ -88,19 +88,27 @@ def encode(self, query_list: List[str], is_query=True) -> np.ndarray: if "bge" in self.model_name.lower(): if is_query: - query_list = [f"Represent this sentence for searching relevant passages: {query}" for query in query_list] + query_list = [ + f"Represent this sentence for searching relevant passages: {query}" for query in query_list + ] - inputs = self.tokenizer(query_list, max_length=self.max_length, padding=True, truncation=True, return_tensors="pt") + inputs = self.tokenizer( + query_list, max_length=self.max_length, padding=True, truncation=True, return_tensors="pt" + ) inputs = {k: v.cuda() for k, v in inputs.items()} if "T5" in type(self.model).__name__: # T5-based retrieval model - decoder_input_ids = torch.zeros((inputs["input_ids"].shape[0], 1), dtype=torch.long).to(inputs["input_ids"].device) + decoder_input_ids = torch.zeros((inputs["input_ids"].shape[0], 1), dtype=torch.long).to( + inputs["input_ids"].device + ) output = self.model(**inputs, decoder_input_ids=decoder_input_ids, return_dict=True) query_emb = output.last_hidden_state[:, 0, :] else: output = self.model(**inputs, return_dict=True) - query_emb = pooling(output.pooler_output, output.last_hidden_state, inputs["attention_mask"], self.pooling_method) + query_emb = pooling( + output.pooler_output, output.last_hidden_state, inputs["attention_mask"], self.pooling_method + ) if "dpr" not in self.model_name.lower(): query_emb = torch.nn.functional.normalize(query_emb, dim=-1) @@ -125,13 +133,13 @@ def __init__(self, config): def _search(self, query: str, num: int, return_score: bool): raise NotImplementedError - def _batch_search(self, query_list: List[str], num: int, return_score: bool): + def _batch_search(self, query_list: list[str], num: int, return_score: bool): raise NotImplementedError def search(self, query: str, num: int = None, return_score: bool = False): return self._search(query, num, return_score) - def batch_search(self, query_list: List[str], num: int = None, return_score: bool = False): + def batch_search(self, query_list: list[str], num: int = None, return_score: bool = False): return self._batch_search(query_list, num, return_score) @@ -166,7 +174,14 @@ def _search(self, query: str, num: int = None, return_score: bool = False): if self.contain_doc: all_contents = [json.loads(self.searcher.doc(hit.docid).raw())["contents"] for hit in hits] - results = [{"title": content.split("\n")[0].strip('"'), "text": "\n".join(content.split("\n")[1:]), "contents": content} for content in all_contents] + results = [ + { + "title": content.split("\n")[0].strip('"'), + "text": "\n".join(content.split("\n")[1:]), + "contents": content, + } + for content in all_contents + ] else: results = load_docs(self.corpus, [hit.docid for hit in hits]) @@ -175,7 +190,7 @@ def _search(self, query: str, num: int = None, return_score: bool = False): else: return results - def _batch_search(self, query_list: List[str], num: int = None, return_score: bool = False): + def _batch_search(self, query_list: list[str], num: int = None, return_score: bool = False): results = [] scores = [] for query in query_list: @@ -199,7 +214,13 @@ def __init__(self, config): self.index = faiss.index_cpu_to_all_gpus(self.index, co=co) self.corpus = load_corpus(self.corpus_path) - self.encoder = Encoder(model_name=self.retrieval_method, model_path=config.retrieval_model_path, pooling_method=config.retrieval_pooling_method, max_length=config.retrieval_query_max_length, use_fp16=config.retrieval_use_fp16) + self.encoder = Encoder( + model_name=self.retrieval_method, + model_path=config.retrieval_model_path, + pooling_method=config.retrieval_pooling_method, + max_length=config.retrieval_query_max_length, + use_fp16=config.retrieval_use_fp16, + ) self.topk = config.retrieval_topk self.batch_size = config.retrieval_batch_size @@ -216,7 +237,7 @@ def _search(self, query: str, num: int = None, return_score: bool = False): else: return results - def _batch_search(self, query_list: List[str], num: int = None, return_score: bool = False): + def _batch_search(self, query_list: list[str], num: int = None, return_score: bool = False): if isinstance(query_list, str): query_list = [query_list] if num is None: @@ -297,7 +318,7 @@ def __init__( class QueryRequest(BaseModel): - queries: List[str] + queries: list[str] topk: Optional[int] = None return_scores: bool = False @@ -334,7 +355,9 @@ def retrieve_endpoint(request: QueryRequest): request.topk = config.retrieval_topk # fallback to default # Perform batch retrieval - results, scores = retriever.batch_search(query_list=request.queries, num=request.topk, return_score=request.return_scores) + results, scores = retriever.batch_search( + query_list=request.queries, num=request.topk, return_score=request.return_scores + ) # Format response resp = [] @@ -342,7 +365,7 @@ def retrieve_endpoint(request: QueryRequest): if request.return_scores: # If scores are returned, combine them with results combined = [] - for doc, score in zip(single_result, scores[i]): + for doc, score in zip(single_result, scores[i], strict=True): combined.append({"document": doc, "score": score}) resp.append(combined) else: @@ -352,11 +375,20 @@ def retrieve_endpoint(request: QueryRequest): if __name__ == "__main__": parser = argparse.ArgumentParser(description="Launch the local faiss retriever.") - parser.add_argument("--index_path", type=str, default="/home/peterjin/mnt/index/wiki-18/e5_Flat.index", help="Corpus indexing file.") - parser.add_argument("--corpus_path", type=str, default="/home/peterjin/mnt/data/retrieval-corpus/wiki-18.jsonl", help="Local corpus file.") + parser.add_argument( + "--index_path", type=str, default="/home/peterjin/mnt/index/wiki-18/e5_Flat.index", help="Corpus indexing file." + ) + parser.add_argument( + "--corpus_path", + type=str, + default="/home/peterjin/mnt/data/retrieval-corpus/wiki-18.jsonl", + help="Local corpus file.", + ) parser.add_argument("--topk", type=int, default=3, help="Number of retrieved passages for one query.") parser.add_argument("--retriever_name", type=str, default="e5", help="Name of the retriever model.") - parser.add_argument("--retriever_model", type=str, default="intfloat/e5-base-v2", help="Path of the retriever model.") + parser.add_argument( + "--retriever_model", type=str, default="intfloat/e5-base-v2", help="Path of the retriever model." + ) parser.add_argument("--faiss_gpu", action="store_true", help="Use GPU for computation") args = parser.parse_args() diff --git a/examples/sglang_multiturn/search_r1_like/run_qwen2.5-3b_instruct_search_multiturn.sh b/examples/sglang_multiturn/search_r1_like/run_qwen2.5-3b_instruct_search_multiturn.sh index 11becfce7..4415e47a9 100644 --- a/examples/sglang_multiturn/search_r1_like/run_qwen2.5-3b_instruct_search_multiturn.sh +++ b/examples/sglang_multiturn/search_r1_like/run_qwen2.5-3b_instruct_search_multiturn.sh @@ -43,16 +43,16 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.rollout.max_model_len=15000 \ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \ actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ - actor_rollout_ref.rollout.name=sglang_async \ + actor_rollout_ref.rollout.name=sglang \ actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ actor_rollout_ref.rollout.n=5 \ - actor_rollout_ref.rollout.multi_turn.max_turns=2 \ + actor_rollout_ref.rollout.multi_turn.max_assistant_turns=2 \ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=8 \ actor_rollout_ref.ref.fsdp_config.param_offload=True \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ trainer.val_before_train=False \ - trainer.logger=['console','wandb'] \ + trainer.logger='["console","wandb"]' \ trainer.project_name='search_r1_like_async_rl' \ trainer.experiment_name='qwen2.5-3b-instruct_function_rm-search-async-sgl-multi-w-searchtool-verify-n16' \ trainer.n_gpus_per_node=8 \ diff --git a/examples/slurm/ray_on_slurm.slurm b/examples/slurm/ray_on_slurm.slurm index f29e85a3a..86567d811 100644 --- a/examples/slurm/ray_on_slurm.slurm +++ b/examples/slurm/ray_on_slurm.slurm @@ -45,8 +45,6 @@ export ip_head echo "IP Head: $ip_head" # make sure we set environment variables before Ray initialization -# If you are using vllm<=0.6.3, you might need to set the following environment variable to avoid bugs: -# export VLLM_ATTENTION_BACKEND=XFORMERS printenv @@ -91,9 +89,8 @@ PYTHONUNBUFFERED=1 srun --overlap --nodes=1 --ntasks=1 -w "$head_node" \ critic.model.path=Qwen/Qwen2.5-0.5B-Instruct \ critic.ppo_micro_batch_size_per_gpu=4 \ algorithm.use_kl_in_reward=False \ - trainer.logger=['console'] \ + trainer.logger=console \ trainer.val_before_train=False \ - trainer.default_hdfs_dir=null \ trainer.n_gpus_per_node="${SLURM_GPUS_PER_NODE}" \ trainer.nnodes="${SLURM_NNODES}" \ trainer.save_freq=10 \ diff --git a/examples/split_placement/README.md b/examples/split_placement/README.md new file mode 100644 index 000000000..226b5436d --- /dev/null +++ b/examples/split_placement/README.md @@ -0,0 +1,60 @@ +# Split Placement Example +Here we introduce how to run the naive implementation of the split placement of PPO algorithm. +We will release the complete version of flexible placement in the near future. + + For quickstart, you can only follow Step 2 to modify the code and then follow Step 4 to execute the split placement example. + +### Step 1: Placing the models to different GPUs +Specify the placement and resource allocation. In the example, we place the actor and reference in the first half of the GPUs while map the critic and reward model (if any) to the second half of the GPUs. +```python +actor_rollout_ref_pool_id = 'actor_rollout_ref_pool' +critic_pool_id = 'critic_pool' +if config.trainer.nnodes // 2 == 0 and config.trainer.n_gpus_per_node // 2 > 0: + resource_pool_spec = { + actor_rollout_ref_pool_id: [config.trainer.n_gpus_per_node // 2] * config.trainer.nnodes, + critic_pool_id: [config.trainer.n_gpus_per_node // 2] * config.trainer.nnodes, + } +else: + resource_pool_spec = { + actor_rollout_ref_pool_id: [config.trainer.n_gpus_per_node] * (config.trainer.nnodes // 2), + critic_pool_id: [config.trainer.n_gpus_per_node] * (config.trainer.nnodes // 2), + } +print(f'resource_pool_spec: {resource_pool_spec}') +mapping = { + Role.ActorRollout: actor_rollout_ref_pool_id, + Role.Critic: critic_pool_id, + Role.RefPolicy: actor_rollout_ref_pool_id, +} +mapping[Role.RewardModel] = critic_pool_id +``` + +### Step 2: Make the models executed asynchronously +Based on the model placement, we need to make the models executed asynchronously. + +To do so, you need to turn off the `blocking` flag (i.e., `blocking=False`) in our decorator of some model operations. +For example, we hope the actor update and critic update can be executed in parallel, then we need to make the following modification in `fsdp_workers.py` + +``` +@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO, blocking=False) +def update_actor(self, data: DataProto): + ... +@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO, blocking=False) +def update_critic(self, data: DataProto): + ... +``` + +We can also parallelize the computation of `ref_log_prob` and `values` and `rewards` in the split placement. For simplicity of the tutorial, we don't do this in this example. + +### Step 3: Execute these operation in parallel in the single controller process +To implement the parallel execution of the actor and critic update, the only thing we need to modify in the `ray_trainer.py` is to `get` the concurrent `futures` on the single controller process. + +```python +critic_output = critic_output.get() +actor_output = actor_output.get() +``` + +### Step 4: Run the split placement example + +``` +bash run_deepseek7b_llm.sh +``` \ No newline at end of file diff --git a/examples/split_placement/main_ppo_split.py b/examples/split_placement/main_ppo_split.py index 35e1d3cfe..c438e7a13 100644 --- a/examples/split_placement/main_ppo_split.py +++ b/examples/split_placement/main_ppo_split.py @@ -123,8 +123,8 @@ def main_task(config): tokenizer = hf_tokenizer(local_path) # define worker classes - if config.actor_rollout_ref.actor.strategy == "fsdp": - assert config.actor_rollout_ref.actor.strategy == config.critic.strategy + if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}: + assert config.critic.strategy in {"fsdp", "fsdp2"} from verl.single_controller.ray import RayWorkerGroup from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker @@ -178,7 +178,7 @@ def main_task(config): # - finally, we combine all the rewards together # - The reward type depends on the tag of the data if config.reward_model.enable: - if config.reward_model.strategy == "fsdp": + if config.reward_model.strategy in {"fsdp", "fsdp2"}: from verl.workers.fsdp_workers import RewardModelWorker elif config.reward_model.strategy == "megatron": from verl.workers.megatron_workers import RewardModelWorker @@ -203,6 +203,7 @@ def main_task(config): ray_worker_group_cls=ray_worker_group_cls, reward_fn=reward_fn, val_reward_fn=val_reward_fn, + device_name=config.trainer.device, ) trainer.init_workers() trainer.fit() diff --git a/examples/split_placement/run_deepseek7b_llm.sh b/examples/split_placement/run_deepseek7b_llm.sh index d4484138a..473dcccdd 100644 --- a/examples/split_placement/run_deepseek7b_llm.sh +++ b/examples/split_placement/run_deepseek7b_llm.sh @@ -28,7 +28,7 @@ python3 main_ppo_split.py \ critic.model.fsdp_config.optimizer_offload=False \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ - trainer.logger=['console','wandb'] \ + trainer.logger='["console","wandb"]' \ trainer.project_name='verl_example_gsm8k' \ trainer.experiment_name='deepseek_llm_7b_function_rm' \ trainer.n_gpus_per_node=8 \ diff --git a/examples/split_placement/split_monkey_patch.py b/examples/split_placement/split_monkey_patch.py index 81582af11..ef58509b9 100644 --- a/examples/split_placement/split_monkey_patch.py +++ b/examples/split_placement/split_monkey_patch.py @@ -25,11 +25,11 @@ from verl import DataProto from verl.trainer.ppo.ray_trainer import ( AdvantageEstimator, - _timer, apply_kl_penalty, compute_advantage, compute_data_metrics, compute_timing_metrics, + marked_timer, ) from verl.utils.metric import reduce_metrics @@ -81,13 +81,15 @@ def fit(self): gen_batch = batch.pop(batch_keys=["input_ids", "attention_mask", "position_ids"]) is_last_step = self.global_steps >= self.total_training_steps - with _timer("step", timing_raw): + with marked_timer("step", timing_raw): # generate a batch - with _timer("gen", timing_raw): + with marked_timer("gen", timing_raw): gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) + timing_raw.update(gen_batch_output.meta_info["timing"]) + gen_batch_output.meta_info.pop("timing", None) if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX: - with _timer("gen_max", timing_raw): + with marked_timer("gen_max", timing_raw): gen_baseline_batch = deepcopy(gen_batch) gen_baseline_batch.meta_info["do_sample"] = False gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch) @@ -102,7 +104,9 @@ def fit(self): del gen_baseline_batch, gen_baseline_output - batch.non_tensor_batch["uid"] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object) + batch.non_tensor_batch["uid"] = np.array( + [str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object + ) # repeat to align with repeated responses in rollout batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) batch = batch.union(gen_batch_output) @@ -118,23 +122,23 @@ def fit(self): batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() # recompute old_log_probs - with _timer("old_log_prob", timing_raw): + with marked_timer("old_log_prob", timing_raw): old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) batch = batch.union(old_log_prob) if self.use_reference_policy: # compute reference log_prob - with _timer("ref", timing_raw): + with marked_timer("ref", timing_raw): ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) batch = batch.union(ref_log_prob) # compute values if self.use_critic: - with _timer("values", timing_raw): + with marked_timer("values", timing_raw): values = self.critic_wg.compute_values(batch) batch = batch.union(values) - with _timer("adv", timing_raw): + with marked_timer("adv", timing_raw): # compute scores. Support both model and function-based. # We first compute the scores using reward model. Then, we call reward_fn to combine # the results from reward model and rule-based results. @@ -149,7 +153,9 @@ def fit(self): # compute rewards. apply_kl_penalty if available if self.config.algorithm.use_kl_in_reward: - batch, kl_metrics = apply_kl_penalty(batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty) + batch, kl_metrics = apply_kl_penalty( + batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty + ) metrics.update(kl_metrics) else: batch.batch["token_level_rewards"] = batch.batch["token_level_scores"] @@ -165,37 +171,46 @@ def fit(self): norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, ) - # update critic - if self.use_critic: - with _timer("update_critic_call", timing_raw): - critic_output = self.critic_wg.update_critic(batch) - # implement critic warmup if self.config.trainer.critic_warmup <= self.global_steps: # update actor - with _timer("update_actor_call", timing_raw): + with marked_timer("update_actor_call", timing_raw): actor_output = self.actor_rollout_wg.update_actor(batch) + else: + actor_output = None + + # update critic + if self.use_critic: + with marked_timer("update_critic_call", timing_raw): + critic_output = self.critic_wg.update_critic(batch) - # NOTE: make sure you set blocking=False in update_actor and update_crtic in the worker class - with _timer("update_actor_critic", timing_raw): - critic_output = critic_output.get() - critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"]) - metrics.update(critic_output_metrics) + # NOTE: make sure you set blocking=False in update_actor and update_crtic in the worker class + with marked_timer("update_actor_critic", timing_raw): + critic_output = critic_output.get() + critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"]) + metrics.update(critic_output_metrics) + if actor_output is not None: actor_output = actor_output.get() actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) metrics.update(actor_output_metrics) # validate - if self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0): - with _timer("testing", timing_raw): + if ( + self.val_reward_fn is not None + and self.config.trainer.test_freq > 0 + and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0) + ): + with marked_timer("testing", timing_raw): val_metrics: dict = self._validate() if is_last_step: last_val_metrics = val_metrics metrics.update(val_metrics) - if self.config.trainer.save_freq > 0 and (is_last_step or self.global_steps % self.config.trainer.save_freq == 0): - with _timer("save_checkpoint", timing_raw): + if self.config.trainer.save_freq > 0 and ( + is_last_step or self.global_steps % self.config.trainer.save_freq == 0 + ): + with marked_timer("save_checkpoint", timing_raw): self._save_checkpoint() # collect metrics diff --git a/examples/tuning/0.5b/qwen2-0.5b_grpo-lora_1_h100_fsdp_vllm.sh b/examples/tuning/0.5b/qwen2-0.5b_grpo-lora_1_h100_fsdp_vllm.sh new file mode 100644 index 000000000..a40ae6f60 --- /dev/null +++ b/examples/tuning/0.5b/qwen2-0.5b_grpo-lora_1_h100_fsdp_vllm.sh @@ -0,0 +1,69 @@ +# -*- coding: utf-8 -*- +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +NOW=$(date +%Y%m%d) +export WANDB_DIR=gsm8k-grpo-lora-qwen2.5-0.5b-${NOW} +export WANDB_PROJECT=${WANDB_DIR} +export WANDB_EXP=0.5b-${NOW} +MODEL_PATH=Qwen/Qwen2.5-0.5B-Instruct + +set -x +nproc_per_gpu=116 +nnodes=1 +ngpu_per_node=1 +total_procs=$(( nproc_per_gpu * nnodes * ngpu_per_node )) +mini_batch_size=$(( total_procs )) + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=data/gsm8k/train.parquet \ + data.val_files=data/gsm8k/test.parquet \ + data.train_batch_size=${total_procs} \ + data.val_batch_size=${total_procs} \ + data.max_prompt_length=512 \ + data.max_response_length=1024 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.shuffle=False \ + actor_rollout_ref.model.path=$MODEL_PATH \ + actor_rollout_ref.model.use_shm=True \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.model.lora_rank=32 \ + actor_rollout_ref.model.lora_alpha=32 \ + actor_rollout_ref.model.target_modules=all-linear \ + actor_rollout_ref.actor.optim.lr=3e-5 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=${mini_batch_size} \ + actor_rollout_ref.actor.ppo_micro_batch_size=${mini_batch_size} \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=True \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=${mini_batch_size} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.1 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.rollout.max_num_seqs=512 \ + actor_rollout_ref.rollout.max_model_len=1536 \ + actor_rollout_ref.rollout.max_num_batched_tokens=1536 \ + actor_rollout_ref.rollout.enable_chunked_prefill=False \ + actor_rollout_ref.rollout.load_format=safetensors \ + actor_rollout_ref.rollout.layered_summon=True \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${mini_batch_size} \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \ + actor_rollout_ref.actor.entropy_coeff=0.001 \ + algorithm.kl_ctrl.kl_coef=0.001 \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name=${WANDB_PROJECT} \ + trainer.experiment_name=${WANDB_EXP} \ + trainer.n_gpus_per_node=1 \ + trainer.nnodes=1 \ + trainer.save_freq=20 \ + trainer.test_freq=5 \ + trainer.total_epochs=1 $@ 2>&1 | tee ${WANDB_PROJECT}.log diff --git a/examples/tuning/1.5b/qwen2-1.5b_grpo-lora_1_h100_fsdp_vllm.sh b/examples/tuning/1.5b/qwen2-1.5b_grpo-lora_1_h100_fsdp_vllm.sh new file mode 100644 index 000000000..6b6ede29b --- /dev/null +++ b/examples/tuning/1.5b/qwen2-1.5b_grpo-lora_1_h100_fsdp_vllm.sh @@ -0,0 +1,69 @@ +# -*- coding: utf-8 -*- +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +NOW=$(date +%Y%m%d) +export WANDB_DIR=gsm8k-grpo-lora-qwen2.5-1.5b-${NOW} +export WANDB_PROJECT=${WANDB_DIR} +export WANDB_EXP=1.5b-${NOW} +MODEL_PATH=Qwen/Qwen2.5-1.5B-Instruct + +set -x +nproc_per_gpu=128 +nnodes=1 +ngpu_per_node=1 +total_procs=$(( nproc_per_gpu * nnodes * ngpu_per_node )) +mini_batch_size=$(( total_procs )) + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=data/gsm8k/train.parquet \ + data.val_files=data/gsm8k/test.parquet \ + data.train_batch_size=${total_procs} \ + data.val_batch_size=${total_procs} \ + data.max_prompt_length=512 \ + data.max_response_length=1024 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.shuffle=False \ + actor_rollout_ref.model.path=$MODEL_PATH \ + actor_rollout_ref.model.use_shm=True \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.model.lora_rank=32 \ + actor_rollout_ref.model.lora_alpha=32 \ + actor_rollout_ref.model.target_modules=all-linear \ + actor_rollout_ref.actor.optim.lr=3e-5 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=${mini_batch_size} \ + actor_rollout_ref.actor.ppo_micro_batch_size=${mini_batch_size} \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=True \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=${mini_batch_size} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.1 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.rollout.max_num_seqs=512 \ + actor_rollout_ref.rollout.max_model_len=1536 \ + actor_rollout_ref.rollout.max_num_batched_tokens=1536 \ + actor_rollout_ref.rollout.enable_chunked_prefill=False \ + actor_rollout_ref.rollout.load_format=safetensors \ + actor_rollout_ref.rollout.layered_summon=True \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${mini_batch_size} \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \ + actor_rollout_ref.actor.entropy_coeff=0.001 \ + algorithm.kl_ctrl.kl_coef=0.001 \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name=${WANDB_PROJECT} \ + trainer.experiment_name=${WANDB_EXP} \ + trainer.n_gpus_per_node=1 \ + trainer.nnodes=1 \ + trainer.save_freq=20 \ + trainer.test_freq=5 \ + trainer.total_epochs=1 $@ 2>&1 | tee ${WANDB_PROJECT}.log diff --git a/examples/tuning/14b/qwen2-14b_grpo-lora_2_h100_fsdp_vllm.sh b/examples/tuning/14b/qwen2-14b_grpo-lora_2_h100_fsdp_vllm.sh new file mode 100644 index 000000000..247945ffc --- /dev/null +++ b/examples/tuning/14b/qwen2-14b_grpo-lora_2_h100_fsdp_vllm.sh @@ -0,0 +1,69 @@ +# -*- coding: utf-8 -*- +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +NOW=$(date +%Y%m%d) +export WANDB_DIR=gsm8k-grpo-lora-qwen2.5-14b-${NOW} +export WANDB_PROJECT=${WANDB_DIR} +export WANDB_EXP=14b-${NOW} +MODEL_PATH=Qwen/Qwen2.5-14B-Instruct + +set -x +nproc_per_gpu=58 # 32√ → 64× → 48√ → 56√ → 60× → 58√ → 59× +nnodes=1 +ngpu_per_node=2 +total_procs=$(( nproc_per_gpu * nnodes * ngpu_per_node )) +mini_batch_size=$(( total_procs )) + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=data/gsm8k/train.parquet \ + data.val_files=data/gsm8k/test.parquet \ + data.train_batch_size=${total_procs} \ + data.val_batch_size=${total_procs} \ + data.max_prompt_length=512 \ + data.max_response_length=1024 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.shuffle=False \ + actor_rollout_ref.model.path=$MODEL_PATH \ + actor_rollout_ref.model.use_shm=True \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.model.lora_rank=32 \ + actor_rollout_ref.model.lora_alpha=32 \ + actor_rollout_ref.model.target_modules=all-linear \ + actor_rollout_ref.actor.optim.lr=3e-5 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=${mini_batch_size} \ + actor_rollout_ref.actor.ppo_micro_batch_size=${mini_batch_size} \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=True \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=${mini_batch_size} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.25 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.rollout.max_num_seqs=512 \ + actor_rollout_ref.rollout.max_model_len=1536 \ + actor_rollout_ref.rollout.max_num_batched_tokens=1536 \ + actor_rollout_ref.rollout.enable_chunked_prefill=False \ + actor_rollout_ref.rollout.load_format=safetensors \ + actor_rollout_ref.rollout.layered_summon=True \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${mini_batch_size} \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=2 \ + actor_rollout_ref.actor.entropy_coeff=0.001 \ + algorithm.kl_ctrl.kl_coef=0.001 \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name=${WANDB_PROJECT} \ + trainer.experiment_name=${WANDB_EXP} \ + trainer.n_gpus_per_node=2 \ + trainer.nnodes=1 \ + trainer.save_freq=20 \ + trainer.test_freq=5 \ + trainer.total_epochs=1 $@ 2>&1 | tee ${WANDB_PROJECT}.log diff --git a/examples/tuning/14b/qwen2_14b_grpo_4_h800_fsdp_vllm.sh b/examples/tuning/14b/qwen2_14b_grpo_4_h800_fsdp_vllm.sh index 53f9b5887..2df21533c 100644 --- a/examples/tuning/14b/qwen2_14b_grpo_4_h800_fsdp_vllm.sh +++ b/examples/tuning/14b/qwen2_14b_grpo_4_h800_fsdp_vllm.sh @@ -1,7 +1,5 @@ set -x -#export VLLM_ATTENTION_BACKEND=XFORMERS - gsm8k_train_path=$HOME/data/rlhf/gsm8k/train.parquet gsm8k_test_path=$HOME/data/rlhf/math/test.parquet model_path=Qwen/Qwen2.5-Coder-14B-Instruct @@ -39,11 +37,11 @@ PYTHONPATH=/opt/tiger/open_verl python3 -m verl.trainer.main_ppo \ actor_rollout_ref.ref.fsdp_config.param_offload=True \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ - trainer.logger=['console','wandb'] \ + trainer.logger='["console","wandb"]' \ trainer.project_name='verl_grpo_example_gsm8k' \ trainer.experiment_name='qwen2_14b_function_rm' \ trainer.n_gpus_per_node=4 \ trainer.nnodes=1 \ trainer.save_freq=-1 \ trainer.test_freq=5 \ - trainer.total_epochs=1 $@ \ No newline at end of file + trainer.total_epochs=1 $@ diff --git a/examples/tuning/32b/qwen2-32b_grpo-lora_4_h100_fsdp_vllm.sh b/examples/tuning/32b/qwen2-32b_grpo-lora_4_h100_fsdp_vllm.sh new file mode 100644 index 000000000..d707a4adc --- /dev/null +++ b/examples/tuning/32b/qwen2-32b_grpo-lora_4_h100_fsdp_vllm.sh @@ -0,0 +1,69 @@ +# -*- coding: utf-8 -*- +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +NOW=$(date +%Y%m%d) +export WANDB_DIR=gsm8k-grpo-lora-qwen2.5-32b-${NOW} +export WANDB_PROJECT=${WANDB_DIR} +export WANDB_EXP=32b-${NOW} +MODEL_PATH=Qwen/Qwen2.5-32B-Instruct + +set -x +nproc_per_gpu=45 # 32√ → 64× → 48× → 40√ → 44√ → 46× → 45× +nnodes=1 +ngpu_per_node=4 +total_procs=$(( nproc_per_gpu * nnodes * ngpu_per_node )) +mini_batch_size=$(( total_procs )) + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=data/gsm8k/train.parquet \ + data.val_files=data/gsm8k/test.parquet \ + data.train_batch_size=${total_procs} \ + data.val_batch_size=${total_procs} \ + data.max_prompt_length=512 \ + data.max_response_length=1024 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.shuffle=False \ + actor_rollout_ref.model.path=$MODEL_PATH \ + actor_rollout_ref.model.use_shm=True \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.model.lora_rank=32 \ + actor_rollout_ref.model.lora_alpha=32 \ + actor_rollout_ref.model.target_modules=all-linear \ + actor_rollout_ref.actor.optim.lr=3e-5 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=${mini_batch_size} \ + actor_rollout_ref.actor.ppo_micro_batch_size=${mini_batch_size} \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=True \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=${mini_batch_size} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.3 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.rollout.max_num_seqs=512 \ + actor_rollout_ref.rollout.max_model_len=1536 \ + actor_rollout_ref.rollout.max_num_batched_tokens=1536 \ + actor_rollout_ref.rollout.enable_chunked_prefill=False \ + actor_rollout_ref.rollout.load_format=safetensors \ + actor_rollout_ref.rollout.layered_summon=True \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${mini_batch_size} \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=2 \ + actor_rollout_ref.actor.entropy_coeff=0.001 \ + algorithm.kl_ctrl.kl_coef=0.001 \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name=${WANDB_PROJECT} \ + trainer.experiment_name=${WANDB_EXP} \ + trainer.n_gpus_per_node=4 \ + trainer.nnodes=1 \ + trainer.save_freq=20 \ + trainer.test_freq=5 \ + trainer.total_epochs=1 $@ 2>&1 | tee ${WANDB_PROJECT}.log diff --git a/examples/tuning/32b/qwen2_32B_grpo_8_h20_megatron_vllm.sh b/examples/tuning/32b/qwen2_32B_grpo_8_h20_megatron_vllm.sh index a687dd37f..3a96fe504 100644 --- a/examples/tuning/32b/qwen2_32B_grpo_8_h20_megatron_vllm.sh +++ b/examples/tuning/32b/qwen2_32B_grpo_8_h20_megatron_vllm.sh @@ -42,7 +42,7 @@ python3 -m verl.trainer.main_ppo --config-path=config \ actor_rollout_ref.ref.megatron.param_offload=True \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ - trainer.logger=['console'] \ + trainer.logger=console \ trainer.project_name='megatron_vllm_qwen2_32b' \ trainer.experiment_name='qwen2_32b_grpo_8_h20' \ trainer.n_gpus_per_node=8 \ diff --git a/examples/tuning/3b/qwen2-3b_grpo-lora_1_h100_fsdp_vllm.sh b/examples/tuning/3b/qwen2-3b_grpo-lora_1_h100_fsdp_vllm.sh new file mode 100644 index 000000000..fac34a5d5 --- /dev/null +++ b/examples/tuning/3b/qwen2-3b_grpo-lora_1_h100_fsdp_vllm.sh @@ -0,0 +1,69 @@ +# -*- coding: utf-8 -*- +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +NOW=$(date +%Y%m%d) +export WANDB_DIR=gsm8k-grpo-lora-qwen2.5-3b-${NOW} +export WANDB_PROJECT=${WANDB_DIR} +export WANDB_EXP=3b-${NOW} +MODEL_PATH=Qwen/Qwen2.5-3B-Instruct + +set -x +nproc_per_gpu=62 +nnodes=1 +ngpu_per_node=1 +total_procs=$(( nproc_per_gpu * nnodes * ngpu_per_node )) +mini_batch_size=$(( total_procs )) + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=data/gsm8k/train.parquet \ + data.val_files=data/gsm8k/test.parquet \ + data.train_batch_size=${total_procs} \ + data.val_batch_size=${total_procs} \ + data.max_prompt_length=512 \ + data.max_response_length=1024 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.shuffle=False \ + actor_rollout_ref.model.path=$MODEL_PATH \ + actor_rollout_ref.model.use_shm=True \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.model.lora_rank=32 \ + actor_rollout_ref.model.lora_alpha=32 \ + actor_rollout_ref.model.target_modules=all-linear \ + actor_rollout_ref.actor.optim.lr=3e-5 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=${mini_batch_size} \ + actor_rollout_ref.actor.ppo_micro_batch_size=${mini_batch_size} \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=True \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=${mini_batch_size} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.1 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.rollout.max_num_seqs=512 \ + actor_rollout_ref.rollout.max_model_len=1536 \ + actor_rollout_ref.rollout.max_num_batched_tokens=1536 \ + actor_rollout_ref.rollout.enable_chunked_prefill=False \ + actor_rollout_ref.rollout.load_format=safetensors \ + actor_rollout_ref.rollout.layered_summon=True \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${mini_batch_size} \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \ + actor_rollout_ref.actor.entropy_coeff=0.001 \ + algorithm.kl_ctrl.kl_coef=0.001 \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name=${WANDB_PROJECT} \ + trainer.experiment_name=${WANDB_EXP} \ + trainer.n_gpus_per_node=1 \ + trainer.nnodes=1 \ + trainer.save_freq=20 \ + trainer.test_freq=5 \ + trainer.total_epochs=1 $@ 2>&1 | tee ${WANDB_PROJECT}.log diff --git a/examples/tuning/70b/qwen2-70b_grpo_32_h20_fsdp_vllm.sh b/examples/tuning/70b/qwen2-70b_grpo_32_h20_fsdp_vllm.sh index 617e4aa15..9a1d50ad1 100644 --- a/examples/tuning/70b/qwen2-70b_grpo_32_h20_fsdp_vllm.sh +++ b/examples/tuning/70b/qwen2-70b_grpo_32_h20_fsdp_vllm.sh @@ -33,7 +33,7 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.ref.fsdp_config.param_offload=True \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ - trainer.logger=['console','wandb'] \ + trainer.logger='["console","wandb"]' \ trainer.project_name='verl_grpo_example_gsm8k' \ trainer.experiment_name='Qwen2_72B_Instruct' \ trainer.n_gpus_per_node=8 \ diff --git a/examples/tuning/70b/qwen2-70b_grpo_32_h800_fsdp_vllm.sh b/examples/tuning/70b/qwen2-70b_grpo_32_h800_fsdp_vllm.sh index dfb638c8c..b15f406b1 100644 --- a/examples/tuning/70b/qwen2-70b_grpo_32_h800_fsdp_vllm.sh +++ b/examples/tuning/70b/qwen2-70b_grpo_32_h800_fsdp_vllm.sh @@ -35,7 +35,7 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.ref.fsdp_config.param_offload=True \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ - trainer.logger=['console','wandb'] \ + trainer.logger='["console","wandb"]' \ trainer.project_name='verl_grpo_example_gsm8k' \ trainer.experiment_name='Qwen2_72B_Instruct' \ trainer.n_gpus_per_node=8 \ diff --git a/examples/tuning/70b/qwen2-72b_grpo-lora_8_h100_fsdp_vllm.sh b/examples/tuning/70b/qwen2-72b_grpo-lora_8_h100_fsdp_vllm.sh new file mode 100644 index 000000000..7f93ed32f --- /dev/null +++ b/examples/tuning/70b/qwen2-72b_grpo-lora_8_h100_fsdp_vllm.sh @@ -0,0 +1,69 @@ +# -*- coding: utf-8 -*- +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +NOW=$(date +%Y%m%d) +export WANDB_DIR=gsm8k-grpo-lora-qwen2.5-72b-${NOW} +export WANDB_PROJECT=${WANDB_DIR} +export WANDB_EXP=72b-${NOW} +MODEL_PATH=Qwen/Qwen2.5-72B-Instruct + +set -x +nproc_per_gpu=22 # 16√ → 32× → 24× → 20√ → 22√ → 23× +nnodes=1 +ngpu_per_node=8 +total_procs=$(( nproc_per_gpu * nnodes * ngpu_per_node )) +mini_batch_size=$(( total_procs )) + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=data/gsm8k/train.parquet \ + data.val_files=data/gsm8k/test.parquet \ + data.train_batch_size=${total_procs} \ + data.val_batch_size=${total_procs} \ + data.max_prompt_length=512 \ + data.max_response_length=1024 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.shuffle=False \ + actor_rollout_ref.model.path=$MODEL_PATH \ + actor_rollout_ref.model.use_shm=True \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.model.lora_rank=32 \ + actor_rollout_ref.model.lora_alpha=32 \ + actor_rollout_ref.model.target_modules=all-linear \ + actor_rollout_ref.actor.optim.lr=3e-5 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=${mini_batch_size} \ + actor_rollout_ref.actor.ppo_micro_batch_size=${mini_batch_size} \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=True \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=${mini_batch_size} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=8 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.rollout.max_num_seqs=512 \ + actor_rollout_ref.rollout.max_model_len=1536 \ + actor_rollout_ref.rollout.max_num_batched_tokens=1536 \ + actor_rollout_ref.rollout.enable_chunked_prefill=False \ + actor_rollout_ref.rollout.load_format=safetensors \ + actor_rollout_ref.rollout.layered_summon=True \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${mini_batch_size} \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=2 \ + actor_rollout_ref.actor.entropy_coeff=0.001 \ + algorithm.kl_ctrl.kl_coef=0.001 \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name=${WANDB_PROJECT} \ + trainer.experiment_name=${WANDB_EXP} \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=20 \ + trainer.test_freq=5 \ + trainer.total_epochs=15 $@ 2>&1 | tee ${WANDB_PROJECT}.log diff --git a/examples/tuning/7b/qwen2-7b_grpo-lora_1_h100_fsdp_vllm.sh b/examples/tuning/7b/qwen2-7b_grpo-lora_1_h100_fsdp_vllm.sh new file mode 100644 index 000000000..a663a90d6 --- /dev/null +++ b/examples/tuning/7b/qwen2-7b_grpo-lora_1_h100_fsdp_vllm.sh @@ -0,0 +1,69 @@ +# -*- coding: utf-8 -*- +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +NOW=$(date +%Y%m%d) +export WANDB_DIR=gsm8k-grpo-lora-qwen2.5-7b-${NOW} +export WANDB_PROJECT=${WANDB_DIR} +export WANDB_EXP=7b-${NOW} +MODEL_PATH=Qwen/Qwen2.5-7B-Instruct + +set -x +nproc_per_gpu=16 # 64√ → 128× → 96√ → 112× → 104× → 100√ → 102× → 101× +nnodes=1 +ngpu_per_node=1 +total_procs=$(( nproc_per_gpu * nnodes * ngpu_per_node )) +mini_batch_size=$(( total_procs )) + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=data/gsm8k/train.parquet \ + data.val_files=data/gsm8k/test.parquet \ + data.train_batch_size=${total_procs} \ + data.val_batch_size=${total_procs} \ + data.max_prompt_length=512 \ + data.max_response_length=1024 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.shuffle=False \ + actor_rollout_ref.model.path=$MODEL_PATH \ + actor_rollout_ref.model.use_shm=True \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.model.lora_rank=32 \ + actor_rollout_ref.model.lora_alpha=32 \ + actor_rollout_ref.model.target_modules=all-linear \ + actor_rollout_ref.actor.optim.lr=3e-5 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=${mini_batch_size} \ + actor_rollout_ref.actor.ppo_micro_batch_size=${mini_batch_size} \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=True \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=${mini_batch_size} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.2 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.rollout.max_num_seqs=512 \ + actor_rollout_ref.rollout.max_model_len=1536 \ + actor_rollout_ref.rollout.max_num_batched_tokens=1536 \ + actor_rollout_ref.rollout.enable_chunked_prefill=False \ + actor_rollout_ref.rollout.load_format=safetensors \ + actor_rollout_ref.rollout.layered_summon=True \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${mini_batch_size} \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \ + actor_rollout_ref.actor.entropy_coeff=0.001 \ + algorithm.kl_ctrl.kl_coef=0.001 \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name=${WANDB_PROJECT} \ + trainer.experiment_name=${WANDB_EXP} \ + trainer.n_gpus_per_node=1 \ + trainer.nnodes=1 \ + trainer.save_freq=20 \ + trainer.test_freq=5 \ + trainer.total_epochs=1 $@ 2>&1 | tee ${WANDB_PROJECT}.log diff --git a/examples/tuning/7b/qwen2-7b_grpo_2_h800_fsdp_vllm.sh b/examples/tuning/7b/qwen2-7b_grpo_2_h800_fsdp_vllm.sh index ae6d052f6..598e82b41 100644 --- a/examples/tuning/7b/qwen2-7b_grpo_2_h800_fsdp_vllm.sh +++ b/examples/tuning/7b/qwen2-7b_grpo_2_h800_fsdp_vllm.sh @@ -1,6 +1,5 @@ set -x -#export VLLM_ATTENTION_BACKEND=XFORMERS gsm8k_train_path=$HOME/data/rlhf/gsm8k/train.parquet gsm8k_test_path=$HOME/data/rlhf/math/test.parquet @@ -39,11 +38,11 @@ PYTHONPATH=/opt/tiger/open_verl python3 -m verl.trainer.main_ppo \ actor_rollout_ref.ref.fsdp_config.param_offload=True \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ - trainer.logger=['console','wandb'] \ + trainer.logger='["console","wandb"]' \ trainer.project_name='verl_grpo_example_gsm8k' \ trainer.experiment_name='qwen2_7b_function_rm' \ trainer.n_gpus_per_node=2 \ trainer.nnodes=1 \ trainer.save_freq=-1 \ trainer.test_freq=5 \ - trainer.total_epochs=15 $@ \ No newline at end of file + trainer.total_epochs=15 $@ diff --git a/performance_tuning_guide.md b/performance_tuning_guide.md new file mode 100644 index 000000000..3629c0909 --- /dev/null +++ b/performance_tuning_guide.md @@ -0,0 +1,56 @@ +Optimizing training speed for large models like the 70B long-context model is a battle against two constraints: GPU memory and computational throughput. This guide explains the key parameters in veRL that allow you to manage these constraints effectively. + +## 1. Preliminaries + +Before diving into individual parameters, there is a crucial formula of global (or mini-)batch size to keep in mind. + +$$ +\text{mini\_batch\_size} = \text{micro\_batch\_size} \times \text{grad\_accum\_steps} \times \text{data\_parallel\_size} +$$ + +- **Mini-batch size**: The true batch size that the optimizer sees. Affects ML stability and convergence. +- **Micro-batch size**: What a single GPU processes in one forward/backward pass. Affects GPU memory usage and utilization. +- **Gradient accumulation**: How many micro-batch gradients to sum up before updating the model, usually to achieve a larger batch size. +- **Data parallel size (DP)**: The number of parallel groups working on different data replicas. + + +## 2. Training + +- **Mini-batch size(`actor.ppo_mini_batch_size`)**[tag: ML, system]: The mini (or global) batch size to update the model's gradients. This is the conventional mini (or global) batch size in classical ML and large-scale LLM (pre-)training. + - From an ML perspective, a larger batch size leads to more stable gradient estimates (lower variance), which is especially beneficial for LLM training. However, too large a mini-batch size can hurt model generalization by converging to sharp minima, slow down convergence in terms of number of parameter updates (fewer updates per epoch). When using online RL algorithms such as PPO, setting `actor.ppo_mini_batch_size` close to or even equal to `data.train_prompt_bsz` can improve “on-policyness,” thereby reducing policy lag and potentially leading to more stable and consistent updates. + - From a system perspective, we need to consider GPU memory limits and training speed. A larger mini-batch size, when achievable, generally leads to faster training by better utilizing the GPU's computational resources. This speedup is most obvious when no gradient accumulation is needed. Even with gradient accumulation, a larger mini-batch size can be faster because it amortizes the cost of the optimizer step over more data . +- **Micro-batch size(`actor.ppo_micro_batch_size_per_gpu`**and deprecated **`actor.ppo_micro_batch_size`)**[tag: system]: The actual batch size processed per GPU (or DP) in a single forward/backward pass. Unlike Mini-batch size, Micro-batch size is purely a system optimization parameter and does not directly affect ML performance. Its primary role is to manage memory usage while maximizing GPU utilization. Finding the optimal micro-batch size is tricky: it depends on the training sequence length, model sharding strategy (e.g., FSDP), and model parallelization scheme (e.g., Megatron). Also note, the `per_gpu` naming is a bit imprecise, it actually means `per_data_parallel` . E.g., if you enable sequence parallel, multiple GPUs in one DP handle one sample. + + *Best practice:* Start with a micro-batch size of 1. If there are no out-of-memory (OOM) errors, incrementally increase it to find the sweet spot that maximizes throughput. If you encounter OOM, consider more aggressive sharding or parallelism strategies like sequence parallelism before reducing the micro-batch size. Another option is to use `use_dynamic_bsz`. + +- **Dynamic batch size(`actor.use_dynamic_bsz`)**[tag: system]: If true, this setting packs multiple sequences into a single, long sequence up to a specified token limit and **ignores the specified micro-batch size**. Instead of tuning the micro-batch size, you tune `actor.ppo_max_token_len_per_gpu`. Under the hood, it uses sequence packing to arrange sequences into a single batch, minimizing padding tokens. This technique saves redundant GPU compute that would otherwise be wasted on padding tokens. + + *Best practice*: Enable it for workloads with variable sequence lengths. + +- **Sequence parallel(`actor.ulysses_sequence_parallel_size`** in FSDP or **`actor.megatron.context_parallel_size`** in Megatron**)**: The Transformer architecture's self-attention mechanism leads to activation memory scaling (near) quadratically with context length ($O(N^2)$). In long-context regimes (e.g., 32k sequence length), this activation memory becomes a primary cause of OOM errors. Sequence Parallelism (SP) addresses this by sharding the input sequence across GPUs within a data-parallel group. For example, setting `SP=4` on a 32k sequence would result in each of the 4 GPUs processing an 8k sub-sequence. The trade-off is that your effective data-parallel size is reduced by a factor of SP, which can slow down training due to increased communication for the attention mechanism's all-gather operations. + + *Best practice*: Set `SP > 1` when OOM errors are primarily caused by long context lengths. + +- **Gradient checkpointing(`enable_gradient_checkpointing`)**[tag: system]: During the backward pass, gradients are computed using activations from the forward pass, which are typically stored in GPU memory. Instead of storing all activations, gradient checkpointing saves only a subset and recomputes the others on-the-fly during the backward pass. This can drastically reduce memory consumption at the cost of computation time for the backward pass. + +- **Parameter offload(`param_offload`, `grad_offload`, `optimizer_offload`)**[tag: system]: This technique frees up GPU memory by moving model parameters or optimizer states to CPU. During computation, the necessary data is transferred back to the GPU. This can save a significant amount of GPU memory but introduces a substantial communication bottleneck between the CPU and GPU. + +**FSDP-specific** + +- **FSDP size(`fsdp_size=-1`)**[tag: system]: The `fsdp_size` controls the number of GPUs in each FSDP group. `fsdp_size=-1` (default): FSDP will shard the model parameters, gradients, and optimizer states **across all available training GPUs**. `fsdp_size > 1`: Specifies a custom number of GPUs per FSDP group. + +**Megatron-specific** + +- **Tensor parallel(`megatron.tensor_model_parallel_size`)**[tag: system]: Tensor Parallelism (TP) is an *intra-layer* parallelism technique that shards individual weight matrices within Transformer layers (e.g., in MLP and self-attention blocks) across multiple GPUs. This significantly reduces the memory footprint of the model parameters, grad optimizer states, and activation on each GPU. However, it requires high-bandwidth communication (all-reduce operations) within the TP group after each parallel computation. +- **Pipeline parallel(`megatron.pipeline_model_parallel_size`)**[tag: system]: Pipeline Parallelism (PP) is an *inter-layer* parallelism technique that partitions the model's layers into sequential stages, with each stage assigned to a different GPU. This reduces memory by requiring each GPU to store only a fraction of the model's layers and their activations. + +## 3. Rollout (vLLM) + +- **Generation batch size(`data.train_prompt_bsz`)**: This is the number of unique prompts sent to the vLLM engine for generation in a single batch. This is distinct from the training mini-batch size. A larger generation batch size can increase throughput by better utilizing the GPU and reducing the times of switch of rollout and train phases, but introduces the risk of “off-policy” if your training doesn’t support large enough mini-batch size (`actor.ppo_mini_batch_size`). +- **Tensor parallel(`rollout.tensor_model_parallel_size`)**: Similar to training, this shards the model's weights across multiple GPUs to serve models that are too large for a single GPU. For inference, TP is crucial for both fitting the model in memory and for increasing throughput. +- **KV cache(`rollout.gpu_memory_utilization`)**: The Key-Value (KV) cache is crucial for faster inference and the primary memory consumer during LLM inference. It stores the key and value states for all previously generated tokens in a sequence to prevent costly recomputation. `gpu_memory_utilization` in vLLM controls the fraction of GPU memory pre-allocated for the KV cache. A higher value (e.g., 0.90) allows vLLM to handle more concurrent requests and/or longer sequences efficiently, but leaves less memory for other processes. +- **Max token length per GPU(`infer_ppo_max_token_len`)**: Max tokens to be processed in the forward computation. Similar to `actor.ppo_max_token_len_per_gpu` in training. Increasing it will increase token throughput, but also increasing the GPU memory. + +## 4. Case Studies + +An example 70B long-cot training, using distilled-r1-Llama3.1-70B as example, is run `sbatch scripts/example_multinode_rl_llama3.1_70b_distill_megatron.sh` using 32 nodes with training TP=8, PP=2, SP=8, rollout TP=4. The parameters are tuned (e.g., we find PP doesn't need to be high, 2 is enough), but not optmized. \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index a97a81e22..e10da9e4d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,9 +19,9 @@ name = "verl" dynamic = ["version", "dependencies", "optional-dependencies", "authors", "urls"] description = "verl: Volcano Engine Reinforcement Learning for LLM" -license = {file = "LICENSE"} # or "Apache-2.0", if you prefer an SPDX identifier +license = {text = "Apache-2.0"} # Changed from file to text format readme = {file = "README.md", content-type = "text/markdown"} -requires-python = ">=3.8" +requires-python = ">=3.10" # ------------------------------- # tool.ruff - Linting configuration @@ -29,7 +29,8 @@ requires-python = ">=3.8" [tool.ruff] # Note: While the formatter will attempt to format lines such that they remain within the line-length, # it isn't a hard upper bound, and formatted lines may exceed the line-length. -line-length = 300 # TODO: Reduce this to a more reasonable value +line-length = 120 +exclude = ["tests/workers/rollout/test_sglang_async_rollout_sf_tools.py", "scripts/legacy_model_merger.py"] [tool.ruff.lint] isort = {known-first-party = ["verl"]} @@ -56,10 +57,14 @@ ignore = [ "B007", # f-string format "UP032", - # Can remove once 3.10+ is the minimum Python version - "UP007", # `.log()` statement uses f-string "G004", + # X | None for type annotations + "UP045", + # deprecated import + "UP035", + # line length + "E501" ] # ------------------------------- @@ -83,5 +88,6 @@ version = {file = "verl/version/version"} [tool.setuptools.package-data] verl = [ "version/*", - "trainer/config/*.yaml" -] \ No newline at end of file + "trainer/config/*.yaml", + "trainer/config/*/*.yaml", +] diff --git a/recipe/README.md b/recipe/README.md index e47f40e96..29fb40384 100644 --- a/recipe/README.md +++ b/recipe/README.md @@ -1,8 +1,12 @@ # Recipe The examples under `recipes/` are representative extensions to verl for specific end-to-end RL training recipes. +The help the community reproduce experiments, verl team provides a snapshot of the codebase when each recipe is initially PR'ed to verl main. You can find them via [github branches](https://github.com/volcengine/verl/branches/all?query=recipe) # Awesome work using verl +- [Logic-RL](https://github.com/Unakar/Logic-RL): a reproduction of DeepSeek R1 Zero on 2K Tiny Logic Puzzle Dataset. ![GitHub Repo stars](https://img.shields.io/github/stars/Unakar/Logic-RL) +- [Seed-Coder](https://github.com/ByteDance-Seed/Seed-Coder): RL training of Seed-Coder boosts performance on competitive programming ![GitHub Repo stars](https://img.shields.io/github/stars/ByteDance-Seed/Seed-Coder) +- [all-hands/openhands-lm-32b-v0.1](https://www.all-hands.dev/blog/introducing-openhands-lm-32b----a-strong-open-coding-agent-model): A strong, open coding agent model, trained with [multi-turn fine-tuning](https://github.com/volcengine/verl/pull/195) - [s3](https://github.com/pat-jj/s3) **Efficient Yet Effective** Search Agent Training via RL ![GitHub Repo stars](https://img.shields.io/github/stars/pat-jj/s3) - [Rec-R1](https://arxiv.org/pdf/2503.24289): Bridging Generative Large Language Models and Recommendation Systems via Reinforcement Learning - [Explore RL Data Scaling](https://arxiv.org/abs/2503.22230): Exploring Data Scaling Trends and Effects in Reinforcement Learning from Human Feedback diff --git a/recipe/char_count/README.md b/recipe/char_count/README.md new file mode 100644 index 000000000..18f902d15 --- /dev/null +++ b/recipe/char_count/README.md @@ -0,0 +1,41 @@ +# Char Count +## Introduction +Char count is a simple NLP task. We create it for beginners to grasp the idea of RLVR. The task can be trained using a tiny model (e.g., https://huggingface.co/HuggingFaceTB/SmolLM2-135M) on a consumer GPU with only 8GB. + +## Problem formulation +The prompt is: "How many {char} are there in {word}?". In order for LLM to better answer this question, we create SFT dataset with intermediate steps. For example, + +```text +Question: How many n are there in n-i-n-e? +Answer: +n = n +i != n +n = n +e != n +\boxed{2} +``` + +Note that +- We add a dash between each individual char to make the task easier because each individual char will be tokenized to the same token by most tokenizer. +- In the SFT dataset, we create a CoT by listing all the individual chars and whether it equals to the target. In the end, it outputs the final answer inside the box. +- The task can be verified. +- The word is not always meaningful. Each char is sampled uniformly from a to z. We make the total length and the answer uniformly distributed within a range. + +## Scripts +To create the dataset, run +```bash +python3 create_dataset.py +``` +We create a train set and a val set. Both of them are used of SFT and RL. You can specify the total number of data, min/max length and data path. + +To run the SFT +```bash +bash train_sft.sh +``` +We train SFT for 3 epochs. After 3 epochs, the validation score is around 0.12. + +To run GRPO +```bash +bash train_grpo.sh +``` +We train GRPO for 2 epochs. After 2 epochs, the validation score is around 0.36. diff --git a/recipe/char_count/create_dataset.py b/recipe/char_count/create_dataset.py new file mode 100644 index 000000000..47571e023 --- /dev/null +++ b/recipe/char_count/create_dataset.py @@ -0,0 +1,191 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Task description: +Given a random word and a random char, count the number of occurrence of char in the word. + +Create CoT dataset that split the word into separate char. Then list the char and count the occurrence. + +The word set comes from shakespeare +""" + +import os.path +import random + +prompt_template = "How many {} are there in word {}?" + + +def generate_random_char(): + return chr(97 + random.randint(0, 25)) + + +def create_prompt_response(min_length=3, max_length=5): + # randomly generate a length + word_length = random.randint(min_length, max_length) + # randomly generate a target count number. This makes the target number + target_count_number = random.randint(1, word_length) + + char_lst = [] + # generate the word + # step 1: generate the target word + target_char = generate_random_char() + + for _ in range(target_count_number): + char_lst.append(target_char) + + # step 2: generate other words + for _ in range(word_length - target_count_number): + while True: + char = generate_random_char() + if char != target_char: + char_lst.append(char) + break + + # step 3: random permute char_lst + random.shuffle(char_lst) + + word = "-".join(char_lst) + + prompt = prompt_template.format(target_char, word) + final_answer = [] + + # cot + number = 0 + for i, char in enumerate(char_lst): + cot = f"{char}" + if char != target_char: + cot += " != " + else: + cot += " = " + number += 1 + cot += f"{target_char}." + + final_answer.append(cot) + + conclusion = f"\\boxed{{{number}}} {target_char} in {word}." + + final_answer.append(conclusion) + + final_answer = "\n".join(final_answer) + + return prompt, final_answer + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--total_number", type=int, default=10000) + parser.add_argument("--min_length", type=int, default=5) + parser.add_argument("--max_length", type=int, default=20) + parser.add_argument("--data_path", type=str, default="~/data/char_count") + + args = vars(parser.parse_args()) + + total_number = args["total_number"] + min_length = args["min_length"] + max_length = args["max_length"] + data_path = args["data_path"] + data_path = os.path.expanduser(data_path) + + full_output = [] + for _ in range(total_number): + output = create_prompt_response(min_length=min_length, max_length=max_length) + full_output.append(output) + + # random reorder + random.shuffle(full_output) + + # split for train and test + train_split_len = int(0.9 * len(full_output)) + train_outputs = full_output[:train_split_len] + test_output = full_output[train_split_len:] + + sft_train_dataset = {"prompt": [], "response": []} + + for o in train_outputs: + sft_train_dataset["prompt"].append(o[0]) + sft_train_dataset["response"].append(o[1]) + + sft_test_dataset = {"prompt": [], "response": []} + + for o in test_output: + sft_test_dataset["prompt"].append(o[0]) + sft_test_dataset["response"].append(o[1]) + + import pandas as pd + + sft_train_dataset = pd.DataFrame(data=sft_train_dataset) + sft_test_dataset = pd.DataFrame(data=sft_test_dataset) + + folder = os.path.join(data_path, "sft") + + os.makedirs(folder, exist_ok=True) + + sft_train_dataset.to_parquet(os.path.join(folder, "train.parquet")) + sft_test_dataset.to_parquet(os.path.join(folder, "test.parquet")) + + # build RL dataset + rl_train_dataset = {"prompt": [], "data_source": [], "ability": [], "reward_model": [], "extra_info": []} + + rl_test_dataset = {"prompt": [], "data_source": [], "ability": [], "reward_model": [], "extra_info": []} + + from verl.utils.reward_score.math import last_boxed_only_string, remove_boxed + + for o in train_outputs: + prompt = o[0] + response = o[1] + prompt_with_template = [ + { + "role": "user", + "content": prompt, + } + ] + + rl_train_dataset["prompt"].append(prompt_with_template) + rl_train_dataset["data_source"].append("char_count") + rl_train_dataset["ability"].append("other") + rl_train_dataset["reward_model"].append( + {"style": "rule", "ground_truth": remove_boxed(last_boxed_only_string(response))} + ) + rl_train_dataset["extra_info"].append({"response": response}) + + for o in test_output: + prompt = o[0] + response = o[1] + prompt_with_template = [ + { + "role": "user", + "content": prompt, + } + ] + + rl_test_dataset["prompt"].append(prompt_with_template) + rl_test_dataset["data_source"].append("char_count") + rl_test_dataset["ability"].append("other") + rl_test_dataset["reward_model"].append( + {"style": "rule", "ground_truth": remove_boxed(last_boxed_only_string(response))} + ) + rl_test_dataset["extra_info"].append({"response": response}) + + rl_train_dataset = pd.DataFrame(data=rl_train_dataset) + rl_test_dataset = pd.DataFrame(data=rl_test_dataset) + + folder = os.path.join(data_path, "rl") + + os.makedirs(folder, exist_ok=True) + + rl_train_dataset.to_parquet(os.path.join(folder, "train.parquet")) + rl_test_dataset.to_parquet(os.path.join(folder, "test.parquet")) diff --git a/recipe/char_count/reward_function.py b/recipe/char_count/reward_function.py new file mode 100644 index 000000000..9bdffe2a5 --- /dev/null +++ b/recipe/char_count/reward_function.py @@ -0,0 +1,34 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Reward function +""" + +from verl.utils.reward_score import math + + +def char_count_reward_function(data_source, solution_str, ground_truth, extra_info=None): + try: + last_boxed_string = math.last_boxed_only_string(solution_str) + if last_boxed_string is None: + return 0 + solution = math.remove_boxed(last_boxed_string) + if solution == ground_truth: + return 1 + else: + return 0 + except Exception: + print(ground_truth, solution_str) + return 0 diff --git a/recipe/char_count/train_grpo.sh b/recipe/char_count/train_grpo.sh new file mode 100644 index 000000000..5de85422f --- /dev/null +++ b/recipe/char_count/train_grpo.sh @@ -0,0 +1,43 @@ +set -x + + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=$HOME/data/char_count/rl/train.parquet \ + data.val_files=$HOME/data/char_count/rl/test.parquet \ + data.train_batch_size=128 \ + data.max_prompt_length=128 \ + data.max_response_length=128 \ + data.filter_overlong_prompts=False \ + data.truncation='error' \ + actor_rollout_ref.model.path=./models/sft/global_step_105 \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=16 \ + actor_rollout_ref.actor.use_dynamic_bsz=True \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=5000 \ + actor_rollout_ref.actor.use_kl_loss=False \ + actor_rollout_ref.actor.kl_loss_coef=0.0 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=True \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \ + actor_rollout_ref.rollout.n=8 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","tensorboard"]' \ + trainer.project_name='verl_example' \ + trainer.experiment_name='smol135m_grpo' \ + trainer.val_before_train=True \ + trainer.n_gpus_per_node=1 \ + trainer.nnodes=1 \ + trainer.save_freq=-1 \ + trainer.test_freq=5 \ + trainer.total_epochs=2 \ + custom_reward_function.path=recipe/char_count/reward_function.py \ + custom_reward_function.name=char_count_reward_function diff --git a/recipe/char_count/train_sft.sh b/recipe/char_count/train_sft.sh new file mode 100644 index 000000000..56f5cec53 --- /dev/null +++ b/recipe/char_count/train_sft.sh @@ -0,0 +1,21 @@ +set -x + +nproc_per_node=1 +save_path=./models/sft + +torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \ + -m verl.trainer.fsdp_sft_trainer \ + data.train_files=$HOME/data/char_count/sft/train.parquet \ + data.val_files=$HOME/data/char_count/sft/test.parquet \ + data.prompt_key=prompt \ + data.response_key=response \ + data.micro_batch_size_per_gpu=8 \ + data.max_length=256 \ + data.train_batch_size=256 \ + use_remove_padding=True \ + model.partial_pretrain=HuggingFaceTB/SmolLM2-135M-Instruct \ + trainer.default_local_dir=$save_path \ + trainer.project_name=char_count-sft \ + trainer.experiment_name=char_count-sft-SmolLM2-135M-Instruct \ + trainer.total_epochs=3 \ + trainer.logger=console \ No newline at end of file diff --git a/recipe/dalu/config/dapo_fsdp_config.yaml b/recipe/dalu/config/dapo_fsdp_config.yaml new file mode 100644 index 000000000..47141447e --- /dev/null +++ b/recipe/dalu/config/dapo_fsdp_config.yaml @@ -0,0 +1,26 @@ +hydra: + searchpath: + - file://verl/trainer/config + +defaults: + - ppo_trainer + - _self_ + +data: + gen_batch_size: ${data.train_batch_size} + +reward_model: + reward_manager: dapo + overlong_buffer: + enable: False # We try to avoid forgetting to set enable + len: 0 + penalty_factor: 0.0 + log: False + +algorithm: + filter_groups: + _target_: verl.trainer.config.FilterGroupsConfig + enable: False # We try to avoid forgetting to set enable + metric: null # acc / score / seq_reward / seq_final_reward / ... + max_num_gen_batches: 0 # Non-positive values mean no upper limit + diff --git a/recipe/dalu/config/dapo_megatron_config.yaml b/recipe/dalu/config/dapo_megatron_config.yaml new file mode 100644 index 000000000..5b83fab85 --- /dev/null +++ b/recipe/dalu/config/dapo_megatron_config.yaml @@ -0,0 +1,25 @@ +hydra: + searchpath: + - file://verl/trainer/config + +defaults: + - ppo_megatron_trainer + - _self_ + +data: + gen_batch_size: ${data.train_batch_size} + +reward_model: + reward_manager: dapo + overlong_buffer: + enable: False # We try to avoid forgetting to set enable + len: 0 + penalty_factor: 0.0 + log: False + +algorithm: + filter_groups: + _target_: verl.trainer.config.FilterGroupsConfig + enable: False # We try to avoid forgetting to set enable + metric: null # acc / score / seq_reward / seq_final_reward / ... + max_num_gen_batches: 0 # Non-positive values mean no upper limit \ No newline at end of file diff --git a/recipe/dalu/config/dapo_trainer.yaml b/recipe/dalu/config/dapo_trainer.yaml new file mode 100644 index 000000000..47ac00fd6 --- /dev/null +++ b/recipe/dalu/config/dapo_trainer.yaml @@ -0,0 +1,28 @@ +hydra: + searchpath: + - file://verl/trainer/config + +defaults: + - ppo_trainer + - _self_ + +data: + gen_batch_size: ${data.train_batch_size} + +reward_model: + reward_manager: dapo + overlong_buffer: + enable: False # We try to avoid forgetting to set enable + len: 0 + penalty_factor: 0.0 + log: False + +algorithm: + filter_groups: + _target_: verl.trainer.config.FilterGroupsConfig + enable: False # We try to avoid forgetting to set enable + metric: null # acc / score / seq_reward / seq_final_reward / ... + max_num_gen_batches: 0 # Non-positive values mean no upper limit + +trainer: + project_name: verl-dapo diff --git a/recipe/dalu/dalu_ray_trainer.py b/recipe/dalu/dalu_ray_trainer.py new file mode 100644 index 000000000..206ff069e --- /dev/null +++ b/recipe/dalu/dalu_ray_trainer.py @@ -0,0 +1,733 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +FSDP PPO Trainer with Ray-based single controller. +This trainer supports model-agonistic model initialization with huggingface +""" + +import os +import uuid +from collections import defaultdict +from copy import deepcopy +from pprint import pprint + +import numpy as np +import pandas as pd +import torch +from tqdm import tqdm + +from verl import DataProto +from verl.trainer.ppo.core_algos import agg_loss +from verl.trainer.ppo.metric_utils import ( + compute_data_metrics, + compute_throughout_metrics, + compute_timing_metrics, + reduce_metrics, +) +from verl.trainer.ppo.ray_trainer import ( + AdvantageEstimator, + RayPPOTrainer, + apply_kl_penalty, + compute_advantage, + compute_response_mask, +) +from verl.utils.profiler import marked_timer + + +class RayDALUTrainer(RayPPOTrainer): + """ + Note that this trainer runs on the driver process on a single CPU/GPU node. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Initialize the drop count attributes to ensure they're always available + self.n_drop_easy = 0 + self.n_drop_hard = 0 + + def _create_priority_dataloader(self, epoch_idx, dynamic_filtering, enable_budget): + """ + Create the dataloader every time before the epoch starts. + """ + from torch.utils.data import SequentialSampler + from verl.trainer.main_ppo import create_rl_sampler + from verl.utils.dataset.rl_dataset import collate_fn + from torchdata.stateful_dataloader import StatefulDataLoader + + # Initialize columns for the first epoch + max_easy_ratio = self.config.data.get("max_easy_ratio", 0.1) + max_hard_ratio = self.config.data.get("max_hard_ratio", 0.2) + if epoch_idx == 0: + # Get the initial pass rate column name from config, with default fallback + initial_pass_rate_column = self.config.data.get("initial_pass_rate_column", "qwen3_30b_pass_rate") + self.train_dataset.dataframe["prev_pass_rate"] = self.train_dataset.dataframe[initial_pass_rate_column] + # use half of the max response length as the average length for the first epoch + self.train_dataset.dataframe["prev_passed_avg_length"] = self.config.data.get("max_response_length", 1024*28) * 3 / 4 + self.train_dataset.dataframe["prev_passed_max_length"] = self.config.data.get("max_response_length", 1024*28) * 3 / 4 + self.train_dataset.dataframe["prev_passed_80th_length"] = self.config.data.get("max_response_length", 1024*28) + self.train_dataset.dataframe["prev_passed_50th_length"] = self.config.data.get("max_response_length", 1024*28) + + def _assign_length_budget(row, pass_rate_upper_bound, max_response_length): + prompt_pass_rate = row['prev_pass_rate'] + passed_prompt_avg_length = row['prev_passed_avg_length'] + passed_prompt_max_length = row['prev_passed_max_length'] + + # Get configurable multipliers with default values + perfect_pass_rate_multiplier = self.config.data.get("perfect_pass_rate_multiplier", 1.0) + high_pass_rate_multiplier = self.config.data.get("high_pass_rate_multiplier", 0.8) + + if prompt_pass_rate == 1.0: + new_length_budget = max(high_pass_rate_multiplier * passed_prompt_max_length, passed_prompt_avg_length) + elif prompt_pass_rate > pass_rate_upper_bound: + new_length_budget = max(high_pass_rate_multiplier * passed_prompt_max_length, passed_prompt_avg_length) + else: + new_length_budget = passed_prompt_max_length + (max_response_length - passed_prompt_max_length) * (1 - prompt_pass_rate) + + new_length_budget = max(new_length_budget, 4000) # Set minimum to 2000 + new_length_budget = min(new_length_budget, max_response_length) # Cap at max response length + + # print(f"new_length_budget: {new_length_budget}") + # print(f"max_response_length: {max_response_length}") + # print(f"passed_prompt_max_length: {passed_prompt_max_length}") + # print(f"passed_prompt_avg_length: {passed_prompt_avg_length}") + # print(f"prompt_pass_rate: {prompt_pass_rate}") + # print(f"pass_rate_upper_bound: {pass_rate_upper_bound}") + + return int(new_length_budget) + + if enable_budget: + max_response_length = self.config.data.get("max_response_length", 1024*28) + pass_rate_upper_bound = self.config.trainer.get("pass_rate_upper_bound", 1.0) + + self.train_dataset.dataframe["per_prompt_length_budget"] = self.train_dataset.dataframe.apply( + lambda row: _assign_length_budget(row, pass_rate_upper_bound, max_response_length), + axis=1 + ) + else: + self.train_dataset.dataframe["per_prompt_length_budget"] = self.config.data.get("max_response_length", 1024*28) # Use fixed length budget + + if dynamic_filtering: + original_df = self.train_dataset.dataframe.copy() + # Separate data by pass rate + perfect_mask = original_df["prev_pass_rate"] == 1.0 + failed_mask = original_df["prev_pass_rate"] == 0.0 + medium_mask = (original_df["prev_pass_rate"] > 0.0) & (original_df["prev_pass_rate"] < 1.0) + + # Get indices for each category + medium_indices = original_df[medium_mask].index.tolist() + perfect_indices = original_df[perfect_mask].index.tolist() + failed_indices = original_df[failed_mask].index.tolist() + + # Keep all medium difficulty data + kept_indices = set(medium_indices) + n_medium = len(medium_indices) + + # Limit perfect examples to 1/10 of medium examples + self.n_drop_easy = 0 + if perfect_indices: + np.random.seed(42 + epoch_idx) + n_keep_perfect = int(max(1, min(n_medium * max_easy_ratio, len(perfect_indices)))) + if n_keep_perfect > 0: + kept_perfect = np.random.choice(perfect_indices, size=n_keep_perfect, replace=False) + kept_indices.update(kept_perfect) + self.n_drop_easy = len(perfect_indices) - n_keep_perfect + + # Limit failed examples to 1/5 of medium examples + self.n_drop_hard = 0 + if failed_indices: + np.random.seed(43 + epoch_idx) + n_keep_failed = int(max(1, min(n_medium * max_hard_ratio, len(failed_indices)))) + if n_keep_failed > 0: + kept_failed = np.random.choice(failed_indices, size=n_keep_failed, replace=False) + kept_indices.update(kept_failed) + self.n_drop_hard = len(failed_indices) - n_keep_failed + + filtered_df = original_df.loc[list(kept_indices)].reset_index(drop=True) + # Log filtering statistics + n_perfect_kept = len(set(perfect_indices) & kept_indices) + n_failed_kept = len(set(failed_indices) & kept_indices) + n_medium_kept = len(set(medium_indices) & kept_indices) + + print(f"Dataset filtering statistics for epoch {epoch_idx}:") + print(f"Original dataset size: {len(original_df)}") + print(f" - Perfect examples (pass_rate=1.0): {len(perfect_indices)} -> {n_perfect_kept} kept ({n_perfect_kept/max(1,len(perfect_indices))*100:.1f}%)") + print(f" - Failed examples (pass_rate=0.0): {len(failed_indices)} -> {n_failed_kept} kept ({n_failed_kept/max(1,len(failed_indices))*100:.1f}%)") + print(f" - Medium examples (0 {n_medium_kept} kept ({n_medium_kept/max(1,len(medium_indices))*100:.1f}%)") + print(f"Filtered dataset size: {len(filtered_df)}") + print(f"Total discarded data points: {len(original_df) - len(filtered_df)}") + print(f"Total percentage discarded: {100 * (len(original_df) - len(filtered_df)) / len(original_df):.2f}%") + else: + filtered_df = self.train_dataset.dataframe.copy() + + # Shuffle the dataset before sorting to randomize order of samples with same pass_rate + # This ensures better diversity in training batches + filtered_df = filtered_df.sample(frac=1.0, random_state=42 + epoch_idx).reset_index(drop=True) + + # Sort by per_prompt_length_budget for more efficient rollout batching + filtered_df = filtered_df.sort_values(by="per_prompt_length_budget", ascending=True).reset_index(drop=True) + + # Create filtered dataset copy + train_dataset_copy = deepcopy(self.train_dataset) + train_dataset_copy.dataframe = filtered_df + + # Create dataloader + self.train_dataloader = StatefulDataLoader( + dataset=train_dataset_copy, + batch_size=self.config.data.get("gen_batch_size", self.config.data.train_batch_size), + num_workers=self.config.data.get("dataloader_num_workers", 8), + drop_last=True, + collate_fn=collate_fn, + sampler=SequentialSampler(data_source=filtered_df), + ) + + print(f"Size of train dataloader: {len(self.train_dataloader)}, Size of val dataloader: {len(self.val_dataloader)}") + assert len(self.train_dataloader) >= 1, "Train dataloader is empty!" + assert len(self.val_dataloader) >= 1, "Validation dataloader is empty!" + + return train_dataset_copy + + def fit(self): + """ + The training loop of PPO. + The driver process only need to call the compute functions of the worker group through RPC + to construct the PPO dataflow. + The light-weight advantage computation is done on the driver process. + """ + from omegaconf import OmegaConf + + from verl.utils.tracking import Tracking + + logger = Tracking( + project_name=self.config.trainer.project_name, + experiment_name=self.config.trainer.experiment_name, + default_backend=self.config.trainer.logger, + config=OmegaConf.to_container(self.config, resolve=True), + run_id=self.config.trainer.get("run_id", ""), + ) + + self.global_steps = 0 + + # load checkpoint before doing anything + self._load_checkpoint() + + # perform validation before training + # currently, we only support validation using the reward_function. + if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True): + val_metrics = self._validate() + assert val_metrics, f"{val_metrics=}" + pprint(f"Initial validation metrics: {val_metrics}") + logger.log(data=val_metrics, step=self.global_steps) + if self.config.trainer.get("val_only", False): + return + + # add tqdm + progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress") + + # we start from step 1 + self.global_steps += 1 + last_val_metrics = None + + timing_raw = defaultdict(float) + batch = None + num_prompt_in_batch = 0 + num_gen_batches = 0 + for epoch in range(self.config.trainer.total_epochs): + train_dataset = self._create_priority_dataloader( + epoch_idx=epoch, + dynamic_filtering=self.config.data.get("dynamic_filtering", False), + enable_budget=self.config.trainer.get("enable_budget", False), + ) + # create create the default_local_dir if not exists + if not os.path.exists(self.config.trainer.default_local_dir): + os.makedirs(self.config.trainer.default_local_dir) + train_dataset.dataframe.to_csv(os.path.join(self.config.trainer.default_local_dir, + f"train_dataset_epoch_{epoch}.csv"), index=False) + + for batch_dict in self.train_dataloader: + metrics = {} + + do_profile = ( + self.global_steps in self.config.trainer.profile_steps + if self.config.trainer.profile_steps is not None + else False + ) + with marked_timer("start_profile", timing_raw): + if do_profile: + self.actor_rollout_wg.start_profile(role="e2e", profile_step=self.global_steps) + if self.use_reference_policy: + self.ref_policy_wg.start_profile() + if self.use_critic: + self.critic_wg.start_profile() + if self.use_rm: + self.rm_wg.start_profile() + + new_batch: DataProto = DataProto.from_single_dict(batch_dict) + num_gen_batches += 1 + # pop those keys for generation + if "multi_modal_data" in new_batch.non_tensor_batch.keys(): + gen_batch = new_batch.pop( + batch_keys=["input_ids", "attention_mask", "position_ids"], + non_tensor_batch_keys=["raw_prompt_ids", "multi_modal_data", "per_prompt_length_budget"], + ) + else: + gen_batch = new_batch.pop( + batch_keys=["input_ids", "attention_mask", "position_ids"], + non_tensor_batch_keys=["raw_prompt_ids", "per_prompt_length_budget"], + ) + gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) + + is_last_step = self.global_steps >= self.total_training_steps + + with marked_timer("step", timing_raw): + # generate a batch + with marked_timer("gen", timing_raw, "red"): + gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) + timing_raw.update(gen_batch_output.meta_info["timing"]) + gen_batch_output.meta_info.pop("timing", None) + + if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX: + with marked_timer("gen_max", timing_raw, "red"): + gen_baseline_batch = deepcopy(gen_batch) + gen_baseline_batch.meta_info["do_sample"] = False + gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch) + + new_batch = new_batch.union(gen_baseline_output) + reward_baseline_tensor = self.reward_fn(new_batch) + reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1) + + new_batch.pop(batch_keys=list(gen_baseline_output.batch.keys())) + + new_batch.batch["reward_baselines"] = reward_baseline_tensor + + del gen_baseline_batch, gen_baseline_output + + new_batch.non_tensor_batch["uid"] = np.array( + [str(uuid.uuid4()) for _ in range(len(new_batch.batch))], dtype=object + ) + # repeat to align with repeated responses in rollout + new_batch = new_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) + + # Remove per_prompt_length_budget from gen_batch_output if it exists to prevent union conflict + if "per_prompt_length_budget" in gen_batch_output.non_tensor_batch: + gen_batch_output.non_tensor_batch.pop("per_prompt_length_budget") + + # raise ValueError(gen_batch_output, gen_batch_output.non_tensor_batch) + new_batch = new_batch.union(gen_batch_output) + + with marked_timer("reward", timing_raw, "yellow"): + # compute scores. Support both model and function-based. + # We first compute the scores using reward model. Then, we call reward_fn to combine + # the results from reward model and rule-based results. + if self.use_rm: + # we first compute reward model score + reward_tensor = self.rm_wg.compute_rm_score(new_batch) + new_batch = new_batch.union(reward_tensor) + + # we combine with rule-based rm + reward_extra_infos_dict: dict[str, list] + try: + reward_result = self.reward_fn(new_batch, return_dict=True) + reward_tensor = reward_result["reward_tensor"] + reward_extra_infos_dict = reward_result.get("reward_extra_info", {}) + except Exception as e: + print(f"Error in reward_fn: {e}") + reward_tensor = self.reward_fn(new_batch) + reward_extra_infos_dict = {} + + new_batch.batch["token_level_scores"] = reward_tensor + + if reward_extra_infos_dict: + new_batch.non_tensor_batch.update( + {k: np.array(v) for k, v in reward_extra_infos_dict.items()} + ) + + # compute rewards. apply_kl_penalty if available + if self.config.algorithm.use_kl_in_reward: + new_batch, kl_metrics = apply_kl_penalty( + new_batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty + ) + metrics.update( + kl_metrics + ) # TODO: This will be cleared if we use multiple genenration batches + else: + new_batch.batch["token_level_rewards"] = new_batch.batch["token_level_scores"] + + if not self.config.algorithm.filter_groups.enable: + batch = new_batch + else: # NOTE: When prompts after filtering is less than train batch size, + # we skip to the next generation batch + metric_name = self.config.algorithm.filter_groups.metric + if metric_name == "seq_final_reward": + # Turn to numpy for easier filtering + new_batch.non_tensor_batch["seq_final_reward"] = ( + new_batch.batch["token_level_rewards"].sum(dim=-1).numpy() + ) + elif metric_name == "seq_reward": + new_batch.non_tensor_batch["seq_reward"] = ( + new_batch.batch["token_level_scores"].sum(dim=-1).numpy() + ) + + # Collect the sequence reward for each trajectory + prompt_uid2metric_vals = defaultdict(list) + for uid, metric_val in zip( + new_batch.non_tensor_batch["uid"], new_batch.non_tensor_batch[metric_name], strict=True + ): + prompt_uid2metric_vals[uid].append(metric_val) + + prompt_uid2metric_std = {} + for prompt_uid, metric_vals in prompt_uid2metric_vals.items(): + prompt_uid2metric_std[prompt_uid] = np.std(metric_vals) + + kept_prompt_uids = [ + uid + for uid, std in prompt_uid2metric_std.items() + if std > 0 or len(prompt_uid2metric_vals[uid]) == 1 + ] + num_prompt_in_batch += len(kept_prompt_uids) + + kept_traj_idxs = [] + for idx, traj_from_prompt_uid in enumerate(new_batch.non_tensor_batch["uid"]): + if traj_from_prompt_uid in kept_prompt_uids: + kept_traj_idxs.append(idx) + + new_batch = new_batch[kept_traj_idxs] + batch = new_batch if batch is None else DataProto.concat([batch, new_batch]) + + prompt_bsz = self.config.data.train_batch_size + if num_prompt_in_batch < prompt_bsz: + print(f"{num_prompt_in_batch=} < {prompt_bsz=}") + max_num_gen_batches = self.config.algorithm.filter_groups.max_num_gen_batches + if max_num_gen_batches <= 0 or num_gen_batches < max_num_gen_batches: + print(f"{num_gen_batches=}. Keep generating...") + progress_bar.update(1) + continue + else: + raise ValueError( + f"{num_gen_batches=} >= {max_num_gen_batches=}." + + " Generated too many. Please check if your data are too difficult." + + " You could also try set max_num_gen_batches=0 to enable endless trials." + ) + else: + # Align the batch + traj_bsz = self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n + batch = batch[:traj_bsz] + + # === Updating === + + batch.batch["response_mask"] = compute_response_mask(batch) + + # Balance the number of valid tokens across DP ranks. + # NOTE: This usually changes the order of data in the `batch`, + # which won't affect the advantage calculation (since it's based on uid), + # but might affect the loss calculation (due to the change of mini-batching). + # TODO: Decouple the DP balancing and mini-batching. + if self.config.trainer.balance_batch: + self._balance_batch(batch, metrics=metrics) + + # compute global_valid tokens + batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() + + # recompute old_log_probs + with marked_timer("old_log_prob", timing_raw, "blue"): + old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) + entropys = old_log_prob.batch["entropys"] + response_masks = batch.batch["response_mask"] + loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode + entropy_agg = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode) + old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()} + metrics.update(old_log_prob_metrics) + old_log_prob.batch.pop("entropys") + batch = batch.union(old_log_prob) + + if self.use_reference_policy: + # compute reference log_prob + with marked_timer("ref", timing_raw, "olive"): + ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) + batch = batch.union(ref_log_prob) + + # compute values + if self.use_critic: + with marked_timer("values", timing_raw, "cyan"): + values = self.critic_wg.compute_values(batch) + batch = batch.union(values) + + with marked_timer("adv", timing_raw, "brown"): + # compute advantages, executed on the driver process + norm_adv_by_std_in_grpo = self.config.algorithm.get("norm_adv_by_std_in_grpo", True) + batch = compute_advantage( + batch, + adv_estimator=self.config.algorithm.adv_estimator, + gamma=self.config.algorithm.gamma, + lam=self.config.algorithm.lam, + num_repeat=self.config.actor_rollout_ref.rollout.n, + norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, + ) + with marked_timer("pass_rate_append", timing_raw, "orange"): + # compute the pass rate for the batch + temp_df = pd.DataFrame({ + "prompt_id": batch.non_tensor_batch["prompt_id"], + "prev_pass_rate": batch.non_tensor_batch["score"] + }) + pass_rate_df = temp_df.groupby("prompt_id", as_index=False)["prev_pass_rate"].mean().set_index('prompt_id')[['prev_pass_rate']] + + # compute the average response length for each prompt_id in the batch + # Only include lengths for successful rollouts (score == 1) + response_length = batch.batch["responses"].shape[-1] + response_mask = batch.batch["attention_mask"][:, -response_length:] + response_lengths = response_mask.sum(dim=-1).float().cpu().numpy() # actual lengths per response + scores = batch.non_tensor_batch["score"] + + # Filter for successful rollouts only (score == 1) + successful_mask = scores == 1 + if np.any(successful_mask): + successful_prompt_ids = batch.non_tensor_batch["prompt_id"][successful_mask] + successful_response_lengths = response_lengths[successful_mask] + + temp_length_df = pd.DataFrame({ + "prompt_id": successful_prompt_ids, + "response_length": successful_response_lengths + }) + avg_length_df = temp_length_df.groupby("prompt_id", as_index=False)["response_length"].mean().set_index('prompt_id')[['response_length']] + avg_length_df.rename(columns={"response_length": "prev_passed_avg_length"}, inplace=True) + max_length_df = temp_length_df.groupby("prompt_id", as_index=False)["response_length"].max().set_index('prompt_id')[['response_length']] + max_length_df.rename(columns={"response_length": "prev_passed_max_length"}, inplace=True) + quantile_8_length_df = temp_length_df.groupby("prompt_id", as_index=False)["response_length"].quantile(0.8).set_index('prompt_id')[['response_length']] + quantile_8_length_df.rename(columns={"response_length": "prev_passed_80th_length"}, inplace=True) + quantile_5_length_df = temp_length_df.groupby("prompt_id", as_index=False)["response_length"].quantile(0.5).set_index('prompt_id')[['response_length']] + quantile_5_length_df.rename(columns={'response_length': "prev_passed_50th_length"}, inplace=True) + + # Update the dataframe with both pass rates and average lengths + self.train_dataset.dataframe = self.train_dataset.dataframe.set_index('prompt_id') + avg_length_df = avg_length_df.astype(self.train_dataset.dataframe['prev_passed_avg_length'].dtypes) + max_length_df = max_length_df.astype(self.train_dataset.dataframe['prev_passed_max_length'].dtypes) + quantile_8_length_df = quantile_8_length_df.astype(self.train_dataset.dataframe['prev_passed_avg_length'].dtypes) + quantile_5_length_df = quantile_5_length_df.astype(self.train_dataset.dataframe['prev_passed_avg_length'].dtypes) + + self.train_dataset.dataframe.update(pass_rate_df) + self.train_dataset.dataframe.update(avg_length_df) + self.train_dataset.dataframe.update(max_length_df) + self.train_dataset.dataframe.update(quantile_8_length_df) + self.train_dataset.dataframe.update(quantile_5_length_df) + print(quantile_8_length_df, quantile_5_length_df) + print(self.train_dataset.dataframe.columns) + + self.train_dataset.dataframe = self.train_dataset.dataframe.reset_index() + else: + # If no successful rollouts in this batch, only update pass rates + self.train_dataset.dataframe = self.train_dataset.dataframe.set_index('prompt_id') + self.train_dataset.dataframe.update(pass_rate_df) + self.train_dataset.dataframe = self.train_dataset.dataframe.reset_index() + print("No successful rollouts (score=1) in this batch, skipping length update") + + # update critic + if self.use_critic: + with marked_timer("update_critic", timing_raw, "pink"): + critic_output = self.critic_wg.update_critic(batch) + critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"]) + metrics.update(critic_output_metrics) + + # implement critic warmup + if self.config.trainer.critic_warmup <= self.global_steps: + # update actor + with marked_timer("update_actor", timing_raw, "red"): + actor_output = self.actor_rollout_wg.update_actor(batch) + actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) + metrics.update(actor_output_metrics) + + # validate + if ( + self.val_reward_fn is not None + and self.config.trainer.test_freq > 0 + and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0) + ): + with marked_timer("testing", timing_raw, "green"): + val_metrics: dict = self._validate() + if is_last_step: + last_val_metrics = val_metrics + metrics.update(val_metrics) + + if self.config.trainer.save_freq > 0 and ( + is_last_step or self.global_steps % self.config.trainer.save_freq == 0 + ): + with marked_timer("save_checkpoint", timing_raw, "green"): + self._save_checkpoint() + + with marked_timer("stop_profile", timing_raw): + if do_profile: + self.actor_rollout_wg.stop_profile() + if self.use_reference_policy: + self.ref_policy_wg.stop_profile() + if self.use_critic: + self.critic_wg.stop_profile() + if self.use_rm: + self.rm_wg.stop_profile() + + # collect metrics + metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) + metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) + # TODO: implement actual tflpo and theoretical tflpo + n_gpus = self.resource_pool_manager.get_n_gpus() + metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus)) + timing_raw = defaultdict(float) # clear timing + + # Add per-prompt metrics for the current batch only + if hasattr(self.train_dataset, 'dataframe') and batch is not None: + # Get unique prompt_ids from the current batch + batch_prompt_ids = batch.non_tensor_batch["prompt_id"] + unique_prompt_ids = np.unique(batch_prompt_ids) + + # Filter dataframe to only include prompts from current batch + batch_df = self.train_dataset.dataframe[self.train_dataset.dataframe['prompt_id'].isin(unique_prompt_ids)] + + if len(batch_df) > 0: + metrics.update({ + "train/per_prompt_pass_rate_avg": batch_df["prev_pass_rate"].mean(), + "train/per_prompt_pass_rate_std": batch_df["prev_pass_rate"].std(), + "train/per_prompt_pass_rate_min": batch_df["prev_pass_rate"].min(), + "train/per_prompt_pass_rate_max": batch_df["prev_pass_rate"].max(), + "train/num_unique_prompts": len(unique_prompt_ids), + "train/per_prompt_length_budget_avg": batch_df["per_prompt_length_budget"].mean(), + "train/per_prompt_length_budget_std": batch_df["per_prompt_length_budget"].std(), + "train/per_prompt_length_budget_min": batch_df["per_prompt_length_budget"].min(), + "train/per_prompt_length_budget_max": batch_df["per_prompt_length_budget"].max(), + "train/prev_passed_max_length_avg": batch_df["prev_passed_max_length"].mean(), + "train/prev_passed_max_length_std": batch_df["prev_passed_max_length"].std(), + "train/prev_passed_max_length_min": batch_df["prev_passed_max_length"].min(), + "train/prev_passed_max_length_max": batch_df["prev_passed_max_length"].max(), + "train/prev_passed_avg_length_avg": batch_df["prev_passed_avg_length"].mean(), + "train/prev_passed_avg_length_std": batch_df["prev_passed_avg_length"].std(), + "train/prev_passed_avg_length_min": batch_df["prev_passed_avg_length"].min(), + "train/prev_passed_avg_length_max": batch_df["prev_passed_avg_length"].max(), + 'train/prev_passed_80th_length_avg': batch_df["prev_passed_80th_length"].mean(), + 'train/prev_passed_80th_length_std': batch_df["prev_passed_80th_length"].std(), + 'train/prev_passed_80th_length_min': batch_df["prev_passed_80th_length"].min(), + 'train/prev_passed_80th_length_max': batch_df["prev_passed_80th_length"].max(), + 'train/prev_passed_50th_length_avg': batch_df["prev_passed_50th_length"].mean(), + 'train/prev_passed_50th_length_std': batch_df["prev_passed_50th_length"].std(), + 'train/prev_passed_50th_length_min': batch_df["prev_passed_50th_length"].min(), + 'train/prev_passed_50th_length_max': batch_df["prev_passed_50th_length"].max() + }) + + metrics["train/num_gen_batches"] = num_gen_batches + metrics['train/num_prompts'] = len(train_dataset.dataframe) + metrics['train/perct_dropped_prompts'] = 100 * ( (len(self.train_dataset.dataframe) - len(train_dataset.dataframe)) / len(self.train_dataset.dataframe)) + metrics['train/n_drop_easy'] = self.n_drop_easy if self.n_drop_easy is not None else 0 + metrics['train/n_drop_hard'] = self.n_drop_hard if self.n_drop_hard is not None else 0 + metrics['train/epoch'] = epoch + batch = None + num_prompt_in_batch = 0 + num_gen_batches = 0 + + # TODO: make a canonical logger that supports various backend + logger.log(data=metrics, step=self.global_steps) + + if is_last_step: + pprint(f"Final validation metrics: {last_val_metrics}") + progress_bar.close() + return + + progress_bar.update(1) + self.global_steps += 1 + + def _save_dataset_state(self, local_global_step_folder): + """ + Save the current dataset state including updated pass rates and lengths. + This is crucial for resuming training with enable_budget feature. + """ + if not self.config.trainer.get('enable_budget', False): + return + + dataset_state_path = os.path.join(local_global_step_folder, 'dataset_state.pt') + + # Save the current dataset state + dataset_state = { + 'dataframe': self.train_dataset.dataframe.copy(), + 'n_drop_easy': getattr(self, 'n_drop_easy', 0), + 'n_drop_hard': getattr(self, 'n_drop_hard', 0), + } + + torch.save(dataset_state, dataset_state_path) + print(f"Saved dataset state to {dataset_state_path}") + print(f" - Dataset size: {len(self.train_dataset.dataframe)}") + print(f" - Pass rate range: {self.train_dataset.dataframe['prev_pass_rate'].min():.3f} - {self.train_dataset.dataframe['prev_pass_rate'].max():.3f}") + if 'prev_passed_max_length' in self.train_dataset.dataframe.columns: + print(f" - Max length range: {self.train_dataset.dataframe['prev_passed_max_length'].min():.1f} - {self.train_dataset.dataframe['prev_passed_max_length'].max():.1f}") + + def _load_dataset_state(self, global_step_folder): + """ + Load the dataset state including updated pass rates and lengths. + This restores the learned statistics from previous training. + """ + if not self.config.trainer.get('enable_budget', False): + return + + dataset_state_path = os.path.join(global_step_folder, 'dataset_state.pt') + + if os.path.exists(dataset_state_path): + print(f"Loading dataset state from {dataset_state_path}") + dataset_state = torch.load(dataset_state_path, weights_only=False) + + # Restore dataset with updated pass rates and lengths + self.train_dataset.dataframe = dataset_state['dataframe'] + self.n_drop_easy = dataset_state.get('n_drop_easy', 0) + self.n_drop_hard = dataset_state.get('n_drop_hard', 0) + + print(f"Restored dataset state:") + print(f" - Dataset size: {len(self.train_dataset.dataframe)}") + print(f" - Pass rate range: {self.train_dataset.dataframe['prev_pass_rate'].min():.3f} - {self.train_dataset.dataframe['prev_pass_rate'].max():.3f}") + if 'prev_passed_avg_length' in self.train_dataset.dataframe.columns: + print(f" - Avg length range: {self.train_dataset.dataframe['prev_passed_avg_length'].min():.1f} - {self.train_dataset.dataframe['prev_passed_avg_length'].max():.1f}") + if 'prev_passed_max_length' in self.train_dataset.dataframe.columns: + print(f" - Max length range: {self.train_dataset.dataframe['prev_passed_max_length'].min():.1f} - {self.train_dataset.dataframe['prev_passed_max_length'].max():.1f}") + else: + print(f"No dataset state found at {dataset_state_path}, starting with original dataset") + self.n_drop_easy = 0 + self.n_drop_hard = 0 + + def _save_checkpoint(self): + """ + Override to include dataset state saving for enable_budget feature. + """ + # Call parent method to save models and dataloader + super()._save_checkpoint() + + # Save additional dataset state for enable_budget + local_global_step_folder = os.path.join(self.config.trainer.default_local_dir, f"global_step_{self.global_steps}") + self._save_dataset_state(local_global_step_folder) + + def _load_checkpoint(self): + """ + Override to include dataset state loading for enable_budget feature. + """ + # Store original global_steps to detect if we loaded from checkpoint + original_global_steps = self.global_steps + + # Call parent method to load models and dataloader + result = super()._load_checkpoint() + + # If global_steps changed, we loaded from checkpoint + if self.global_steps > original_global_steps: + checkpoint_folder = self.config.trainer.default_local_dir + if not os.path.isabs(checkpoint_folder): + working_dir = os.getcwd() + checkpoint_folder = os.path.join(working_dir, checkpoint_folder) + + # Find the same checkpoint folder that was loaded + from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path + global_step_folder = find_latest_ckpt_path(checkpoint_folder) + + if global_step_folder is not None: + self._load_dataset_state(global_step_folder) + + return result \ No newline at end of file diff --git a/recipe/dalu/m1_32b_k2think_continue_data1_math_code.sh b/recipe/dalu/m1_32b_k2think_continue_data1_math_code.sh new file mode 100644 index 000000000..66c5d0f8d --- /dev/null +++ b/recipe/dalu/m1_32b_k2think_continue_data1_math_code.sh @@ -0,0 +1,326 @@ +#!/bin/bash +#SBATCH --job-name=rl-32b-k2think-continue-data5_1_math_code +#SBATCH --nodes=32 +#SBATCH --ntasks=32 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=96 +#SBATCH --gres=gpu:8 +#SBATCH --mem=0 +#SBATCH --output=slurm/%x-%j.out +#SBATCH --error=slurm/%x-%j.err +#SBATCH --exclusive +#SBATCH --time=720:00:00 + + +# =================== Frequently Used Variables =================== +RESUME_CKPT_DIR_NAME="/lustrefs/users/haonan.li/Reasoning360/checkpoints/DALU/364222-rl-32b-k2think-continue-data5_1_math_code-K2-Think" # Fill in the checkpoint directory name to resume from, otherwise from scratch +WANDB_ID="w4ohol8m" +export STEM_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-320:8000" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain + +# =================== Cluster Environment =================== +# force IB and pick the rails explicitly +export OMPI_MCA_coll_hcoll_enable=0 \ + CUDA_DEVICE_ORDER=PCI_BUS_ID \ + NCCL_SOCKET_IFNAME=eth0 \ + UCX_TLS=rc \ + UCX_NET_DEVICES=mlx5_ib0:1 \ + NCCL_DEBUG=WARN \ + NCCL_TOPO_FILE=/opt/microsoft/ndv5-topo.xml \ + NCCL_IB_PCI_RELAXED_ORDERING=1 \ + NCCL_IB_QPS_PER_CONNECTION=4 \ + NCCL_IGNORE_CPU_AFFINITY=1 \ + NCCL_P2P_NET_CHUNKSIZE=$((512 * 1024)) \ + NCCL_PXN_DISABLE=1 \ + NCCL_MIN_NCHANNELS=32 \ + SHARP_SMX_UCX_INTERFACE=mlx5_ib0:1 \ + SHARP_COLL_ENABLE_SAT=1 \ + SHARP_COLL_LOG_LEVEL=3 \ + SHARP_COLL_ENABLE_PCI_RELAXED_ORDERING=1 \ + NCCL_COLLNET_ENABLE=1 \ + NCCL_TIMEOUT=7200 \ + NCCL_BLOCKING_WAIT=1 \ + TORCH_NCCL_TRACE_BUFFER_SIZE=1000 + + +export TRITON_HOME=/tmp/triton_cache + +# Get the list of allocated nodes +nodes=( $(scontrol show hostnames "$SLURM_JOB_NODELIST") ) +echo "Nodes to check: ${nodes[@]}" + +# We'll track PIDs so we can wait on them and detect errors +declare -A pids +export head_node=${nodes[0]} +head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) +port=6379 +address_head=$head_node_ip:$port + +export worker_num=$SLURM_NNODES +export HYDRA_FULL_ERROR=1 +export VLLM_USE_V1=0 + +# =================== Data Mixture =================== +#TRAIN_DATA_DIR=/mnt/sharefs/users/zhuojun.cheng/guru_data/train/postprocessed_dedup_am +TRAIN_DATA_DIR=/lustrefs/users/haonan.li/data/k2/train_scored_dedup_am_12k_len_rm_flipscore_score_method_5_1_datamix_6 +TEST_DATA_DIR=/lustrefs/users/haonan.li/data/k2/test_12k_len +# Math (train) +math_train1_path=${TRAIN_DATA_DIR}/math__combined_118.2k.part1.parquet +math_train2_path=${TRAIN_DATA_DIR}/math__combined_118.2k.part2.parquet +# Math (test) +math_test_path=${TEST_DATA_DIR}/math__math_500.parquet +aime25_test_path=${TEST_DATA_DIR}/math__aime2025_repeated_8x_240.parquet +aime_test_path=${TEST_DATA_DIR}/math__aime_repeated_8x_240.parquet +amc_test_path=${TEST_DATA_DIR}/math__amc_repeated_4x_332.parquet + +# Code (train) +leetcode_train_path=${TRAIN_DATA_DIR}/codegen__deduped_leetcode2k_2.4k.parquet +livecodebench_train_path=${TRAIN_DATA_DIR}/codegen__deduped_livecodebench_599.parquet +primeintellect_train_path=${TRAIN_DATA_DIR}/codegen__deduped_primeintellect_9.6k.parquet +taco_train_path=${TRAIN_DATA_DIR}/codegen__deduped_taco_11.1k.parquet +# Code (test) +humaneval_test_path=${TEST_DATA_DIR}/codegen__humaneval_164.parquet +mbpp_test_path=${TEST_DATA_DIR}/codegen__mbpp_500.parquet +livecodebench_test_path=${TEST_DATA_DIR}/codegen__livecodebench_279.parquet + +# Logic (train) +arcagi1_train_path=${TRAIN_DATA_DIR}/logic__arcagi1_297.parquet +arcagi2_train_path=${TRAIN_DATA_DIR}/logic__arcagi2_653.parquet +barc_train_path=${TRAIN_DATA_DIR}/logic__barc_3.4k.parquet +graph_train_path=${TRAIN_DATA_DIR}/logic__graph_logical_dataset_1.4k.parquet +ordering_train_path=${TRAIN_DATA_DIR}/logic__ordering_puzzle_dataset_2.9k.parquet +zebra_train_path=${TRAIN_DATA_DIR}/logic__zebra_puzzle_dataset_5.0k.parquet +reasoning_gym_train_path=${TRAIN_DATA_DIR}/logic__reasoning_gym_40.6k.parquet +synlogic_train_path=${TRAIN_DATA_DIR}/logic__synlogic_12.1k.parquet + +# Logic (test) +zebralogic_test_path=${TEST_DATA_DIR}/logic__zebra_puzzle_dataset_200.parquet +reasoning_gym_test_path=${TEST_DATA_DIR}/logic__reasoning_gym_425.parquet +reasoning_gym_large_test_path=${TEST_DATA_DIR}/logic__reasoning_gym_4.3k.parquet +synlogic_test_path=${TEST_DATA_DIR}/logic__synlogic_217.parquet +arcagi1_test_path=${TEST_DATA_DIR}/logic__arcagi1_400.parquet + +# Simulation (train) +codeio_train_path=${TRAIN_DATA_DIR}/simulation__codeio_fixed_12.1k.parquet +# Simulation (test) +codeio_test_path=${TEST_DATA_DIR}/simulation__codeio_200.parquet +cruxeval_i_test_path=${TEST_DATA_DIR}/simulation__cruxeval-i_800.parquet +cruxeval_o_test_path=${TEST_DATA_DIR}/simulation__cruxeval-o_800.parquet + +# Table (train) +hitab_train_path=${TRAIN_DATA_DIR}/table__hitab_7.4k.parquet +multihier_train_path=${TRAIN_DATA_DIR}/table__multihier_2.9k.parquet +# Table (test) +multihier_test_path=${TEST_DATA_DIR}/table__multihier_336.parquet +hitab_test_path=${TEST_DATA_DIR}/table__hitab_1k.parquet +finqa_test_path=${TEST_DATA_DIR}/table__finqa_1.1k.parquet + +# Stem (train) +webinstruct_train_path=${TRAIN_DATA_DIR}/stem__web_31.7k.parquet +nemotron_train_path=${TRAIN_DATA_DIR}/stem__nemotron_13.3k.parquet +# Stem (test) +nemotron_test_path=${TEST_DATA_DIR}/stem__nemotron_100.parquet +nemotron_large_test_path=${TEST_DATA_DIR}/stem__nemotron_10.0k.parquet + +gpqa_diamond_test_path=${TEST_DATA_DIR}/stem__gpqa_diamond_198.parquet +supergpqa_test_path=${TEST_DATA_DIR}/stem__supergpqa_1k.parquet + +# IfBench (train) +ifbench_train_path=${TRAIN_DATA_DIR}/ifbench__fixed_85.6k.parquet # There might be bug, wait for fix +# IfBench (test) +ifbench_test_path=${TEST_DATA_DIR}/ifbench_800.parquet +ifbench_large_test_path=${TEST_DATA_DIR}/ifbench_8k.parquet + +# OOD (test) +ifeval_test_path=${TEST_DATA_DIR}/ood__ifeval_100.parquet +livebench_data_analysis_test_path=${TEST_DATA_DIR}/ood__livebench_data_analysis_150.parquet +livebench_language_test_path=${TEST_DATA_DIR}/ood__livebench_language_140.parquet +livebench_reasoning_test_path=${TEST_DATA_DIR}/ood__livebench_reasoning_150.parquet + +train_files="['${math_train1_path}', '${math_train2_path}', '${leetcode_train_path}', '${livecodebench_train_path}', '${primeintellect_train_path}', '${taco_train_path}', '${arcagi1_train_path}' ]" # Use math as example, add to more tasks as needed +# test_files="['${math_train1_path}']" +test_files="['${aime25_test_path}', '${amc_test_path}', '${aime_test_path}', '${math_test_path}', '${humaneval_test_path}','${livecodebench_test_path}','${reasoning_gym_test_path}','${gpqa_diamond_test_path}','${ifeval_test_path}']" # Use math as example, add to more tasks as needed test_files="['${supergpqa_test_path}','${ifeval_test_path}']" # Use math as example, add to more tasks as needed + +# =================== Model =================== +BASE_MODEL=LLM360/K2-Think +CONDA_BIN_PATH=/lustrefs/users/haonan.li/miniconda3/envs/Reasoning360/bin/ + +# =================== Logging =================== +WANDB_PROJECT=DALU +WANDB_EXPERIMENT_NAME=${SLURM_JOB_ID}-${SLURM_JOB_NAME}-${BASE_MODEL##*/} + +# Set default local directory for checkpoints +DEFAULT_LOCAL_DIR="checkpoints/${WANDB_PROJECT}/${WANDB_EXPERIMENT_NAME}" + +# If RESUME_CKPT_DIR is not empty, resume from the checkpoint +if [[ -n "$RESUME_CKPT_DIR_NAME" ]]; then + WANDB_EXPERIMENT_NAME=$(basename "$RESUME_CKPT_DIR_NAME") + DEFAULT_LOCAL_DIR="$RESUME_CKPT_DIR_NAME" +fi + + +# =================== Ray start =================== +# ray stop at all nodes +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 ${CONDA_BIN_PATH}ray stop + +sleep 10 +# Remove existing Ray cluster +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 rm -rf /tmp/ray/ray_current_cluster + +# Start Ray head node +srun --nodes=1 --ntasks=1 -w "$head_node" --export=ALL \ + env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \ + ${CONDA_BIN_PATH}ray start --head --node-ip-address="$head_node_ip" --port=$port \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --include-dashboard=True --block & + +sleep 10 + +# Start Ray worker nodes +for ((i = 1; i < worker_num; i++)); do + node_i=${nodes[$i]} + echo "Starting WORKER $i at $node_i" + srun --nodes=1 --ntasks=1 -w "$node_i" --export=ALL \ + env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \ + ${CONDA_BIN_PATH}ray start --address "$address_head" \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --block & +done +sleep 10 + + +# =================== RL Config =================== +# Note, we borrowed the config format from DAPO while here disabled all DAPO features to run the naive RL baseline. + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=$((1024 * 4)) +max_response_length=$((1024 * 48)) +enable_overlong_buffer=False +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +enable_filter_groups=False +filter_groups_metric=acc +max_num_gen_batches=10 +train_prompt_bsz=512 # on-policy model update batchsize: train_prompt_bsz * rollout.n +gen_prompt_bsz=$((train_prompt_bsz * 1)) +n_resp_per_prompt=16 +train_prompt_mini_bsz=32 # model grad update batchsize + +# Algorithm +temperature=1.2 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout + +# Training config +sp_size=1 +gen_tp=8 +gen_max_num_seqs=1024 +infer_micro_batch_size=null +train_micro_batch_size=null +use_dynamic_bsz=True +actor_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up model forward & backward but note memory overflow +infer_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up modelforward, but note memory overflow +offload=True + +# =================== Start RL training =================== +"${CONDA_BIN_PATH}python" -m recipe.dalu.main_dalu \ + --config-path=config \ + --config-name="dapo_fsdp_config.yaml" \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.prompt_key=prompt \ + data.truncation='right' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + data.gen_batch_size=${gen_prompt_bsz} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.actor.strategy="fsdp" \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.optim.warmup_style=constant \ + actor_rollout_ref.actor.optim.min_lr_ratio=0. \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.max_num_seqs=${gen_max_num_seqs} \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${top_p}\ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.model.path=$BASE_MODEL \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.rollout.multi_turn.enable=False \ + actor_rollout_ref.rollout.mode="sync" \ + +actor_rollout_ref.model.override_config.attention_dropout=0. \ + +actor_rollout_ref.model.override_config.embd_pdrop=0. \ + +actor_rollout_ref.model.override_config.resid_pdrop=0. \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + reward_model.reward_manager=async_multi_process \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.logger=['console','wandb'] \ + trainer.project_name=${WANDB_PROJECT} \ + trainer.experiment_name=${WANDB_EXPERIMENT_NAME} \ + trainer.val_before_train=False \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=$worker_num \ + trainer.save_freq=10 \ + trainer.test_freq=10 \ + trainer.total_epochs=10 \ + trainer.log_val_generations=1 \ + trainer.resume_mode=auto \ + trainer.max_actor_ckpt_to_keep=2 \ + trainer.default_local_dir="${DEFAULT_LOCAL_DIR}" \ + +trainer.run_id=${WANDB_ID} \ + +trainer.enable_budget=True \ + +data.dynamic_filtering=True \ + +data.pass_rate_upper_bound=0.9 \ + +data.initial_pass_rate_column=deepseek_r1_0528_pass_rate diff --git a/recipe/dalu/m1_32b_k2think_continue_data2_all.sh b/recipe/dalu/m1_32b_k2think_continue_data2_all.sh new file mode 100644 index 000000000..6f2d177d6 --- /dev/null +++ b/recipe/dalu/m1_32b_k2think_continue_data2_all.sh @@ -0,0 +1,325 @@ +#!/bin/bash +#SBATCH --job-name=rl-32b-k2think-continue-data5_2_all +#SBATCH --nodes=32 +#SBATCH --ntasks=32 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=96 +#SBATCH --gres=gpu:8 +#SBATCH --mem=0 +#SBATCH --output=slurm/%x-%j.out +#SBATCH --error=slurm/%x-%j.err +#SBATCH --exclusive +#SBATCH --time=720:00:00 + + +# =================== Frequently Used Variables =================== +RESUME_CKPT_DIR_NAME="" # Fill in the checkpoint directory name to resume from, otherwise from scratch +WANDB_ID="" +export STEM_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-320:8000" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain + +# =================== Cluster Environment =================== +# force IB and pick the rails explicitly +export OMPI_MCA_coll_hcoll_enable=0 \ + CUDA_DEVICE_ORDER=PCI_BUS_ID \ + NCCL_SOCKET_IFNAME=eth0 \ + UCX_TLS=rc \ + UCX_NET_DEVICES=mlx5_ib0:1 \ + NCCL_DEBUG=WARN \ + NCCL_TOPO_FILE=/opt/microsoft/ndv5-topo.xml \ + NCCL_IB_PCI_RELAXED_ORDERING=1 \ + NCCL_IB_QPS_PER_CONNECTION=4 \ + NCCL_IGNORE_CPU_AFFINITY=1 \ + NCCL_P2P_NET_CHUNKSIZE=$((512 * 1024)) \ + NCCL_PXN_DISABLE=1 \ + NCCL_MIN_NCHANNELS=32 \ + SHARP_SMX_UCX_INTERFACE=mlx5_ib0:1 \ + SHARP_COLL_ENABLE_SAT=1 \ + SHARP_COLL_LOG_LEVEL=3 \ + SHARP_COLL_ENABLE_PCI_RELAXED_ORDERING=1 \ + NCCL_COLLNET_ENABLE=1 \ + NCCL_TIMEOUT=7200 \ + NCCL_BLOCKING_WAIT=1 \ + TORCH_NCCL_TRACE_BUFFER_SIZE=1000 + + +export TRITON_HOME=/tmp/triton_cache + +# Get the list of allocated nodes +nodes=( $(scontrol show hostnames "$SLURM_JOB_NODELIST") ) +echo "Nodes to check: ${nodes[@]}" + +# We'll track PIDs so we can wait on them and detect errors +declare -A pids +export head_node=${nodes[0]} +head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) +port=6379 +address_head=$head_node_ip:$port + +export worker_num=$SLURM_NNODES +export HYDRA_FULL_ERROR=1 +export VLLM_USE_V1=0 + +# =================== Data Mixture =================== +#TRAIN_DATA_DIR=/mnt/sharefs/users/zhuojun.cheng/guru_data/train/postprocessed_dedup_am +TRAIN_DATA_DIR=/lustrefs/users/haonan.li/data/k2/train_scored_dedup_am_12k_len_rm_flipscore_score_method_5_2_datamix_6 +TEST_DATA_DIR=/lustrefs/users/haonan.li/data/k2/test_12k_len +# Math (train) +math_train1_path=${TRAIN_DATA_DIR}/math__combined_118.2k.part1.parquet +math_train2_path=${TRAIN_DATA_DIR}/math__combined_118.2k.part2.parquet +# Math (test) +math_test_path=${TEST_DATA_DIR}/math__math_500.parquet +aime25_test_path=${TEST_DATA_DIR}/math__aime2025_repeated_8x_240.parquet +aime_test_path=${TEST_DATA_DIR}/math__aime_repeated_8x_240.parquet +amc_test_path=${TEST_DATA_DIR}/math__amc_repeated_4x_332.parquet + +# Code (train) +leetcode_train_path=${TRAIN_DATA_DIR}/codegen__deduped_leetcode2k_2.4k.parquet +livecodebench_train_path=${TRAIN_DATA_DIR}/codegen__deduped_livecodebench_599.parquet +primeintellect_train_path=${TRAIN_DATA_DIR}/codegen__deduped_primeintellect_9.6k.parquet +taco_train_path=${TRAIN_DATA_DIR}/codegen__deduped_taco_11.1k.parquet +# Code (test) +humaneval_test_path=${TEST_DATA_DIR}/codegen__humaneval_164.parquet +mbpp_test_path=${TEST_DATA_DIR}/codegen__mbpp_500.parquet +livecodebench_test_path=${TEST_DATA_DIR}/codegen__livecodebench_279.parquet + +# Logic (train) +arcagi1_train_path=${TRAIN_DATA_DIR}/logic__arcagi1_297.parquet +arcagi2_train_path=${TRAIN_DATA_DIR}/logic__arcagi2_653.parquet +barc_train_path=${TRAIN_DATA_DIR}/logic__barc_3.4k.parquet +graph_train_path=${TRAIN_DATA_DIR}/logic__graph_logical_dataset_1.4k.parquet +ordering_train_path=${TRAIN_DATA_DIR}/logic__ordering_puzzle_dataset_2.9k.parquet +zebra_train_path=${TRAIN_DATA_DIR}/logic__zebra_puzzle_dataset_5.0k.parquet +reasoning_gym_train_path=${TRAIN_DATA_DIR}/logic__reasoning_gym_40.6k.parquet +synlogic_train_path=${TRAIN_DATA_DIR}/logic__synlogic_12.1k.parquet + +# Logic (test) +zebralogic_test_path=${TEST_DATA_DIR}/logic__zebra_puzzle_dataset_200.parquet +reasoning_gym_test_path=${TEST_DATA_DIR}/logic__reasoning_gym_425.parquet +reasoning_gym_large_test_path=${TEST_DATA_DIR}/logic__reasoning_gym_4.3k.parquet +synlogic_test_path=${TEST_DATA_DIR}/logic__synlogic_217.parquet +arcagi1_test_path=${TEST_DATA_DIR}/logic__arcagi1_400.parquet + +# Simulation (train) +codeio_train_path=${TRAIN_DATA_DIR}/simulation__codeio_fixed_12.1k.parquet +# Simulation (test) +codeio_test_path=${TEST_DATA_DIR}/simulation__codeio_200.parquet +cruxeval_i_test_path=${TEST_DATA_DIR}/simulation__cruxeval-i_800.parquet +cruxeval_o_test_path=${TEST_DATA_DIR}/simulation__cruxeval-o_800.parquet + +# Table (train) +hitab_train_path=${TRAIN_DATA_DIR}/table__hitab_7.4k.parquet +multihier_train_path=${TRAIN_DATA_DIR}/table__multihier_2.9k.parquet +# Table (test) +multihier_test_path=${TEST_DATA_DIR}/table__multihier_336.parquet +hitab_test_path=${TEST_DATA_DIR}/table__hitab_1k.parquet +finqa_test_path=${TEST_DATA_DIR}/table__finqa_1.1k.parquet + +# Stem (train) +webinstruct_train_path=${TRAIN_DATA_DIR}/stem__web_31.7k.parquet +nemotron_train_path=${TRAIN_DATA_DIR}/stem__nemotron_13.3k.parquet +# Stem (test) +nemotron_test_path=${TEST_DATA_DIR}/stem__nemotron_100.parquet +nemotron_large_test_path=${TEST_DATA_DIR}/stem__nemotron_10.0k.parquet + +gpqa_diamond_test_path=${TEST_DATA_DIR}/stem__gpqa_diamond_198.parquet +supergpqa_test_path=${TEST_DATA_DIR}/stem__supergpqa_1k.parquet + +# IfBench (train) +ifbench_train_path=${TRAIN_DATA_DIR}/ifbench__fixed_85.6k.parquet # There might be bug, wait for fix +# IfBench (test) +ifbench_test_path=${TEST_DATA_DIR}/ifbench_800.parquet +ifbench_large_test_path=${TEST_DATA_DIR}/ifbench_8k.parquet + +# OOD (test) +ifeval_test_path=${TEST_DATA_DIR}/ood__ifeval_100.parquet +livebench_data_analysis_test_path=${TEST_DATA_DIR}/ood__livebench_data_analysis_150.parquet +livebench_language_test_path=${TEST_DATA_DIR}/ood__livebench_language_140.parquet +livebench_reasoning_test_path=${TEST_DATA_DIR}/ood__livebench_reasoning_150.parquet + +train_files="['${math_train1_path}', '${math_train2_path}', '${leetcode_train_path}', '${livecodebench_train_path}', '${primeintellect_train_path}', '${taco_train_path}', '${arcagi1_train_path}', '${arcagi2_train_path}', '${barc_train_path}', '${graph_train_path}', '${ordering_train_path}', '${zebra_train_path}', '${reasoning_gym_train_path}', '${synlogic_train_path}', '${codeio_train_path}', '${hitab_train_path}', '${multihier_train_path}', '${webinstruct_train_path}', '${nemotron_train_path}']" # Use math as example, add to more tasks as needed +# test_files="['${math_train1_path}']" +test_files="['${aime25_test_path}', '${amc_test_path}', '${aime_test_path}', '${math_test_path}', '${humaneval_test_path}','${livecodebench_test_path}','${zebralogic_test_path}','${reasoning_gym_test_path}','${synlogic_test_path}','${multihier_test_path}','${nemotron_test_path}','${gpqa_diamond_test_path}','${ifeval_test_path}']" # Use math as example, add to more tasks as needed test_files="['${supergpqa_test_path}','${ifeval_test_path}']" # Use math as example, add to more tasks as needed + +# =================== Model =================== +BASE_MODEL=LLM360/K2-Think +CONDA_BIN_PATH=/lustrefs/users/haonan.li/miniconda3/envs/Reasoning360/bin/ + +# =================== Logging =================== +WANDB_PROJECT=DALU +WANDB_EXPERIMENT_NAME=${SLURM_JOB_ID}-${SLURM_JOB_NAME}-${BASE_MODEL##*/} + +# Set default local directory for checkpoints +DEFAULT_LOCAL_DIR="checkpoints/${WANDB_PROJECT}/${WANDB_EXPERIMENT_NAME}" + +# If RESUME_CKPT_DIR is not empty, resume from the checkpoint +if [[ -n "$RESUME_CKPT_DIR_NAME" ]]; then + WANDB_EXPERIMENT_NAME=$(basename "$RESUME_CKPT_DIR_NAME") + DEFAULT_LOCAL_DIR="$RESUME_CKPT_DIR_NAME" +fi + + +# =================== Ray start =================== +# ray stop at all nodes +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 ${CONDA_BIN_PATH}ray stop + +sleep 10 +# Remove existing Ray cluster +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 rm -rf /tmp/ray/ray_current_cluster + +# Start Ray head node +srun --nodes=1 --ntasks=1 -w "$head_node" --export=ALL \ + env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \ + ${CONDA_BIN_PATH}ray start --head --node-ip-address="$head_node_ip" --port=$port \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --include-dashboard=True --block & + +sleep 10 + +# Start Ray worker nodes +for ((i = 1; i < worker_num; i++)); do + node_i=${nodes[$i]} + echo "Starting WORKER $i at $node_i" + srun --nodes=1 --ntasks=1 -w "$node_i" --export=ALL \ + env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \ + ${CONDA_BIN_PATH}ray start --address "$address_head" \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --block & +done +sleep 10 + + +# =================== RL Config =================== +# Note, we borrowed the config format from DAPO while here disabled all DAPO features to run the naive RL baseline. + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=$((1024 * 4)) +max_response_length=$((1024 * 48)) +enable_overlong_buffer=False +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +enable_filter_groups=False +filter_groups_metric=acc +max_num_gen_batches=10 +train_prompt_bsz=512 # on-policy model update batchsize: train_prompt_bsz * rollout.n +gen_prompt_bsz=$((train_prompt_bsz * 1)) +n_resp_per_prompt=16 +train_prompt_mini_bsz=32 # model grad update batchsize + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout + +# Training config +sp_size=1 +gen_tp=8 +gen_max_num_seqs=1024 +infer_micro_batch_size=null +train_micro_batch_size=null +use_dynamic_bsz=True +actor_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up model forward & backward but note memory overflow +infer_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up modelforward, but note memory overflow +offload=True + +# =================== Start RL training =================== +"${CONDA_BIN_PATH}python" -m recipe.dalu.main_dalu \ + --config-path=config \ + --config-name="dapo_fsdp_config.yaml" \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.prompt_key=prompt \ + data.truncation='right' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + data.gen_batch_size=${gen_prompt_bsz} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.actor.strategy="fsdp" \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.optim.warmup_style=constant \ + actor_rollout_ref.actor.optim.min_lr_ratio=0. \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.max_num_seqs=${gen_max_num_seqs} \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${top_p}\ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.model.path=$BASE_MODEL \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.rollout.multi_turn.enable=False \ + actor_rollout_ref.rollout.mode="sync" \ + +actor_rollout_ref.model.override_config.attention_dropout=0. \ + +actor_rollout_ref.model.override_config.embd_pdrop=0. \ + +actor_rollout_ref.model.override_config.resid_pdrop=0. \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + reward_model.reward_manager=async_multi_process \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.logger=['console','wandb'] \ + trainer.project_name=${WANDB_PROJECT} \ + trainer.experiment_name=${WANDB_EXPERIMENT_NAME} \ + trainer.val_before_train=False \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=$worker_num \ + trainer.save_freq=10 \ + trainer.test_freq=10 \ + trainer.total_epochs=10 \ + trainer.log_val_generations=1 \ + trainer.resume_mode=auto \ + trainer.max_actor_ckpt_to_keep=2 \ + trainer.default_local_dir="${DEFAULT_LOCAL_DIR}" \ + +trainer.enable_budget=True \ + +data.dynamic_filtering=True \ + +data.pass_rate_upper_bound=0.9 \ + +data.initial_pass_rate_column=deepseek_r1_0528_pass_rate diff --git a/recipe/dalu/m1_32b_k2think_continue_data2_math_code.sh b/recipe/dalu/m1_32b_k2think_continue_data2_math_code.sh new file mode 100644 index 000000000..0ac48b224 --- /dev/null +++ b/recipe/dalu/m1_32b_k2think_continue_data2_math_code.sh @@ -0,0 +1,327 @@ +#!/bin/bash +#SBATCH --job-name=rl-32b-k2think-continue-data5_2 +#SBATCH --nodes=32 +#SBATCH --ntasks=32 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=96 +#SBATCH --gres=gpu:8 +#SBATCH --mem=0 +#SBATCH --output=slurm/%x-%j.out +#SBATCH --error=slurm/%x-%j.err +#SBATCH --exclusive +#SBATCH --time=720:00:00 + + +# =================== Frequently Used Variables =================== +RESUME_CKPT_DIR_NAME="/lustrefs/users/haonan.li/Reasoning360/checkpoints/DALU/363947-rl-32b-k2think-continue-data5_2-K2-Think" # Fill in the checkpoint directory name to resume from, otherwise from scratch +WANDB_ID="lyl3mkni" +export STEM_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-320:8000" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain + +# =================== Cluster Environment =================== +# force IB and pick the rails explicitly +export OMPI_MCA_coll_hcoll_enable=0 \ + CUDA_DEVICE_ORDER=PCI_BUS_ID \ + NCCL_SOCKET_IFNAME=eth0 \ + UCX_TLS=rc \ + UCX_NET_DEVICES=mlx5_ib0:1 \ + NCCL_DEBUG=WARN \ + NCCL_TOPO_FILE=/opt/microsoft/ndv5-topo.xml \ + NCCL_IB_PCI_RELAXED_ORDERING=1 \ + NCCL_IB_QPS_PER_CONNECTION=4 \ + NCCL_IGNORE_CPU_AFFINITY=1 \ + NCCL_P2P_NET_CHUNKSIZE=$((512 * 1024)) \ + NCCL_PXN_DISABLE=1 \ + NCCL_MIN_NCHANNELS=32 \ + SHARP_SMX_UCX_INTERFACE=mlx5_ib0:1 \ + SHARP_COLL_ENABLE_SAT=1 \ + SHARP_COLL_LOG_LEVEL=3 \ + SHARP_COLL_ENABLE_PCI_RELAXED_ORDERING=1 \ + NCCL_COLLNET_ENABLE=1 \ + NCCL_TIMEOUT=7200 \ + NCCL_BLOCKING_WAIT=1 \ + TORCH_NCCL_TRACE_BUFFER_SIZE=1000 + + +export TRITON_HOME=/tmp/triton_cache + +# Get the list of allocated nodes +nodes=( $(scontrol show hostnames "$SLURM_JOB_NODELIST") ) +echo "Nodes to check: ${nodes[@]}" + +# We'll track PIDs so we can wait on them and detect errors +declare -A pids +export head_node=${nodes[0]} +head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) +port=6379 +address_head=$head_node_ip:$port + +export worker_num=$SLURM_NNODES +export HYDRA_FULL_ERROR=1 +export VLLM_USE_V1=0 + +# =================== Data Mixture =================== +#TRAIN_DATA_DIR=/mnt/sharefs/users/zhuojun.cheng/guru_data/train/postprocessed_dedup_am +TRAIN_DATA_DIR=/lustrefs/users/haonan.li/data/k2/train_scored_dedup_am_12k_len_rm_flipscore_score_method_5_2_datamix_6 +TEST_DATA_DIR=/lustrefs/users/haonan.li/data/k2/test_12k_len +# Math (train) +math_train1_path=${TRAIN_DATA_DIR}/math__combined_118.2k.part1.parquet +math_train2_path=${TRAIN_DATA_DIR}/math__combined_118.2k.part2.parquet +# Math (test) +math_test_path=${TEST_DATA_DIR}/math__math_500.parquet +aime25_test_path=${TEST_DATA_DIR}/math__aime2025_repeated_8x_240.parquet +aime_test_path=${TEST_DATA_DIR}/math__aime_repeated_8x_240.parquet +amc_test_path=${TEST_DATA_DIR}/math__amc_repeated_4x_332.parquet + +# Code (train) +leetcode_train_path=${TRAIN_DATA_DIR}/codegen__deduped_leetcode2k_2.4k.parquet +livecodebench_train_path=${TRAIN_DATA_DIR}/codegen__deduped_livecodebench_599.parquet +primeintellect_train_path=${TRAIN_DATA_DIR}/codegen__deduped_primeintellect_9.6k.parquet +taco_train_path=${TRAIN_DATA_DIR}/codegen__deduped_taco_11.1k.parquet +# Code (test) +humaneval_test_path=${TEST_DATA_DIR}/codegen__humaneval_164.parquet +mbpp_test_path=${TEST_DATA_DIR}/codegen__mbpp_500.parquet +livecodebench_test_path=${TEST_DATA_DIR}/codegen__livecodebench_279.parquet + +# Logic (train) +arcagi1_train_path=${TRAIN_DATA_DIR}/logic__arcagi1_297.parquet +arcagi2_train_path=${TRAIN_DATA_DIR}/logic__arcagi2_653.parquet +barc_train_path=${TRAIN_DATA_DIR}/logic__barc_3.4k.parquet +graph_train_path=${TRAIN_DATA_DIR}/logic__graph_logical_dataset_1.4k.parquet +ordering_train_path=${TRAIN_DATA_DIR}/logic__ordering_puzzle_dataset_2.9k.parquet +zebra_train_path=${TRAIN_DATA_DIR}/logic__zebra_puzzle_dataset_5.0k.parquet +reasoning_gym_train_path=${TRAIN_DATA_DIR}/logic__reasoning_gym_40.6k.parquet +synlogic_train_path=${TRAIN_DATA_DIR}/logic__synlogic_12.1k.parquet + +# Logic (test) +zebralogic_test_path=${TEST_DATA_DIR}/logic__zebra_puzzle_dataset_200.parquet +reasoning_gym_test_path=${TEST_DATA_DIR}/logic__reasoning_gym_425.parquet +reasoning_gym_large_test_path=${TEST_DATA_DIR}/logic__reasoning_gym_4.3k.parquet +synlogic_test_path=${TEST_DATA_DIR}/logic__synlogic_217.parquet +arcagi1_test_path=${TEST_DATA_DIR}/logic__arcagi1_400.parquet + +# Simulation (train) +codeio_train_path=${TRAIN_DATA_DIR}/simulation__codeio_fixed_12.1k.parquet +# Simulation (test) +codeio_test_path=${TEST_DATA_DIR}/simulation__codeio_200.parquet +cruxeval_i_test_path=${TEST_DATA_DIR}/simulation__cruxeval-i_800.parquet +cruxeval_o_test_path=${TEST_DATA_DIR}/simulation__cruxeval-o_800.parquet + +# Table (train) +hitab_train_path=${TRAIN_DATA_DIR}/table__hitab_7.4k.parquet +multihier_train_path=${TRAIN_DATA_DIR}/table__multihier_2.9k.parquet +# Table (test) +multihier_test_path=${TEST_DATA_DIR}/table__multihier_336.parquet +hitab_test_path=${TEST_DATA_DIR}/table__hitab_1k.parquet +finqa_test_path=${TEST_DATA_DIR}/table__finqa_1.1k.parquet + +# Stem (train) +webinstruct_train_path=${TRAIN_DATA_DIR}/stem__web_31.7k.parquet +nemotron_train_path=${TRAIN_DATA_DIR}/stem__nemotron_13.3k.parquet +# Stem (test) +nemotron_test_path=${TEST_DATA_DIR}/stem__nemotron_100.parquet +nemotron_large_test_path=${TEST_DATA_DIR}/stem__nemotron_10.0k.parquet + +gpqa_diamond_test_path=${TEST_DATA_DIR}/stem__gpqa_diamond_198.parquet +supergpqa_test_path=${TEST_DATA_DIR}/stem__supergpqa_1k.parquet + +# IfBench (train) +ifbench_train_path=${TRAIN_DATA_DIR}/ifbench__fixed_85.6k.parquet # There might be bug, wait for fix +# IfBench (test) +ifbench_test_path=${TEST_DATA_DIR}/ifbench_800.parquet +ifbench_large_test_path=${TEST_DATA_DIR}/ifbench_8k.parquet + +# OOD (test) +ifeval_test_path=${TEST_DATA_DIR}/ood__ifeval_100.parquet +livebench_data_analysis_test_path=${TEST_DATA_DIR}/ood__livebench_data_analysis_150.parquet +livebench_language_test_path=${TEST_DATA_DIR}/ood__livebench_language_140.parquet +livebench_reasoning_test_path=${TEST_DATA_DIR}/ood__livebench_reasoning_150.parquet + +train_files="['${math_train1_path}', '${math_train2_path}', '${leetcode_train_path}', '${livecodebench_train_path}', '${primeintellect_train_path}', '${taco_train_path}', '${arcagi1_train_path}', '${arcagi2_train_path}', '${barc_train_path}', '${graph_train_path}', '${zebra_train_path}', '${reasoning_gym_train_path}', '${codeio_train_path}']" # Use math as example, add to more tasks as needed +# test_files="['${math_train1_path}']" +test_files="['${aime25_test_path}', '${amc_test_path}', '${aime_test_path}', '${math_test_path}', '${humaneval_test_path}','${livecodebench_test_path}','${zebralogic_test_path}','${reasoning_gym_test_path}','${gpqa_diamond_test_path}','${ifeval_test_path}']" # Use math as example, add to more tasks as needed +# test_files="['${supergpqa_test_path}','${ifeval_test_path}']" # Use math as example, add to more tasks as needed + +# =================== Model =================== +BASE_MODEL=LLM360/K2-Think +CONDA_BIN_PATH=/lustrefs/users/haonan.li/miniconda3/envs/Reasoning360/bin/ + +# =================== Logging =================== +WANDB_PROJECT=DALU +WANDB_EXPERIMENT_NAME=${SLURM_JOB_ID}-${SLURM_JOB_NAME}-${BASE_MODEL##*/} + +# Set default local directory for checkpoints +DEFAULT_LOCAL_DIR="checkpoints/${WANDB_PROJECT}/${WANDB_EXPERIMENT_NAME}" + +# If RESUME_CKPT_DIR is not empty, resume from the checkpoint +if [[ -n "$RESUME_CKPT_DIR_NAME" ]]; then + WANDB_EXPERIMENT_NAME=$(basename "$RESUME_CKPT_DIR_NAME") + DEFAULT_LOCAL_DIR="$RESUME_CKPT_DIR_NAME" +fi + + +# =================== Ray start =================== +# ray stop at all nodes +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 ${CONDA_BIN_PATH}ray stop + +sleep 10 +# Remove existing Ray cluster +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 rm -rf /tmp/ray/ray_current_cluster + +# Start Ray head node +srun --nodes=1 --ntasks=1 -w "$head_node" --export=ALL \ + env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \ + ${CONDA_BIN_PATH}ray start --head --node-ip-address="$head_node_ip" --port=$port \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --include-dashboard=True --block & + +sleep 10 + +# Start Ray worker nodes +for ((i = 1; i < worker_num; i++)); do + node_i=${nodes[$i]} + echo "Starting WORKER $i at $node_i" + srun --nodes=1 --ntasks=1 -w "$node_i" --export=ALL \ + env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \ + ${CONDA_BIN_PATH}ray start --address "$address_head" \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --block & +done +sleep 10 + + +# =================== RL Config =================== +# Note, we borrowed the config format from DAPO while here disabled all DAPO features to run the naive RL baseline. + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=$((1024 * 4)) +max_response_length=$((1024 * 48)) +enable_overlong_buffer=False +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +enable_filter_groups=False +filter_groups_metric=acc +max_num_gen_batches=10 +train_prompt_bsz=512 # on-policy model update batchsize: train_prompt_bsz * rollout.n +gen_prompt_bsz=$((train_prompt_bsz * 1)) +n_resp_per_prompt=16 +train_prompt_mini_bsz=32 # model grad update batchsize + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout + +# Training config +sp_size=1 +gen_tp=8 +gen_max_num_seqs=1024 +infer_micro_batch_size=null +train_micro_batch_size=null +use_dynamic_bsz=True +actor_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up model forward & backward but note memory overflow +infer_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up modelforward, but note memory overflow +offload=True + +# =================== Start RL training =================== +"${CONDA_BIN_PATH}python" -m recipe.dalu.main_dalu \ + --config-path=config \ + --config-name="dapo_fsdp_config.yaml" \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.prompt_key=prompt \ + data.truncation='right' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + data.gen_batch_size=${gen_prompt_bsz} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.actor.strategy="fsdp" \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.optim.warmup_style=constant \ + actor_rollout_ref.actor.optim.min_lr_ratio=0. \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.max_num_seqs=${gen_max_num_seqs} \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${top_p}\ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.model.path=$BASE_MODEL \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.rollout.multi_turn.enable=False \ + actor_rollout_ref.rollout.mode="sync" \ + +actor_rollout_ref.model.override_config.attention_dropout=0. \ + +actor_rollout_ref.model.override_config.embd_pdrop=0. \ + +actor_rollout_ref.model.override_config.resid_pdrop=0. \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + reward_model.reward_manager=async_multi_process \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.logger=['console','wandb'] \ + trainer.project_name=${WANDB_PROJECT} \ + trainer.experiment_name=${WANDB_EXPERIMENT_NAME} \ + trainer.val_before_train=False \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=$worker_num \ + trainer.save_freq=10 \ + trainer.test_freq=10 \ + trainer.total_epochs=10 \ + trainer.log_val_generations=1 \ + trainer.resume_mode=auto \ + +trainer.run_id=${WANDB_ID} \ + trainer.max_actor_ckpt_to_keep=2 \ + trainer.default_local_dir="${DEFAULT_LOCAL_DIR}" \ + +trainer.enable_budget=True \ + +data.dynamic_filtering=True \ + +data.pass_rate_upper_bound=0.9 \ + +data.initial_pass_rate_column=deepseek_r1_0528_pass_rate diff --git a/recipe/dalu/m1_dalu_70b_data5_2_am.sh b/recipe/dalu/m1_dalu_70b_data5_2_am.sh new file mode 100644 index 000000000..13b0b49b0 --- /dev/null +++ b/recipe/dalu/m1_dalu_70b_data5_2_am.sh @@ -0,0 +1,327 @@ +#!/bin/bash +#SBATCH --job-name=rl-70b-am-data_5_2_mix6 +#SBATCH --nodes=16 +#SBATCH --ntasks=16 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=96 +#SBATCH --gres=gpu:8 +#SBATCH --mem=0 +#SBATCH --output=slurm/%x-%j.out +#SBATCH --error=slurm/%x-%j.err +#SBATCH --exclusive +#SBATCH --time=720:00:00 + + +# =================== Frequently Used Variables =================== +RESUME_CKPT_DIR_NAME="" # Fill in the checkpoint directory name to resume from, otherwise from scratch +WANDB_RUN_ID="" +export STEM_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-320:8000" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain + +# =================== Cluster Environment =================== +# force IB and pick the rails explicitly +export OMPI_MCA_coll_hcoll_enable=0 \ + CUDA_DEVICE_ORDER=PCI_BUS_ID \ + NCCL_SOCKET_IFNAME=eth0 \ + UCX_TLS=rc \ + UCX_NET_DEVICES=mlx5_ib0:1 \ + NCCL_DEBUG=WARN \ + NCCL_TOPO_FILE=/opt/microsoft/ndv5-topo.xml \ + NCCL_IB_PCI_RELAXED_ORDERING=1 \ + NCCL_IB_QPS_PER_CONNECTION=4 \ + NCCL_IGNORE_CPU_AFFINITY=1 \ + NCCL_P2P_NET_CHUNKSIZE=$((512 * 1024)) \ + NCCL_PXN_DISABLE=1 \ + NCCL_MIN_NCHANNELS=32 \ + SHARP_SMX_UCX_INTERFACE=mlx5_ib0:1 \ + SHARP_COLL_ENABLE_SAT=1 \ + SHARP_COLL_LOG_LEVEL=3 \ + SHARP_COLL_ENABLE_PCI_RELAXED_ORDERING=1 \ + NCCL_COLLNET_ENABLE=1 \ + NCCL_TIMEOUT=7200 \ + NCCL_BLOCKING_WAIT=1 \ + TORCH_NCCL_TRACE_BUFFER_SIZE=1000 + + +export TRITON_HOME=/tmp/triton_cache + +# Get the list of allocated nodes +nodes=( $(scontrol show hostnames "$SLURM_JOB_NODELIST") ) +echo "Nodes to check: ${nodes[@]}" + +# We'll track PIDs so we can wait on them and detect errors +declare -A pids +export head_node=${nodes[0]} +head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) +port=6379 +address_head=$head_node_ip:$port + +export worker_num=$SLURM_NNODES +export HYDRA_FULL_ERROR=1 +export VLLM_USE_V1=0 + +# =================== Data Mixture =================== +#TRAIN_DATA_DIR=/mnt/sharefs/users/zhuojun.cheng/guru_data/train/postprocessed_dedup_am +TRAIN_DATA_DIR=/lustrefs/users/haonan.li/data/k2/train_scored_dedup_am_12k_len_rm_flipscore_score_method_5_2_datamix_6 +TEST_DATA_DIR=/lustrefs/users/haonan.li/data/k2/test_12k_len +# Math (train) +math_train1_path=${TRAIN_DATA_DIR}/math__combined_118.2k.part1.parquet +math_train2_path=${TRAIN_DATA_DIR}/math__combined_118.2k.part2.parquet +# Math (test) +math_test_path=${TEST_DATA_DIR}/math__math_500.parquet +aime25_test_path=${TEST_DATA_DIR}/math__aime2025_repeated_8x_240.parquet +aime_test_path=${TEST_DATA_DIR}/math__aime_repeated_8x_240.parquet +amc_test_path=${TEST_DATA_DIR}/math__amc_repeated_4x_332.parquet + +# Code (train) +leetcode_train_path=${TRAIN_DATA_DIR}/codegen__deduped_leetcode2k_2.4k.parquet +livecodebench_train_path=${TRAIN_DATA_DIR}/codegen__deduped_livecodebench_599.parquet +primeintellect_train_path=${TRAIN_DATA_DIR}/codegen__deduped_primeintellect_9.6k.parquet +taco_train_path=${TRAIN_DATA_DIR}/codegen__deduped_taco_11.1k.parquet +# Code (test) +humaneval_test_path=${TEST_DATA_DIR}/codegen__humaneval_164.parquet +mbpp_test_path=${TEST_DATA_DIR}/codegen__mbpp_500.parquet +livecodebench_test_path=${TEST_DATA_DIR}/codegen__livecodebench_279.parquet + +# Logic (train) +arcagi1_train_path=${TRAIN_DATA_DIR}/logic__arcagi1_297.parquet +arcagi2_train_path=${TRAIN_DATA_DIR}/logic__arcagi2_653.parquet +barc_train_path=${TRAIN_DATA_DIR}/logic__barc_3.4k.parquet +graph_train_path=${TRAIN_DATA_DIR}/logic__graph_logical_dataset_1.4k.parquet +ordering_train_path=${TRAIN_DATA_DIR}/logic__ordering_puzzle_dataset_2.9k.parquet +zebra_train_path=${TRAIN_DATA_DIR}/logic__zebra_puzzle_dataset_5.0k.parquet +reasoning_gym_train_path=${TRAIN_DATA_DIR}/logic__reasoning_gym_40.6k.parquet +synlogic_train_path=${TRAIN_DATA_DIR}/logic__synlogic_12.1k.parquet + +# Logic (test) +zebralogic_test_path=${TEST_DATA_DIR}/logic__zebra_puzzle_dataset_200.parquet +reasoning_gym_test_path=${TEST_DATA_DIR}/logic__reasoning_gym_425.parquet +reasoning_gym_large_test_path=${TEST_DATA_DIR}/logic__reasoning_gym_4.3k.parquet +synlogic_test_path=${TEST_DATA_DIR}/logic__synlogic_217.parquet +arcagi1_test_path=${TEST_DATA_DIR}/logic__arcagi1_400.parquet + +# Simulation (train) +codeio_train_path=${TRAIN_DATA_DIR}/simulation__codeio_fixed_12.1k.parquet +# Simulation (test) +codeio_test_path=${TEST_DATA_DIR}/simulation__codeio_200.parquet +cruxeval_i_test_path=${TEST_DATA_DIR}/simulation__cruxeval-i_800.parquet +cruxeval_o_test_path=${TEST_DATA_DIR}/simulation__cruxeval-o_800.parquet + +# Table (train) +hitab_train_path=${TRAIN_DATA_DIR}/table__hitab_7.4k.parquet +multihier_train_path=${TRAIN_DATA_DIR}/table__multihier_2.9k.parquet +# Table (test) +multihier_test_path=${TEST_DATA_DIR}/table__multihier_336.parquet +hitab_test_path=${TEST_DATA_DIR}/table__hitab_1k.parquet +finqa_test_path=${TEST_DATA_DIR}/table__finqa_1.1k.parquet + +# Stem (train) +webinstruct_train_path=${TRAIN_DATA_DIR}/stem__web_31.7k.parquet +nemotron_train_path=${TRAIN_DATA_DIR}/stem__nemotron_13.3k.parquet +# Stem (test) +nemotron_test_path=${TEST_DATA_DIR}/stem__nemotron_100.parquet +nemotron_large_test_path=${TEST_DATA_DIR}/stem__nemotron_10.0k.parquet + +gpqa_diamond_test_path=${TEST_DATA_DIR}/stem__gpqa_diamond_198.parquet +supergpqa_test_path=${TEST_DATA_DIR}/stem__supergpqa_1k.parquet + +# IfBench (train) +ifbench_train_path=${TRAIN_DATA_DIR}/ifbench__fixed_85.6k.parquet # There might be bug, wait for fix +# IfBench (test) +ifbench_test_path=${TEST_DATA_DIR}/ifbench_800.parquet +ifbench_large_test_path=${TEST_DATA_DIR}/ifbench_8k.parquet + +# OOD (test) +ifeval_test_path=${TEST_DATA_DIR}/ood__ifeval_100.parquet +livebench_data_analysis_test_path=${TEST_DATA_DIR}/ood__livebench_data_analysis_150.parquet +livebench_language_test_path=${TEST_DATA_DIR}/ood__livebench_language_140.parquet +livebench_reasoning_test_path=${TEST_DATA_DIR}/ood__livebench_reasoning_150.parquet + +train_files="['${math_train1_path}', '${math_train2_path}', '${leetcode_train_path}', '${livecodebench_train_path}', '${primeintellect_train_path}', '${taco_train_path}', '${arcagi1_train_path}', '${arcagi2_train_path}', '${barc_train_path}', '${graph_train_path}', '${ordering_train_path}', '${zebra_train_path}', '${reasoning_gym_train_path}', '${hitab_train_path}', '${multihier_train_path}', '${webinstruct_train_path}', '${nemotron_train_path}', '${ifbench_train_path}']" # Use math as example, add to more tasks as needed +# test_files="['${math_train1_path}']" +test_files="['${aime25_test_path}', '${amc_test_path}', '${aime_test_path}', '${math_test_path}', '${humaneval_test_path}','${livecodebench_test_path}','${zebralogic_test_path}','${reasoning_gym_test_path}','${multihier_test_path}','${nemotron_test_path}','${gpqa_diamond_test_path}','${ifeval_test_path}']" # Use math as example, add to more tasks as needed +# test_files="['${supergpqa_test_path}','${ifeval_test_path}']" # Use math as example, add to more tasks as needed + +# =================== Model =================== +BASE_MODEL=/lustrefs/users/runner/workspace/checkpoints/huggingface/sft/mid4_sft_reasoning_am_cos_epoch/checkpoints/checkpoint_0002250 +CONDA_BIN_PATH=/lustrefs/users/haonan.li/miniconda3/envs/Reasoning360/bin/ + +# =================== Logging =================== +WANDB_PROJECT=DALU +WANDB_EXPERIMENT_NAME=${SLURM_JOB_ID}-${SLURM_JOB_NAME}-${BASE_MODEL##*/} + +# Set default local directory for checkpoints +DEFAULT_LOCAL_DIR="checkpoints/${WANDB_PROJECT}/${WANDB_EXPERIMENT_NAME}" + +# If RESUME_CKPT_DIR is not empty, resume from the checkpoint +if [[ -n "$RESUME_CKPT_DIR_NAME" ]]; then + WANDB_EXPERIMENT_NAME=$(basename "$RESUME_CKPT_DIR_NAME") + DEFAULT_LOCAL_DIR="$RESUME_CKPT_DIR_NAME" +fi + + +# =================== Ray start =================== +# ray stop at all nodes +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 ${CONDA_BIN_PATH}ray stop + +sleep 10 +# Remove existing Ray cluster +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 rm -rf /tmp/ray/ray_current_cluster + +# Start Ray head node +srun --nodes=1 --ntasks=1 -w "$head_node" --export=ALL \ + env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \ + ${CONDA_BIN_PATH}ray start --head --node-ip-address="$head_node_ip" --port=$port \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --include-dashboard=True --block & + +sleep 10 + +# Start Ray worker nodes +for ((i = 1; i < worker_num; i++)); do + node_i=${nodes[$i]} + echo "Starting WORKER $i at $node_i" + srun --nodes=1 --ntasks=1 -w "$node_i" --export=ALL \ + env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \ + ${CONDA_BIN_PATH}ray start --address "$address_head" \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --block & +done +sleep 10 + + +# =================== RL Config =================== +# Note, we borrowed the config format from DAPO while here disabled all DAPO features to run the naive RL baseline. + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=$((1024 * 4)) +max_response_length=$((1024 * 28)) +enable_overlong_buffer=False +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +enable_filter_groups=False +filter_groups_metric=acc +max_num_gen_batches=10 +train_prompt_bsz=512 # on-policy model update batchsize: train_prompt_bsz * rollout.n +gen_prompt_bsz=$((train_prompt_bsz * 1)) +n_resp_per_prompt=16 +train_prompt_mini_bsz=32 # model grad update batchsize + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout + +# Training config +sp_size=1 +gen_tp=8 +gen_max_num_seqs=1024 +infer_micro_batch_size=null +train_micro_batch_size=null +use_dynamic_bsz=True +actor_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up model forward & backward but note memory overflow +infer_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up modelforward, but note memory overflow +offload=True + +# =================== Start RL training =================== +"${CONDA_BIN_PATH}python" -m recipe.dalu.main_dalu \ + --config-path=config \ + --config-name="dapo_fsdp_config.yaml" \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.prompt_key=prompt \ + data.truncation='right' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + data.gen_batch_size=${gen_prompt_bsz} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.actor.strategy="fsdp" \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.optim.warmup_style=constant \ + actor_rollout_ref.actor.optim.min_lr_ratio=0. \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.max_num_seqs=${gen_max_num_seqs} \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${top_p}\ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.model.path=$BASE_MODEL \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.rollout.multi_turn.enable=False \ + actor_rollout_ref.rollout.mode="sync" \ + +actor_rollout_ref.model.override_config.attention_dropout=0. \ + +actor_rollout_ref.model.override_config.embd_pdrop=0. \ + +actor_rollout_ref.model.override_config.resid_pdrop=0. \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + reward_model.reward_manager=async_multi_process \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.logger=['console','wandb'] \ + trainer.project_name=${WANDB_PROJECT} \ + trainer.experiment_name=${WANDB_EXPERIMENT_NAME} \ + trainer.val_before_train=True \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=$worker_num \ + trainer.save_freq=10 \ + trainer.test_freq=10 \ + trainer.total_epochs=5 \ + trainer.log_val_generations=1 \ + trainer.resume_mode=auto \ + +trainer.run_id=${WANDB_RUN_ID} \ + trainer.max_actor_ckpt_to_keep=2 \ + trainer.default_local_dir="${DEFAULT_LOCAL_DIR}" \ + +trainer.enable_budget=True \ + +data.dynamic_filtering=True \ + +data.pass_rate_upper_bound=0.9 \ + +data.initial_pass_rate_column=deepseek_r1_0528_pass_rate diff --git a/recipe/dalu/m1_dalu_70b_data5_2_inst.sh b/recipe/dalu/m1_dalu_70b_data5_2_inst.sh new file mode 100644 index 000000000..5090c24e7 --- /dev/null +++ b/recipe/dalu/m1_dalu_70b_data5_2_inst.sh @@ -0,0 +1,327 @@ +#!/bin/bash +#SBATCH --job-name=rl-70b-inst-data_5_2_mix6 +#SBATCH --nodes=16 +#SBATCH --ntasks=16 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=96 +#SBATCH --gres=gpu:8 +#SBATCH --mem=0 +#SBATCH --output=slurm/%x-%j.out +#SBATCH --error=slurm/%x-%j.err +#SBATCH --exclusive +#SBATCH --time=720:00:00 + + +# =================== Frequently Used Variables =================== +RESUME_CKPT_DIR_NAME="" # Fill in the checkpoint directory name to resume from, otherwise from scratch +WANDB_RUN_ID="" +export STEM_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-320:8000" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain + +# =================== Cluster Environment =================== +# force IB and pick the rails explicitly +export OMPI_MCA_coll_hcoll_enable=0 \ + CUDA_DEVICE_ORDER=PCI_BUS_ID \ + NCCL_SOCKET_IFNAME=eth0 \ + UCX_TLS=rc \ + UCX_NET_DEVICES=mlx5_ib0:1 \ + NCCL_DEBUG=WARN \ + NCCL_TOPO_FILE=/opt/microsoft/ndv5-topo.xml \ + NCCL_IB_PCI_RELAXED_ORDERING=1 \ + NCCL_IB_QPS_PER_CONNECTION=4 \ + NCCL_IGNORE_CPU_AFFINITY=1 \ + NCCL_P2P_NET_CHUNKSIZE=$((512 * 1024)) \ + NCCL_PXN_DISABLE=1 \ + NCCL_MIN_NCHANNELS=32 \ + SHARP_SMX_UCX_INTERFACE=mlx5_ib0:1 \ + SHARP_COLL_ENABLE_SAT=1 \ + SHARP_COLL_LOG_LEVEL=3 \ + SHARP_COLL_ENABLE_PCI_RELAXED_ORDERING=1 \ + NCCL_COLLNET_ENABLE=1 \ + NCCL_TIMEOUT=7200 \ + NCCL_BLOCKING_WAIT=1 \ + TORCH_NCCL_TRACE_BUFFER_SIZE=1000 + + +export TRITON_HOME=/tmp/triton_cache + +# Get the list of allocated nodes +nodes=( $(scontrol show hostnames "$SLURM_JOB_NODELIST") ) +echo "Nodes to check: ${nodes[@]}" + +# We'll track PIDs so we can wait on them and detect errors +declare -A pids +export head_node=${nodes[0]} +head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) +port=6379 +address_head=$head_node_ip:$port + +export worker_num=$SLURM_NNODES +export HYDRA_FULL_ERROR=1 +export VLLM_USE_V1=0 + +# =================== Data Mixture =================== +#TRAIN_DATA_DIR=/mnt/sharefs/users/zhuojun.cheng/guru_data/train/postprocessed_dedup_am +TRAIN_DATA_DIR=/lustrefs/users/haonan.li/data/k2/train_scored_dedup_am_12k_len_rm_flipscore_score_method_5_2_datamix_6 +TEST_DATA_DIR=/lustrefs/users/haonan.li/data/k2/test_12k_len +# Math (train) +math_train1_path=${TRAIN_DATA_DIR}/math__combined_118.2k.part1.parquet +math_train2_path=${TRAIN_DATA_DIR}/math__combined_118.2k.part2.parquet +# Math (test) +math_test_path=${TEST_DATA_DIR}/math__math_500.parquet +aime25_test_path=${TEST_DATA_DIR}/math__aime2025_repeated_8x_240.parquet +aime_test_path=${TEST_DATA_DIR}/math__aime_repeated_8x_240.parquet +amc_test_path=${TEST_DATA_DIR}/math__amc_repeated_4x_332.parquet + +# Code (train) +leetcode_train_path=${TRAIN_DATA_DIR}/codegen__deduped_leetcode2k_2.4k.parquet +livecodebench_train_path=${TRAIN_DATA_DIR}/codegen__deduped_livecodebench_599.parquet +primeintellect_train_path=${TRAIN_DATA_DIR}/codegen__deduped_primeintellect_9.6k.parquet +taco_train_path=${TRAIN_DATA_DIR}/codegen__deduped_taco_11.1k.parquet +# Code (test) +humaneval_test_path=${TEST_DATA_DIR}/codegen__humaneval_164.parquet +mbpp_test_path=${TEST_DATA_DIR}/codegen__mbpp_500.parquet +livecodebench_test_path=${TEST_DATA_DIR}/codegen__livecodebench_279.parquet + +# Logic (train) +arcagi1_train_path=${TRAIN_DATA_DIR}/logic__arcagi1_297.parquet +arcagi2_train_path=${TRAIN_DATA_DIR}/logic__arcagi2_653.parquet +barc_train_path=${TRAIN_DATA_DIR}/logic__barc_3.4k.parquet +graph_train_path=${TRAIN_DATA_DIR}/logic__graph_logical_dataset_1.4k.parquet +ordering_train_path=${TRAIN_DATA_DIR}/logic__ordering_puzzle_dataset_2.9k.parquet +zebra_train_path=${TRAIN_DATA_DIR}/logic__zebra_puzzle_dataset_5.0k.parquet +reasoning_gym_train_path=${TRAIN_DATA_DIR}/logic__reasoning_gym_40.6k.parquet +synlogic_train_path=${TRAIN_DATA_DIR}/logic__synlogic_12.1k.parquet + +# Logic (test) +zebralogic_test_path=${TEST_DATA_DIR}/logic__zebra_puzzle_dataset_200.parquet +reasoning_gym_test_path=${TEST_DATA_DIR}/logic__reasoning_gym_425.parquet +reasoning_gym_large_test_path=${TEST_DATA_DIR}/logic__reasoning_gym_4.3k.parquet +synlogic_test_path=${TEST_DATA_DIR}/logic__synlogic_217.parquet +arcagi1_test_path=${TEST_DATA_DIR}/logic__arcagi1_400.parquet + +# Simulation (train) +codeio_train_path=${TRAIN_DATA_DIR}/simulation__codeio_fixed_12.1k.parquet +# Simulation (test) +codeio_test_path=${TEST_DATA_DIR}/simulation__codeio_200.parquet +cruxeval_i_test_path=${TEST_DATA_DIR}/simulation__cruxeval-i_800.parquet +cruxeval_o_test_path=${TEST_DATA_DIR}/simulation__cruxeval-o_800.parquet + +# Table (train) +hitab_train_path=${TRAIN_DATA_DIR}/table__hitab_7.4k.parquet +multihier_train_path=${TRAIN_DATA_DIR}/table__multihier_2.9k.parquet +# Table (test) +multihier_test_path=${TEST_DATA_DIR}/table__multihier_336.parquet +hitab_test_path=${TEST_DATA_DIR}/table__hitab_1k.parquet +finqa_test_path=${TEST_DATA_DIR}/table__finqa_1.1k.parquet + +# Stem (train) +webinstruct_train_path=${TRAIN_DATA_DIR}/stem__web_31.7k.parquet +nemotron_train_path=${TRAIN_DATA_DIR}/stem__nemotron_13.3k.parquet +# Stem (test) +nemotron_test_path=${TEST_DATA_DIR}/stem__nemotron_100.parquet +nemotron_large_test_path=${TEST_DATA_DIR}/stem__nemotron_10.0k.parquet + +gpqa_diamond_test_path=${TEST_DATA_DIR}/stem__gpqa_diamond_198.parquet +supergpqa_test_path=${TEST_DATA_DIR}/stem__supergpqa_1k.parquet + +# IfBench (train) +ifbench_train_path=${TRAIN_DATA_DIR}/ifbench__fixed_85.6k.parquet # There might be bug, wait for fix +# IfBench (test) +ifbench_test_path=${TEST_DATA_DIR}/ifbench_800.parquet +ifbench_large_test_path=${TEST_DATA_DIR}/ifbench_8k.parquet + +# OOD (test) +ifeval_test_path=${TEST_DATA_DIR}/ood__ifeval_100.parquet +livebench_data_analysis_test_path=${TEST_DATA_DIR}/ood__livebench_data_analysis_150.parquet +livebench_language_test_path=${TEST_DATA_DIR}/ood__livebench_language_140.parquet +livebench_reasoning_test_path=${TEST_DATA_DIR}/ood__livebench_reasoning_150.parquet + +train_files="['${math_train1_path}', '${math_train2_path}', '${leetcode_train_path}', '${livecodebench_train_path}', '${primeintellect_train_path}', '${taco_train_path}', '${arcagi1_train_path}', '${arcagi2_train_path}', '${barc_train_path}', '${graph_train_path}', '${ordering_train_path}', '${zebra_train_path}', '${reasoning_gym_train_path}', '${hitab_train_path}', '${multihier_train_path}', '${webinstruct_train_path}', '${nemotron_train_path}', '${ifbench_train_path}']" # Use math as example, add to more tasks as needed +# test_files="['${math_train1_path}']" +test_files="['${aime25_test_path}', '${amc_test_path}', '${aime_test_path}', '${math_test_path}', '${humaneval_test_path}','${livecodebench_test_path}','${zebralogic_test_path}','${reasoning_gym_test_path}','${multihier_test_path}','${nemotron_test_path}','${gpqa_diamond_test_path}','${ifeval_test_path}']" # Use math as example, add to more tasks as needed +# test_files="['${supergpqa_test_path}','${ifeval_test_path}']" # Use math as example, add to more tasks as needed + +# =================== Model =================== +BASE_MODEL=/lustrefs/users/runner/workspace/checkpoints/huggingface/sft/mid4_sft_instruct_cos_epoch/checkpoints/checkpoint_0000900 +CONDA_BIN_PATH=/lustrefs/users/haonan.li/miniconda3/envs/Reasoning360/bin/ + +# =================== Logging =================== +WANDB_PROJECT=DALU +WANDB_EXPERIMENT_NAME=${SLURM_JOB_ID}-${SLURM_JOB_NAME}-${BASE_MODEL##*/} + +# Set default local directory for checkpoints +DEFAULT_LOCAL_DIR="checkpoints/${WANDB_PROJECT}/${WANDB_EXPERIMENT_NAME}" + +# If RESUME_CKPT_DIR is not empty, resume from the checkpoint +if [[ -n "$RESUME_CKPT_DIR_NAME" ]]; then + WANDB_EXPERIMENT_NAME=$(basename "$RESUME_CKPT_DIR_NAME") + DEFAULT_LOCAL_DIR="$RESUME_CKPT_DIR_NAME" +fi + + +# =================== Ray start =================== +# ray stop at all nodes +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 ${CONDA_BIN_PATH}ray stop + +sleep 10 +# Remove existing Ray cluster +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 rm -rf /tmp/ray/ray_current_cluster + +# Start Ray head node +srun --nodes=1 --ntasks=1 -w "$head_node" --export=ALL \ + env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \ + ${CONDA_BIN_PATH}ray start --head --node-ip-address="$head_node_ip" --port=$port \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --include-dashboard=True --block & + +sleep 10 + +# Start Ray worker nodes +for ((i = 1; i < worker_num; i++)); do + node_i=${nodes[$i]} + echo "Starting WORKER $i at $node_i" + srun --nodes=1 --ntasks=1 -w "$node_i" --export=ALL \ + env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \ + ${CONDA_BIN_PATH}ray start --address "$address_head" \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --block & +done +sleep 10 + + +# =================== RL Config =================== +# Note, we borrowed the config format from DAPO while here disabled all DAPO features to run the naive RL baseline. + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=$((1024 * 4)) +max_response_length=$((1024 * 28)) +enable_overlong_buffer=False +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +enable_filter_groups=False +filter_groups_metric=acc +max_num_gen_batches=10 +train_prompt_bsz=512 # on-policy model update batchsize: train_prompt_bsz * rollout.n +gen_prompt_bsz=$((train_prompt_bsz * 1)) +n_resp_per_prompt=16 +train_prompt_mini_bsz=32 # model grad update batchsize + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout + +# Training config +sp_size=1 +gen_tp=8 +gen_max_num_seqs=1024 +infer_micro_batch_size=null +train_micro_batch_size=null +use_dynamic_bsz=True +actor_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up model forward & backward but note memory overflow +infer_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up modelforward, but note memory overflow +offload=True + +# =================== Start RL training =================== +"${CONDA_BIN_PATH}python" -m recipe.dalu.main_dalu \ + --config-path=config \ + --config-name="dapo_fsdp_config.yaml" \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.prompt_key=prompt \ + data.truncation='right' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + data.gen_batch_size=${gen_prompt_bsz} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.actor.strategy="fsdp" \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.optim.warmup_style=constant \ + actor_rollout_ref.actor.optim.min_lr_ratio=0. \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.max_num_seqs=${gen_max_num_seqs} \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${top_p}\ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.model.path=$BASE_MODEL \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.rollout.multi_turn.enable=False \ + actor_rollout_ref.rollout.mode="sync" \ + +actor_rollout_ref.model.override_config.attention_dropout=0. \ + +actor_rollout_ref.model.override_config.embd_pdrop=0. \ + +actor_rollout_ref.model.override_config.resid_pdrop=0. \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + reward_model.reward_manager=async_multi_process \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.logger=['console','wandb'] \ + trainer.project_name=${WANDB_PROJECT} \ + trainer.experiment_name=${WANDB_EXPERIMENT_NAME} \ + trainer.val_before_train=True \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=$worker_num \ + trainer.save_freq=10 \ + trainer.test_freq=10 \ + trainer.total_epochs=5 \ + trainer.log_val_generations=1 \ + trainer.resume_mode=auto \ + +trainer.run_id=${WANDB_RUN_ID} \ + trainer.max_actor_ckpt_to_keep=2 \ + trainer.default_local_dir="${DEFAULT_LOCAL_DIR}" \ + +trainer.enable_budget=True \ + +data.dynamic_filtering=True \ + +data.pass_rate_upper_bound=0.9 \ + +data.initial_pass_rate_column=deepseek_r1_0528_pass_rate diff --git a/recipe/dalu/m1_dalu_70b_data5_2_ot.sh b/recipe/dalu/m1_dalu_70b_data5_2_ot.sh new file mode 100644 index 000000000..50e58d361 --- /dev/null +++ b/recipe/dalu/m1_dalu_70b_data5_2_ot.sh @@ -0,0 +1,327 @@ +#!/bin/bash +#SBATCH --job-name=rl-70b-ot-data_5_2_mix6 +#SBATCH --nodes=16 +#SBATCH --ntasks=16 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=96 +#SBATCH --gres=gpu:8 +#SBATCH --mem=0 +#SBATCH --output=slurm/%x-%j.out +#SBATCH --error=slurm/%x-%j.err +#SBATCH --exclusive +#SBATCH --time=720:00:00 + + +# =================== Frequently Used Variables =================== +RESUME_CKPT_DIR_NAME="" # Fill in the checkpoint directory name to resume from, otherwise from scratch +WANDB_RUN_ID="" +export STEM_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-320:8000" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain + +# =================== Cluster Environment =================== +# force IB and pick the rails explicitly +export OMPI_MCA_coll_hcoll_enable=0 \ + CUDA_DEVICE_ORDER=PCI_BUS_ID \ + NCCL_SOCKET_IFNAME=eth0 \ + UCX_TLS=rc \ + UCX_NET_DEVICES=mlx5_ib0:1 \ + NCCL_DEBUG=WARN \ + NCCL_TOPO_FILE=/opt/microsoft/ndv5-topo.xml \ + NCCL_IB_PCI_RELAXED_ORDERING=1 \ + NCCL_IB_QPS_PER_CONNECTION=4 \ + NCCL_IGNORE_CPU_AFFINITY=1 \ + NCCL_P2P_NET_CHUNKSIZE=$((512 * 1024)) \ + NCCL_PXN_DISABLE=1 \ + NCCL_MIN_NCHANNELS=32 \ + SHARP_SMX_UCX_INTERFACE=mlx5_ib0:1 \ + SHARP_COLL_ENABLE_SAT=1 \ + SHARP_COLL_LOG_LEVEL=3 \ + SHARP_COLL_ENABLE_PCI_RELAXED_ORDERING=1 \ + NCCL_COLLNET_ENABLE=1 \ + NCCL_TIMEOUT=7200 \ + NCCL_BLOCKING_WAIT=1 \ + TORCH_NCCL_TRACE_BUFFER_SIZE=1000 + + +export TRITON_HOME=/tmp/triton_cache + +# Get the list of allocated nodes +nodes=( $(scontrol show hostnames "$SLURM_JOB_NODELIST") ) +echo "Nodes to check: ${nodes[@]}" + +# We'll track PIDs so we can wait on them and detect errors +declare -A pids +export head_node=${nodes[0]} +head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) +port=6379 +address_head=$head_node_ip:$port + +export worker_num=$SLURM_NNODES +export HYDRA_FULL_ERROR=1 +export VLLM_USE_V1=0 + +# =================== Data Mixture =================== +#TRAIN_DATA_DIR=/mnt/sharefs/users/zhuojun.cheng/guru_data/train/postprocessed_dedup_am +TRAIN_DATA_DIR=/lustrefs/users/haonan.li/data/k2/train_scored_dedup_am_12k_len_rm_flipscore_score_method_5_2_datamix_6 +TEST_DATA_DIR=/lustrefs/users/haonan.li/data/k2/test_12k_len +# Math (train) +math_train1_path=${TRAIN_DATA_DIR}/math__combined_118.2k.part1.parquet +math_train2_path=${TRAIN_DATA_DIR}/math__combined_118.2k.part2.parquet +# Math (test) +math_test_path=${TEST_DATA_DIR}/math__math_500.parquet +aime25_test_path=${TEST_DATA_DIR}/math__aime2025_repeated_8x_240.parquet +aime_test_path=${TEST_DATA_DIR}/math__aime_repeated_8x_240.parquet +amc_test_path=${TEST_DATA_DIR}/math__amc_repeated_4x_332.parquet + +# Code (train) +leetcode_train_path=${TRAIN_DATA_DIR}/codegen__deduped_leetcode2k_2.4k.parquet +livecodebench_train_path=${TRAIN_DATA_DIR}/codegen__deduped_livecodebench_599.parquet +primeintellect_train_path=${TRAIN_DATA_DIR}/codegen__deduped_primeintellect_9.6k.parquet +taco_train_path=${TRAIN_DATA_DIR}/codegen__deduped_taco_11.1k.parquet +# Code (test) +humaneval_test_path=${TEST_DATA_DIR}/codegen__humaneval_164.parquet +mbpp_test_path=${TEST_DATA_DIR}/codegen__mbpp_500.parquet +livecodebench_test_path=${TEST_DATA_DIR}/codegen__livecodebench_279.parquet + +# Logic (train) +arcagi1_train_path=${TRAIN_DATA_DIR}/logic__arcagi1_297.parquet +arcagi2_train_path=${TRAIN_DATA_DIR}/logic__arcagi2_653.parquet +barc_train_path=${TRAIN_DATA_DIR}/logic__barc_3.4k.parquet +graph_train_path=${TRAIN_DATA_DIR}/logic__graph_logical_dataset_1.4k.parquet +ordering_train_path=${TRAIN_DATA_DIR}/logic__ordering_puzzle_dataset_2.9k.parquet +zebra_train_path=${TRAIN_DATA_DIR}/logic__zebra_puzzle_dataset_5.0k.parquet +reasoning_gym_train_path=${TRAIN_DATA_DIR}/logic__reasoning_gym_40.6k.parquet +synlogic_train_path=${TRAIN_DATA_DIR}/logic__synlogic_12.1k.parquet + +# Logic (test) +zebralogic_test_path=${TEST_DATA_DIR}/logic__zebra_puzzle_dataset_200.parquet +reasoning_gym_test_path=${TEST_DATA_DIR}/logic__reasoning_gym_425.parquet +reasoning_gym_large_test_path=${TEST_DATA_DIR}/logic__reasoning_gym_4.3k.parquet +synlogic_test_path=${TEST_DATA_DIR}/logic__synlogic_217.parquet +arcagi1_test_path=${TEST_DATA_DIR}/logic__arcagi1_400.parquet + +# Simulation (train) +codeio_train_path=${TRAIN_DATA_DIR}/simulation__codeio_fixed_12.1k.parquet +# Simulation (test) +codeio_test_path=${TEST_DATA_DIR}/simulation__codeio_200.parquet +cruxeval_i_test_path=${TEST_DATA_DIR}/simulation__cruxeval-i_800.parquet +cruxeval_o_test_path=${TEST_DATA_DIR}/simulation__cruxeval-o_800.parquet + +# Table (train) +hitab_train_path=${TRAIN_DATA_DIR}/table__hitab_7.4k.parquet +multihier_train_path=${TRAIN_DATA_DIR}/table__multihier_2.9k.parquet +# Table (test) +multihier_test_path=${TEST_DATA_DIR}/table__multihier_336.parquet +hitab_test_path=${TEST_DATA_DIR}/table__hitab_1k.parquet +finqa_test_path=${TEST_DATA_DIR}/table__finqa_1.1k.parquet + +# Stem (train) +webinstruct_train_path=${TRAIN_DATA_DIR}/stem__web_31.7k.parquet +nemotron_train_path=${TRAIN_DATA_DIR}/stem__nemotron_13.3k.parquet +# Stem (test) +nemotron_test_path=${TEST_DATA_DIR}/stem__nemotron_100.parquet +nemotron_large_test_path=${TEST_DATA_DIR}/stem__nemotron_10.0k.parquet + +gpqa_diamond_test_path=${TEST_DATA_DIR}/stem__gpqa_diamond_198.parquet +supergpqa_test_path=${TEST_DATA_DIR}/stem__supergpqa_1k.parquet + +# IfBench (train) +ifbench_train_path=${TRAIN_DATA_DIR}/ifbench__fixed_85.6k.parquet # There might be bug, wait for fix +# IfBench (test) +ifbench_test_path=${TEST_DATA_DIR}/ifbench_800.parquet +ifbench_large_test_path=${TEST_DATA_DIR}/ifbench_8k.parquet + +# OOD (test) +ifeval_test_path=${TEST_DATA_DIR}/ood__ifeval_100.parquet +livebench_data_analysis_test_path=${TEST_DATA_DIR}/ood__livebench_data_analysis_150.parquet +livebench_language_test_path=${TEST_DATA_DIR}/ood__livebench_language_140.parquet +livebench_reasoning_test_path=${TEST_DATA_DIR}/ood__livebench_reasoning_150.parquet + +train_files="['${math_train1_path}', '${math_train2_path}', '${leetcode_train_path}', '${livecodebench_train_path}', '${primeintellect_train_path}', '${taco_train_path}', '${arcagi1_train_path}', '${arcagi2_train_path}', '${barc_train_path}', '${graph_train_path}', '${ordering_train_path}', '${zebra_train_path}', '${reasoning_gym_train_path}', '${hitab_train_path}', '${multihier_train_path}', '${webinstruct_train_path}', '${nemotron_train_path}', '${ifbench_train_path}']" # Use math as example, add to more tasks as needed +# test_files="['${math_train1_path}']" +test_files="['${aime25_test_path}', '${amc_test_path}', '${aime_test_path}', '${math_test_path}', '${humaneval_test_path}','${livecodebench_test_path}','${zebralogic_test_path}','${reasoning_gym_test_path}','${multihier_test_path}','${nemotron_test_path}','${gpqa_diamond_test_path}','${ifeval_test_path}']" # Use math as example, add to more tasks as needed +# test_files="['${supergpqa_test_path}','${ifeval_test_path}']" # Use math as example, add to more tasks as needed + +# =================== Model =================== +BASE_MODEL=/lustrefs/users/runner/workspace/checkpoints/huggingface/sft/mid4_sft_reasoning_ot_cos_epoch/checkpoints/checkpoint_0006300 +CONDA_BIN_PATH=/lustrefs/users/haonan.li/miniconda3/envs/Reasoning360/bin/ + +# =================== Logging =================== +WANDB_PROJECT=DALU +WANDB_EXPERIMENT_NAME=${SLURM_JOB_ID}-${SLURM_JOB_NAME}-${BASE_MODEL##*/} + +# Set default local directory for checkpoints +DEFAULT_LOCAL_DIR="checkpoints/${WANDB_PROJECT}/${WANDB_EXPERIMENT_NAME}" + +# If RESUME_CKPT_DIR is not empty, resume from the checkpoint +if [[ -n "$RESUME_CKPT_DIR_NAME" ]]; then + WANDB_EXPERIMENT_NAME=$(basename "$RESUME_CKPT_DIR_NAME") + DEFAULT_LOCAL_DIR="$RESUME_CKPT_DIR_NAME" +fi + + +# =================== Ray start =================== +# ray stop at all nodes +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 ${CONDA_BIN_PATH}ray stop + +sleep 10 +# Remove existing Ray cluster +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 rm -rf /tmp/ray/ray_current_cluster + +# Start Ray head node +srun --nodes=1 --ntasks=1 -w "$head_node" --export=ALL \ + env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \ + ${CONDA_BIN_PATH}ray start --head --node-ip-address="$head_node_ip" --port=$port \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --include-dashboard=True --block & + +sleep 10 + +# Start Ray worker nodes +for ((i = 1; i < worker_num; i++)); do + node_i=${nodes[$i]} + echo "Starting WORKER $i at $node_i" + srun --nodes=1 --ntasks=1 -w "$node_i" --export=ALL \ + env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \ + ${CONDA_BIN_PATH}ray start --address "$address_head" \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --block & +done +sleep 10 + + +# =================== RL Config =================== +# Note, we borrowed the config format from DAPO while here disabled all DAPO features to run the naive RL baseline. + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=$((1024 * 4)) +max_response_length=$((1024 * 28)) +enable_overlong_buffer=False +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +enable_filter_groups=False +filter_groups_metric=acc +max_num_gen_batches=10 +train_prompt_bsz=512 # on-policy model update batchsize: train_prompt_bsz * rollout.n +gen_prompt_bsz=$((train_prompt_bsz * 1)) +n_resp_per_prompt=16 +train_prompt_mini_bsz=32 # model grad update batchsize + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout + +# Training config +sp_size=1 +gen_tp=8 +gen_max_num_seqs=1024 +infer_micro_batch_size=null +train_micro_batch_size=null +use_dynamic_bsz=True +actor_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up model forward & backward but note memory overflow +infer_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up modelforward, but note memory overflow +offload=True + +# =================== Start RL training =================== +"${CONDA_BIN_PATH}python" -m recipe.dalu.main_dalu \ + --config-path=config \ + --config-name="dapo_fsdp_config.yaml" \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.prompt_key=prompt \ + data.truncation='right' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + data.gen_batch_size=${gen_prompt_bsz} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.actor.strategy="fsdp" \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.optim.warmup_style=constant \ + actor_rollout_ref.actor.optim.min_lr_ratio=0. \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.max_num_seqs=${gen_max_num_seqs} \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${top_p}\ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.model.path=$BASE_MODEL \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.rollout.multi_turn.enable=False \ + actor_rollout_ref.rollout.mode="sync" \ + +actor_rollout_ref.model.override_config.attention_dropout=0. \ + +actor_rollout_ref.model.override_config.embd_pdrop=0. \ + +actor_rollout_ref.model.override_config.resid_pdrop=0. \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + reward_model.reward_manager=async_multi_process \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.logger=['console','wandb'] \ + trainer.project_name=${WANDB_PROJECT} \ + trainer.experiment_name=${WANDB_EXPERIMENT_NAME} \ + trainer.val_before_train=True \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=$worker_num \ + trainer.save_freq=10 \ + trainer.test_freq=10 \ + trainer.total_epochs=5 \ + trainer.log_val_generations=1 \ + trainer.resume_mode=auto \ + +trainer.run_id=${WANDB_RUN_ID} \ + trainer.max_actor_ckpt_to_keep=2 \ + trainer.default_local_dir="${DEFAULT_LOCAL_DIR}" \ + +trainer.enable_budget=True \ + +data.dynamic_filtering=True \ + +data.pass_rate_upper_bound=0.9 \ + +data.initial_pass_rate_column=deepseek_r1_0528_pass_rate diff --git a/recipe/dalu/m1_thinktype_baseline.sh b/recipe/dalu/m1_thinktype_baseline.sh new file mode 100644 index 000000000..3f78f742f --- /dev/null +++ b/recipe/dalu/m1_thinktype_baseline.sh @@ -0,0 +1,327 @@ +#!/bin/bash +#SBATCH --job-name=rl-thinktype-baseline +#SBATCH --nodes=8 +#SBATCH --ntasks=8 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=96 +#SBATCH --gres=gpu:8 +#SBATCH --mem=0 +#SBATCH --output=slurm/%x-%j.out +#SBATCH --error=slurm/%x-%j.err +#SBATCH --exclusive +#SBATCH --time=720:00:00 + + +# =================== Frequently Used Variables =================== +RESUME_CKPT_DIR_NAME="" +WANDB_ID="" +export STEM_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-320:8000" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain + +# =================== Cluster Environment =================== +# force IB and pick the rails explicitly +export OMPI_MCA_coll_hcoll_enable=0 \ + CUDA_DEVICE_ORDER=PCI_BUS_ID \ + NCCL_SOCKET_IFNAME=eth0 \ + UCX_TLS=rc \ + UCX_NET_DEVICES=mlx5_ib0:1 \ + NCCL_DEBUG=WARN \ + NCCL_TOPO_FILE=/opt/microsoft/ndv5-topo.xml \ + NCCL_IB_PCI_RELAXED_ORDERING=1 \ + NCCL_IB_QPS_PER_CONNECTION=4 \ + NCCL_IGNORE_CPU_AFFINITY=1 \ + NCCL_P2P_NET_CHUNKSIZE=$((512 * 1024)) \ + NCCL_PXN_DISABLE=1 \ + NCCL_MIN_NCHANNELS=32 \ + SHARP_SMX_UCX_INTERFACE=mlx5_ib0:1 \ + SHARP_COLL_ENABLE_SAT=1 \ + SHARP_COLL_LOG_LEVEL=3 \ + SHARP_COLL_ENABLE_PCI_RELAXED_ORDERING=1 \ + NCCL_COLLNET_ENABLE=1 \ + NCCL_TIMEOUT=7200 \ + NCCL_BLOCKING_WAIT=1 \ + TORCH_NCCL_TRACE_BUFFER_SIZE=1000 + + +export TRITON_HOME=/tmp/triton_cache + +# Get the list of allocated nodes +nodes=( $(scontrol show hostnames "$SLURM_JOB_NODELIST") ) +echo "Nodes to check: ${nodes[@]}" + +# We'll track PIDs so we can wait on them and detect errors +declare -A pids +export head_node=${nodes[0]} +head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) +port=6379 +address_head=$head_node_ip:$port + +export worker_num=$SLURM_NNODES +export HYDRA_FULL_ERROR=1 +export VLLM_USE_V1=0 + +# =================== Data Mixture =================== +#TRAIN_DATA_DIR=/mnt/sharefs/users/zhuojun.cheng/guru_data/train/postprocessed_dedup_am +TRAIN_DATA_DIR=/lustrefs/users/haonan.li/data/k2/train_scored_dedup_am_12k_len_rm_flipscore_score_method_5_2_datamix_6 +TEST_DATA_DIR=/lustrefs/users/haonan.li/data/k2/test_12k_len +# Math (train) +math_train1_path=${TRAIN_DATA_DIR}/math__combined_118.2k.part1.parquet +math_train2_path=${TRAIN_DATA_DIR}/math__combined_118.2k.part2.parquet +# Math (test) +math_test_path=${TEST_DATA_DIR}/math__math_500.parquet +aime25_test_path=${TEST_DATA_DIR}/math__aime2025_repeated_8x_240.parquet +aime_test_path=${TEST_DATA_DIR}/math__aime_repeated_8x_240.parquet +amc_test_path=${TEST_DATA_DIR}/math__amc_repeated_4x_332.parquet + +# Code (train) +leetcode_train_path=${TRAIN_DATA_DIR}/codegen__deduped_leetcode2k_2.4k.parquet +livecodebench_train_path=${TRAIN_DATA_DIR}/codegen__deduped_livecodebench_599.parquet +primeintellect_train_path=${TRAIN_DATA_DIR}/codegen__deduped_primeintellect_9.6k.parquet +taco_train_path=${TRAIN_DATA_DIR}/codegen__deduped_taco_11.1k.parquet +# Code (test) +humaneval_test_path=${TEST_DATA_DIR}/codegen__humaneval_164.parquet +mbpp_test_path=${TEST_DATA_DIR}/codegen__mbpp_500.parquet +livecodebench_test_path=${TEST_DATA_DIR}/codegen__livecodebench_279.parquet + +# Logic (train) +arcagi1_train_path=${TRAIN_DATA_DIR}/logic__arcagi1_297.parquet +arcagi2_train_path=${TRAIN_DATA_DIR}/logic__arcagi2_653.parquet +barc_train_path=${TRAIN_DATA_DIR}/logic__barc_3.4k.parquet +graph_train_path=${TRAIN_DATA_DIR}/logic__graph_logical_dataset_1.4k.parquet +ordering_train_path=${TRAIN_DATA_DIR}/logic__ordering_puzzle_dataset_2.9k.parquet +zebra_train_path=${TRAIN_DATA_DIR}/logic__zebra_puzzle_dataset_5.0k.parquet +reasoning_gym_train_path=${TRAIN_DATA_DIR}/logic__reasoning_gym_40.6k.parquet +synlogic_train_path=${TRAIN_DATA_DIR}/logic__synlogic_12.1k.parquet + +# Logic (test) +zebralogic_test_path=${TEST_DATA_DIR}/logic__zebra_puzzle_dataset_200.parquet +reasoning_gym_test_path=${TEST_DATA_DIR}/logic__reasoning_gym_425.parquet +reasoning_gym_large_test_path=${TEST_DATA_DIR}/logic__reasoning_gym_4.3k.parquet +synlogic_test_path=${TEST_DATA_DIR}/logic__synlogic_217.parquet +arcagi1_test_path=${TEST_DATA_DIR}/logic__arcagi1_400.parquet + +# Simulation (train) +codeio_train_path=${TRAIN_DATA_DIR}/simulation__codeio_fixed_12.1k.parquet +# Simulation (test) +codeio_test_path=${TEST_DATA_DIR}/simulation__codeio_200.parquet +cruxeval_i_test_path=${TEST_DATA_DIR}/simulation__cruxeval-i_800.parquet +cruxeval_o_test_path=${TEST_DATA_DIR}/simulation__cruxeval-o_800.parquet + +# Table (train) +hitab_train_path=${TRAIN_DATA_DIR}/table__hitab_7.4k.parquet +multihier_train_path=${TRAIN_DATA_DIR}/table__multihier_2.9k.parquet +# Table (test) +multihier_test_path=${TEST_DATA_DIR}/table__multihier_336.parquet +hitab_test_path=${TEST_DATA_DIR}/table__hitab_1k.parquet +finqa_test_path=${TEST_DATA_DIR}/table__finqa_1.1k.parquet + +# Stem (train) +webinstruct_train_path=${TRAIN_DATA_DIR}/stem__web_31.7k.parquet +nemotron_train_path=${TRAIN_DATA_DIR}/stem__nemotron_13.3k.parquet +# Stem (test) +nemotron_test_path=${TEST_DATA_DIR}/stem__nemotron_100.parquet +nemotron_large_test_path=${TEST_DATA_DIR}/stem__nemotron_10.0k.parquet + +gpqa_diamond_test_path=${TEST_DATA_DIR}/stem__gpqa_diamond_198.parquet +supergpqa_test_path=${TEST_DATA_DIR}/stem__supergpqa_1k.parquet + +# IfBench (train) +ifbench_train_path=${TRAIN_DATA_DIR}/ifbench__fixed_85.6k.parquet # There might be bug, wait for fix +# IfBench (test) +ifbench_test_path=${TEST_DATA_DIR}/ifbench_800.parquet +ifbench_large_test_path=${TEST_DATA_DIR}/ifbench_8k.parquet + +# OOD (test) +ifeval_test_path=${TEST_DATA_DIR}/ood__ifeval_100.parquet +livebench_data_analysis_test_path=${TEST_DATA_DIR}/ood__livebench_data_analysis_150.parquet +livebench_language_test_path=${TEST_DATA_DIR}/ood__livebench_language_140.parquet +livebench_reasoning_test_path=${TEST_DATA_DIR}/ood__livebench_reasoning_150.parquet + +train_files="['${math_train1_path}', '${math_train2_path}', '${arcagi1_train_path}', '${arcagi2_train_path}', '${reasoning_gym_train_path}']" # Use math as example, add to more tasks as needed +# test_files="['${math_train1_path}']" +test_files="['${aime25_test_path}', '${amc_test_path}', '${aime_test_path}', '${math_test_path}', '${reasoning_gym_test_path}','${gpqa_diamond_test_path}']" # Use math as example, add to more tasks as needed +# test_files="['${supergpqa_test_path}','${ifeval_test_path}']" # Use math as example, add to more tasks as needed + +# =================== Model =================== +BASE_MODEL=MBZUAI-IFM/TP-base-7B +CONDA_BIN_PATH=/lustrefs/users/haonan.li/miniconda3/envs/Reasoning360/bin/ + +# =================== Logging =================== +WANDB_PROJECT=ThinkType +WANDB_EXPERIMENT_NAME=${SLURM_JOB_ID}-${SLURM_JOB_NAME}-${BASE_MODEL##*/} + +# Set default local directory for checkpoints +DEFAULT_LOCAL_DIR="checkpoints/${WANDB_PROJECT}/${WANDB_EXPERIMENT_NAME}" + +# If RESUME_CKPT_DIR is not empty, resume from the checkpoint +if [[ -n "$RESUME_CKPT_DIR_NAME" ]]; then + WANDB_EXPERIMENT_NAME=$(basename "$RESUME_CKPT_DIR_NAME") + DEFAULT_LOCAL_DIR="$RESUME_CKPT_DIR_NAME" +fi + + +# =================== Ray start =================== +# ray stop at all nodes +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 ${CONDA_BIN_PATH}ray stop + +sleep 10 +# Remove existing Ray cluster +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 rm -rf /tmp/ray/ray_current_cluster + +# Start Ray head node +srun --nodes=1 --ntasks=1 -w "$head_node" --export=ALL \ + env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \ + ${CONDA_BIN_PATH}ray start --head --node-ip-address="$head_node_ip" --port=$port \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --include-dashboard=True --block & + +sleep 10 + +# Start Ray worker nodes +for ((i = 1; i < worker_num; i++)); do + node_i=${nodes[$i]} + echo "Starting WORKER $i at $node_i" + srun --nodes=1 --ntasks=1 -w "$node_i" --export=ALL \ + env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \ + ${CONDA_BIN_PATH}ray start --address "$address_head" \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --block & +done +sleep 10 + + +# =================== RL Config =================== +# Note, we borrowed the config format from DAPO while here disabled all DAPO features to run the naive RL baseline. + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=$((1024 * 4)) +max_response_length=$((1024 * 28)) +enable_overlong_buffer=False +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +enable_filter_groups=False +filter_groups_metric=acc +max_num_gen_batches=10 +train_prompt_bsz=512 # on-policy model update batchsize: train_prompt_bsz * rollout.n +gen_prompt_bsz=$((train_prompt_bsz * 1)) +n_resp_per_prompt=16 +train_prompt_mini_bsz=32 # model grad update batchsize + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout + +# Training config +sp_size=2 +gen_tp=4 +gen_max_num_seqs=1024 +infer_micro_batch_size=null +train_micro_batch_size=null +use_dynamic_bsz=True +actor_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up model forward & backward but note memory overflow +infer_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up modelforward, but note memory overflow +offload=True + +# =================== Start RL training =================== +"${CONDA_BIN_PATH}python" -m recipe.dalu.main_dalu \ + --config-path=config \ + --config-name="dapo_fsdp_config.yaml" \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.prompt_key=prompt \ + data.truncation='right' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + data.gen_batch_size=${gen_prompt_bsz} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.actor.strategy="fsdp" \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.optim.warmup_style=constant \ + actor_rollout_ref.actor.optim.min_lr_ratio=0. \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.max_num_seqs=${gen_max_num_seqs} \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${top_p}\ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.model.path=$BASE_MODEL \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.rollout.multi_turn.enable=False \ + actor_rollout_ref.rollout.mode="sync" \ + +actor_rollout_ref.model.override_config.attention_dropout=0. \ + +actor_rollout_ref.model.override_config.embd_pdrop=0. \ + +actor_rollout_ref.model.override_config.resid_pdrop=0. \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + reward_model.reward_manager=async_multi_process \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.logger=['console','wandb'] \ + trainer.project_name=${WANDB_PROJECT} \ + trainer.experiment_name=${WANDB_EXPERIMENT_NAME} \ + trainer.val_before_train=False \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=$worker_num \ + trainer.save_freq=10 \ + trainer.test_freq=10 \ + trainer.total_epochs=10 \ + trainer.log_val_generations=1 \ + trainer.resume_mode=auto \ + trainer.max_actor_ckpt_to_keep=10 \ + +trainer.run_id=${WANDB_ID} \ + trainer.default_local_dir="${DEFAULT_LOCAL_DIR}" \ + +trainer.enable_budget=False \ + +data.dynamic_filtering=False \ + +data.pass_rate_upper_bound=0.9 \ + +data.initial_pass_rate_column=deepseek_r1_0528_pass_rate diff --git a/recipe/dalu/m1_thinktype_etf.sh b/recipe/dalu/m1_thinktype_etf.sh new file mode 100644 index 000000000..4ad5ad890 --- /dev/null +++ b/recipe/dalu/m1_thinktype_etf.sh @@ -0,0 +1,327 @@ +#!/bin/bash +#SBATCH --job-name=rl-thinktype-etf +#SBATCH --nodes=8 +#SBATCH --ntasks=8 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=96 +#SBATCH --gres=gpu:8 +#SBATCH --mem=0 +#SBATCH --output=slurm/%x-%j.out +#SBATCH --error=slurm/%x-%j.err +#SBATCH --exclusive +#SBATCH --time=720:00:00 + + +# =================== Frequently Used Variables =================== +RESUME_CKPT_DIR_NAME="" +WANDB_ID="" +export STEM_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-320:8000" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain + +# =================== Cluster Environment =================== +# force IB and pick the rails explicitly +export OMPI_MCA_coll_hcoll_enable=0 \ + CUDA_DEVICE_ORDER=PCI_BUS_ID \ + NCCL_SOCKET_IFNAME=eth0 \ + UCX_TLS=rc \ + UCX_NET_DEVICES=mlx5_ib0:1 \ + NCCL_DEBUG=WARN \ + NCCL_TOPO_FILE=/opt/microsoft/ndv5-topo.xml \ + NCCL_IB_PCI_RELAXED_ORDERING=1 \ + NCCL_IB_QPS_PER_CONNECTION=4 \ + NCCL_IGNORE_CPU_AFFINITY=1 \ + NCCL_P2P_NET_CHUNKSIZE=$((512 * 1024)) \ + NCCL_PXN_DISABLE=1 \ + NCCL_MIN_NCHANNELS=32 \ + SHARP_SMX_UCX_INTERFACE=mlx5_ib0:1 \ + SHARP_COLL_ENABLE_SAT=1 \ + SHARP_COLL_LOG_LEVEL=3 \ + SHARP_COLL_ENABLE_PCI_RELAXED_ORDERING=1 \ + NCCL_COLLNET_ENABLE=1 \ + NCCL_TIMEOUT=7200 \ + NCCL_BLOCKING_WAIT=1 \ + TORCH_NCCL_TRACE_BUFFER_SIZE=1000 + + +export TRITON_HOME=/tmp/triton_cache + +# Get the list of allocated nodes +nodes=( $(scontrol show hostnames "$SLURM_JOB_NODELIST") ) +echo "Nodes to check: ${nodes[@]}" + +# We'll track PIDs so we can wait on them and detect errors +declare -A pids +export head_node=${nodes[0]} +head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) +port=6379 +address_head=$head_node_ip:$port + +export worker_num=$SLURM_NNODES +export HYDRA_FULL_ERROR=1 +export VLLM_USE_V1=0 + +# =================== Data Mixture =================== +#TRAIN_DATA_DIR=/mnt/sharefs/users/zhuojun.cheng/guru_data/train/postprocessed_dedup_am +TRAIN_DATA_DIR=/lustrefs/users/haonan.li/data/k2/train_scored_dedup_am_12k_len_rm_flipscore_score_method_5_2_datamix_6 +TEST_DATA_DIR=/lustrefs/users/haonan.li/data/k2/test_12k_len +# Math (train) +math_train1_path=${TRAIN_DATA_DIR}/math__combined_118.2k.part1.parquet +math_train2_path=${TRAIN_DATA_DIR}/math__combined_118.2k.part2.parquet +# Math (test) +math_test_path=${TEST_DATA_DIR}/math__math_500.parquet +aime25_test_path=${TEST_DATA_DIR}/math__aime2025_repeated_8x_240.parquet +aime_test_path=${TEST_DATA_DIR}/math__aime_repeated_8x_240.parquet +amc_test_path=${TEST_DATA_DIR}/math__amc_repeated_4x_332.parquet + +# Code (train) +leetcode_train_path=${TRAIN_DATA_DIR}/codegen__deduped_leetcode2k_2.4k.parquet +livecodebench_train_path=${TRAIN_DATA_DIR}/codegen__deduped_livecodebench_599.parquet +primeintellect_train_path=${TRAIN_DATA_DIR}/codegen__deduped_primeintellect_9.6k.parquet +taco_train_path=${TRAIN_DATA_DIR}/codegen__deduped_taco_11.1k.parquet +# Code (test) +humaneval_test_path=${TEST_DATA_DIR}/codegen__humaneval_164.parquet +mbpp_test_path=${TEST_DATA_DIR}/codegen__mbpp_500.parquet +livecodebench_test_path=${TEST_DATA_DIR}/codegen__livecodebench_279.parquet + +# Logic (train) +arcagi1_train_path=${TRAIN_DATA_DIR}/logic__arcagi1_297.parquet +arcagi2_train_path=${TRAIN_DATA_DIR}/logic__arcagi2_653.parquet +barc_train_path=${TRAIN_DATA_DIR}/logic__barc_3.4k.parquet +graph_train_path=${TRAIN_DATA_DIR}/logic__graph_logical_dataset_1.4k.parquet +ordering_train_path=${TRAIN_DATA_DIR}/logic__ordering_puzzle_dataset_2.9k.parquet +zebra_train_path=${TRAIN_DATA_DIR}/logic__zebra_puzzle_dataset_5.0k.parquet +reasoning_gym_train_path=${TRAIN_DATA_DIR}/logic__reasoning_gym_40.6k.parquet +synlogic_train_path=${TRAIN_DATA_DIR}/logic__synlogic_12.1k.parquet + +# Logic (test) +zebralogic_test_path=${TEST_DATA_DIR}/logic__zebra_puzzle_dataset_200.parquet +reasoning_gym_test_path=${TEST_DATA_DIR}/logic__reasoning_gym_425.parquet +reasoning_gym_large_test_path=${TEST_DATA_DIR}/logic__reasoning_gym_4.3k.parquet +synlogic_test_path=${TEST_DATA_DIR}/logic__synlogic_217.parquet +arcagi1_test_path=${TEST_DATA_DIR}/logic__arcagi1_400.parquet + +# Simulation (train) +codeio_train_path=${TRAIN_DATA_DIR}/simulation__codeio_fixed_12.1k.parquet +# Simulation (test) +codeio_test_path=${TEST_DATA_DIR}/simulation__codeio_200.parquet +cruxeval_i_test_path=${TEST_DATA_DIR}/simulation__cruxeval-i_800.parquet +cruxeval_o_test_path=${TEST_DATA_DIR}/simulation__cruxeval-o_800.parquet + +# Table (train) +hitab_train_path=${TRAIN_DATA_DIR}/table__hitab_7.4k.parquet +multihier_train_path=${TRAIN_DATA_DIR}/table__multihier_2.9k.parquet +# Table (test) +multihier_test_path=${TEST_DATA_DIR}/table__multihier_336.parquet +hitab_test_path=${TEST_DATA_DIR}/table__hitab_1k.parquet +finqa_test_path=${TEST_DATA_DIR}/table__finqa_1.1k.parquet + +# Stem (train) +webinstruct_train_path=${TRAIN_DATA_DIR}/stem__web_31.7k.parquet +nemotron_train_path=${TRAIN_DATA_DIR}/stem__nemotron_13.3k.parquet +# Stem (test) +nemotron_test_path=${TEST_DATA_DIR}/stem__nemotron_100.parquet +nemotron_large_test_path=${TEST_DATA_DIR}/stem__nemotron_10.0k.parquet + +gpqa_diamond_test_path=${TEST_DATA_DIR}/stem__gpqa_diamond_198.parquet +supergpqa_test_path=${TEST_DATA_DIR}/stem__supergpqa_1k.parquet + +# IfBench (train) +ifbench_train_path=${TRAIN_DATA_DIR}/ifbench__fixed_85.6k.parquet # There might be bug, wait for fix +# IfBench (test) +ifbench_test_path=${TEST_DATA_DIR}/ifbench_800.parquet +ifbench_large_test_path=${TEST_DATA_DIR}/ifbench_8k.parquet + +# OOD (test) +ifeval_test_path=${TEST_DATA_DIR}/ood__ifeval_100.parquet +livebench_data_analysis_test_path=${TEST_DATA_DIR}/ood__livebench_data_analysis_150.parquet +livebench_language_test_path=${TEST_DATA_DIR}/ood__livebench_language_140.parquet +livebench_reasoning_test_path=${TEST_DATA_DIR}/ood__livebench_reasoning_150.parquet + +train_files="['${math_train1_path}', '${math_train2_path}', '${arcagi1_train_path}', '${arcagi2_train_path}', '${reasoning_gym_train_path}']" # Use math as example, add to more tasks as needed +# test_files="['${math_train1_path}']" +test_files="['${aime25_test_path}', '${amc_test_path}', '${aime_test_path}', '${math_test_path}', '${reasoning_gym_test_path}','${gpqa_diamond_test_path}']" # Use math as example, add to more tasks as needed +# test_files="['${supergpqa_test_path}','${ifeval_test_path}']" # Use math as example, add to more tasks as needed + +# =================== Model =================== +BASE_MODEL=MBZUAI-IFM/TP-ETF-7B +CONDA_BIN_PATH=/lustrefs/users/haonan.li/miniconda3/envs/Reasoning360/bin/ + +# =================== Logging =================== +WANDB_PROJECT=ThinkType +WANDB_EXPERIMENT_NAME=${SLURM_JOB_ID}-${SLURM_JOB_NAME}-${BASE_MODEL##*/} + +# Set default local directory for checkpoints +DEFAULT_LOCAL_DIR="checkpoints/${WANDB_PROJECT}/${WANDB_EXPERIMENT_NAME}" + +# If RESUME_CKPT_DIR is not empty, resume from the checkpoint +if [[ -n "$RESUME_CKPT_DIR_NAME" ]]; then + WANDB_EXPERIMENT_NAME=$(basename "$RESUME_CKPT_DIR_NAME") + DEFAULT_LOCAL_DIR="$RESUME_CKPT_DIR_NAME" +fi + + +# =================== Ray start =================== +# ray stop at all nodes +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 ${CONDA_BIN_PATH}ray stop + +sleep 10 +# Remove existing Ray cluster +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 rm -rf /tmp/ray/ray_current_cluster + +# Start Ray head node +srun --nodes=1 --ntasks=1 -w "$head_node" --export=ALL \ + env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \ + ${CONDA_BIN_PATH}ray start --head --node-ip-address="$head_node_ip" --port=$port \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --include-dashboard=True --block & + +sleep 10 + +# Start Ray worker nodes +for ((i = 1; i < worker_num; i++)); do + node_i=${nodes[$i]} + echo "Starting WORKER $i at $node_i" + srun --nodes=1 --ntasks=1 -w "$node_i" --export=ALL \ + env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \ + ${CONDA_BIN_PATH}ray start --address "$address_head" \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --block & +done +sleep 10 + + +# =================== RL Config =================== +# Note, we borrowed the config format from DAPO while here disabled all DAPO features to run the naive RL baseline. + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=$((1024 * 4)) +max_response_length=$((1024 * 28)) +enable_overlong_buffer=False +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +enable_filter_groups=False +filter_groups_metric=acc +max_num_gen_batches=10 +train_prompt_bsz=512 # on-policy model update batchsize: train_prompt_bsz * rollout.n +gen_prompt_bsz=$((train_prompt_bsz * 1)) +n_resp_per_prompt=16 +train_prompt_mini_bsz=32 # model grad update batchsize + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout + +# Training config +sp_size=2 +gen_tp=4 +gen_max_num_seqs=1024 +infer_micro_batch_size=null +train_micro_batch_size=null +use_dynamic_bsz=True +actor_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up model forward & backward but note memory overflow +infer_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up modelforward, but note memory overflow +offload=True + +# =================== Start RL training =================== +"${CONDA_BIN_PATH}python" -m recipe.dalu.main_dalu \ + --config-path=config \ + --config-name="dapo_fsdp_config.yaml" \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.prompt_key=prompt \ + data.truncation='right' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + data.gen_batch_size=${gen_prompt_bsz} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.actor.strategy="fsdp" \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.optim.warmup_style=constant \ + actor_rollout_ref.actor.optim.min_lr_ratio=0. \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.max_num_seqs=${gen_max_num_seqs} \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${top_p}\ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.model.path=$BASE_MODEL \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.rollout.multi_turn.enable=False \ + actor_rollout_ref.rollout.mode="sync" \ + +actor_rollout_ref.model.override_config.attention_dropout=0. \ + +actor_rollout_ref.model.override_config.embd_pdrop=0. \ + +actor_rollout_ref.model.override_config.resid_pdrop=0. \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + reward_model.reward_manager=async_multi_process \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.logger=['console','wandb'] \ + trainer.project_name=${WANDB_PROJECT} \ + trainer.experiment_name=${WANDB_EXPERIMENT_NAME} \ + trainer.val_before_train=False \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=$worker_num \ + trainer.save_freq=10 \ + trainer.test_freq=10 \ + trainer.total_epochs=10 \ + trainer.log_val_generations=1 \ + trainer.resume_mode=auto \ + trainer.max_actor_ckpt_to_keep=10 \ + trainer.default_local_dir="${DEFAULT_LOCAL_DIR}" \ + +trainer.run_id=${WANDB_ID} \ + +trainer.enable_budget=False \ + +data.dynamic_filtering=False \ + +data.pass_rate_upper_bound=0.9 \ + +data.initial_pass_rate_column=deepseek_r1_0528_pass_rate diff --git a/recipe/dalu/m2_dalu_7b.sh b/recipe/dalu/m2_dalu_7b.sh new file mode 100644 index 000000000..23c310f19 --- /dev/null +++ b/recipe/dalu/m2_dalu_7b.sh @@ -0,0 +1,320 @@ +#!/bin/bash +#SBATCH --job-name=dalu-7b-data_5_3 +#SBATCH --nodes=32 +#SBATCH --ntasks=32 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=96 +#SBATCH --gres=gpu:8 +#SBATCH --output=slurm/%x-%j.out +#SBATCH --error=slurm/%x-%j.err +#SBATCH --exclusive +#SBATCH --time=720:00:00 + + +# =================== Frequently Used Variables =================== +RESUME_CKPT_DIR_NAME="" # Fill in the checkpoint directory name to resume from, otherwise from scratch +export STEM_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-099:8000" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain + +# =================== Cluster Environment =================== +# force IB and pick the rails explicitly +export OMPI_MCA_coll_hcoll_enable=0 \ + CUDA_DEVICE_ORDER=PCI_BUS_ID \ + NCCL_SOCKET_IFNAME=eth0 \ + UCX_TLS=rc \ + UCX_NET_DEVICES=mlx5_ib0:1 \ + NCCL_DEBUG=WARN \ + NCCL_TOPO_FILE=/opt/microsoft/ndv5-topo.xml \ + NCCL_IB_PCI_RELAXED_ORDERING=1 \ + NCCL_IB_QPS_PER_CONNECTION=4 \ + NCCL_IGNORE_CPU_AFFINITY=1 \ + NCCL_P2P_NET_CHUNKSIZE=$((512 * 1024)) \ + NCCL_PXN_DISABLE=1 \ + NCCL_MIN_NCHANNELS=32 \ + SHARP_SMX_UCX_INTERFACE=mlx5_ib0:1 \ + SHARP_COLL_ENABLE_SAT=1 \ + SHARP_COLL_LOG_LEVEL=3 \ + SHARP_COLL_ENABLE_PCI_RELAXED_ORDERING=1 \ + NCCL_COLLNET_ENABLE=1 + + +export TRITON_HOME=/tmp/triton_cache + +# Get the list of allocated nodes +nodes=( $(scontrol show hostnames "$SLURM_JOB_NODELIST") ) +echo "Nodes to check: ${nodes[@]}" + +# We'll track PIDs so we can wait on them and detect errors +declare -A pids +export head_node=${nodes[0]} +head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) +port=6379 +address_head=$head_node_ip:$port + +export worker_num=$SLURM_NNODES +export HYDRA_FULL_ERROR=1 +export VLLM_USE_V1=0 + +# =================== Data Mixture =================== +#TRAIN_DATA_DIR=/mnt/sharefs/users/zhuojun.cheng/guru_data/train/postprocessed_dedup_am +TRAIN_DATA_DIR=/lustrefs/users/haonan.li/data/k2/train_scored_dedup_am_12k_len_rm_flipscore_dpsk_jshape_no0_5_3 +TEST_DATA_DIR=/lustrefs/users/haonan.li/data/k2/test_12k_len +# Math (train) +math_train1_path=${TRAIN_DATA_DIR}/math__combined_118.2k.part1.parquet +math_train2_path=${TRAIN_DATA_DIR}/math__combined_118.2k.part2.parquet +# Math (test) +math_test_path=${TEST_DATA_DIR}/math__math_500.parquet +aime25_test_path=${TEST_DATA_DIR}/math__aime2025_repeated_8x_240.parquet +aime_test_path=${TEST_DATA_DIR}/math__aime_repeated_8x_240.parquet + +# Code (train) +leetcode_train_path=${TRAIN_DATA_DIR}/codegen__deduped_leetcode2k_2.4k.parquet +livecodebench_train_path=${TRAIN_DATA_DIR}/codegen__deduped_livecodebench_599.parquet +primeintellect_train_path=${TRAIN_DATA_DIR}/codegen__deduped_primeintellect_9.6k.parquet +taco_train_path=${TRAIN_DATA_DIR}/codegen__deduped_taco_11.1k.parquet +# Code (test) +humaneval_test_path=${TEST_DATA_DIR}/codegen__humaneval_164.parquet +mbpp_test_path=${TEST_DATA_DIR}/codegen__mbpp_500.parquet +livecodebench_test_path=${TEST_DATA_DIR}/codegen__livecodebench_279.parquet + +# Logic (train) +arcagi1_train_path=${TRAIN_DATA_DIR}/logic__arcagi1_297.parquet +arcagi2_train_path=${TRAIN_DATA_DIR}/logic__arcagi2_653.parquet +barc_train_path=${TRAIN_DATA_DIR}/logic__barc_3.4k.parquet +graph_train_path=${TRAIN_DATA_DIR}/logic__graph_logical_dataset_1.4k.parquet +ordering_train_path=${TRAIN_DATA_DIR}/logic__ordering_puzzle_dataset_2.9k.parquet +zebra_train_path=${TRAIN_DATA_DIR}/logic__zebra_puzzle_dataset_5.0k.parquet +reasoning_gym_train_path=${TRAIN_DATA_DIR}/logic__reasoning_gym_40.6k.parquet +synlogic_train_path=${TRAIN_DATA_DIR}/logic__synlogic_12.1k.parquet + +# Logic (test) +zebralogic_test_path=${TEST_DATA_DIR}/logic__zebra_puzzle_dataset_200.parquet +reasoning_gym_test_path=${TEST_DATA_DIR}/logic__reasoning_gym_430.parquet +reasoning_gym_large_test_path=${TEST_DATA_DIR}/logic__reasoning_gym_4.3k.parquet +synlogic_test_path=${TEST_DATA_DIR}/logic__synlogic_500.parquet +arcagi1_test_path=${TEST_DATA_DIR}/logic__arcagi1_400.parquet + +# Simulation (train) +codeio_train_path=${TRAIN_DATA_DIR}/simulation__codeio_fixed_12.1k.parquet +# Simulation (test) +codeio_test_path=${TEST_DATA_DIR}/simulation__codeio_200.parquet +cruxeval_i_test_path=${TEST_DATA_DIR}/simulation__cruxeval-i_800.parquet +cruxeval_o_test_path=${TEST_DATA_DIR}/simulation__cruxeval-o_800.parquet + +# Table (train) +hitab_train_path=${TRAIN_DATA_DIR}/table__hitab_7.4k.parquet +multihier_train_path=${TRAIN_DATA_DIR}/table__multihier_2.9k.parquet +# Table (test) +multihier_test_path=${TEST_DATA_DIR}/table__multihier_336.parquet +hitab_test_path=${TEST_DATA_DIR}/table__hitab_1k.parquet +finqa_test_path=${TEST_DATA_DIR}/table__finqa_1.1k.parquet + +# Stem (train) +webinstruct_train_path=${TRAIN_DATA_DIR}/stem__web_31.7k.parquet +nemotron_train_path=${TRAIN_DATA_DIR}/stem__nemotron_13.3k.parquet +# Stem (test) +nemotron_test_path=${TEST_DATA_DIR}/stem__nemotron_100.parquet +nemotron_large_test_path=${TEST_DATA_DIR}/stem__nemotron_10.0k.parquet + +gpqa_diamond_test_path=${TEST_DATA_DIR}/stem__gpqa_diamond_198.parquet +supergpqa_test_path=${TEST_DATA_DIR}/stem__supergpqa_1k.parquet + +# IfBench (train) +ifbench_train_path=${TRAIN_DATA_DIR}/ifbench__fixed_85.6k.parquet # There might be bug, wait for fix +# IfBench (test) +ifbench_test_path=${TEST_DATA_DIR}/ifbench_800.parquet +ifbench_large_test_path=${TEST_DATA_DIR}/ifbench_8k.parquet + +# OOD (test) +ifeval_test_path=${TEST_DATA_DIR}/ood__ifeval_100.parquet +livebench_data_analysis_test_path=${TEST_DATA_DIR}/ood__livebench_data_analysis_150.parquet +livebench_language_test_path=${TEST_DATA_DIR}/ood__livebench_language_140.parquet +livebench_reasoning_test_path=${TEST_DATA_DIR}/ood__livebench_reasoning_150.parquet + +train_files="['${math_train1_path}', '${math_train2_path}', '${leetcode_train_path}', '${livecodebench_train_path}', '${primeintellect_train_path}', '${taco_train_path}', '${arcagi1_train_path}', '${arcagi2_train_path}', '${barc_train_path}', '${graph_train_path}', '${ordering_train_path}', '${zebra_train_path}', '${reasoning_gym_train_path}', '${synlogic_train_path}', '${codeio_train_path}', '${hitab_train_path}', '${multihier_train_path}', '${webinstruct_train_path}', '${nemotron_train_path}']" # Use math as example, add to more tasks as needed +# test_files="['${math_train1_path}']" +test_files="['${aime25_test_path}','${livecodebench_test_path}','${zebralogic_test_path}','${reasoning_gym_test_path}','${synlogic_test_path}','${multihier_test_path}','${nemotron_test_path}','${gpqa_diamond_test_path}','${ifeval_test_path}']" # Use math as example, add to more tasks as needed +# test_files="['${supergpqa_test_path}','${ifeval_test_path}']" # Use math as example, add to more tasks as needed + +# =================== Model =================== +BASE_MODEL=deepseek-ai/DeepSeek-R1-Distill-Qwen-7B +CONDA_BIN_PATH=/lustrefs/users/haonan.li/miniconda3/envs/Reasoning360/bin/ + +# =================== Logging =================== +WANDB_PROJECT=DALU +WANDB_EXPERIMENT_NAME=${SLURM_JOB_ID}-${SLURM_JOB_NAME}-${BASE_MODEL##*/} + +# Set default local directory for checkpoints +DEFAULT_LOCAL_DIR="checkpoints/${WANDB_PROJECT}/${WANDB_EXPERIMENT_NAME}" + +# If RESUME_CKPT_DIR is not empty, resume from the checkpoint +if [[ -n "$RESUME_CKPT_DIR_NAME" ]]; then + WANDB_EXPERIMENT_NAME=$(basename "$RESUME_CKPT_DIR_NAME") + DEFAULT_LOCAL_DIR="$RESUME_CKPT_DIR_NAME" +fi + + +# =================== Ray start =================== +# ray stop at all nodes +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 ${CONDA_BIN_PATH}ray stop + +sleep 10 +# Remove existing Ray cluster +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 rm -rf /tmp/ray/ray_current_cluster + +# Start Ray head node +srun --nodes=1 --ntasks=1 -w "$head_node" --export=ALL \ + env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \ + ${CONDA_BIN_PATH}ray start --head --node-ip-address="$head_node_ip" --port=$port \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --include-dashboard=True --block & + +sleep 10 + +# Start Ray worker nodes +for ((i = 1; i < worker_num; i++)); do + node_i=${nodes[$i]} + echo "Starting WORKER $i at $node_i" + srun --nodes=1 --ntasks=1 -w "$node_i" --export=ALL \ + env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \ + ${CONDA_BIN_PATH}ray start --address "$address_head" \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --block & +done +sleep 10 + + +# =================== RL Config =================== +# Note, we borrowed the config format from DAPO while here disabled all DAPO features to run the naive RL baseline. + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=$((1024 * 4)) +max_response_length=$((1024 * 28)) +enable_overlong_buffer=False +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +enable_filter_groups=False +filter_groups_metric=acc +max_num_gen_batches=10 +train_prompt_bsz=512 # on-policy model update batchsize: train_prompt_bsz * rollout.n +gen_prompt_bsz=$((train_prompt_bsz * 1)) +n_resp_per_prompt=16 +train_prompt_mini_bsz=32 # model grad update batchsize + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout + +# Training config +sp_size=4 +gen_tp=4 +gen_max_num_seqs=1024 +infer_micro_batch_size=null +train_micro_batch_size=null +use_dynamic_bsz=True +actor_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up model forward & backward but note memory overflow +infer_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up modelforward, but note memory overflow +offload=True + +# =================== Start RL training =================== +"${CONDA_BIN_PATH}python" -m recipe.dalu.main_dalu \ + --config-path=config \ + --config-name="dapo_fsdp_config.yaml" \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.prompt_key=prompt \ + data.truncation='right' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + data.gen_batch_size=${gen_prompt_bsz} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.actor.strategy="fsdp" \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.optim.warmup_style=constant \ + actor_rollout_ref.actor.optim.min_lr_ratio=0. \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.max_num_seqs=${gen_max_num_seqs} \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${top_p}\ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.model.path=$BASE_MODEL \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.rollout.multi_turn.enable=False \ + actor_rollout_ref.rollout.mode="sync" \ + +actor_rollout_ref.model.override_config.attention_dropout=0. \ + +actor_rollout_ref.model.override_config.embd_pdrop=0. \ + +actor_rollout_ref.model.override_config.resid_pdrop=0. \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + reward_model.reward_manager=async_multi_process \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.logger=['console','wandb'] \ + trainer.project_name=${WANDB_PROJECT} \ + trainer.experiment_name=${WANDB_EXPERIMENT_NAME} \ + trainer.val_before_train=True \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=$worker_num \ + trainer.save_freq=10 \ + trainer.test_freq=10 \ + trainer.total_epochs=5 \ + trainer.log_val_generations=1 \ + trainer.resume_mode=auto \ + trainer.max_actor_ckpt_to_keep=2 \ + trainer.default_local_dir="${DEFAULT_LOCAL_DIR}" \ + +trainer.enable_budget=True \ + +data.dynamic_filtering=True \ + +data.pass_rate_upper_bound=1 \ + +data.initial_pass_rate_column=deepseek_r1_0528_pass_rate diff --git a/recipe/dalu/main_dalu.py b/recipe/dalu/main_dalu.py new file mode 100644 index 000000000..e75796a3a --- /dev/null +++ b/recipe/dalu/main_dalu.py @@ -0,0 +1,172 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Note that we don't combine the main with ray_trainer as ray_trainer is used by other main. +""" + +import os +import socket + +import hydra +import ray +from omegaconf import OmegaConf + +from verl.trainer.ppo.reward import load_reward_manager +from verl.utils.device import is_cuda_available + +from .dalu_ray_trainer import RayDALUTrainer +import sys +sys.set_int_max_str_digits(10000) + +@hydra.main(config_path="config", config_name="dalu_trainer", version_base=None) +def main(config): + run_ppo(config) + + +def run_ppo(config) -> None: + if not ray.is_initialized(): + # this is for local ray cluster + ray.init( + runtime_env={ + "env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN", "VLLM_LOGGING_LEVEL": "WARN"} + }, + num_cpus=config.ray_init.num_cpus, + ) + + if ( + is_cuda_available + and OmegaConf.select(config.trainer, "profile_steps") is not None + and len(OmegaConf.select(config.trainer, "profile_steps")) > 0 + ): + nsight_options = OmegaConf.to_container(config.trainer.controller_nsight_options) + runner = TaskRunner.options(runtime_env={"nsight": nsight_options}).remote() + else: + runner = TaskRunner.remote() + ray.get(runner.run.remote(config)) + + +@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head +class TaskRunner: + def run(self, config): + # print initial config + from pprint import pprint + + from omegaconf import OmegaConf + + from verl.utils.fs import copy_to_local + + print(f"TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}") + + pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values + OmegaConf.resolve(config) + + # download the checkpoint from hdfs + local_path = copy_to_local(config.actor_rollout_ref.model.path) + + # instantiate tokenizer + from verl.utils import hf_processor, hf_tokenizer + + tokenizer = hf_tokenizer(local_path) + processor = hf_processor(local_path, use_fast=True) # used for multimodal LLM, could be none + + # define worker classes + if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}: + assert config.critic.strategy in {"fsdp", "fsdp2"} + from verl.single_controller.ray import RayWorkerGroup + from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker + + ray_worker_group_cls = RayWorkerGroup + + elif config.actor_rollout_ref.actor.strategy == "megatron": + assert config.actor_rollout_ref.actor.strategy == config.critic.strategy + from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup + from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker + + ray_worker_group_cls = NVMegatronRayWorkerGroup + + else: + raise NotImplementedError + + from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role + + role_worker_mapping = { + Role.ActorRollout: ray.remote(ActorRolloutRefWorker), + Role.Critic: ray.remote(CriticWorker), + } + + global_pool_id = "global_pool" + resource_pool_spec = { + global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, + } + mapping = { + Role.ActorRollout: global_pool_id, + Role.Critic: global_pool_id, + } + + # we should adopt a multi-source reward function here + # - for rule-based rm, we directly call a reward score + # - for model-based rm, we call a model + # - for code related prompt, we send to a sandbox if there are test cases + # - finally, we combine all the rewards together + # - The reward type depends on the tag of the data + if config.reward_model.enable: + if config.reward_model.strategy in {"fsdp", "fsdp2"}: + from verl.workers.fsdp_workers import RewardModelWorker + elif config.reward_model.strategy == "megatron": + from verl.workers.megatron_workers import RewardModelWorker + else: + raise NotImplementedError + role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker) + mapping[Role.RewardModel] = global_pool_id + + # reference model + if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss: + role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker) + mapping[Role.RefPolicy] = global_pool_id + + reward_fn = load_reward_manager( + config, + tokenizer, + 0, + max_resp_len=config.data.max_response_length, + overlong_buffer_cfg=config.reward_model.overlong_buffer, + ) + + # Note that we always use function-based RM for validation + val_reward_fn = load_reward_manager( + config, + tokenizer, + 1, + max_resp_len=config.data.max_response_length, + overlong_buffer_cfg=config.reward_model.overlong_buffer, + ) + resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) + + trainer = RayDALUTrainer( + config=config, + tokenizer=tokenizer, + processor=processor, + role_worker_mapping=role_worker_mapping, + resource_pool_manager=resource_pool_manager, + ray_worker_group_cls=ray_worker_group_cls, + reward_fn=reward_fn, + val_reward_fn=val_reward_fn, + device_name=config.trainer.device, + ) + trainer.init_workers() + trainer.fit() + + +if __name__ == "__main__": + main() diff --git a/recipe/dalu/run_dalu_1.5b.sh b/recipe/dalu/run_dalu_1.5b.sh new file mode 100644 index 000000000..8171165c0 --- /dev/null +++ b/recipe/dalu/run_dalu_1.5b.sh @@ -0,0 +1,305 @@ +#!/bin/bash +#SBATCH --job-name=dalu-1.5b +#SBATCH --nodes=4 +#SBATCH --ntasks=4 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:8 +#SBATCH --cpus-per-task=128 +#SBATCH --mem=0 +#SBATCH --output=slurm/%x-%j.out +#SBATCH --error=slurm/%x-%j.err +#SBATCH --exclusive +#SBATCH --time=720:00:00 +#SBATCH --account=iq + + +# =================== Frequently Used Variables =================== +RESUME_CKPT_DIR_NAME="" # Fill in the checkpoint directory name to resume from, otherwise from scratch +export STEM_LLM_JUDGE_URL="http://10.24.1.157:8000" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain + +# =================== Cluster Environment =================== +export NCCL_DEBUG=info +export NCCL_ALGO=NVLSTree +export NCCL_IBEXT_DISABLE=1 +export NCCL_NVLS_ENABLE=1 +export NCCL_IB_HCA=mlx5 +export UCX_NET_DEVICES=mlx5_0:1,mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1 +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export CUDA_LAUNCH_BLOCKING=1 +export TRITON_HOME=/tmp/triton_cache + +# Get the list of allocated nodes +nodes=( $(scontrol show hostnames "$SLURM_JOB_NODELIST") ) +echo "Nodes to check: ${nodes[@]}" + +# We'll track PIDs so we can wait on them and detect errors +declare -A pids +export head_node=${nodes[0]} +head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) +port=6379 +address_head=$head_node_ip:$port + +export worker_num=$SLURM_NNODES +export HYDRA_FULL_ERROR=1 +export VLLM_USE_V1=0 + +# =================== Data Mixture =================== +#TRAIN_DATA_DIR=/mnt/sharefs/users/zhuojun.cheng/guru_data/train/postprocessed_dedup_am +TRAIN_DATA_DIR=/mnt/sharefs/users/haonan.li/data/k2/train_scored_dedup_am_12k_len +TEST_DATA_DIR=/mnt/sharefs/users/haonan.li/data/k2/test_12k_len +# Math (train) +math_train1_path=${TRAIN_DATA_DIR}/math__combined_118.2k.part1.parquet +math_train2_path=${TRAIN_DATA_DIR}/math__combined_118.2k.part2.parquet +# Math (test) +math_test_path=${TEST_DATA_DIR}/math__math_500.parquet +aime25_test_path=${TEST_DATA_DIR}/math__aime2025_repeated_8x_240.parquet +aime_test_path=${TEST_DATA_DIR}/math__aime_repeated_8x_240.parquet + +# Code (train) +leetcode_train_path=${TRAIN_DATA_DIR}/codegen__deduped_leetcode2k_2.4k.parquet +livecodebench_train_path=${TRAIN_DATA_DIR}/codegen__deduped_livecodebench_599.parquet +primeintellect_train_path=${TRAIN_DATA_DIR}/codegen__deduped_primeintellect_9.6k.parquet +taco_train_path=${TRAIN_DATA_DIR}/codegen__deduped_taco_11.1k.parquet +# Code (test) +humaneval_test_path=${TEST_DATA_DIR}/codegen__humaneval_164.parquet +mbpp_test_path=${TEST_DATA_DIR}/codegen__mbpp_500.parquet +livecodebench_test_path=${TEST_DATA_DIR}/codegen__livecodebench_279.parquet + +# Logic (train) +arcagi1_train_path=${TRAIN_DATA_DIR}/logic__arcagi1_297.parquet +arcagi2_train_path=${TRAIN_DATA_DIR}/logic__arcagi2_653.parquet +barc_train_path=${TRAIN_DATA_DIR}/logic__barc_3.4k.parquet +graph_train_path=${TRAIN_DATA_DIR}/logic__graph_logical_dataset_1.4k.parquet +ordering_train_path=${TRAIN_DATA_DIR}/logic__ordering_puzzle_dataset_2.9k.parquet +zebra_train_path=${TRAIN_DATA_DIR}/logic__zebra_puzzle_dataset_5.0k.parquet +reasoning_gym_train_path=${TRAIN_DATA_DIR}/logic__reasoning_gym_40.6k.parquet +synlogic_train_path=${TRAIN_DATA_DIR}/logic__synlogic_12.1k.parquet + +# Logic (test) +zebralogic_test_path=${TEST_DATA_DIR}/logic__zebra_puzzle_dataset_200.parquet +reasoning_gym_test_path=${TEST_DATA_DIR}/logic__reasoning_gym_425.parquet +reasoning_gym_large_test_path=${TEST_DATA_DIR}/logic__reasoning_gym_4.3k.parquet +synlogic_test_path=${TEST_DATA_DIR}/logic__synlogic_217.parquet +arcagi1_test_path=${TEST_DATA_DIR}/logic__arcagi1_400.parquet + +# Simulation (train) +codeio_train_path=${TRAIN_DATA_DIR}/simulation__codeio_fixed_12.1k.parquet +# Simulation (test) +codeio_test_path=${TEST_DATA_DIR}/simulation__codeio_200.parquet +cruxeval_i_test_path=${TEST_DATA_DIR}/simulation__cruxeval-i_800.parquet +cruxeval_o_test_path=${TEST_DATA_DIR}/simulation__cruxeval-o_800.parquet + +# Table (train) +hitab_train_path=${TRAIN_DATA_DIR}/table__hitab_7.4k.parquet +multihier_train_path=${TRAIN_DATA_DIR}/table__multihier_2.9k.parquet +# Table (test) +multihier_test_path=${TEST_DATA_DIR}/table__multihier_336.parquet +hitab_test_path=${TEST_DATA_DIR}/table__hitab_1k.parquet +finqa_test_path=${TEST_DATA_DIR}/table__finqa_1.1k.parquet + +# Stem (train) +webinstruct_train_path=${TRAIN_DATA_DIR}/stem__web_31.7k.parquet +nemotron_train_path=${TRAIN_DATA_DIR}/stem__nemotron_13.3k.parquet +# Stem (test) +nemotron_test_path=${TEST_DATA_DIR}/stem__nemotron_1000.parquet +nemotron_large_test_path=${TEST_DATA_DIR}/stem__nemotron_10.0k.parquet + +gpqa_diamond_test_path=${TEST_DATA_DIR}/stem__gpqa_diamond_198.parquet +supergpqa_test_path=${TEST_DATA_DIR}/stem__supergpqa_1k.parquet + +# IfBench (train) +ifbench_train_path=${TRAIN_DATA_DIR}/ifbench__fixed_85.6k.parquet +# IfBench (test) +ifbench_test_path=${TEST_DATA_DIR}/ifbench_800.parquet +ifbench_large_test_path=${TEST_DATA_DIR}/ifbench_8k.parquet + +# OOD (test) +ifeval_test_path=${TEST_DATA_DIR}/ood__ifeval_541.parquet +livebench_data_analysis_test_path=${TEST_DATA_DIR}/ood__livebench_data_analysis_150.parquet +livebench_language_test_path=${TEST_DATA_DIR}/ood__livebench_language_140.parquet +livebench_reasoning_test_path=${TEST_DATA_DIR}/ood__livebench_reasoning_150.parquet + +train_files="['${math_train1_path}', '${math_train2_path}', '${leetcode_train_path}', '${livecodebench_train_path}', '${primeintellect_train_path}', '${taco_train_path}', '${arcagi1_train_path}', '${arcagi2_train_path}', '${barc_train_path}', '${graph_train_path}', '${ordering_train_path}', '${zebra_train_path}', '${reasoning_gym_train_path}', '${synlogic_train_path}', '${codeio_train_path}', '${hitab_train_path}', '${multihier_train_path}', '${webinstruct_train_path}', '${nemotron_train_path}', '${ifbench_train_path}']" # Use math as example, add to more tasks as needed +test_files="['${aime25_test_path}','${livecodebench_test_path}','${zebralogic_test_path}','${reasoning_gym_test_path}','${synlogic_test_path}','${codeio_test_path}','${multihier_test_path}','${nemotron_test_path}','${supergpqa_test_path}','${ifeval_test_path}']" # Use math as example, add to more tasks as needed + +# =================== Model =================== +BASE_MODEL=deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B +CONDA_BIN_PATH=/mnt/weka/home/haonan.li/miniconda3/envs/Reasoning360/bin/ + +# =================== Logging =================== +WANDB_PROJECT=DALU +WANDB_EXPERIMENT_NAME=${SLURM_JOB_ID}-${SLURM_JOB_NAME}-${BASE_MODEL##*/} + +# Set default local directory for checkpoints +DEFAULT_LOCAL_DIR="checkpoints/${WANDB_PROJECT}/${WANDB_EXPERIMENT_NAME}" + +# If RESUME_CKPT_DIR is not empty, resume from the checkpoint +if [[ -n "$RESUME_CKPT_DIR_NAME" ]]; then + WANDB_EXPERIMENT_NAME=$(basename "$RESUME_CKPT_DIR_NAME") + DEFAULT_LOCAL_DIR="$RESUME_CKPT_DIR_NAME" +fi + + +# =================== Ray start =================== +# ray stop at all nodes +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 ${CONDA_BIN_PATH}ray stop + +sleep 10 +# Remove existing Ray cluster +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 rm -rf /tmp/ray/ray_current_cluster + +# Start Ray head node +srun --nodes=1 --ntasks=1 -w "$head_node" --export=ALL \ + ${CONDA_BIN_PATH}ray start --head --node-ip-address="$head_node_ip" --port=$port \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --include-dashboard=True --block & + +sleep 10 + +# Start Ray worker nodes +for ((i = 1; i < worker_num; i++)); do + node_i=${nodes[$i]} + echo "Starting WORKER $i at $node_i" + srun --nodes=1 --ntasks=1 -w "$node_i" --export=ALL \ + ${CONDA_BIN_PATH}ray start --address "$address_head" \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --block & +done +sleep 10 + + +# =================== RL Config =================== +# Note, we borrowed the config format from DAPO while here disabled all DAPO features to run the naive RL baseline. + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.2 + +max_prompt_length=$((1024 * 4)) +max_response_length=$((1024 * 16)) +enable_overlong_buffer=False +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +enable_filter_groups=False +filter_groups_metric=acc +max_num_gen_batches=10 +train_prompt_bsz=256 # on-policy model update batchsize: train_prompt_bsz * rollout.n +gen_prompt_bsz=$((train_prompt_bsz * 1)) +n_resp_per_prompt=8 +train_prompt_mini_bsz=32 # model grad update batchsize + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout + +# Training config +sp_size=1 +gen_tp=1 +gen_max_num_seqs=1024 +infer_micro_batch_size=null +train_micro_batch_size=null +use_dynamic_bsz=True +actor_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up model forward & backward but note memory overflow +infer_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up modelforward, but note memory overflow +offload=True + +# =================== Start RL training =================== +"${CONDA_BIN_PATH}python" -m recipe.dalu.main_dalu \ + --config-path=config \ + --config-name="dapo_fsdp_config.yaml" \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.prompt_key=prompt \ + data.truncation='right' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + data.gen_batch_size=${gen_prompt_bsz} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.actor.strategy="fsdp" \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.optim.warmup_style=constant \ + actor_rollout_ref.actor.optim.min_lr_ratio=0. \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.max_num_seqs=${gen_max_num_seqs} \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${top_p}\ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.model.path=$BASE_MODEL \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.rollout.multi_turn.enable=False \ + actor_rollout_ref.rollout.mode="sync" \ + +actor_rollout_ref.model.override_config.attention_dropout=0. \ + +actor_rollout_ref.model.override_config.embd_pdrop=0. \ + +actor_rollout_ref.model.override_config.resid_pdrop=0. \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + reward_model.reward_manager=async_multi_process \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.logger=['console','wandb'] \ + trainer.project_name=${WANDB_PROJECT} \ + trainer.experiment_name=${WANDB_EXPERIMENT_NAME} \ + trainer.val_before_train=True \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=$worker_num \ + trainer.save_freq=10 \ + trainer.test_freq=10 \ + trainer.total_epochs=5 \ + trainer.log_val_generations=1 \ + trainer.resume_mode=auto \ + trainer.max_actor_ckpt_to_keep=2 \ + trainer.default_local_dir="${DEFAULT_LOCAL_DIR}" \ + +trainer.enable_budget=True \ + +data.dynamic_filtering=True \ + +data.pass_rate_upper_bound=1 \ + +data.initial_pass_rate_column=deepseek_r1_0528_pass_rate \ No newline at end of file diff --git a/recipe/dalu/run_dalu_32b.sh b/recipe/dalu/run_dalu_32b.sh new file mode 100644 index 000000000..103cdbf5f --- /dev/null +++ b/recipe/dalu/run_dalu_32b.sh @@ -0,0 +1,323 @@ +#!/bin/bash +#SBATCH --job-name=dalu-32b-data_5_3 +#SBATCH --nodes=16 +#SBATCH --ntasks=16 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:8 +#SBATCH --cpus-per-task=128 +#SBATCH --mem=0 +#SBATCH --output=slurm/%x-%j.out +#SBATCH --error=slurm/%x-%j.err +#SBATCH --exclusive +#SBATCH --time=720:00:00 +#SBATCH --account=iq + + +# =================== Frequently Used Variables =================== +RESUME_CKPT_DIR_NAME="" # Fill in the checkpoint directory name to resume from, otherwise from scratch +export STEM_LLM_JUDGE_URL="http://10.24.1.157:8000" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain + +# =================== Cluster Environment =================== +# force IB and pick the rails explicitly +export NCCL_NET=IB +export NCCL_IB_HCA="mlx5_0,mlx5_1,mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7" + +# stripe across rails +export NCCL_CROSS_NIC=1 + +# NVLink/NVSwitch offload on HGX/DGX nodes +export NCCL_NVLS_ENABLE=1 + +# optional stability knob (turn on if you’ve seen IBEXT problems) +export NCCL_IBEXT_DISABLE=1 + +# optional: select an IB Service Level if your fabric/QoS defines one (0–15) +# leave unset if your fabric doesn’t use SL policies +# export NCCL_IB_SL=0 + +# misc quality-of-life/perf +export NCCL_DEBUG=warn +export NCCL_SOCKET_IFNAME="^lo,docker,virbr" +export CUDA_DEVICE_MAX_CONNECTIONS=8 + + + +export TRITON_HOME=/tmp/triton_cache + +# Get the list of allocated nodes +nodes=( $(scontrol show hostnames "$SLURM_JOB_NODELIST") ) +echo "Nodes to check: ${nodes[@]}" + +# We'll track PIDs so we can wait on them and detect errors +declare -A pids +export head_node=${nodes[0]} +head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) +port=6379 +address_head=$head_node_ip:$port + +export worker_num=$SLURM_NNODES +export HYDRA_FULL_ERROR=1 +export VLLM_USE_V1=0 + +# =================== Data Mixture =================== +#TRAIN_DATA_DIR=/mnt/sharefs/users/zhuojun.cheng/guru_data/train/postprocessed_dedup_am +TRAIN_DATA_DIR=/mnt/sharefs/users/haonan.li/data/k2/backup/train_scored_dedup_am_12k_len_rm_flipscore_dpsk_jshape_no0_5_3 +TEST_DATA_DIR=/mnt/sharefs/users/haonan.li/data/k2/test_12k_len +# Math (train) +math_train1_path=${TRAIN_DATA_DIR}/math__combined_118.2k.part1.parquet +math_train2_path=${TRAIN_DATA_DIR}/math__combined_118.2k.part2.parquet +# Math (test) +math_test_path=${TEST_DATA_DIR}/math__math_500.parquet +aime25_test_path=${TEST_DATA_DIR}/math__aime2025_repeated_8x_240.parquet +aime_test_path=${TEST_DATA_DIR}/math__aime_repeated_8x_240.parquet + +# Code (train) +leetcode_train_path=${TRAIN_DATA_DIR}/codegen__deduped_leetcode2k_2.4k.parquet +livecodebench_train_path=${TRAIN_DATA_DIR}/codegen__deduped_livecodebench_599.parquet +primeintellect_train_path=${TRAIN_DATA_DIR}/codegen__deduped_primeintellect_9.6k.parquet +taco_train_path=${TRAIN_DATA_DIR}/codegen__deduped_taco_11.1k.parquet +# Code (test) +humaneval_test_path=${TEST_DATA_DIR}/codegen__humaneval_164.parquet +mbpp_test_path=${TEST_DATA_DIR}/codegen__mbpp_500.parquet +livecodebench_test_path=${TEST_DATA_DIR}/codegen__livecodebench_279.parquet + +# Logic (train) +arcagi1_train_path=${TRAIN_DATA_DIR}/logic__arcagi1_297.parquet +arcagi2_train_path=${TRAIN_DATA_DIR}/logic__arcagi2_653.parquet +barc_train_path=${TRAIN_DATA_DIR}/logic__barc_3.4k.parquet +graph_train_path=${TRAIN_DATA_DIR}/logic__graph_logical_dataset_1.4k.parquet +ordering_train_path=${TRAIN_DATA_DIR}/logic__ordering_puzzle_dataset_2.9k.parquet +zebra_train_path=${TRAIN_DATA_DIR}/logic__zebra_puzzle_dataset_5.0k.parquet +reasoning_gym_train_path=${TRAIN_DATA_DIR}/logic__reasoning_gym_40.6k.parquet +synlogic_train_path=${TRAIN_DATA_DIR}/logic__synlogic_12.1k.parquet + +# Logic (test) +zebralogic_test_path=${TEST_DATA_DIR}/logic__zebra_puzzle_dataset_200.parquet +reasoning_gym_test_path=${TEST_DATA_DIR}/logic__reasoning_gym_425.parquet +reasoning_gym_large_test_path=${TEST_DATA_DIR}/logic__reasoning_gym_4.3k.parquet +synlogic_test_path=${TEST_DATA_DIR}/logic__synlogic_217.parquet +arcagi1_test_path=${TEST_DATA_DIR}/logic__arcagi1_400.parquet + +# Simulation (train) +codeio_train_path=${TRAIN_DATA_DIR}/simulation__codeio_fixed_12.1k.parquet +# Simulation (test) +codeio_test_path=${TEST_DATA_DIR}/simulation__codeio_200.parquet +cruxeval_i_test_path=${TEST_DATA_DIR}/simulation__cruxeval-i_800.parquet +cruxeval_o_test_path=${TEST_DATA_DIR}/simulation__cruxeval-o_800.parquet + +# Table (train) +hitab_train_path=${TRAIN_DATA_DIR}/table__hitab_7.4k.parquet +multihier_train_path=${TRAIN_DATA_DIR}/table__multihier_2.9k.parquet +# Table (test) +multihier_test_path=${TEST_DATA_DIR}/table__multihier_336.parquet +hitab_test_path=${TEST_DATA_DIR}/table__hitab_1k.parquet +finqa_test_path=${TEST_DATA_DIR}/table__finqa_1.1k.parquet + +# Stem (train) +webinstruct_train_path=${TRAIN_DATA_DIR}/stem__web_31.7k.parquet +nemotron_train_path=${TRAIN_DATA_DIR}/stem__nemotron_13.3k.parquet +# Stem (test) +nemotron_test_path=${TEST_DATA_DIR}/stem__nemotron_100.parquet +nemotron_large_test_path=${TEST_DATA_DIR}/stem__nemotron_10.0k.parquet + +gpqa_diamond_test_path=${TEST_DATA_DIR}/stem__gpqa_diamond_198.parquet +supergpqa_test_path=${TEST_DATA_DIR}/stem__supergpqa_1k.parquet + +# IfBench (train) +ifbench_train_path=${TRAIN_DATA_DIR}/ifbench__fixed_85.6k.parquet # There might be bug, wait for fix +# IfBench (test) +ifbench_test_path=${TEST_DATA_DIR}/ifbench_800.parquet +ifbench_large_test_path=${TEST_DATA_DIR}/ifbench_8k.parquet + +# OOD (test) +ifeval_test_path=${TEST_DATA_DIR}/ood__ifeval_100.parquet +livebench_data_analysis_test_path=${TEST_DATA_DIR}/ood__livebench_data_analysis_150.parquet +livebench_language_test_path=${TEST_DATA_DIR}/ood__livebench_language_140.parquet +livebench_reasoning_test_path=${TEST_DATA_DIR}/ood__livebench_reasoning_150.parquet + +train_files="['${math_train1_path}', '${math_train2_path}', '${leetcode_train_path}', '${livecodebench_train_path}', '${primeintellect_train_path}', '${taco_train_path}', '${arcagi1_train_path}', '${arcagi2_train_path}', '${barc_train_path}', '${graph_train_path}', '${ordering_train_path}', '${zebra_train_path}', '${reasoning_gym_train_path}', '${synlogic_train_path}', '${codeio_train_path}', '${hitab_train_path}', '${multihier_train_path}', '${webinstruct_train_path}', '${nemotron_train_path}']" # Use math as example, add to more tasks as needed +# test_files="['${math_train1_path}']" +test_files="['${aime25_test_path}','${livecodebench_test_path}','${zebralogic_test_path}','${reasoning_gym_test_path}','${synlogic_test_path}','${multihier_test_path}','${nemotron_test_path}','${gpqa_diamond_test_path}','${ifeval_test_path}']" # Use math as example, add to more tasks as needed +# test_files="['${supergpqa_test_path}','${ifeval_test_path}']" # Use math as example, add to more tasks as needed + +# =================== Model =================== +BASE_MODEL=/mnt/sharefs/users/haonan.li/models/32b-am-1084 +CONDA_BIN_PATH=/mnt/weka/home/haonan.li/miniconda3/envs/Reasoning360/bin/ + +# =================== Logging =================== +WANDB_PROJECT=DALU +WANDB_EXPERIMENT_NAME=${SLURM_JOB_ID}-${SLURM_JOB_NAME}-${BASE_MODEL##*/} + +# Set default local directory for checkpoints +DEFAULT_LOCAL_DIR="checkpoints/${WANDB_PROJECT}/${WANDB_EXPERIMENT_NAME}" + +# If RESUME_CKPT_DIR is not empty, resume from the checkpoint +if [[ -n "$RESUME_CKPT_DIR_NAME" ]]; then + WANDB_EXPERIMENT_NAME=$(basename "$RESUME_CKPT_DIR_NAME") + DEFAULT_LOCAL_DIR="$RESUME_CKPT_DIR_NAME" +fi + + +# =================== Ray start =================== +# ray stop at all nodes +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 ${CONDA_BIN_PATH}ray stop + +sleep 10 +# Remove existing Ray cluster +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 rm -rf /tmp/ray/ray_current_cluster + +# Start Ray head node +srun --nodes=1 --ntasks=1 -w "$head_node" --export=ALL \ + ${CONDA_BIN_PATH}ray start --head --node-ip-address="$head_node_ip" --port=$port \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --include-dashboard=True --block & + +sleep 10 + +# Start Ray worker nodes +for ((i = 1; i < worker_num; i++)); do + node_i=${nodes[$i]} + echo "Starting WORKER $i at $node_i" + srun --nodes=1 --ntasks=1 -w "$node_i" --export=ALL \ + ${CONDA_BIN_PATH}ray start --address "$address_head" \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --block & +done +sleep 10 + + +# =================== RL Config =================== +# Note, we borrowed the config format from DAPO while here disabled all DAPO features to run the naive RL baseline. + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.4 + +max_prompt_length=$((1024 * 4)) +max_response_length=$((1024 * 32)) +enable_overlong_buffer=False +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +enable_filter_groups=False +filter_groups_metric=acc +max_num_gen_batches=10 +train_prompt_bsz=512 # on-policy model update batchsize: train_prompt_bsz * rollout.n +gen_prompt_bsz=$((train_prompt_bsz * 1)) +n_resp_per_prompt=16 +train_prompt_mini_bsz=32 # model grad update batchsize + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout + +# Training config +sp_size=1 +gen_tp=8 +gen_max_num_seqs=512 +infer_micro_batch_size=null +train_micro_batch_size=null +use_dynamic_bsz=True +actor_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up model forward & backward but note memory overflow +infer_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up modelforward, but note memory overflow +offload=True + +# =================== Start RL training =================== +"${CONDA_BIN_PATH}python" -m recipe.dalu.main_dalu \ + --config-path=config \ + --config-name="dapo_fsdp_config.yaml" \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.prompt_key=prompt \ + data.truncation='right' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + data.gen_batch_size=${gen_prompt_bsz} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.actor.strategy="fsdp" \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.optim.warmup_style=constant \ + actor_rollout_ref.actor.optim.min_lr_ratio=0. \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.max_num_seqs=${gen_max_num_seqs} \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${top_p}\ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.model.path=$BASE_MODEL \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.rollout.multi_turn.enable=False \ + actor_rollout_ref.rollout.mode="sync" \ + +actor_rollout_ref.model.override_config.attention_dropout=0. \ + +actor_rollout_ref.model.override_config.embd_pdrop=0. \ + +actor_rollout_ref.model.override_config.resid_pdrop=0. \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + reward_model.reward_manager=async_multi_process \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.logger=['console','wandb'] \ + trainer.project_name=${WANDB_PROJECT} \ + trainer.experiment_name=${WANDB_EXPERIMENT_NAME} \ + trainer.val_before_train=True \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=$worker_num \ + trainer.save_freq=20 \ + trainer.test_freq=20 \ + trainer.total_epochs=5 \ + trainer.log_val_generations=1 \ + trainer.resume_mode=auto \ + trainer.max_actor_ckpt_to_keep=2 \ + trainer.default_local_dir="${DEFAULT_LOCAL_DIR}" \ + +trainer.enable_budget=True \ + +data.dynamic_filtering=True \ + +data.pass_rate_upper_bound=1 \ + +data.initial_pass_rate_column=deepseek_r1_0528_pass_rate \ No newline at end of file diff --git a/recipe/dalu/run_dalu_7b.sh b/recipe/dalu/run_dalu_7b.sh new file mode 100644 index 000000000..0f84751e1 --- /dev/null +++ b/recipe/dalu/run_dalu_7b.sh @@ -0,0 +1,305 @@ +#!/bin/bash +#SBATCH --job-name=dalu-7b +#SBATCH --nodes=4 +#SBATCH --ntasks=4 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:8 +#SBATCH --cpus-per-task=128 +#SBATCH --mem=0 +#SBATCH --output=slurm/%x-%j.out +#SBATCH --error=slurm/%x-%j.err +#SBATCH --exclusive +#SBATCH --time=720:00:00 +#SBATCH --account=iq + + +# =================== Frequently Used Variables =================== +RESUME_CKPT_DIR_NAME="" # Fill in the checkpoint directory name to resume from, otherwise from scratch +export STEM_LLM_JUDGE_URL="http://10.24.1.157:8000" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain + +# =================== Cluster Environment =================== +export NCCL_DEBUG=info +export NCCL_ALGO=NVLSTree +export NCCL_IBEXT_DISABLE=1 +export NCCL_NVLS_ENABLE=1 +export NCCL_IB_HCA=mlx5 +export UCX_NET_DEVICES=mlx5_0:1,mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1 +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export CUDA_LAUNCH_BLOCKING=1 +export TRITON_HOME=/tmp/triton_cache + +# Get the list of allocated nodes +nodes=( $(scontrol show hostnames "$SLURM_JOB_NODELIST") ) +echo "Nodes to check: ${nodes[@]}" + +# We'll track PIDs so we can wait on them and detect errors +declare -A pids +export head_node=${nodes[0]} +head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) +port=6379 +address_head=$head_node_ip:$port + +export worker_num=$SLURM_NNODES +export HYDRA_FULL_ERROR=1 +export VLLM_USE_V1=0 + +# =================== Data Mixture =================== +#TRAIN_DATA_DIR=/mnt/sharefs/users/zhuojun.cheng/guru_data/train/postprocessed_dedup_am +TRAIN_DATA_DIR=/mnt/sharefs/users/haonan.li/data/k2/train_scored_dedup_am_12k_len +TEST_DATA_DIR=/mnt/sharefs/users/haonan.li/data/k2/test_12k_len +# Math (train) +math_train1_path=${TRAIN_DATA_DIR}/math__combined_118.2k.part1.parquet +math_train2_path=${TRAIN_DATA_DIR}/math__combined_118.2k.part2.parquet +# Math (test) +math_test_path=${TEST_DATA_DIR}/math__math_500.parquet +aime25_test_path=${TEST_DATA_DIR}/math__aime2025_repeated_8x_240.parquet +aime_test_path=${TEST_DATA_DIR}/math__aime_repeated_8x_240.parquet + +# Code (train) +leetcode_train_path=${TRAIN_DATA_DIR}/codegen__deduped_leetcode2k_2.4k.parquet +livecodebench_train_path=${TRAIN_DATA_DIR}/codegen__deduped_livecodebench_599.parquet +primeintellect_train_path=${TRAIN_DATA_DIR}/codegen__deduped_primeintellect_9.6k.parquet +taco_train_path=${TRAIN_DATA_DIR}/codegen__deduped_taco_11.1k.parquet +# Code (test) +humaneval_test_path=${TEST_DATA_DIR}/codegen__humaneval_164.parquet +mbpp_test_path=${TEST_DATA_DIR}/codegen__mbpp_500.parquet +livecodebench_test_path=${TEST_DATA_DIR}/codegen__livecodebench_279.parquet + +# Logic (train) +arcagi1_train_path=${TRAIN_DATA_DIR}/logic__arcagi1_297.parquet +arcagi2_train_path=${TRAIN_DATA_DIR}/logic__arcagi2_653.parquet +barc_train_path=${TRAIN_DATA_DIR}/logic__barc_3.4k.parquet +graph_train_path=${TRAIN_DATA_DIR}/logic__graph_logical_dataset_1.4k.parquet +ordering_train_path=${TRAIN_DATA_DIR}/logic__ordering_puzzle_dataset_2.9k.parquet +zebra_train_path=${TRAIN_DATA_DIR}/logic__zebra_puzzle_dataset_5.0k.parquet +reasoning_gym_train_path=${TRAIN_DATA_DIR}/logic__reasoning_gym_40.6k.parquet +synlogic_train_path=${TRAIN_DATA_DIR}/logic__synlogic_12.1k.parquet + +# Logic (test) +zebralogic_test_path=${TEST_DATA_DIR}/logic__zebra_puzzle_dataset_200.parquet +reasoning_gym_test_path=${TEST_DATA_DIR}/logic__reasoning_gym_425.parquet +reasoning_gym_large_test_path=${TEST_DATA_DIR}/logic__reasoning_gym_4.3k.parquet +synlogic_test_path=${TEST_DATA_DIR}/logic__synlogic_217.parquet +arcagi1_test_path=${TEST_DATA_DIR}/logic__arcagi1_400.parquet + +# Simulation (train) +codeio_train_path=${TRAIN_DATA_DIR}/simulation__codeio_fixed_12.1k.parquet +# Simulation (test) +codeio_test_path=${TEST_DATA_DIR}/simulation__codeio_200.parquet +cruxeval_i_test_path=${TEST_DATA_DIR}/simulation__cruxeval-i_800.parquet +cruxeval_o_test_path=${TEST_DATA_DIR}/simulation__cruxeval-o_800.parquet + +# Table (train) +hitab_train_path=${TRAIN_DATA_DIR}/table__hitab_7.4k.parquet +multihier_train_path=${TRAIN_DATA_DIR}/table__multihier_2.9k.parquet +# Table (test) +multihier_test_path=${TEST_DATA_DIR}/table__multihier_336.parquet +hitab_test_path=${TEST_DATA_DIR}/table__hitab_1k.parquet +finqa_test_path=${TEST_DATA_DIR}/table__finqa_1.1k.parquet + +# Stem (train) +webinstruct_train_path=${TRAIN_DATA_DIR}/stem__web_31.7k.parquet +nemotron_train_path=${TRAIN_DATA_DIR}/stem__nemotron_13.3k.parquet +# Stem (test) +nemotron_test_path=${TEST_DATA_DIR}/stem__nemotron_1000.parquet +nemotron_large_test_path=${TEST_DATA_DIR}/stem__nemotron_10.0k.parquet + +gpqa_diamond_test_path=${TEST_DATA_DIR}/stem__gpqa_diamond_198.parquet +supergpqa_test_path=${TEST_DATA_DIR}/stem__supergpqa_1k.parquet + +# IfBench (train) +ifbench_train_path=${TRAIN_DATA_DIR}/ifbench__fixed_85.6k.parquet +# IfBench (test) +ifbench_test_path=${TEST_DATA_DIR}/ifbench_800.parquet +ifbench_large_test_path=${TEST_DATA_DIR}/ifbench_8k.parquet + +# OOD (test) +ifeval_test_path=${TEST_DATA_DIR}/ood__ifeval_541.parquet +livebench_data_analysis_test_path=${TEST_DATA_DIR}/ood__livebench_data_analysis_150.parquet +livebench_language_test_path=${TEST_DATA_DIR}/ood__livebench_language_140.parquet +livebench_reasoning_test_path=${TEST_DATA_DIR}/ood__livebench_reasoning_150.parquet + +train_files="['${math_train1_path}', '${math_train2_path}', '${leetcode_train_path}', '${livecodebench_train_path}', '${primeintellect_train_path}', '${taco_train_path}', '${arcagi1_train_path}', '${arcagi2_train_path}', '${barc_train_path}', '${graph_train_path}', '${ordering_train_path}', '${zebra_train_path}', '${reasoning_gym_train_path}', '${synlogic_train_path}', '${codeio_train_path}', '${hitab_train_path}', '${multihier_train_path}', '${webinstruct_train_path}', '${nemotron_train_path}', '${ifbench_train_path}']" # Use math as example, add to more tasks as needed +test_files="['${aime25_test_path}','${livecodebench_test_path}','${zebralogic_test_path}','${reasoning_gym_test_path}','${synlogic_test_path}','${codeio_test_path}','${multihier_test_path}','${nemotron_test_path}','${supergpqa_test_path}','${ifeval_test_path}']" # Use math as example, add to more tasks as needed + +# =================== Model =================== +BASE_MODEL=deepseek-ai/DeepSeek-R1-Distill-Qwen-7B +CONDA_BIN_PATH=/mnt/weka/home/haonan.li/miniconda3/envs/Reasoning360/bin/ + +# =================== Logging =================== +WANDB_PROJECT=DALU +WANDB_EXPERIMENT_NAME=${SLURM_JOB_ID}-${SLURM_JOB_NAME}-${BASE_MODEL##*/} + +# Set default local directory for checkpoints +DEFAULT_LOCAL_DIR="checkpoints/${WANDB_PROJECT}/${WANDB_EXPERIMENT_NAME}" + +# If RESUME_CKPT_DIR is not empty, resume from the checkpoint +if [[ -n "$RESUME_CKPT_DIR_NAME" ]]; then + WANDB_EXPERIMENT_NAME=$(basename "$RESUME_CKPT_DIR_NAME") + DEFAULT_LOCAL_DIR="$RESUME_CKPT_DIR_NAME" +fi + + +# =================== Ray start =================== +# ray stop at all nodes +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 ${CONDA_BIN_PATH}ray stop + +sleep 10 +# Remove existing Ray cluster +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 rm -rf /tmp/ray/ray_current_cluster + +# Start Ray head node +srun --nodes=1 --ntasks=1 -w "$head_node" --export=ALL \ + ${CONDA_BIN_PATH}ray start --head --node-ip-address="$head_node_ip" --port=$port \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --include-dashboard=True --block & + +sleep 10 + +# Start Ray worker nodes +for ((i = 1; i < worker_num; i++)); do + node_i=${nodes[$i]} + echo "Starting WORKER $i at $node_i" + srun --nodes=1 --ntasks=1 -w "$node_i" --export=ALL \ + ${CONDA_BIN_PATH}ray start --address "$address_head" \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --block & +done +sleep 10 + + +# =================== RL Config =================== +# Note, we borrowed the config format from DAPO while here disabled all DAPO features to run the naive RL baseline. + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.2 + +max_prompt_length=$((1024 * 4)) +max_response_length=$((1024 * 16)) +enable_overlong_buffer=False +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +enable_filter_groups=False +filter_groups_metric=acc +max_num_gen_batches=10 +train_prompt_bsz=256 # on-policy model update batchsize: train_prompt_bsz * rollout.n +gen_prompt_bsz=$((train_prompt_bsz * 1)) +n_resp_per_prompt=8 +train_prompt_mini_bsz=32 # model grad update batchsize + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout + +# Training config +sp_size=1 +gen_tp=1 +gen_max_num_seqs=1024 +infer_micro_batch_size=null +train_micro_batch_size=null +use_dynamic_bsz=True +actor_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up model forward & backward but note memory overflow +infer_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up modelforward, but note memory overflow +offload=True + +# =================== Start RL training =================== +"${CONDA_BIN_PATH}python" -m recipe.dalu.main_dalu \ + --config-path=config \ + --config-name="dapo_fsdp_config.yaml" \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.prompt_key=prompt \ + data.truncation='right' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + data.gen_batch_size=${gen_prompt_bsz} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.actor.strategy="fsdp" \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.optim.warmup_style=constant \ + actor_rollout_ref.actor.optim.min_lr_ratio=0. \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.max_num_seqs=${gen_max_num_seqs} \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${top_p}\ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.model.path=$BASE_MODEL \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.rollout.multi_turn.enable=False \ + actor_rollout_ref.rollout.mode="sync" \ + +actor_rollout_ref.model.override_config.attention_dropout=0. \ + +actor_rollout_ref.model.override_config.embd_pdrop=0. \ + +actor_rollout_ref.model.override_config.resid_pdrop=0. \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + reward_model.reward_manager=async_multi_process \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.logger=['console','wandb'] \ + trainer.project_name=${WANDB_PROJECT} \ + trainer.experiment_name=${WANDB_EXPERIMENT_NAME} \ + trainer.val_before_train=True \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=$worker_num \ + trainer.save_freq=10 \ + trainer.test_freq=10 \ + trainer.total_epochs=5 \ + trainer.log_val_generations=1 \ + trainer.resume_mode=auto \ + trainer.max_actor_ckpt_to_keep=2 \ + trainer.default_local_dir="${DEFAULT_LOCAL_DIR}" \ + +trainer.enable_budget=True \ + +data.dynamic_filtering=True \ + +data.pass_rate_upper_bound=1 \ + +data.initial_pass_rate_column=deepseek_r1_0528_pass_rate \ No newline at end of file diff --git a/recipe/dalu/runtime_env.yaml b/recipe/dalu/runtime_env.yaml new file mode 100644 index 000000000..13f4b2ba2 --- /dev/null +++ b/recipe/dalu/runtime_env.yaml @@ -0,0 +1,5 @@ +working_dir: ./ +excludes: ["/.git/"] +env_vars: + TORCH_NCCL_AVOID_RECORD_STREAMS: "1" + VLLM_USE_V1: "1" diff --git a/recipe/dapo/README.md b/recipe/dapo/README.md index d2395622a..75b80f1aa 100644 --- a/recipe/dapo/README.md +++ b/recipe/dapo/README.md @@ -1,4 +1,4 @@ -# DAPO Open-Source Implementation +# Recipe: Decoupled Clip and Dynamic Sampling Policy Optimization (DAPO) > Open-Source Algorithm Implementation & Expriement Running: [Yuxuan Tong](https://tongyx361.github.io/), [Guangming Sheng](https://hk.linkedin.com/in/guangming-sheng-b50640211) @@ -6,10 +6,10 @@ > > **🔥 News!!!** > -> - [2025/04] We reproduced the results of two versions of DAPO ([Full](./run_dapo_qwen2.5_32b.sh) & [w/o Dynamic Sampling](./run_dapo_wo_ds_qwen2.5_32b.sh)), achieving 52% and 50% on AIME 2024 respectively, based on [the latest codebase on `gm-tyx/puffin/main`](https://github.com/volcengine/verl/tree/gm-tyx/puffin/main/recipe/dapo). Please check the details in [W&B](https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/workspace?nw=wmb4qxfht0n). +> - [2025/04] We reproduced the results of two versions of DAPO ([Full](./run_dapo_qwen2.5_32b.sh) & [w/o Dynamic Sampling](./run_dapo_wo_ds_qwen2.5_32b.sh)), achieving 52% and 50% on AIME 2024 respectively, based on [the latest codebase on `recipe/dapo`](https://github.com/volcengine/verl/tree/recipe/dapo/recipe/dapo). Please check the details in [W&B](https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/workspace?nw=wmb4qxfht0n). > - [2025/03] We published the training record of [an early version of DAPO (w/o Token-level PG Loss & Dynamic Sampling)](./run_dapo_early_qwen2.5_32b.sh), achieving 44% on AIME 2024, in [W&B](https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/workspace?nw=wmb4qxfht0n). -🏠 [Homepage](https://dapo-sia.github.io/) | 📝 [Paper](https://dapo-sia.github.io/static/pdf/dapo_paper.pdf) | 🤗 [Datasets&Models@HF](https://huggingface.co/collections/BytedTsinghua-SIA/dapo-67d7f1517ee33c8aed059da0) | 🐱 [Code@GitHub](https://github.com/volcengine/verl/tree/gm-tyx/puffin/main/recipe/dapo) | 🐱 [Repo@GitHub](https://github.com/BytedTsinghua-SIA/DAPO) +🏠 [Homepage](https://dapo-sia.github.io/) | 📝 [Paper@arXiv](https://arxiv.org/abs/2503.14476) | 🤗 [Datasets&Models@HF](https://huggingface.co/collections/BytedTsinghua-SIA/dapo-67d7f1517ee33c8aed059da0) | 🐱 [Code@GitHub](https://github.com/volcengine/verl/tree/recipe/dapo/recipe/dapo) | 🐱 [Repo@GitHub](https://github.com/BytedTsinghua-SIA/DAPO) > We propose the **D**ecoupled Clip and Dynamic s**A**mpling **P**olicy **O**ptimization (DAPO) algorithm. By making our work publicly available, we provide the broader research community and society with practical access to scalable reinforcement learning, enabling all to benefit from these advancements. Our system is based on the awesome [verl](https://github.com/volcengine/verl) framework. Thanks for their great work! Applying DAPO training to Qwen2.5-32B base model proves to outperform the previous state-of-the-art DeepSeek-R1-Zero-Qwen-32B on AIME 2024, achieving **50%** accuracy with **50%** less training steps. > @@ -17,13 +17,6 @@ ## Quickstart -0. (For reproduction) Checkout to the commit: - -```bash -git fetch origin gm-tyx/puffin/main -git checkout f7e13f5 -``` - 1. Prepare the datasets **on the Ray cluster**: ```bash @@ -37,22 +30,25 @@ cd verl # Repo root export RAY_ADDRESS="http://${RAY_IP:-localhost}:8265" # The Ray cluster address to connect to export WORKING_DIR="${PWD}" # The local directory to package to the Ray cluster # Set the runtime environment like env vars and pip packages for the Ray cluster in yaml -export RUNTIME_ENV="./verl/trainer/runtime_env.yaml" -bash recipe/dapo/run_dapo_qwen2.5_32b.sh +export RUNTIME_ENV="./recipe/dapo/runtime_env.yaml" # This sets environment variables for the Ray cluster +bash recipe/dapo/run_dapo_qwen2.5_32b.sh # or other scripts ``` ## Reproduction Runs -| Setup | AIME 2024 Acc. | Training Script | Training Record | -| -------------------------------------------- | -------------- | ---------------------------------------------------------------- | ----------------------------------------------------------------------------------------- | -| DAPO w/o Token-level Loss & Dynamic Sampling | 44% | [run_dapo_early_qwen2.5_32b.sh](./run_dapo_early_qwen2.5_32b.sh) | [W&B](https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/workspace?nw=wmb4qxfht0n) | -| DAPO w/o Dynamic Sampling | 50% | [run_dapo_wo_ds_qwen2.5_32b.sh](./run_dapo_wo_ds_qwen2.5_32b.sh) | [W&B](https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/workspace?nw=wmb4qxfht0n) | -| DAPO | 52% | [run_dapo_qwen2.5_32b.sh](./run_dapo_qwen2.5_32b.sh) | [W&B](https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/workspace?nw=wmb4qxfht0n) | +| Setup | AIME 2024 Acc. | Hardware | Image | Commit | Environment Variables | Training Script | Training Record | +| -------------------------------------------- | -------------- | --------- | -------------------------------------------------------------------- | -------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------- | +| DAPO | 52% | 16x8xH800 | `hiyouga/verl:ngc-th2.6.0-cu126-vllm0.8.3-flashinfer0.2.2-cxx11abi0` | [`4f80e4`](https://github.com/volcengine/verl/tree/4f80e465c2ec79ab9c3c30ec74b9745de61d0490) | [runtime_env.yaml](https://github.com/volcengine/verl/blob/4f80e465c2ec79ab9c3c30ec74b9745de61d0490/recipe/dapo/runtime_env.yaml) | [run_dapo_qwen2.5_32b.sh](https://github.com/volcengine/verl/blob/4f80e465c2ec79ab9c3c30ec74b9745de61d0490/recipe/dapo/run_dapo_qwen2.5_32b.sh) | [W&B](https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/workspace?nw=wmb4qxfht0n) | +| DAPO w/o Dynamic Sampling | 50% | 16x8xH800 | `hiyouga/verl:ngc-th2.6.0-cu126-vllm0.8.3-flashinfer0.2.2-cxx11abi0` | [`4f80e4`](https://github.com/volcengine/verl/tree/4f80e465c2ec79ab9c3c30ec74b9745de61d0490) | [runtime_env.yaml](https://github.com/volcengine/verl/blob/4f80e465c2ec79ab9c3c30ec74b9745de61d0490/recipe/dapo/runtime_env.yaml) | [run_dapo_wo_ds_qwen2.5_32b.sh](https://github.com/volcengine/verl/blob/4f80e465c2ec79ab9c3c30ec74b9745de61d0490/recipe/dapo/run_dapo_wo_ds_qwen2.5_32b.sh) | [W&B](https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/workspace?nw=wmb4qxfht0n) | +| DAPO w/o Token-level Loss & Dynamic Sampling | 44% | 16x8xH20 | `hiyouga/verl:ngc-th2.5.1-cu120-vllm0.7.4-hotfix` | [`4f80e4`](https://github.com/volcengine/verl/tree/4f80e465c2ec79ab9c3c30ec74b9745de61d0490) | [runtime_env.yaml](https://github.com/volcengine/verl/blob/4f80e465c2ec79ab9c3c30ec74b9745de61d0490/recipe/dapo/runtime_env.yaml) | [run_dapo_early_qwen2.5_32b.sh](https://github.com/volcengine/verl/blob/4f80e465c2ec79ab9c3c30ec74b9745de61d0490/recipe/dapo/run_dapo_early_qwen2.5_32b.sh) | [W&B](https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/workspace?nw=wmb4qxfht0n) | -## Configuration +> [!IMPORTANT] +> +> **📢 Call for Contribution!** +> +> Welcome to submit your reproduction runs and setups! -> [!NOTE] -> Most experiments in the paper, including the best-performant one, are run without Overlong Filtering because it's somehow overlapping with Overlong Reward Shaping in terms of properly learning from the longest outputs. So we don't implement it here. +## Configuration ### Separated Clip Epsilons (-> Clip-Higher) @@ -172,3 +168,25 @@ if self.overlong_buffer_cfg.enable: overlong_reward = min(-exceed_len / overlong_buffer_len * overlong_penalty_factor, 0) reward += overlong_reward ``` + +## FAQ + +### Where is the "Overlong Filtering" in the paper? + +Most experiments in the paper, including the best-performant one, are run without Overlong Filtering because it's somehow overlapping with Overlong Reward Shaping in terms of properly learning from the longest outputs. So we don't implement it here. + +### What's the difference between [the `recipe/dapo` directory in the `main` branch](https://github.com/volcengine/verl/tree/main/recipe/dapo) and the [`recipe/dapo` branch](https://github.com/volcengine/verl/tree/recipe/dapo/recipe/dapo)? + +[The `recipe/dapo` branch](https://github.com/volcengine/verl/tree/recipe/dapo/recipe/dapo) is for **as-is reproduction** and thus won't be updated with new features. + +[The `recipe/dapo` directory in the `main` branch](https://github.com/volcengine/verl/tree/main/recipe/dapo) works as an example of how to extend the latest `verl` to implement an algorithm recipe, which will be maintained with new features. + +### Why can't I produce similar results after modifications? + +RL infrastructures nowadays still have inherent unrobustness, on which we are still working hard to improve. + +We strongly recommend to only modify one thing at a time. + +We also list some known problems here: + +1. Enabling CUDA graph (`enforce_eager=False`) might cause model performance degradation, whose cause is still under investigation. diff --git a/recipe/dapo/config/dapo_fsdp_config.yaml b/recipe/dapo/config/dapo_fsdp_config.yaml new file mode 100644 index 000000000..47141447e --- /dev/null +++ b/recipe/dapo/config/dapo_fsdp_config.yaml @@ -0,0 +1,26 @@ +hydra: + searchpath: + - file://verl/trainer/config + +defaults: + - ppo_trainer + - _self_ + +data: + gen_batch_size: ${data.train_batch_size} + +reward_model: + reward_manager: dapo + overlong_buffer: + enable: False # We try to avoid forgetting to set enable + len: 0 + penalty_factor: 0.0 + log: False + +algorithm: + filter_groups: + _target_: verl.trainer.config.FilterGroupsConfig + enable: False # We try to avoid forgetting to set enable + metric: null # acc / score / seq_reward / seq_final_reward / ... + max_num_gen_batches: 0 # Non-positive values mean no upper limit + diff --git a/recipe/dapo/config/dapo_megatron_config.yaml b/recipe/dapo/config/dapo_megatron_config.yaml new file mode 100644 index 000000000..5b83fab85 --- /dev/null +++ b/recipe/dapo/config/dapo_megatron_config.yaml @@ -0,0 +1,25 @@ +hydra: + searchpath: + - file://verl/trainer/config + +defaults: + - ppo_megatron_trainer + - _self_ + +data: + gen_batch_size: ${data.train_batch_size} + +reward_model: + reward_manager: dapo + overlong_buffer: + enable: False # We try to avoid forgetting to set enable + len: 0 + penalty_factor: 0.0 + log: False + +algorithm: + filter_groups: + _target_: verl.trainer.config.FilterGroupsConfig + enable: False # We try to avoid forgetting to set enable + metric: null # acc / score / seq_reward / seq_final_reward / ... + max_num_gen_batches: 0 # Non-positive values mean no upper limit \ No newline at end of file diff --git a/recipe/dapo/config/dapo_trainer.yaml b/recipe/dapo/config/dapo_trainer.yaml index 0c518b7a9..47ac00fd6 100644 --- a/recipe/dapo/config/dapo_trainer.yaml +++ b/recipe/dapo/config/dapo_trainer.yaml @@ -19,6 +19,7 @@ reward_model: algorithm: filter_groups: + _target_: verl.trainer.config.FilterGroupsConfig enable: False # We try to avoid forgetting to set enable metric: null # acc / score / seq_reward / seq_final_reward / ... max_num_gen_batches: 0 # Non-positive values mean no upper limit diff --git a/recipe/dapo/dapo_ray_trainer.py b/recipe/dapo/dapo_ray_trainer.py index 4ba1dd441..faf23423c 100644 --- a/recipe/dapo/dapo_ray_trainer.py +++ b/recipe/dapo/dapo_ray_trainer.py @@ -33,7 +33,14 @@ compute_timing_metrics, reduce_metrics, ) -from verl.trainer.ppo.ray_trainer import AdvantageEstimator, RayPPOTrainer, _timer, apply_kl_penalty, compute_advantage, compute_response_mask +from verl.trainer.ppo.ray_trainer import ( + AdvantageEstimator, + RayPPOTrainer, + apply_kl_penalty, + compute_advantage, + compute_response_mask, +) +from verl.utils.profiler import marked_timer class RayDAPOTrainer(RayPPOTrainer): @@ -89,6 +96,21 @@ def fit(self): for batch_dict in self.train_dataloader: metrics = {} + do_profile = ( + self.global_steps in self.config.trainer.profile_steps + if self.config.trainer.profile_steps is not None + else False + ) + with marked_timer("start_profile", timing_raw): + if do_profile: + self.actor_rollout_wg.start_profile(role="e2e", profile_step=self.global_steps) + if self.use_reference_policy: + self.ref_policy_wg.start_profile() + if self.use_critic: + self.critic_wg.start_profile() + if self.use_rm: + self.rm_wg.start_profile() + new_batch: DataProto = DataProto.from_single_dict(batch_dict) num_gen_batches += 1 # pop those keys for generation @@ -102,16 +124,19 @@ def fit(self): batch_keys=["input_ids", "attention_mask", "position_ids"], non_tensor_batch_keys=["raw_prompt_ids"], ) + gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) is_last_step = self.global_steps >= self.total_training_steps - with _timer("step", timing_raw): + with marked_timer("step", timing_raw): # generate a batch - with _timer("gen", timing_raw): + with marked_timer("gen", timing_raw, "red"): gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) + timing_raw.update(gen_batch_output.meta_info["timing"]) + gen_batch_output.meta_info.pop("timing", None) if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX: - with _timer("gen_max", timing_raw): + with marked_timer("gen_max", timing_raw, "red"): gen_baseline_batch = deepcopy(gen_batch) gen_baseline_batch.meta_info["do_sample"] = False gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch) @@ -126,12 +151,14 @@ def fit(self): del gen_baseline_batch, gen_baseline_output - new_batch.non_tensor_batch["uid"] = np.array([str(uuid.uuid4()) for _ in range(len(new_batch.batch))], dtype=object) + new_batch.non_tensor_batch["uid"] = np.array( + [str(uuid.uuid4()) for _ in range(len(new_batch.batch))], dtype=object + ) # repeat to align with repeated responses in rollout new_batch = new_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) new_batch = new_batch.union(gen_batch_output) - with _timer("reward", timing_raw): + with marked_timer("reward", timing_raw, "yellow"): # compute scores. Support both model and function-based. # We first compute the scores using reward model. Then, we call reward_fn to combine # the results from reward model and rule-based results. @@ -145,7 +172,7 @@ def fit(self): try: reward_result = self.reward_fn(new_batch, return_dict=True) reward_tensor = reward_result["reward_tensor"] - reward_extra_infos_dict = reward_result["reward_extra_info"] + reward_extra_infos_dict = reward_result.get("reward_extra_info", {}) except Exception as e: print(f"Error in reward_fn: {e}") reward_tensor = self.reward_fn(new_batch) @@ -153,14 +180,19 @@ def fit(self): new_batch.batch["token_level_scores"] = reward_tensor - print(f"{list(reward_extra_infos_dict.keys())=}") if reward_extra_infos_dict: - new_batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()}) + new_batch.non_tensor_batch.update( + {k: np.array(v) for k, v in reward_extra_infos_dict.items()} + ) # compute rewards. apply_kl_penalty if available if self.config.algorithm.use_kl_in_reward: - new_batch, kl_metrics = apply_kl_penalty(new_batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty) - metrics.update(kl_metrics) # TODO: This will be cleared if we use multiple genenration batches + new_batch, kl_metrics = apply_kl_penalty( + new_batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty + ) + metrics.update( + kl_metrics + ) # TODO: This will be cleared if we use multiple genenration batches else: new_batch.batch["token_level_rewards"] = new_batch.batch["token_level_scores"] @@ -171,20 +203,30 @@ def fit(self): metric_name = self.config.algorithm.filter_groups.metric if metric_name == "seq_final_reward": # Turn to numpy for easier filtering - new_batch.non_tensor_batch["seq_final_reward"] = new_batch.batch["token_level_rewards"].sum(dim=-1).numpy() + new_batch.non_tensor_batch["seq_final_reward"] = ( + new_batch.batch["token_level_rewards"].sum(dim=-1).numpy() + ) elif metric_name == "seq_reward": - new_batch.non_tensor_batch["seq_reward"] = new_batch.batch["token_level_scores"].sum(dim=-1).numpy() + new_batch.non_tensor_batch["seq_reward"] = ( + new_batch.batch["token_level_scores"].sum(dim=-1).numpy() + ) # Collect the sequence reward for each trajectory prompt_uid2metric_vals = defaultdict(list) - for uid, metric_val in zip(new_batch.non_tensor_batch["uid"], new_batch.non_tensor_batch[metric_name]): + for uid, metric_val in zip( + new_batch.non_tensor_batch["uid"], new_batch.non_tensor_batch[metric_name], strict=True + ): prompt_uid2metric_vals[uid].append(metric_val) prompt_uid2metric_std = {} for prompt_uid, metric_vals in prompt_uid2metric_vals.items(): prompt_uid2metric_std[prompt_uid] = np.std(metric_vals) - kept_prompt_uids = [uid for uid, std in prompt_uid2metric_std.items() if std > 0 or len(prompt_uid2metric_vals[uid]) == 1] + kept_prompt_uids = [ + uid + for uid, std in prompt_uid2metric_std.items() + if std > 0 or len(prompt_uid2metric_vals[uid]) == 1 + ] num_prompt_in_batch += len(kept_prompt_uids) kept_traj_idxs = [] @@ -201,9 +243,14 @@ def fit(self): max_num_gen_batches = self.config.algorithm.filter_groups.max_num_gen_batches if max_num_gen_batches <= 0 or num_gen_batches < max_num_gen_batches: print(f"{num_gen_batches=}. Keep generating...") + progress_bar.update(1) continue else: - raise ValueError(f"{num_gen_batches=} >= {max_num_gen_batches=}." + " Generated too many. Please check if your data are too difficult." + " You could also try set max_num_gen_batches=0 to enable endless trials.") + raise ValueError( + f"{num_gen_batches=} >= {max_num_gen_batches=}." + + " Generated too many. Please check if your data are too difficult." + + " You could also try set max_num_gen_batches=0 to enable endless trials." + ) else: # Align the batch traj_bsz = self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n @@ -225,30 +272,30 @@ def fit(self): batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() # recompute old_log_probs - with _timer("old_log_prob", timing_raw): + with marked_timer("old_log_prob", timing_raw, "blue"): old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) entropys = old_log_prob.batch["entropys"] response_masks = batch.batch["response_mask"] loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode - entropy_loss = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode) - old_log_prob_metrics = {"actor/entropy_loss": entropy_loss.detach().item()} + entropy_agg = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode) + old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()} metrics.update(old_log_prob_metrics) old_log_prob.batch.pop("entropys") batch = batch.union(old_log_prob) if self.use_reference_policy: # compute reference log_prob - with _timer("ref", timing_raw): + with marked_timer("ref", timing_raw, "olive"): ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) batch = batch.union(ref_log_prob) # compute values if self.use_critic: - with _timer("values", timing_raw): + with marked_timer("values", timing_raw, "cyan"): values = self.critic_wg.compute_values(batch) batch = batch.union(values) - with _timer("adv", timing_raw): + with marked_timer("adv", timing_raw, "brown"): # compute advantages, executed on the driver process norm_adv_by_std_in_grpo = self.config.algorithm.get("norm_adv_by_std_in_grpo", True) batch = compute_advantage( @@ -262,7 +309,7 @@ def fit(self): # update critic if self.use_critic: - with _timer("update_critic", timing_raw): + with marked_timer("update_critic", timing_raw, "pink"): critic_output = self.critic_wg.update_critic(batch) critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"]) metrics.update(critic_output_metrics) @@ -270,23 +317,39 @@ def fit(self): # implement critic warmup if self.config.trainer.critic_warmup <= self.global_steps: # update actor - with _timer("update_actor", timing_raw): + with marked_timer("update_actor", timing_raw, "red"): actor_output = self.actor_rollout_wg.update_actor(batch) actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) metrics.update(actor_output_metrics) # validate - if self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0): - with _timer("testing", timing_raw): + if ( + self.val_reward_fn is not None + and self.config.trainer.test_freq > 0 + and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0) + ): + with marked_timer("testing", timing_raw, "green"): val_metrics: dict = self._validate() if is_last_step: last_val_metrics = val_metrics metrics.update(val_metrics) - if self.config.trainer.save_freq > 0 and (is_last_step or self.global_steps % self.config.trainer.save_freq == 0): - with _timer("save_checkpoint", timing_raw): + if self.config.trainer.save_freq > 0 and ( + is_last_step or self.global_steps % self.config.trainer.save_freq == 0 + ): + with marked_timer("save_checkpoint", timing_raw, "green"): self._save_checkpoint() + with marked_timer("stop_profile", timing_raw): + if do_profile: + self.actor_rollout_wg.stop_profile() + if self.use_reference_policy: + self.ref_policy_wg.stop_profile() + if self.use_critic: + self.critic_wg.stop_profile() + if self.use_rm: + self.rm_wg.stop_profile() + # collect metrics metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) diff --git a/recipe/dapo/main_dapo.py b/recipe/dapo/main_dapo.py index e2587e91f..268591b8d 100644 --- a/recipe/dapo/main_dapo.py +++ b/recipe/dapo/main_dapo.py @@ -15,11 +15,15 @@ Note that we don't combine the main with ray_trainer as ray_trainer is used by other main. """ +import os +import socket import hydra import ray +from omegaconf import OmegaConf -from verl.trainer.ppo.reward import get_custom_reward_fn +from verl.trainer.ppo.reward import load_reward_manager +from verl.utils.device import is_cuda_available from .dapo_ray_trainer import RayDAPOTrainer @@ -33,11 +37,21 @@ def run_ppo(config) -> None: if not ray.is_initialized(): # this is for local ray cluster ray.init( - runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN", "VLLM_LOGGING_LEVEL": "WARN"}}, + runtime_env={ + "env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN", "VLLM_LOGGING_LEVEL": "WARN"} + }, num_cpus=config.ray_init.num_cpus, ) - runner = TaskRunner.remote() + if ( + is_cuda_available + and OmegaConf.select(config.trainer, "profile_steps") is not None + and len(OmegaConf.select(config.trainer, "profile_steps")) > 0 + ): + nsight_options = OmegaConf.to_container(config.trainer.controller_nsight_options) + runner = TaskRunner.options(runtime_env={"nsight": nsight_options}).remote() + else: + runner = TaskRunner.remote() ray.get(runner.run.remote(config)) @@ -51,6 +65,8 @@ def run(self, config): from verl.utils.fs import copy_to_local + print(f"TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}") + pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values OmegaConf.resolve(config) @@ -64,8 +80,8 @@ def run(self, config): processor = hf_processor(local_path, use_fast=True) # used for multimodal LLM, could be none # define worker classes - if config.actor_rollout_ref.actor.strategy == "fsdp": - assert config.actor_rollout_ref.actor.strategy == config.critic.strategy + if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}: + assert config.critic.strategy in {"fsdp", "fsdp2"} from verl.single_controller.ray import RayWorkerGroup from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker @@ -104,7 +120,7 @@ def run(self, config): # - finally, we combine all the rewards together # - The reward type depends on the tag of the data if config.reward_model.enable: - if config.reward_model.strategy == "fsdp": + if config.reward_model.strategy in {"fsdp", "fsdp2"}: from verl.workers.fsdp_workers import RewardModelWorker elif config.reward_model.strategy == "megatron": from verl.workers.megatron_workers import RewardModelWorker @@ -118,38 +134,19 @@ def run(self, config): role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker) mapping[Role.RefPolicy] = global_pool_id - reward_manager_name = config.reward_model.get("reward_manager", "naive") - if reward_manager_name == "naive": - from verl.workers.reward_manager import NaiveRewardManager - - reward_manager_cls = NaiveRewardManager - elif reward_manager_name == "prime": - from verl.workers.reward_manager import PrimeRewardManager - - reward_manager_cls = PrimeRewardManager - elif reward_manager_name == "dapo": - from verl.workers.reward_manager import DAPORewardManager - - reward_manager_cls = DAPORewardManager - else: - raise NotImplementedError - - compute_score = get_custom_reward_fn(config) - reward_fn = reward_manager_cls( - tokenizer=tokenizer, - num_examine=0, - compute_score=compute_score, - reward_fn_key=config.data.reward_fn_key, + reward_fn = load_reward_manager( + config, + tokenizer, + 0, max_resp_len=config.data.max_response_length, overlong_buffer_cfg=config.reward_model.overlong_buffer, ) # Note that we always use function-based RM for validation - val_reward_fn = reward_manager_cls( - tokenizer=tokenizer, - num_examine=1, - compute_score=compute_score, - reward_fn_key=config.data.reward_fn_key, + val_reward_fn = load_reward_manager( + config, + tokenizer, + 1, max_resp_len=config.data.max_response_length, overlong_buffer_cfg=config.reward_model.overlong_buffer, ) @@ -164,6 +161,7 @@ def run(self, config): ray_worker_group_cls=ray_worker_group_cls, reward_fn=reward_fn, val_reward_fn=val_reward_fn, + device_name=config.trainer.device, ) trainer.init_workers() trainer.fit() diff --git a/recipe/dapo/run_dapo_early_qwen2.5_32b.sh b/recipe/dapo/run_dapo_early_qwen2.5_32b.sh index c7bd5c189..81bc2cb12 100644 --- a/recipe/dapo/run_dapo_early_qwen2.5_32b.sh +++ b/recipe/dapo/run_dapo_early_qwen2.5_32b.sh @@ -115,7 +115,7 @@ ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \ reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ reward_model.overlong_buffer.len=${overlong_buffer_len} \ reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ - trainer.logger=['console','wandb'] \ + trainer.logger='["console","wandb"]' \ trainer.project_name="${project_name}" \ trainer.experiment_name="${exp_name}" \ trainer.n_gpus_per_node=8 \ diff --git a/recipe/dapo/run_dapo_qwen2.5_32b.sh b/recipe/dapo/run_dapo_qwen2.5_32b.sh index 6eec26c80..feb783a7c 100644 --- a/recipe/dapo/run_dapo_qwen2.5_32b.sh +++ b/recipe/dapo/run_dapo_qwen2.5_32b.sh @@ -117,7 +117,7 @@ ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \ reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ reward_model.overlong_buffer.len=${overlong_buffer_len} \ reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ - trainer.logger=['console','wandb'] \ + trainer.logger='["console","wandb"]' \ trainer.project_name="${project_name}" \ trainer.experiment_name="${exp_name}" \ trainer.n_gpus_per_node=8 \ diff --git a/recipe/dapo/run_dapo_wo_ds_qwen2.5_32b.sh b/recipe/dapo/run_dapo_wo_ds_qwen2.5_32b.sh index 6064b5be6..b0491aedf 100644 --- a/recipe/dapo/run_dapo_wo_ds_qwen2.5_32b.sh +++ b/recipe/dapo/run_dapo_wo_ds_qwen2.5_32b.sh @@ -112,7 +112,7 @@ ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \ reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ reward_model.overlong_buffer.len=${overlong_buffer_len} \ reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ - trainer.logger=['console','wandb'] \ + trainer.logger='["console","wandb"]' \ trainer.project_name="${project_name}" \ trainer.experiment_name="${exp_name}" \ trainer.n_gpus_per_node=8 \ diff --git a/recipe/dapo/runtime_env.yaml b/recipe/dapo/runtime_env.yaml new file mode 100644 index 000000000..13f4b2ba2 --- /dev/null +++ b/recipe/dapo/runtime_env.yaml @@ -0,0 +1,5 @@ +working_dir: ./ +excludes: ["/.git/"] +env_vars: + TORCH_NCCL_AVOID_RECORD_STREAMS: "1" + VLLM_USE_V1: "1" diff --git a/recipe/dapo/test_dapo_7b.sh b/recipe/dapo/test_dapo_7b.sh index fe5cb297b..2bb94963d 100644 --- a/recipe/dapo/test_dapo_7b.sh +++ b/recipe/dapo/test_dapo_7b.sh @@ -117,7 +117,7 @@ ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \ reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ reward_model.overlong_buffer.len=${overlong_buffer_len} \ reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ - trainer.logger=['console','wandb'] \ + trainer.logger='["console","wandb"]' \ trainer.project_name="${project_name}" \ trainer.experiment_name="${exp_name}" \ trainer.n_gpus_per_node=8 \ diff --git a/recipe/dapo/test_dapo_7b_math.sh b/recipe/dapo/test_dapo_7b_math.sh index 39918ac2d..9574f7722 100644 --- a/recipe/dapo/test_dapo_7b_math.sh +++ b/recipe/dapo/test_dapo_7b_math.sh @@ -34,6 +34,7 @@ NNODES=${NNODES:-8} NGPUS_PER_NODE=${NGPUS_PER_NODE:-8} # Paths RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B"} CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} @@ -54,7 +55,7 @@ offload=True gen_tp=4 fsdp_size=32 -# remember to set VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 for this model +# reference run wandb: https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/runs/ow47vvon?nw=nwusertongyuxuan361 python3 -m verl.trainer.main_ppo \ data.train_files="${TRAIN_FILE}" \ @@ -114,7 +115,7 @@ python3 -m verl.trainer.main_ppo \ +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \ +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \ +reward_model.reward_kwargs.max_resp_len=${max_response_length} \ - trainer.logger=['console','wandb'] \ + trainer.logger='["console","wandb"]' \ trainer.project_name="${project_name}" \ trainer.experiment_name="${exp_name}" \ trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \ diff --git a/verl/recipe/dapo/run_dapo_early_qwen2.5_32b.sh b/recipe/dapo/test_dapo_7b_math_lora.sh similarity index 70% rename from verl/recipe/dapo/run_dapo_early_qwen2.5_32b.sh rename to recipe/dapo/test_dapo_7b_math_lora.sh index 3b1ebfce6..d68e5d625 100644 --- a/verl/recipe/dapo/run_dapo_early_qwen2.5_32b.sh +++ b/recipe/dapo/test_dapo_7b_math_lora.sh @@ -1,8 +1,8 @@ #!/usr/bin/env bash -set -euxo pipefail +set -xeuo pipefail project_name='DAPO' -exp_name='DAPO-Early-Qwen2.5-32B' +exp_name='DAPO-Qwen2.5-7b-MATH-0527a1' adv_estimator=grpo @@ -15,28 +15,26 @@ clip_ratio_low=0.2 clip_ratio_high=0.28 max_prompt_length=$((1024 * 2)) -max_response_length=$((1024 * 20)) +max_response_length=$((1024 * 8)) enable_overlong_buffer=True overlong_buffer_len=$((1024 * 4)) overlong_penalty_factor=1.0 -# An early version for DAPO -loss_agg_mode="seq-mean-token-sum" +loss_agg_mode="token-mean" -enable_filter_groups=False -gen_prompt_bsz=512 # NOTE: no filtering here train_prompt_bsz=512 -train_prompt_mini_bsz=32 n_resp_per_prompt=16 +train_prompt_mini_bsz=32 # Ray -RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} -WORKING_DIR=${WORKING_DIR:-"${PWD}"} -RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} -NNODES=${NNODES:-16} +# RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +# WORKING_DIR=${WORKING_DIR:-"${PWD}"} +# RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} +NNODES=${NNODES:-8} +NGPUS_PER_NODE=${NGPUS_PER_NODE:-8} # Paths RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} -MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-32B"} +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B"} CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} @@ -47,25 +45,24 @@ top_p=1.0 top_k=-1 # 0 for HF rollout, -1 for vLLM rollout val_top_p=0.7 - # Performance Related Parameter -sp_size=8 +sp_size=4 use_dynamic_bsz=True -actor_ppo_max_token_len=$((max_prompt_length + max_response_length)) -infer_ppo_max_token_len=$((max_prompt_length + max_response_length)) +actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2)) +infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3)) offload=True gen_tp=4 +fsdp_size=32 + +# remember to set VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 for this model -ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \ - --working-dir "${WORKING_DIR}" \ - -- python3 -m recipe.dapo.src.main_dapo \ +python3 -m verl.trainer.main_ppo \ data.train_files="${TRAIN_FILE}" \ data.val_files="${TEST_FILE}" \ data.prompt_key=prompt \ data.truncation='left' \ data.max_prompt_length=${max_prompt_length} \ data.max_response_length=${max_response_length} \ - data.gen_batch_size=${gen_prompt_bsz} \ data.train_batch_size=${train_prompt_bsz} \ actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ algorithm.adv_estimator=${adv_estimator} \ @@ -76,8 +73,8 @@ ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \ actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ actor_rollout_ref.actor.clip_ratio_c=10.0 \ - algorithm.filter_groups.enable=${enable_filter_groups} \ actor_rollout_ref.model.use_remove_padding=True \ + +actor_rollout_ref.model.override_config.max_position_embeddings=32768 \ actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ @@ -85,10 +82,8 @@ ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \ actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ actor_rollout_ref.model.path="${MODEL_PATH}" \ - +actor_rollout_ref.model.override_config.attention_dropout=0. \ - +actor_rollout_ref.model.override_config.embd_pdrop=0. \ - +actor_rollout_ref.model.override_config.resid_pdrop=0. \ actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.model.lora_rank=8 \ actor_rollout_ref.actor.optim.lr=1e-6 \ actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ actor_rollout_ref.actor.optim.weight_decay=0.1 \ @@ -105,7 +100,7 @@ ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \ actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ actor_rollout_ref.rollout.temperature=${temperature} \ actor_rollout_ref.rollout.top_p=${top_p} \ - actor_rollout_ref.rollout.top_k="${top_k}" \ + actor_rollout_ref.rollout.top_k=${top_k} \ actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ @@ -113,19 +108,23 @@ ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \ actor_rollout_ref.rollout.val_kwargs.n=1 \ actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ - actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \ reward_model.reward_manager=dapo \ - reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ - reward_model.overlong_buffer.len=${overlong_buffer_len} \ - reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ - trainer.logger=['console','wandb'] \ + +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \ + +reward_model.reward_kwargs.max_resp_len=${max_response_length} \ + trainer.logger='["console","wandb"]' \ trainer.project_name="${project_name}" \ trainer.experiment_name="${exp_name}" \ - trainer.n_gpus_per_node=8 \ + trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \ trainer.nnodes="${NNODES}" \ trainer.val_before_train=True \ - trainer.test_freq=5 \ - trainer.save_freq=5 \ - trainer.total_epochs=1 \ + trainer.test_freq=10 \ + trainer.save_freq=10 \ + trainer.total_epochs=10 \ + trainer.total_training_steps=200 \ trainer.default_local_dir="${CKPTS_DIR}" \ - trainer.resume_mode=auto \ No newline at end of file + trainer.resume_mode=auto \ + trainer.log_val_generations=10 diff --git a/recipe/dapo/test_dapo_7b_math_megatron.sh b/recipe/dapo/test_dapo_7b_math_megatron.sh index 0e3558a4d..4c16cd7d4 100644 --- a/recipe/dapo/test_dapo_7b_math_megatron.sh +++ b/recipe/dapo/test_dapo_7b_math_megatron.sh @@ -118,7 +118,7 @@ python3 -m verl.trainer.main_ppo \ +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \ +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \ +reward_model.reward_kwargs.max_resp_len=${max_response_length} \ - trainer.logger=['console','wandb'] \ + trainer.logger='["console","wandb"]' \ trainer.project_name="${project_name}" \ trainer.experiment_name="${exp_name}" \ trainer.n_gpus_per_node=16 \ diff --git a/recipe/dapo/test_dapo_dspk_671b_megatron.sh b/recipe/dapo/test_dapo_dspk_671b_megatron.sh new file mode 100644 index 000000000..c6988d114 --- /dev/null +++ b/recipe/dapo/test_dapo_dspk_671b_megatron.sh @@ -0,0 +1,143 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +# 0. download the config +# only need to download the configuration_deepseek.py and config.json +# remove the `quantization_config` in the `config.json` +# set `num_nextn_predict_layers=0` to disable MTP, which is not currently supported +huggingface-cli download deepseek-ai/DeepSeek-V3-0324 configuration_deepseek.py config.json + +project_name='DAPO' +exp_name='DAPO-DeepSeek-671b-megatron' + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 4)) +enable_overlong_buffer=True +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=0.1 + +loss_agg_mode="token-mean" + +train_prompt_bsz=512 # must be > n_gpus. need to fix +n_resp_per_prompt=2 +train_prompt_mini_bsz=16 # mini_bsz * n >= micro_bsz * pp * dp + +NNODES=${NNODES:-64} + +# 1. download the dist_ckpt format model from https://huggingface.co/BearBiscuit05/dpsk-v3-671B-BF16-dist_ckpt/tree/main +# change the MODEL_PATH and MCORE_MODEL_PATH to your own path +# Paths +MODEL_PATH="" +MCORE_MODEL_PATH="" +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} +aime24_test_path=${RAY_DATA_HOME}/data/aime-2024.parquet +# TEST_FILE="['$math500_test_path', '$aime24_test_path']" + +TEST_FILE="['$aime24_test_path']" + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +val_top_p=0.7 + +# Performance Related Parameter +use_dynamic_bsz=True +actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2)) +infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3)) +offload=True +gen_tp=32 +train_tp=1 +train_ep=32 +train_pp=16 + +python3 -m verl.trainer.main_ppo \ + --config-path=config \ + --config-name='ppo_megatron_trainer.yaml' \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.megatron.param_offload=${offload} \ + actor_rollout_ref.actor.megatron.optimizer_offload=${offload} \ + actor_rollout_ref.actor.megatron.grad_offload=${offload} \ + actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${train_pp} \ + actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${train_tp} \ + actor_rollout_ref.actor.megatron.expert_model_parallel_size=${train_ep} \ + actor_rollout_ref.actor.megatron.dist_checkpointing_path=${MCORE_MODEL_PATH} \ + actor_rollout_ref.actor.megatron.use_dist_checkpointing=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.num_layers_in_first_pipeline_stage=3 \ + +actor_rollout_ref.actor.megatron.override_transformer_config.num_layers_in_last_pipeline_stage=2 \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.optim.clip_grad=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${train_pp} \ + actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${train_tp} \ + actor_rollout_ref.ref.megatron.expert_model_parallel_size=${train_ep} \ + actor_rollout_ref.ref.megatron.param_offload=${offload} \ + actor_rollout_ref.ref.megatron.dist_checkpointing_path=${MCORE_MODEL_PATH} \ + actor_rollout_ref.ref.megatron.use_dist_checkpointing=True \ + reward_model.reward_manager=dapo \ + +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \ + +reward_model.reward_kwargs.max_resp_len=${max_response_length} \ + trainer.logger='["console","wandb"]' \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes="${NNODES}" \ + trainer.val_before_train=False \ + trainer.test_freq=5 \ + trainer.save_freq=5 \ + trainer.total_epochs=10 \ + trainer.total_training_steps=10 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=auto \ + trainer.log_val_generations=10 diff --git a/recipe/dapo/test_dapo_qwen3_30b_math.sh b/recipe/dapo/test_dapo_qwen3_30b_math.sh index 56ebd0397..741e0d6d0 100644 --- a/recipe/dapo/test_dapo_qwen3_30b_math.sh +++ b/recipe/dapo/test_dapo_qwen3_30b_math.sh @@ -111,7 +111,7 @@ python3 -m verl.trainer.main_ppo \ +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \ +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \ +reward_model.reward_kwargs.max_resp_len=${max_response_length} \ - trainer.logger=['console','wandb'] \ + trainer.logger='["console","wandb"]' \ trainer.project_name="${project_name}" \ trainer.experiment_name="${exp_name}" \ trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \ diff --git a/recipe/entropy/32b_clip_cov.sh b/recipe/entropy/32b_clip_cov.sh new file mode 100644 index 000000000..65cbe2e14 --- /dev/null +++ b/recipe/entropy/32b_clip_cov.sh @@ -0,0 +1,148 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +export WANDB_API_KEY=YOUR_WANDB_API_KEY +# export VLLM_USE_V1=1 + +project_name='Qwen2.5-32B' +exp_name='clipcov' + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=1 +clip_ratio_high=1 +clip_cov_ratio=0.0002 +clip_cov_lb=1.0 +clip_cov_ub=5.0 + +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 8)) +enable_overlong_buffer=False +overlong_buffer_len=$((1024 * 2)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" +loss_mode="clip_cov" +enable_filter_groups=True +filter_groups_metric=acc +max_num_gen_batches=10 +train_prompt_bsz=256 +gen_prompt_bsz=$((train_prompt_bsz * 3)) +train_prompt_mini_bsz=32 +n_resp_per_prompt=8 +max_token=20480 + +# Ray +RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +WORKING_DIR=${WORKING_DIR:-"${PWD}"} +RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} +NNODES=${NNODES:-4} +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +MODEL_PATH=${MODEL_PATH:-"/YOUR_MODELPATH"} +CKPTS_DIR=${CKPTS_DIR:-"/YOUR_CKPTS_PATH"} +TRAIN_FILE=${TRAIN_FILE:-"/YOUR_TRAIN_FILE_PATH"} +TEST_FILE=${TEST_FILE:-["/YOUR_TRAIN_FILE_PATH"]} + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +ppo_kl_coef=1 +kl_cov_ratio=0.02 + +# Mathematically equivalent +use_dynamic_bsz=True +infer_micro_batch_size=null +train_micro_batch_size=null +offload=False + +HYDRA_FULL_ERROR=1 python -m recipe.entropy.main_entropy \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.filter_overlong_prompts=False \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.gen_batch_size=${gen_prompt_bsz} \ + data.train_batch_size=${train_prompt_bsz} \ + data.return_raw_chat=True \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} \ + actor_rollout_ref.actor.policy_loss.clip_cov_ratio=${clip_cov_ratio} \ + actor_rollout_ref.actor.policy_loss.clip_cov_lb=${clip_cov_lb} \ + actor_rollout_ref.actor.policy_loss.clip_cov_ub=${clip_cov_ub} \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \ + actor_rollout_ref.rollout.mode=sync \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${max_token} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${max_token} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${max_token} \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.weight_decay=0 \ + actor_rollout_ref.actor.optim.warmup_style=constant \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \ + actor_rollout_ref.actor.clip_cov_ratio=${clip_cov_ratio} \ + actor_rollout_ref.actor.clip_cov_lb=${clip_cov_lb} \ + actor_rollout_ref.actor.clip_cov_ub=${clip_cov_ub} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=${max_token} \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k="${top_k}" \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=False \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=1 \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + reward_model.reward_manager=dapo \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.logger='["console","wandb"]' \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes="${NNODES}" \ + trainer.val_before_train=False \ + trainer.test_freq=4 \ + trainer.save_freq=32 \ + trainer.total_epochs=1000 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=disable diff --git a/verl/recipe/dapo/test_dapo_7b.sh b/recipe/entropy/32b_kl_cov.sh similarity index 69% rename from verl/recipe/dapo/test_dapo_7b.sh rename to recipe/entropy/32b_kl_cov.sh index f370fffb4..b0ba4519f 100644 --- a/verl/recipe/dapo/test_dapo_7b.sh +++ b/recipe/entropy/32b_kl_cov.sh @@ -1,8 +1,11 @@ #!/usr/bin/env bash -set -euxo pipefail +set -xeuo pipefail -project_name='DAPO' -exp_name='DAPO-Qwen2.5-7B-Math-Test' +export WANDB_API_KEY=YOUR_WANDB_API_KEY +# export VLLM_USE_V1=1 + +project_name='Qwen2.5-32B' +exp_name='klcov' adv_estimator=grpo @@ -12,23 +15,24 @@ use_kl_loss=False kl_loss_coef=0.0 clip_ratio_low=0.2 -clip_ratio_high=0.28 +clip_ratio_high=0.2 max_prompt_length=$((1024 * 2)) -max_response_length=$((1024 * 2)) -enable_overlong_buffer=True -overlong_buffer_len=512 +max_response_length=$((1024 * 8)) +enable_overlong_buffer=False +overlong_buffer_len=$((1024 * 2)) overlong_penalty_factor=1.0 loss_agg_mode="token-mean" - +loss_mode="kl_cov" enable_filter_groups=True filter_groups_metric=acc max_num_gen_batches=10 -train_prompt_bsz=512 +train_prompt_bsz=256 gen_prompt_bsz=$((train_prompt_bsz * 3)) train_prompt_mini_bsz=32 -n_resp_per_prompt=16 +n_resp_per_prompt=8 +max_token=20480 # Ray RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} @@ -37,15 +41,17 @@ RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} NNODES=${NNODES:-4} # Paths RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} -MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B"} -CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} -TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} -TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} +MODEL_PATH=${MODEL_PATH:-"/YOUR_MODELPATH"} +CKPTS_DIR=${CKPTS_DIR:-"/YOUR_CKPTS_PATH"} +TRAIN_FILE=${TRAIN_FILE:-"/YOUR_TRAIN_FILE_PATH"} +TEST_FILE=${TEST_FILE:-["/YOUR_TRAIN_FILE_PATH"]} # Algorithm temperature=1.0 top_p=1.0 top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +ppo_kl_coef=1 +kl_cov_ratio=0.0002 # Mathematically equivalent use_dynamic_bsz=True @@ -53,24 +59,29 @@ infer_micro_batch_size=null train_micro_batch_size=null offload=False -ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \ - --working-dir "${WORKING_DIR}" \ - -- python3 -m recipe.dapo.src.main_dapo \ +HYDRA_FULL_ERROR=1 python -m recipe.entropy.main_entropy \ data.train_files="${TRAIN_FILE}" \ data.val_files="${TEST_FILE}" \ data.prompt_key=prompt \ data.truncation='left' \ + data.filter_overlong_prompts=False \ data.max_prompt_length=${max_prompt_length} \ data.max_response_length=${max_response_length} \ data.gen_batch_size=${gen_prompt_bsz} \ data.train_batch_size=${train_prompt_bsz} \ - data.truncation='left' \ + data.return_raw_chat=True \ actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.loss_mode=${loss_mode} \ + actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} \ + actor_rollout_ref.actor.policy_loss.kl_cov_ratio=${kl_cov_ratio} \ + actor_rollout_ref.actor.policy_loss.ppo_kl_coef=${ppo_kl_coef} \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \ + actor_rollout_ref.rollout.mode=sync \ algorithm.adv_estimator=${adv_estimator} \ algorithm.use_kl_in_reward=${use_kl_in_reward} \ algorithm.kl_ctrl.kl_coef=${kl_coef} \ @@ -81,17 +92,14 @@ ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \ actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ - actor_rollout_ref.actor.ppo_max_token_len_per_gpu=$((max_prompt_length + max_response_length)) \ - actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=$((max_prompt_length + max_response_length)) \ - actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${max_token} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${max_token} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${max_token} \ actor_rollout_ref.model.path="${MODEL_PATH}" \ - +actor_rollout_ref.model.override_config.attention_dropout=0. \ - +actor_rollout_ref.model.override_config.embd_pdrop=0. \ - +actor_rollout_ref.model.override_config.resid_pdrop=0. \ actor_rollout_ref.model.enable_gradient_checkpointing=True \ actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ - actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.optim.weight_decay=0 \ + actor_rollout_ref.actor.optim.warmup_style=constant \ actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ @@ -99,20 +107,19 @@ ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \ actor_rollout_ref.actor.entropy_coeff=0 \ actor_rollout_ref.actor.grad_clip=1.0 \ actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ - actor_rollout_ref.actor.loss_agg_mode=True \ actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \ actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \ actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \ - actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ actor_rollout_ref.rollout.enable_chunked_prefill=True \ - actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.rollout.max_num_batched_tokens=${max_token} \ actor_rollout_ref.rollout.temperature=${temperature} \ actor_rollout_ref.rollout.top_p=${top_p} \ actor_rollout_ref.rollout.top_k="${top_k}" \ actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ actor_rollout_ref.rollout.val_kwargs.top_p=${top_p} \ actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ - actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.do_sample=False \ actor_rollout_ref.rollout.val_kwargs.n=1 \ actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ @@ -122,14 +129,14 @@ ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \ reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ reward_model.overlong_buffer.len=${overlong_buffer_len} \ reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ - trainer.logger=['console','wandb'] \ + trainer.logger='["console","wandb"]' \ trainer.project_name="${project_name}" \ trainer.experiment_name="${exp_name}" \ trainer.n_gpus_per_node=8 \ trainer.nnodes="${NNODES}" \ - trainer.val_before_train=True \ - trainer.test_freq=2 \ - trainer.save_freq=2 \ - trainer.total_epochs=1 \ + trainer.val_before_train=False \ + trainer.test_freq=4 \ + trainer.save_freq=32 \ + trainer.total_epochs=1000 \ trainer.default_local_dir="${CKPTS_DIR}" \ - trainer.resume_mode=disable \ No newline at end of file + trainer.resume_mode=disable diff --git a/recipe/entropy/32b_kl_cov_mininbsz.sh b/recipe/entropy/32b_kl_cov_mininbsz.sh new file mode 100644 index 000000000..15d191838 --- /dev/null +++ b/recipe/entropy/32b_kl_cov_mininbsz.sh @@ -0,0 +1,141 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +export WANDB_API_KEY=YOUR_WANDB_API_KEY +# export VLLM_USE_V1=1 + +project_name='Qwen2.5-32B' +exp_name='klcov' + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.2 + +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 8)) +enable_overlong_buffer=False +overlong_buffer_len=$((1024 * 2)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" +loss_mode="kl_cov" +enable_filter_groups=True +filter_groups_metric=acc +max_num_gen_batches=10 +train_prompt_bsz=256 +gen_prompt_bsz=$((train_prompt_bsz * 3)) +train_prompt_mini_bsz=16 +n_resp_per_prompt=8 +max_token=20480 + +# Ray +RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +WORKING_DIR=${WORKING_DIR:-"${PWD}"} +RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} +NNODES=${NNODES:-4} +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +MODEL_PATH=${MODEL_PATH:-"/YOUR_MODELPATH"} +CKPTS_DIR=${CKPTS_DIR:-"/YOUR_CKPTS_PATH"} +TRAIN_FILE=${TRAIN_FILE:-"/YOUR_TRAIN_FILE_PATH"} +TEST_FILE=${TEST_FILE:-["/YOUR_TRAIN_FILE_PATH"]} + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +ppo_kl_coef=1 +kl_cov_ratio=0.0002 + +# Mathematically equivalent +use_dynamic_bsz=True +infer_micro_batch_size=null +train_micro_batch_size=null +offload=False + +HYDRA_FULL_ERROR=1 python -m recipe.entropy.main_entropy \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.filter_overlong_prompts=False \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.gen_batch_size=${gen_prompt_bsz} \ + data.train_batch_size=${train_prompt_bsz} \ + data.return_raw_chat=True \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} \ + actor_rollout_ref.actor.policy_loss.kl_cov_ratio=${kl_cov_ratio} \ + actor_rollout_ref.actor.policy_loss.ppo_kl_coef=${ppo_kl_coef} \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \ + actor_rollout_ref.rollout.mode=sync \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${max_token} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${max_token} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${max_token} \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.weight_decay=0 \ + actor_rollout_ref.actor.optim.warmup_style=constant \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=${max_token} \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k="${top_k}" \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=False \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=1 \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + reward_model.reward_manager=dapo \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.logger='["console","wandb"]' \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes="${NNODES}" \ + trainer.val_before_train=False \ + trainer.test_freq=4 \ + trainer.save_freq=32 \ + trainer.total_epochs=1000 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=disable diff --git a/recipe/entropy/7b_clip_cov.sh b/recipe/entropy/7b_clip_cov.sh new file mode 100644 index 000000000..7a68f37df --- /dev/null +++ b/recipe/entropy/7b_clip_cov.sh @@ -0,0 +1,145 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +export WANDB_API_KEY=YOUR_WANDB_API_KEY +# export VLLM_USE_V1=1 + +project_name='Qwen2.5-7B' +exp_name='clipcov' + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=1 +clip_ratio_high=1 +clip_cov_ratio=0.0002 +clip_cov_lb=1.0 +clip_cov_ub=5.0 + +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 8)) +enable_overlong_buffer=False +overlong_buffer_len=$((1024 * 2)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" +loss_mode="clip_cov" +enable_filter_groups=True +filter_groups_metric=acc +max_num_gen_batches=10 +train_prompt_bsz=256 +gen_prompt_bsz=$((train_prompt_bsz * 3)) +train_prompt_mini_bsz=32 +n_resp_per_prompt=8 +max_token=30720 + +# Ray +RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +WORKING_DIR=${WORKING_DIR:-"${PWD}"} +RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} +NNODES=${NNODES:-4} +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +MODEL_PATH=${MODEL_PATH:-"/YOUR_MODELPATH"} +CKPTS_DIR=${CKPTS_DIR:-"/YOUR_CKPTS_PATH"} +TRAIN_FILE=${TRAIN_FILE:-"/YOUR_TRAIN_FILE_PATH"} +TEST_FILE=${TEST_FILE:-["/YOUR_TRAIN_FILE_PATH"]} + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +ppo_kl_coef=1 +kl_cov_ratio=0.2 + +# Mathematically equivalent +use_dynamic_bsz=True +infer_micro_batch_size=null +train_micro_batch_size=null +offload=False + +HYDRA_FULL_ERROR=1 python -m recipe.entropy.main_entropy \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.filter_overlong_prompts=False \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.gen_batch_size=${gen_prompt_bsz} \ + data.train_batch_size=${train_prompt_bsz} \ + data.return_raw_chat=True \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} \ + actor_rollout_ref.actor.policy_loss.clip_cov_ratio=${clip_cov_ratio} \ + actor_rollout_ref.actor.policy_loss.clip_cov_lb=${clip_cov_lb} \ + actor_rollout_ref.actor.policy_loss.clip_cov_ub=${clip_cov_ub} \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \ + actor_rollout_ref.rollout.mode=sync \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${max_token} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${max_token} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${max_token} \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.weight_decay=0 \ + actor_rollout_ref.actor.optim.warmup_style=constant \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=${max_token} \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k="${top_k}" \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=False \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=1 \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + reward_model.reward_manager=dapo \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.logger='["console","wandb"]' \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes="${NNODES}" \ + trainer.val_before_train=False \ + trainer.test_freq=4 \ + trainer.save_freq=32 \ + trainer.total_epochs=1000 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=disable diff --git a/verl/recipe/dapo/run_dapo_qwen2.5_32b.sh b/recipe/entropy/7b_kl_cov.sh similarity index 60% rename from verl/recipe/dapo/run_dapo_qwen2.5_32b.sh rename to recipe/entropy/7b_kl_cov.sh index 97007a6ff..5dd1f8870 100644 --- a/verl/recipe/dapo/run_dapo_qwen2.5_32b.sh +++ b/recipe/entropy/7b_kl_cov.sh @@ -1,8 +1,11 @@ -/usr/bin/env bash -set -euxo pipefail +#!/usr/bin/env bash +set -xeuo pipefail -project_name='DAPO' -exp_name='DAPO-Qwen2.5-32B' +export WANDB_API_KEY=YOUR_WANDB_API_KEY +# export VLLM_USE_V1=1 + +project_name='Qwen2.5-7B' +exp_name='klcov' adv_estimator=grpo @@ -12,121 +15,127 @@ use_kl_loss=False kl_loss_coef=0.0 clip_ratio_low=0.2 -clip_ratio_high=0.28 +clip_ratio_high=0.2 max_prompt_length=$((1024 * 2)) -max_response_length=$((1024 * 20)) -enable_overlong_buffer=True -overlong_buffer_len=$((1024 * 4)) +max_response_length=$((1024 * 8)) +enable_overlong_buffer=False +overlong_buffer_len=$((1024 * 2)) overlong_penalty_factor=1.0 loss_agg_mode="token-mean" - +loss_mode="kl_cov" enable_filter_groups=True filter_groups_metric=acc max_num_gen_batches=10 -train_prompt_bsz=512 +train_prompt_bsz=256 gen_prompt_bsz=$((train_prompt_bsz * 3)) -n_resp_per_prompt=16 train_prompt_mini_bsz=32 +n_resp_per_prompt=8 +max_token=30720 # Ray RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} WORKING_DIR=${WORKING_DIR:-"${PWD}"} RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} -NNODES=${NNODES:-16} +NNODES=${NNODES:-4} # Paths RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} -MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-32B"} -CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} -TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} -TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} +MODEL_PATH=${MODEL_PATH:-"/YOUR_MODELPATH"} +CKPTS_DIR=${CKPTS_DIR:-"/YOUR_CKPTS_PATH"} +TRAIN_FILE=${TRAIN_FILE:-"/YOUR_TRAIN_FILE_PATH"} +TEST_FILE=${TEST_FILE:-["/YOUR_TRAIN_FILE_PATH"]} # Algorithm temperature=1.0 top_p=1.0 top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +ppo_kl_coef=1 +kl_cov_ratio=0.002 -# Performance Related Parameter -sp_size=8 +# Mathematically equivalent use_dynamic_bsz=True -actor_ppo_max_token_len=$((max_prompt_length + max_response_length)) -infer_ppo_max_token_len=$((max_prompt_length + max_response_length)) -offload=True -gen_tp=4 +infer_micro_batch_size=null +train_micro_batch_size=null +offload=False -ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \ - --working-dir "${WORKING_DIR}" \ - -- python3 -m recipe.dapo.src.main_dapo \ +HYDRA_FULL_ERROR=1 python -m recipe.entropy.main_entropy \ data.train_files="${TRAIN_FILE}" \ data.val_files="${TEST_FILE}" \ data.prompt_key=prompt \ data.truncation='left' \ + data.filter_overlong_prompts=False \ data.max_prompt_length=${max_prompt_length} \ data.max_response_length=${max_response_length} \ data.gen_batch_size=${gen_prompt_bsz} \ data.train_batch_size=${train_prompt_bsz} \ + data.return_raw_chat=True \ actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ - algorithm.adv_estimator=${adv_estimator} \ - algorithm.use_kl_in_reward=${use_kl_in_reward} \ - algorithm.kl_ctrl.kl_coef=${kl_coef} \ actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} \ + actor_rollout_ref.actor.policy_loss.kl_cov_ratio=${kl_cov_ratio} \ + actor_rollout_ref.actor.policy_loss.ppo_kl_coef=${ppo_kl_coef} \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \ + actor_rollout_ref.rollout.mode=sync \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ algorithm.filter_groups.enable=${enable_filter_groups} \ - algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ algorithm.filter_groups.metric=${filter_groups_metric} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ actor_rollout_ref.model.use_remove_padding=True \ actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ - actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ - actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ - actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${max_token} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${max_token} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${max_token} \ actor_rollout_ref.model.path="${MODEL_PATH}" \ - +actor_rollout_ref.model.override_config.attention_dropout=0. \ - +actor_rollout_ref.model.override_config.embd_pdrop=0. \ - +actor_rollout_ref.model.override_config.resid_pdrop=0. \ actor_rollout_ref.model.enable_gradient_checkpointing=True \ actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ - actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.optim.weight_decay=0 \ + actor_rollout_ref.actor.optim.warmup_style=constant \ actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ actor_rollout_ref.actor.entropy_coeff=0 \ actor_rollout_ref.actor.grad_clip=1.0 \ actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ - actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ actor_rollout_ref.rollout.enable_chunked_prefill=True \ - actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.rollout.max_num_batched_tokens=${max_token} \ actor_rollout_ref.rollout.temperature=${temperature} \ actor_rollout_ref.rollout.top_p=${top_p} \ actor_rollout_ref.rollout.top_k="${top_k}" \ actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ actor_rollout_ref.rollout.val_kwargs.top_p=${top_p} \ actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ - actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.do_sample=False \ actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ - actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=1 \ actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ reward_model.reward_manager=dapo \ reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ reward_model.overlong_buffer.len=${overlong_buffer_len} \ reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ - trainer.logger=['console','wandb'] \ + trainer.logger='["console","wandb"]' \ trainer.project_name="${project_name}" \ trainer.experiment_name="${exp_name}" \ trainer.n_gpus_per_node=8 \ trainer.nnodes="${NNODES}" \ - trainer.val_before_train=True \ - trainer.test_freq=5 \ - trainer.save_freq=5 \ - trainer.total_epochs=1 \ + trainer.val_before_train=False \ + trainer.test_freq=4 \ + trainer.save_freq=32 \ + trainer.total_epochs=1000 \ trainer.default_local_dir="${CKPTS_DIR}" \ - trainer.resume_mode=auto \ No newline at end of file + trainer.resume_mode=disable diff --git a/recipe/entropy/README.md b/recipe/entropy/README.md new file mode 100644 index 000000000..5238cec84 --- /dev/null +++ b/recipe/entropy/README.md @@ -0,0 +1,110 @@ +
+ +# The Entropy Mechanism of Reinforcement Learning for Large Language Model Reasoning. + +[![Paper](https://img.shields.io/badge/paper-A42C25?style=for-the-badge&logo=arxiv&logoColor=white)](https://arxiv.org/pdf/2505.22617) [![Github](https://img.shields.io/badge/PRIME-000000?style=for-the-badge&logo=github&logoColor=000&logoColor=white)](https://github.com/PRIME-RL/Entropy-Mechanism-of-RL) [![alphaXiv](https://img.shields.io/badge/discussion-A42C25?style=for-the-badge&logo=arxiv&logoColor=white&color=blue +)](https://www.alphaxiv.org/abs/2505.22617) [![Twitter](https://img.shields.io/badge/Twitter-%23000000.svg?style=for-the-badge&logo=twitter&logoColor=white)](https://x.com/stingning/status/1928088554166505667) [![Twitter](https://img.shields.io/badge/Twitter-%23000000.svg?style=for-the-badge&logo=twitter&logoColor=white)](https://x.com/charlesfornlp/status/1928089451080585283) [![Twitter-ak](https://img.shields.io/badge/Twitter-%23000000.svg?style=for-the-badge&logo=twitter&logoColor=white)](https://x.com/_akhaliq/status/1928077929105268861) + + + + +
+ + +# 🎉News + +- **[2025/05/29]** 🎉 Ranked **#1** of the day on [Huggingface Daily Papers](https://huggingface.co/papers?date=2025-05-29). +- **[2025/05/29]** Released our Paper on arXiv. See [here](https://arxiv.org/pdf/2505.22617). We provide insights into the entropy mechanism of RL for LLMs and propose two simple yet effective strategies to alleviate the entropy collapse. + + + +# ✨Getting started + +After preparing the training data, for training Qwen2.5-7B on a single node, taking the KL-Cov approach as an example, you can simply run: + +``` +cd verl +conda activate your_env +bash recipe/dapo/7b_kl_cov.sh +``` + +While for training Qwen2.5-32B on multi nodes, you can run the following commands: + +``` +cd verl +conda activate your_env +bash recipe/dapo/32b_kl_cov.sh +``` + +# 📖Introduction + +
+ issue +
+ +This paper addresses the entropy collapse issue in scaling reinforcement learning (RL) for large language models (LLMs), where policy entropy drops sharply during training, leading to overconfidence and performance saturation. We empirically establish a relationship between entropy ($H$) and performance ($R$): $R=−aexp(H)+b$, showing performance is bottlenecked by entropy exhaustion. + +
+ issue +
+ +Theoretically, we find entropy changes are driven by the covariance between action probability and logit updates, which correlates with advantage in Policy Gradient methods. High-probability, high-advantage actions reduce entropy, while rare, high-advantage actions increase it. Empirically, the covariance term remains positive, explaining entropy’s monotonic decline. To mitigate this, we propose ​​Clip-Cov​​ and ​​KL-Cov​​, which restrict updates for high-covariance tokens. These methods effectively prevent entropy collapse, and improve performance. + +# 📃Evaluation + +
+ issue +
+ + +Our method is able to maintain a considerably higher level of entropy throughout training. For example, when the baseline's entropy reaches a plateau and can no longer be consumed, the KL-Cov method still sustains an entropy level over 10 times higher. Meanwhile, the response length of the policy model steadily increases, and its performance on the test set consistently surpasses that of the baseline. This indicates that our model is able to explore more freely during training, learning better policy through RL. +| **Method** | **AIME24** | **AIME25** | **AMC** | **MATH-500** | **OMNI-MATH** | **OlympiadBench** | **Minerva** | **Avg.** | +| ----------------- | ---------: | ---------: | -------: | -----------: | ------------: | ----------------: | ----------: | -------: | +| *Qwen2.5-7B* | | | | | | | | | +| GRPO | 21.2 | 9.6 | 58.7 | 78.8 | 27.9 | 40.7 | 36.7 | 38.6 | +| w. Clip-higher | 18.1 | 11.5 | 56.6 | 79.2 | 29.8 | 43.3 | 40.4 | 38.8 | +| w. **`CLIP-Cov`** | 22.1 | **15.8** | 58.2 | 80.4 | **30.5** | **44.1** | **41.1** | 40.4 | +| w. **`KL-Cov`** | **22.6** | 12.9 | **61.4** | **80.8** | 29.1 | 42.6 | 38.2 | **40.6** | +| *Qwen2.5-32B* | | | | | | | | | +| GRPO | 21.8 | 16.2 | 69.7 | 84.2 | 35.2 | 43.6 | 45.5 | 45.8 | +| w. Clip-higher | 35.6 | 22.3 | 69.5 | 77.2 | 35.1 | 42.5 | 43.0 | 47.2 | +| w. **`CLIP-Cov`** | 32.3 | 22.7 | 67.2 | **87.0** | **42.0** | **57.2** | 46.0 | 50.3 | +| w. **`KL-Cov`** | **36.8** | **30.8** | **74.5** | 84.6 | 39.1 | 49.0 | **46.3** | **52.2** | + +Our two approaches both achieve non-trivial improvements across all benchmarks. Compared to GRPO, our method outperforms it by 2.0% on average for the 7B model and by 6.4% for the 32B model. Moreover, we observe that our method yields more substantial gains on the larger Qwen2.5-32B. Specifically, our method achieves improvements of 15.0% and 14.6% compared to GRPO on the most challenging benchmarks, AIME24 and AIME25, respectively. + + +# 🎈Citation +If you find this paper or repo helpful, please cite us. + +```bibtex +@article{cui2025entropy, + title={The Entropy Mechanism of Reinforcement Learning for Reasoning Language Models}, + author={Cui, Ganqu and Zhang, Yuchen and Chen, Jiacheng and Yuan, Lifan and Wang, Zhi and Zuo, Yuxin and Li, Haozhan and Fan, Yuchen and Chen, Huayu and Chen, Weize and others}, + journal={arXiv preprint arXiv:2505.22617}, + year={2025} +} +``` +# 🌻Acknowledgement +We implement our reinforcement learning algorithm extending from [verl](https://github.com/volcengine/verl). We utilize [vLLM](https://github.com/vllm-project/vllm) for inference. Our models are trained primarily on [Qwen2.5 family](https://github.com/QwenLM/Qwen2.5). Our training data is built from [DAPO-MATH](https://huggingface.co/datasets/BytedTsinghua-SIA/DAPO-Math-17k). Thanks for their great contributions! + +# 📬 Contact + +For questions, discussion, or collaboration opportunities, feel free to contact: +- Ganqu Cui: cuiganqu@pjlab.org.cn +- Yuchen Zhang: yuchen.zhang2003@gmail.com +- Jiacheng Chen: jackchan9345@gmail.com +- Ning Ding: ningding.cs@gmail.com + diff --git a/recipe/entropy/config/entropy_trainer.yaml b/recipe/entropy/config/entropy_trainer.yaml new file mode 100644 index 000000000..969c72946 --- /dev/null +++ b/recipe/entropy/config/entropy_trainer.yaml @@ -0,0 +1,39 @@ +hydra: + searchpath: + - file://verl/trainer/config + +defaults: + - ppo_trainer + - _self_ + +data: + gen_batch_size: ${data.train_batch_size} + +reward_model: + reward_kwargs: + overlong_buffer_cfg: ${reward_model.overlong_buffer} + reward_manager: dapo + overlong_buffer: + enable: False + len: 0 + penalty_factor: 0.0 + log: False + +algorithm: + filter_groups: + enable: False # We try to avoid forgetting to set enable + metric: null # acc / score / seq_reward / seq_final_reward / ... + max_num_gen_batches: 0 # Non-positive values mean no upper limit + +trainer: + project_name: verl-entropy + +actor_rollout_ref: + actor: + policy_loss: + loss_mode: "vanilla" # /clip-cov / kl-cov from https://arxiv.org/abs/2505. + clip_cov_ratio: 0.0002 # for clip-cov loss + clip_cov_lb: 1.0 # for clip-cov loss + clip_cov_ub: 5.0 # for clip-cov loss + kl_cov_ratio: 0.0002 # for kl-cov loss + ppo_kl_coef: 0.1 # for kl-cov loss \ No newline at end of file diff --git a/recipe/entropy/entropy_ray_trainer.py b/recipe/entropy/entropy_ray_trainer.py new file mode 100644 index 000000000..49f2b5feb --- /dev/null +++ b/recipe/entropy/entropy_ray_trainer.py @@ -0,0 +1,663 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +FSDP PPO Trainer with Ray-based single controller. +This trainer supports model-agonistic model initialization with huggingface +""" + +import os +import uuid +from pprint import pprint +from copy import deepcopy +from collections import defaultdict +from tqdm import tqdm + +import numpy as np +import pandas as pd +import ray +import torch + +from verl import DataProto +from verl.trainer.ppo.metric_utils import ( + compute_data_metrics, + compute_throughout_metrics, + compute_timing_metrics, + reduce_metrics, +) +from verl.trainer.ppo.ray_trainer import ( + AdvantageEstimator, + RayPPOTrainer, + apply_kl_penalty, + compute_advantage, + compute_response_mask, +) +from verl.utils.profiler import simple_timer + + +class RayEntropyTrainer(RayPPOTrainer): + """ + Note that this trainer runs on the driver process on a single CPU/GPU node. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Initialize the drop count attributes to ensure they're always available + self.n_drop_easy = 0 + self.n_drop_hard = 0 + + def _create_priority_dataloader(self, epoch_idx, dynamic_filtering, enable_budget): + """ + Create the dataloader every time before the epoch starts. + """ + from torch.utils.data import SequentialSampler + from verl.trainer.main_ppo import create_rl_sampler + from verl.utils.dataset.rl_dataset import collate_fn + from torchdata.stateful_dataloader import StatefulDataLoader + + # Initialize columns for the first epoch + max_easy_ratio = self.config.data.get("max_easy_ratio", 0.1) + max_hard_ratio = self.config.data.get("max_hard_ratio", 0.2) + if epoch_idx == 0: + # Get the initial pass rate column name from config, with default fallback + initial_pass_rate_column = self.config.data.get("initial_pass_rate_column", "qwen3_30b_pass_rate") + self.train_dataset.dataframe["prev_pass_rate"] = self.train_dataset.dataframe[initial_pass_rate_column] + # use half of the max response length as the average length for the first epoch + self.train_dataset.dataframe["prev_passed_avg_length"] = self.config.data.get("max_response_length", 1024*28) * 3 / 4 + self.train_dataset.dataframe["prev_passed_max_length"] = self.config.data.get("max_response_length", 1024*28) * 3 / 4 + + + original_df = self.train_dataset.dataframe.copy() + + if dynamic_filtering: + # Separate data by pass rate + perfect_mask = original_df["prev_pass_rate"] == 1.0 + failed_mask = original_df["prev_pass_rate"] == 0.0 + medium_mask = (original_df["prev_pass_rate"] > 0.0) & (original_df["prev_pass_rate"] < 1.0) + + # Get indices for each category + medium_indices = original_df[medium_mask].index.tolist() + perfect_indices = original_df[perfect_mask].index.tolist() + failed_indices = original_df[failed_mask].index.tolist() + + # Keep all medium difficulty data + kept_indices = set(medium_indices) + n_medium = len(medium_indices) + + # Limit perfect examples to 1/10 of medium examples + self.n_drop_easy = 0 + if perfect_indices: + np.random.seed(42 + epoch_idx) + n_keep_perfect = int(max(1, min(n_medium * max_easy_ratio, len(perfect_indices)))) + if n_keep_perfect > 0: + kept_perfect = np.random.choice(perfect_indices, size=n_keep_perfect, replace=False) + kept_indices.update(kept_perfect) + self.n_drop_easy = len(perfect_indices) - n_keep_perfect + + # Limit failed examples to 1/5 of medium examples + self.n_drop_hard = 0 + if failed_indices: + np.random.seed(43 + epoch_idx) + n_keep_failed = int(max(1, min(n_medium * max_hard_ratio, len(failed_indices)))) + if n_keep_failed > 0: + kept_failed = np.random.choice(failed_indices, size=n_keep_failed, replace=False) + kept_indices.update(kept_failed) + self.n_drop_hard = len(failed_indices) - n_keep_failed + + filtered_df = original_df.loc[list(kept_indices)].reset_index(drop=True) + # Log filtering statistics + n_perfect_kept = len(set(perfect_indices) & kept_indices) + n_failed_kept = len(set(failed_indices) & kept_indices) + n_medium_kept = len(set(medium_indices) & kept_indices) + + print(f"Dataset filtering statistics for epoch {epoch_idx}:") + print(f"Original dataset size: {len(original_df)}") + print(f" - Perfect examples (pass_rate=1.0): {len(perfect_indices)} -> {n_perfect_kept} kept ({n_perfect_kept/max(1,len(perfect_indices))*100:.1f}%)") + print(f" - Failed examples (pass_rate=0.0): {len(failed_indices)} -> {n_failed_kept} kept ({n_failed_kept/max(1,len(failed_indices))*100:.1f}%)") + print(f" - Medium examples (0 {n_medium_kept} kept ({n_medium_kept/max(1,len(medium_indices))*100:.1f}%)") + print(f"Filtered dataset size: {len(filtered_df)}") + print(f"Total discarded data points: {len(original_df) - len(filtered_df)}") + print(f"Total percentage discarded: {100 * (len(original_df) - len(filtered_df)) / len(original_df):.2f}%") + else: + filtered_df = original_df.copy() + + def assign_length_budget(row, pass_rate_upper_bound, max_response_length): + prompt_pass_rate = row['prev_pass_rate'] + passed_prompt_avg_length = row['prev_passed_avg_length'] + passed_prompt_max_length = row['prev_passed_max_length'] + + # Get configurable multipliers with default values + perfect_pass_rate_multiplier = self.config.data.get("perfect_pass_rate_multiplier", 1.0) + high_pass_rate_multiplier = self.config.data.get("high_pass_rate_multiplier", 0.8) + + if prompt_pass_rate == 1.0: + new_length_budget = max(high_pass_rate_multiplier * passed_prompt_max_length, passed_prompt_avg_length) + elif prompt_pass_rate > pass_rate_upper_bound: + new_length_budget = max(high_pass_rate_multiplier * passed_prompt_max_length, passed_prompt_avg_length) + else: + new_length_budget = passed_prompt_max_length + (max_response_length - passed_prompt_max_length) * (1 - prompt_pass_rate) + + new_length_budget = max(new_length_budget, 4000) # Set minimum to 2000 + new_length_budget = min(new_length_budget, max_response_length) # Cap at max response length + + return int(new_length_budget) + + if enable_budget: + # Shared logic for computing per_prompt_length_budget and creating dataloader + max_response_length = self.config.data.get("max_response_length", 1024*28) + pass_rate_upper_bound = self.config.data.get("pass_rate_upper_bound", 1.0) + + + filtered_df["per_prompt_length_budget"] = filtered_df.apply( + lambda row: assign_length_budget(row, pass_rate_upper_bound, max_response_length), axis=1 + ) + filtered_df = filtered_df.sort_values(by="per_prompt_length_budget", ascending=True).reset_index(drop=True) + else: + filtered_df["per_prompt_length_budget"] = self.config.data.get("max_response_length", 1024*28) # Use fixed length budget + + # Sort by per_prompt_length_budget for more efficient rollout batching + filtered_df = filtered_df.sort_values(by="per_prompt_length_budget", ascending=True).reset_index(drop=True) + + # Create filtered dataset copy + train_dataset_copy = deepcopy(self.train_dataset) + train_dataset_copy.dataframe = filtered_df + + # Create dataloader + self.train_dataloader = StatefulDataLoader( + dataset=train_dataset_copy, + batch_size=self.config.data.get("gen_batch_size", self.config.data.train_batch_size), + num_workers=self.config.data.get("dataloader_num_workers", 8), + drop_last=True, + collate_fn=collate_fn, + sampler=SequentialSampler(data_source=filtered_df), + ) + + print(f"Size of train dataloader: {len(self.train_dataloader)}, Size of val dataloader: {len(self.val_dataloader)}") + assert len(self.train_dataloader) >= 1, "Train dataloader is empty!" + assert len(self.val_dataloader) >= 1, "Validation dataloader is empty!" + + return train_dataset_copy + + def fit(self): + """ + The training loop of PPO. + The driver process only need to call the compute functions of the worker group through RPC + to construct the PPO dataflow. + The light-weight advantage computation is done on the driver process. + """ + from verl.utils.tracking import Tracking + from omegaconf import OmegaConf + + from verl.utils.tracking import Tracking + + logger = Tracking( + project_name=self.config.trainer.project_name, + experiment_name=self.config.trainer.experiment_name, + default_backend=self.config.trainer.logger, + config=OmegaConf.to_container(self.config, resolve=True), + ) + self.global_steps = 0 + + # load checkpoint before doing anything + self._load_checkpoint() + + # perform validation on the training data for data filtering + # self._validate_training_data() + + # perform validation before training + # currently, we only support validation using the reward_function. + if self.val_reward_fn is not None and self.config.trainer.get('val_before_train', True): + val_metrics = self._validate() + assert val_metrics, f"{val_metrics=}" + pprint(f"Initial validation metrics: {val_metrics}") + logger.log(data=val_metrics, step=self.global_steps) + if self.config.trainer.get('val_only', False): + return + + # add tqdm + progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress") + + # we start from step 1 + self.global_steps += 1 + last_val_metrics = None + + timing_raw = defaultdict(float) + batch = None + num_prompt_in_batch = 0 + num_gen_batches = 0 + + for epoch in range(self.config.trainer.total_epochs): + for batch_dict in self.train_dataloader: + metrics = {} + + for batch_dict in self.train_dataloader: + metrics = {} + # Here the self.train_dataset is the whole dataset, while self.train_dataloader is a + # DataLoader that yields batches of data across GPUs. + # len(self.train_dataloader) * #GPUs = len(self.train_dataset) + # (bsz, seq_len) + new_batch: DataProto = DataProto.from_single_dict(batch_dict) + num_gen_batches += 1 + # pop those keys for generation + if "multi_modal_inputs" in new_batch.non_tensor_batch.keys(): + gen_batch = new_batch.pop( + batch_keys=["input_ids", "attention_mask", "position_ids"], + non_tensor_batch_keys=["raw_prompt_ids", "multi_modal_data", "multi_modal_inputs"], + ) + else: + gen_batch = new_batch.pop( + batch_keys=["input_ids", "attention_mask", "position_ids"], + non_tensor_batch_keys=["raw_prompt_ids"], + ) + gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) + + is_last_step = self.global_steps >= self.total_training_steps + + with simple_timer("step", timing_raw): + # generate a batch + # with simple_timer("gen", timing_raw): + # gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) + with simple_timer("gen", timing_raw): + if not self.async_rollout_mode: + gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) + else: + gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch) + + if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX: + with simple_timer("gen_max", timing_raw): + gen_baseline_batch = deepcopy(gen_batch) + gen_baseline_batch.meta_info['do_sample'] = False + gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch) + + new_batch = new_batch.union(gen_baseline_output) + reward_baseline_tensor = self.reward_fn(new_batch) + reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1) + + new_batch.pop(batch_keys=list(gen_baseline_output.batch.keys())) + + new_batch.batch['reward_baselines'] = reward_baseline_tensor + + del gen_baseline_batch, gen_baseline_output + + new_batch.non_tensor_batch["uid"] = np.array( + [str(uuid.uuid4()) for _ in range(len(new_batch.batch))], dtype=object + ) + # repeat to align with repeated responses in rollout + new_batch = new_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) + new_batch = new_batch.union(gen_batch_output) + + with simple_timer("reward", timing_raw): + # compute scores. Support both model and function-based. + # We first compute the scores using reward model. Then, we call reward_fn to combine + # the results from reward model and rule-based results. + if self.use_rm: + # we first compute reward model score + reward_tensor = self.rm_wg.compute_rm_score(new_batch) + new_batch = new_batch.union(reward_tensor) + + # we combine with rule-based rm + reward_extra_infos_dict: dict[str, list] + try: + reward_result = self.reward_fn(new_batch, return_dict=True) + reward_tensor = reward_result["reward_tensor"] + reward_extra_infos_dict = reward_result["reward_extra_info"] + except Exception as e: + print(f"Error in reward_fn: {e}") + reward_tensor = self.reward_fn(new_batch) + reward_extra_infos_dict = {} + + new_batch.batch['token_level_scores'] = reward_tensor + + print(f'{list(reward_extra_infos_dict.keys())=}') + if reward_extra_infos_dict: + new_batch.non_tensor_batch.update( + {k: np.array(v) for k, v in reward_extra_infos_dict.items()} + ) + + # compute rewards. apply_kl_penalty if available + if self.config.algorithm.use_kl_in_reward: + new_batch, kl_metrics = apply_kl_penalty( + new_batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty + ) + metrics.update( + kl_metrics + ) # TODO: This will be cleared if we use multiple genenration batches + else: + new_batch.batch['token_level_rewards'] = new_batch.batch['token_level_scores'] + + if not self.config.algorithm.filter_groups.enable: + batch = new_batch + else: # NOTE: When prompts after filtering is less than train batch size, + # we skip to the next generation batch + metric_name = self.config.algorithm.filter_groups.metric + if metric_name == "seq_final_reward": + # Turn to numpy for easier filtering + new_batch.non_tensor_batch["seq_final_reward"] = ( + new_batch.batch["token_level_rewards"].sum(dim=-1).numpy() + ) + elif metric_name == "seq_reward": + new_batch.non_tensor_batch["seq_reward"] = ( + new_batch.batch["token_level_scores"].sum(dim=-1).numpy() + ) + + # Collect the sequence reward for each trajectory + prompt_uid2metric_vals = defaultdict(list) + for uid, metric_val in zip( + new_batch.non_tensor_batch["uid"], new_batch.non_tensor_batch[metric_name], strict=True + ): + prompt_uid2metric_vals[uid].append(metric_val) + + prompt_uid2metric_std = {} + for prompt_uid, metric_vals in prompt_uid2metric_vals.items(): + prompt_uid2metric_std[prompt_uid] = np.std(metric_vals) + + kept_prompt_uids = [ + uid + for uid, std in prompt_uid2metric_std.items() + if std > 0 or len(prompt_uid2metric_vals[uid]) == 1 + ] + num_prompt_in_batch += len(kept_prompt_uids) + + kept_traj_idxs = [] + for idx, traj_from_prompt_uid in enumerate(new_batch.non_tensor_batch['uid']): + if traj_from_prompt_uid in kept_prompt_uids: + kept_traj_idxs.append(idx) + + new_batch = new_batch[kept_traj_idxs] + batch = new_batch if batch is None else DataProto.concat([batch, new_batch]) + + prompt_bsz = self.config.data.train_batch_size + if num_prompt_in_batch < prompt_bsz: + print(f'{num_prompt_in_batch=} < {prompt_bsz=}') + max_num_gen_batches = self.config.algorithm.filter_groups.max_num_gen_batches + if max_num_gen_batches <= 0 or num_gen_batches < max_num_gen_batches: + print(f'{num_gen_batches=}. Keep generating...') + continue + else: + raise ValueError( + f"{num_gen_batches=} >= {max_num_gen_batches=}." + + " Generated too many. Please check if your data are too difficult." + + " You could also try set max_num_gen_batches=0 to enable endless trials." + ) + else: + # Align the batch + traj_bsz = self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n + print( + f"Collected {num_prompt_in_batch} / {self.config.data.train_batch_size} prompt. " + f"Collecting finished." + ) + batch = batch[:traj_bsz] + + # === Updating === + + batch.batch["response_mask"] = compute_response_mask(batch) + + # balance the number of valid tokens on each dp rank. + # Note that this breaks the order of data inside the batch. + # Please take care when you implement group based adv computation such as GRPO and rloo + if self.config.trainer.balance_batch and not self.config.trainer.enable_budget: + self._balance_batch(batch, metrics=metrics) + + # compute global_valid tokens + batch.meta_info['global_token_num'] = torch.sum(batch.batch['attention_mask'], dim=-1).tolist() + + # recompute old_log_probs + with simple_timer("old_log_prob", timing_raw): + old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) + batch = batch.union(old_log_prob) + + if self.use_reference_policy: + # compute reference log_prob + with simple_timer("ref", timing_raw): + ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) + batch = batch.union(ref_log_prob) + + # compute values + if self.use_critic: + with simple_timer("values", timing_raw): + values = self.critic_wg.compute_values(batch) + batch = batch.union(values) + + with simple_timer("adv", timing_raw): + # compute advantages, executed on the driver process + norm_adv_by_std_in_grpo = self.config.algorithm.get("norm_adv_by_std_in_grpo", True) + batch = compute_advantage( + batch, + adv_estimator=self.config.algorithm.adv_estimator, + gamma=self.config.algorithm.gamma, + lam=self.config.algorithm.lam, + num_repeat=self.config.actor_rollout_ref.rollout.n, + norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, + ) + with _timer('pass_rate_append', timing_raw): + # compute the pass rate for the batch + temp_df = pd.DataFrame({ + "prompt_id": batch.non_tensor_batch["prompt_id"], + "prev_pass_rate": batch.non_tensor_batch["score"] + }) + pass_rate_df = temp_df.groupby("prompt_id", as_index=False)["prev_pass_rate"].mean().set_index('prompt_id')[['prev_pass_rate']] + + # compute the average response length for each prompt_id in the batch + # Only include lengths for successful rollouts (score == 1) + response_length = batch.batch["responses"].shape[-1] + response_mask = batch.batch["attention_mask"][:, -response_length:] + response_lengths = response_mask.sum(dim=-1).float().cpu().numpy() # actual lengths per response + scores = batch.non_tensor_batch["score"] + + # Filter for successful rollouts only (score == 1) + successful_mask = scores == 1 + if np.any(successful_mask): + successful_prompt_ids = batch.non_tensor_batch["prompt_id"][successful_mask] + successful_response_lengths = response_lengths[successful_mask] + + temp_length_df = pd.DataFrame({ + "prompt_id": successful_prompt_ids, + "response_length": successful_response_lengths + }) + avg_length_df = temp_length_df.groupby("prompt_id", as_index=False)["response_length"].mean().set_index('prompt_id')[['response_length']] + avg_length_df.rename(columns={"response_length": "prev_passed_avg_length"}, inplace=True) + max_length_df = temp_length_df.groupby("prompt_id", as_index=False)["response_length"].max().set_index('prompt_id')[['response_length']] + max_length_df.rename(columns={"response_length": "prev_passed_max_length"}, inplace=True) + + # Update the dataframe with both pass rates and average lengths + self.train_dataset.dataframe = self.train_dataset.dataframe.set_index('prompt_id') + self.train_dataset.dataframe.update(pass_rate_df) + self.train_dataset.dataframe.update(avg_length_df) + self.train_dataset.dataframe.update(max_length_df) + self.train_dataset.dataframe = self.train_dataset.dataframe.reset_index() + else: + # If no successful rollouts in this batch, only update pass rates + self.train_dataset.dataframe = self.train_dataset.dataframe.set_index('prompt_id') + self.train_dataset.dataframe.update(pass_rate_df) + self.train_dataset.dataframe = self.train_dataset.dataframe.reset_index() + print("No successful rollouts (score=1) in this batch, skipping length update") + + # update critic + if self.use_critic: + with simple_timer("update_critic", timing_raw): + critic_output = self.critic_wg.update_critic(batch) + critic_output_metrics = reduce_metrics(critic_output.meta_info['metrics']) + metrics.update(critic_output_metrics) + + # implement critic warmup + if self.config.trainer.critic_warmup <= self.global_steps: + # update actor + with simple_timer("update_actor", timing_raw): + actor_output = self.actor_rollout_wg.update_actor(batch) + actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) + metrics.update(actor_output_metrics) + + # validate + if ( + self.val_reward_fn is not None + and self.config.trainer.test_freq > 0 + and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0) + ): + with simple_timer("testing", timing_raw): + val_metrics: dict = self._validate() + if is_last_step: + last_val_metrics = val_metrics + metrics.update(val_metrics) + + if self.config.trainer.save_freq > 0 and ( + is_last_step or self.global_steps % self.config.trainer.save_freq == 0 + ): + with simple_timer("save_checkpoint", timing_raw): + self._save_checkpoint() + + # collect metrics + metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) + metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) + # TODO: implement actual tflpo and theoretical tflpo + # Add per-prompt metrics for the current batch only + if hasattr(train_dataset, 'dataframe') and batch is not None: + # Get unique prompt_ids from the current batch + batch_prompt_ids = batch.non_tensor_batch["prompt_id"] + unique_prompt_ids = np.unique(batch_prompt_ids) + + # Filter dataframe to only include prompts from current batch + batch_df = train_dataset.dataframe[train_dataset.dataframe['prompt_id'].isin(unique_prompt_ids)] + + if len(batch_df) > 0: + metrics.update({ + "train/per_prompt_pass_rate_avg": batch_df["prev_pass_rate"].mean(), + "train/per_prompt_pass_rate_std": batch_df["prev_pass_rate"].std(), + "train/per_prompt_pass_rate_min": batch_df["prev_pass_rate"].min(), + "train/per_prompt_pass_rate_max": batch_df["prev_pass_rate"].max(), + "train/num_unique_prompts": len(unique_prompt_ids), + "train/per_prompt_length_budget_avg": batch_df["per_prompt_length_budget"].mean(), + "train/per_prompt_length_budget_std": batch_df["per_prompt_length_budget"].std(), + "train/per_prompt_length_budget_min": batch_df["per_prompt_length_budget"].min(), + "train/per_prompt_length_budget_max": batch_df["per_prompt_length_budget"].max(), + "train/prev_passed_max_length_avg": batch_df["prev_passed_max_length"].mean(), + "train/prev_passed_max_length_std": batch_df["prev_passed_max_length"].std(), + "train/prev_passed_max_length_min": batch_df["prev_passed_max_length"].min(), + "train/prev_passed_max_length_max": batch_df["prev_passed_max_length"].max(), + "train/prev_passed_avg_length_avg": batch_df["prev_passed_avg_length"].mean(), + "train/prev_passed_avg_length_std": batch_df["prev_passed_avg_length"].std(), + "train/prev_passed_avg_length_min": batch_df["prev_passed_avg_length"].min(), + "train/prev_passed_avg_length_max": batch_df["prev_passed_avg_length"].max() + }) + + metrics["train/num_gen_batches"] = num_gen_batches + metrics['train/num_prompts'] = len(train_dataset.dataframe) + metrics['train/perct_dropped_prompts'] = 100 * ( (len(self.train_dataset.dataframe) - len(train_dataset.dataframe)) / len(self.train_dataset.dataframe)) + metrics['train/n_drop_easy'] = self.n_drop_easy if self.n_drop_easy is not None else 0 + metrics['train/n_drop_hard'] = self.n_drop_hard if self.n_drop_hard is not None else 0 + metrics['train/epoch'] = epoch + batch = None + num_prompt_in_batch = 0 + num_gen_batches = 0 + + # TODO: make a canonical logger that supports various backend + logger.log(data=metrics, step=self.global_steps) + + if is_last_step: + pprint(f'Final validation metrics: {last_val_metrics}') + progress_bar.close() + return + + progress_bar.update(1) + self.global_steps += 1 + + def _save_dataset_state(self, local_global_step_folder): + """ + Save the current dataset state including updated pass rates and lengths. + This is crucial for resuming training with enable_budget feature. + """ + if not self.config.trainer.get('enable_budget', False): + return + + dataset_state_path = os.path.join(local_global_step_folder, 'dataset_state.pt') + + # Save the current dataset state + dataset_state = { + 'dataframe': self.train_dataset.dataframe.copy(), + 'n_drop_easy': getattr(self, 'n_drop_easy', 0), + 'n_drop_hard': getattr(self, 'n_drop_hard', 0), + } + + torch.save(dataset_state, dataset_state_path) + print(f"Saved dataset state to {dataset_state_path}") + print(f" - Dataset size: {len(self.train_dataset.dataframe)}") + print(f" - Pass rate range: {self.train_dataset.dataframe['prev_pass_rate'].min():.3f} - {self.train_dataset.dataframe['prev_pass_rate'].max():.3f}") + if 'prev_passed_max_length' in self.train_dataset.dataframe.columns: + print(f" - Max length range: {self.train_dataset.dataframe['prev_passed_max_length'].min():.1f} - {self.train_dataset.dataframe['prev_passed_max_length'].max():.1f}") + + def _load_dataset_state(self, global_step_folder): + """ + Load the dataset state including updated pass rates and lengths. + This restores the learned statistics from previous training. + """ + if not self.config.trainer.get('enable_budget', False): + return + + dataset_state_path = os.path.join(global_step_folder, 'dataset_state.pt') + + if os.path.exists(dataset_state_path): + print(f"Loading dataset state from {dataset_state_path}") + dataset_state = torch.load(dataset_state_path, weights_only=False) + + # Restore dataset with updated pass rates and lengths + self.train_dataset.dataframe = dataset_state['dataframe'] + self.n_drop_easy = dataset_state.get('n_drop_easy', 0) + self.n_drop_hard = dataset_state.get('n_drop_hard', 0) + + print(f"Restored dataset state:") + print(f" - Dataset size: {len(self.train_dataset.dataframe)}") + print(f" - Pass rate range: {self.train_dataset.dataframe['prev_pass_rate'].min():.3f} - {self.train_dataset.dataframe['prev_pass_rate'].max():.3f}") + if 'prev_passed_avg_length' in self.train_dataset.dataframe.columns: + print(f" - Avg length range: {self.train_dataset.dataframe['prev_passed_avg_length'].min():.1f} - {self.train_dataset.dataframe['prev_passed_avg_length'].max():.1f}") + if 'prev_passed_max_length' in self.train_dataset.dataframe.columns: + print(f" - Max length range: {self.train_dataset.dataframe['prev_passed_max_length'].min():.1f} - {self.train_dataset.dataframe['prev_passed_max_length'].max():.1f}") + else: + print(f"No dataset state found at {dataset_state_path}, starting with original dataset") + self.n_drop_easy = 0 + self.n_drop_hard = 0 + + def _save_checkpoint(self): + """ + Override to include dataset state saving for enable_budget feature. + """ + # Call parent method to save models and dataloader + super()._save_checkpoint() + + # Save additional dataset state for enable_budget + local_global_step_folder = os.path.join(self.config.trainer.default_local_dir, f"global_step_{self.global_steps}") + self._save_dataset_state(local_global_step_folder) + + def _load_checkpoint(self): + """ + Override to include dataset state loading for enable_budget feature. + """ + # Store original global_steps to detect if we loaded from checkpoint + original_global_steps = self.global_steps + + # Call parent method to load models and dataloader + result = super()._load_checkpoint() + + # If global_steps changed, we loaded from checkpoint + if self.global_steps > original_global_steps: + checkpoint_folder = self.config.trainer.default_local_dir + if not os.path.isabs(checkpoint_folder): + working_dir = os.getcwd() + checkpoint_folder = os.path.join(working_dir, checkpoint_folder) + + # Find the same checkpoint folder that was loaded + from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path + global_step_folder = find_latest_ckpt_path(checkpoint_folder) + + if global_step_folder is not None: + self._load_dataset_state(global_step_folder) + + return result \ No newline at end of file diff --git a/recipe/entropy/main_entropy.py b/recipe/entropy/main_entropy.py new file mode 100644 index 000000000..a8bb0cb6a --- /dev/null +++ b/recipe/entropy/main_entropy.py @@ -0,0 +1,245 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Note that we don't combine the main with ray_trainer as ray_trainer is used by other main. +""" + +import hydra +import ray + +from .entropy_ray_trainer import RayEntropyTrainer +from .reward import load_reward_manager + + +@hydra.main(config_path="config", config_name="entropy_trainer", version_base=None) +def main(config): + run_ppo(config) + + +def run_ppo(config) -> None: + if not ray.is_initialized(): + # this is for local ray cluster + ray.init( + runtime_env={ + "env_vars": { + "TOKENIZERS_PARALLELISM": "true", + "NCCL_DEBUG": "WARN", + "VLLM_LOGGING_LEVEL": "WARN", + "WANDB_API_KEY": "YOUR_WANDB_API_KEY", + } + }, + num_cpus=config.ray_init.num_cpus, + ) + + runner = TaskRunner.remote() + ray.get(runner.run.remote(config)) + + +def merge_dict(a: dict, b: dict) -> dict: + """Return a new dict that has `a` updated with `b` (b wins on conflicts). + + Example:: + + >>> d1 = {"x": 1, "y": 2} + >>> d2 = {"y": 20, "z": 3} + >>> new_dict = merge_dict(d1, d2) + >>> print(new_dict) # {'x': 1, 'y': 20, 'z': 3} + >>> print(d1) # {"x": 1, "y": 2} (unchanged) + >>> print(d2) # {"y": 20, "z": 3} (unchanged) + """ + return a | b + + +@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head +class TaskRunner: + def run(self, config): + # print initial config + from pprint import pprint + + from omegaconf import OmegaConf + + from verl.utils.fs import copy_to_local + + pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values + OmegaConf.resolve(config) + + # download the checkpoint from hdfs + local_path = copy_to_local(config.actor_rollout_ref.model.path) + print(f"{config.actor_rollout_ref.model.path}") + # instantiate tokenizer + from verl.utils import hf_processor, hf_tokenizer + + trust_remote_code = config.data.get("trust_remote_code", False) + tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) + processor = hf_processor(local_path, use_fast=True) # used for multimodal LLM, could be none + + # define worker classes + if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}: + assert config.critic.strategy in {"fsdp", "fsdp2"} + from verl.single_controller.ray import RayWorkerGroup + from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker, CriticWorker + + actor_rollout_cls = ( + AsyncActorRolloutRefWorker + if config.actor_rollout_ref.rollout.mode == "async" + else ActorRolloutRefWorker + ) + ray_worker_group_cls = RayWorkerGroup + + elif config.actor_rollout_ref.actor.strategy == "megatron": + assert config.actor_rollout_ref.actor.strategy == config.critic.strategy + from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup + from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker + + actor_rollout_cls = ActorRolloutRefWorker + ray_worker_group_cls = NVMegatronRayWorkerGroup + + else: + raise NotImplementedError + + from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role + + role_worker_mapping = { + Role.ActorRollout: ray.remote(actor_rollout_cls), + Role.Critic: ray.remote(CriticWorker), + } + + global_pool_id = "global_pool" + resource_pool_spec = { + global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, + } + mapping = { + Role.ActorRollout: global_pool_id, + Role.Critic: global_pool_id, + } + + # we should adopt a multi-source reward function here + # - for rule-based rm, we directly call a reward score + # - for model-based rm, we call a model + # - for code related prompt, we send to a sandbox if there are test cases + # - finally, we combine all the rewards together + # - The reward type depends on the tag of the data + if config.reward_model.enable: + if config.reward_model.strategy in {"fsdp", "fsdp2"}: + from verl.workers.fsdp_workers import RewardModelWorker + elif config.reward_model.strategy == "megatron": + from verl.workers.megatron_workers import RewardModelWorker + else: + raise NotImplementedError + role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker) + mapping[Role.RewardModel] = global_pool_id + + # use reference model + if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss: + role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker) + mapping[Role.RefPolicy] = global_pool_id + + reward_kwargs = { + "max_resp_len": config.data.max_response_length, + "overlong_buffer_cfg": config.reward_model.overlong_buffer, + } + cfg_reward_kwargs = config.reward_model.get("reward_kwargs", {}) + reward_fn = load_reward_manager( + config, tokenizer, num_examine=0, **OmegaConf.merge(OmegaConf.create(reward_kwargs), cfg_reward_kwargs) + ) + val_reward_fn = load_reward_manager(config, tokenizer, num_examine=1, **reward_kwargs) + resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) + + from verl.utils.dataset.rl_dataset import collate_fn + + train_dataset = create_rl_dataset(config.data.train_files, config.data, tokenizer, processor) + val_dataset = create_rl_dataset(config.data.val_files, config.data, tokenizer, processor) + train_sampler = create_rl_sampler(config.data, train_dataset) + trainer = RayEntropyTrainer( + config=config, + tokenizer=tokenizer, + processor=processor, + role_worker_mapping=role_worker_mapping, + resource_pool_manager=resource_pool_manager, + ray_worker_group_cls=ray_worker_group_cls, + reward_fn=reward_fn, + val_reward_fn=val_reward_fn, + train_dataset=train_dataset, + val_dataset=val_dataset, + collate_fn=collate_fn, + train_sampler=train_sampler, + ) + trainer.init_workers() + trainer.fit() + + +def create_rl_dataset(data_paths, data_config, tokenizer, processor): + """Create a dataset. + + Arguments: + data_config: The data config. + tokenizer (Tokenizer): The tokenizer. + processor (Processor): The processor. + + Returns: + dataset (Dataset): The dataset. + """ + from torch.utils.data import Dataset + + from verl.utils.dataset.rl_dataset import RLHFDataset + + if "custom_cls" in data_config and data_config.custom_cls.get("path", None) is not None: + from verl.utils.import_utils import load_extern_type + + dataset_cls = load_extern_type(data_config.custom_cls.path, data_config.custom_cls.name) + if not issubclass(dataset_cls, Dataset): + raise TypeError( + f"The custom dataset class '{data_config.custom_cls.name}' from '{data_config.custom_cls.path}' " + f"must inherit from torch.utils.data.Dataset" + ) + else: + dataset_cls = RLHFDataset + print(f"Using dataset class: {dataset_cls.__name__}") + + dataset = dataset_cls( + data_files=data_paths, + tokenizer=tokenizer, + processor=processor, + config=data_config, + ) + + return dataset + + +def create_rl_sampler(data_config, dataset): + """Create a sampler for the dataset. + + Arguments: + data_config: The data config. + dataset (Dataset): The dataset. + + Returns: + sampler (Sampler): The sampler. + """ + import torch + from torch.utils.data import RandomSampler, SequentialSampler + + # use sampler for better ckpt resume + if data_config.shuffle: + train_dataloader_generator = torch.Generator() + train_dataloader_generator.manual_seed(data_config.get("seed", 1)) + sampler = RandomSampler(data_source=dataset, generator=train_dataloader_generator) + else: + sampler = SequentialSampler(data_source=dataset) + + return sampler + + +if __name__ == "__main__": + main() diff --git a/recipe/entropy/reward.py b/recipe/entropy/reward.py new file mode 100644 index 000000000..36b8b65a4 --- /dev/null +++ b/recipe/entropy/reward.py @@ -0,0 +1,86 @@ +# Copyright 2025 Individual Contributor: Thibaut Barroyer +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import multiprocessing +from functools import partial + +import ray + +from verl import DataProto +from verl.trainer.ppo.reward import compute_reward, get_custom_reward_fn + +from .reward_score import _default_compute_score + + +def load_reward_manager(config, tokenizer, num_examine, **reward_kwargs): + """ + Load and initialize a reward manager based on the configuration. + + Args: + config: PPO trainer configuration object containing reward_model fields. + tokenizer: Tokenizer object used for processing text. + num_examine: Number of samples to examine. + **reward_kwargs: Additional keyword arguments for the reward manager. + + Returns: + An instance of the specified reward manager class. + """ + from verl.workers.reward_manager import get_reward_manager_cls + + # The list of pre-defined reward managers are defined in `verl/workers/reward_manager/`: + # naive: NaiveRewardManager + # prime: PrimeRewardManager + # batch: BatchRewardManager + # dapo: DAPORewardManager + # Note(haibin.lin): For custom reward managers, please make sure they are imported and + # registered via `verl.workers.reward_manager.register` + # By default reward_manager is set to naive (NaiveRewardManager) + reward_manager_name = config.reward_model.get("reward_manager", "naive") + reward_manager_cls = get_reward_manager_cls(reward_manager_name) + + # Try to get a custom reward function based on the configuration + compute_score = get_custom_reward_fn(config) + final_compute_score = compute_score + + if compute_score is None: + sandbox_config = config.reward_model.get("sandbox_fusion") + sandbox_url = sandbox_config.get("url") if sandbox_config else None + if sandbox_url: + sandbox_manager = multiprocessing.Manager() + # Create a semaphore to control concurrent access to the sandbox + _concurrent_semaphore = sandbox_manager.Semaphore(sandbox_config.get("max_concurrent", 64)) + final_compute_score = partial( + _default_compute_score, sandbox_fusion_url=sandbox_url, concurrent_semaphore=_concurrent_semaphore + ) + else: + final_compute_score = _default_compute_score + + # Instantiate and return the reward manager with the specified parameters + return reward_manager_cls( + tokenizer=tokenizer, + num_examine=num_examine, + compute_score=final_compute_score, + reward_fn_key=config.data.reward_fn_key, + **reward_kwargs, + ) + + +@ray.remote(num_cpus=1) +def compute_reward_async(data: DataProto, config, tokenizer): + """ + Load the reward manager and compute the reward for a batch of data. + This is meant to be run in a separate Ray worker. + """ + reward_fn = load_reward_manager(config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {})) + return compute_reward(data, reward_fn) diff --git a/recipe/entropy/reward_score/__init__.py b/recipe/entropy/reward_score/__init__.py new file mode 100644 index 000000000..7224bf3c3 --- /dev/null +++ b/recipe/entropy/reward_score/__init__.py @@ -0,0 +1,38 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# from . import gsm8k, math, prime_math, prime_code + +import traceback + +from . import entropy_math + + +def _default_compute_score( + data_source, solution_str, ground_truth, extra_info=None, sandbox_fusion_url=None, concurrent_semaphore=None +): + try: + res = entropy_math.compute_score(solution_str, str(ground_truth)) + # print(f"data_source: {data_source}") + # raise NotImplementedError(f"Reward function is not implemented for {data_source=}") + + if isinstance(res, dict): + return res + elif isinstance(res, int | float | bool): + return float(res) + else: + return float(res[0]) + except Exception as e: + print(f"[ERROR] Error in process_completion for task : {str(e)}") + traceback.print_exc() # 打印完整堆栈 + raise # 重新抛出异常以便上层捕获 diff --git a/recipe/entropy/reward_score/entropy_math/__init__.py b/recipe/entropy/reward_score/entropy_math/__init__.py new file mode 100644 index 000000000..1b2ba647d --- /dev/null +++ b/recipe/entropy/reward_score/entropy_math/__init__.py @@ -0,0 +1,1062 @@ +# Copyright 2024 PRIME team and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except Exception in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Provides a math answer grading function with high recall. +Based on HF math_verify, verl, open reasoner zero, etc. +""" + +import os +import re +import signal +from itertools import islice, zip_longest +from math import isclose +from typing import Optional + +import sympy +from latex2sympy2_extended import latex2sympy +from math_verify import ExprExtractionConfig, LatexExtractionConfig, parse, verify +from pylatexenc import latex2text +from sympy import N, simplify +from sympy.parsing import sympy_parser +from sympy.parsing.latex import parse_latex +from sympy.parsing.sympy_parser import parse_expr + +""" +This code is adapted from: Dr. GRPO (https://github.com/sail-sg/understand-r1-zero/blob/main/understand_r1_zero/math_grader.py). +""" + + +def timeout_ours(timeout_seconds: int = 8): + if os.name == "posix": + import signal + + def decorator(func): + def handler(signum, frame): + raise TimeoutError("Operation timed out!") + + def wrapper(*args, **kwargs): + old_handler = signal.getsignal(signal.SIGALRM) + signal.signal(signal.SIGALRM, handler) + signal.alarm(timeout_seconds) + + try: + return func(*args, **kwargs) + finally: + signal.alarm(0) + signal.signal(signal.SIGALRM, old_handler) + + return wrapper + + return decorator + else: + raise NotImplementedError(f"Unsupported OS: {os.name}") + + +# Dan Hendrycks' code +def mathd_normalize_answer(answer: Optional[str]) -> Optional[str]: + if answer is None: + return None + answer = answer.strip() + try: + # Remove enclosing `\text{}`. + m = re.search("^\\\\text\{(?P.+?)\}$", answer) + if m is not None: + answer = m.group("text").strip() + return _strip_string(answer) + except Exception: + return answer + + +# units mainly from MathQA +unit_texts = [ + "east", + "degree", + "mph", + "kmph", + "ft", + "m sqaure", + " m east", + "sq m", + "deg", + "mile", + "q .", + "monkey", + "prime", + "ratio", + "profit of rs", + "rd", + "o", + "gm", + "p . m", + "lb", + "tile", + "per", + "dm", + "lt", + "gain", + "ab", + "way", + "west", + "a .", + "b .", + "c .", + "d .", + "e .", + "f .", + "g .", + "h .", + "t", + "a", + "h", + "no change", + "men", + "soldier", + "pie", + "bc", + "excess", + "st", + "inches", + "noon", + "percent", + "by", + "gal", + "kmh", + "c", + "acre", + "rise", + "a . m", + "th", + "π r 2", + "sq", + "mark", + "l", + "toy", + "coin", + "sq . m", + "gallon", + "° f", + "profit", + "minw", + "yr", + "women", + "feet", + "am", + "pm", + "hr", + "cu cm", + "square", + "v â € ™", + "are", + "rupee", + "rounds", + "cubic", + "cc", + "mtr", + "s", + "ohm", + "number", + "kmph", + "day", + "hour", + "minute", + "min", + "second", + "man", + "woman", + "sec", + "cube", + "mt", + "sq inch", + "mp", + "∏ cm ³", + "hectare", + "more", + "sec", + "unit", + "cu . m", + "cm 2", + "rs .", + "rs", + "kg", + "g", + "month", + "km", + "m", + "cm", + "mm", + "apple", + "liter", + "loss", + "yard", + "pure", + "year", + "increase", + "decrease", + "d", + "less", + "Surface", + "litre", + "pi sq m", + "s .", + "metre", + "meter", + "inch", +] + +unit_texts.extend([t + "s" for t in unit_texts]) + + +def _strip_string(string): + def _fix_fracs(string): + substrs = string.split("\\frac") + new_str = substrs[0] + if len(substrs) > 1: + substrs = substrs[1:] + for substr in substrs: + new_str += "\\frac" + if substr[0] == "{": + new_str += substr + else: + try: + assert len(substr) >= 2 + except Exception: + return string + a = substr[0] + b = substr[1] + if b != "{": + if len(substr) > 2: + post_substr = substr[2:] + new_str += "{" + a + "}{" + b + "}" + post_substr + else: + new_str += "{" + a + "}{" + b + "}" + else: + if len(substr) > 2: + post_substr = substr[2:] + new_str += "{" + a + "}" + b + post_substr + else: + new_str += "{" + a + "}" + b + string = new_str + return string + + def _fix_a_slash_b(string): + if len(string.split("/")) != 2: + return string + a = string.split("/")[0] + b = string.split("/")[1] + try: + a = int(a) + b = int(b) + assert string == "{}/{}".format(a, b) + new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" + return new_string + except Exception: + return string + + def _remove_right_units(string): + # "\\text{ " only ever occurs (at least in the val set) when describing units + if "\\text{ " in string: + splits = string.split("\\text{ ") + assert len(splits) == 2 + return splits[0] + else: + return string + + def _fix_sqrt(string): + if "\\sqrt" not in string: + return string + splits = string.split("\\sqrt") + new_string = splits[0] + for split in splits[1:]: + if split[0] != "{": + a = split[0] + new_substr = "\\sqrt{" + a + "}" + split[1:] + else: + new_substr = "\\sqrt" + split + new_string += new_substr + return new_string + + # linebreaks + string = string.replace("\n", "") + # print(string) + + # remove inverse spaces + string = string.replace("\\!", "") + # print(string) + + # replace \\ with \ + string = string.replace("\\\\", "\\") + # print(string) + + # matrix + string = re.sub(r"\\begin\{array\}\{.*?\}", r"\\begin{pmatrix}", string) + string = re.sub(r"\\end\{array\}", r"\\end{pmatrix}", string) + string = string.replace("bmatrix", "pmatrix") + + # replace tfrac and dfrac with frac + string = string.replace("tfrac", "frac") + string = string.replace("dfrac", "frac") + string = string.replace("\\neq", "\\ne").replace("\\leq", "\\le").replace("\\geq", "\\ge") + # print(string) + + # remove \left and \right + string = string.replace("\\left", "") + string = string.replace("\\right", "") + # print(string) + + # Remove unit: miles, dollars if after is not none + _string = re.sub(r"\\text{.*?}$", "", string).strip() + if _string != "" and _string != string: + # print("Warning: unit not removed: '{}' -> '{}'".format(string, _string)) + string = _string + + # Remove unit: texts + for _ in range(2): + for unit_text in unit_texts: + # use regex, the prefix should be either the start of the string or a non-alphanumeric character + # the suffix should be either the end of the string or a non-alphanumeric character + _string = re.sub(r"(^|\W)" + unit_text + r"($|\W)", r"\1\2", string) + if _string != "": + string = _string + + # Remove circ (degrees) + string = string.replace("^{\\circ}", "") + string = string.replace("^\\circ", "") + + # remove dollar signs + string = string.replace("\\$", "") + + # remove units (on the right) + string = _remove_right_units(string) + + # remove percentage + string = string.replace("\\%", "") + string = string.replace("\%", "") + + # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string + string = string.replace(" .", " 0.") + string = string.replace("{.", "{0.") + # if empty, return empty string + if len(string) == 0: + return string + if string[0] == ".": + string = "0" + string + + # to consider: get rid of e.g. "k = " or "q = " at beginning + if len(string.split("=")) == 2: + if len(string.split("=")[0]) <= 2: + string = string.split("=")[1] + + # fix sqrt3 --> sqrt{3} + string = _fix_sqrt(string) + + # remove spaces + string = string.replace(" ", "") + + # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). + # Also does a/b --> \\frac{a}{b} + string = _fix_fracs(string) + + # manually change 0.5 --> \frac{1}{2} + if string == "0.5": + string = "\\frac{1}{2}" + + # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y + string = _fix_a_slash_b(string) + + return string + + +SUBSTITUTIONS = [ + ("an ", ""), + ("a ", ""), + (".$", "$"), + ("\\$", ""), + (r"\ ", ""), + (" ", ""), + ("mbox", "text"), + (",\\text{and}", ","), + ("\\text{and}", ","), + ("\\text{m}", "\\text{}"), +] + + +REMOVED_EXPRESSIONS = [ + "square", + "ways", + "integers", + "dollars", + "mph", + "inches", + "ft", + "hours", + "km", + "units", + "\\ldots", + "sue", + "points", + "feet", + "minutes", + "digits", + "cents", + "degrees", + "cm", + "gm", + "pounds", + "meters", + "meals", + "edges", + "students", + "childrentickets", + "multiples", + "\\text{s}", + "\\text{.}", + "\\text{\ns}", + "\\text{}^2", + "\\text{}^3", + "\\text{\n}", + "\\text{}", + r"\mathrm{th}", + r"^\circ", + r"^{\circ}", + r"\;", + r",\!", + "{,}", + '"', + "\\dots", +] + + +def normalize_final_answer(final_answer: str) -> str: + """ + Normalize a final answer to a quantitative reasoning question. + This code comes from https://arxiv.org/pdf/2206.14858.pdf, page18. + """ + # final_answer = final_answer.split("=")[-1] + + for before, after in SUBSTITUTIONS: + final_answer = final_answer.replace(before, after) + for expr in REMOVED_EXPRESSIONS: + final_answer = final_answer.replace(expr, "") + + # Extract answer that is in LaTeX math, is bold, + # is surrounded by a box, etc. + final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer) + final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer) + final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer) + final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer) + final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer) + + # Normalize shorthand TeX: + # \fracab -> \frac{a}{b} + # \frac{abc}{bef} -> \frac{abc}{bef} + # \fracabc -> \frac{a}{b}c + # \sqrta -> \sqrt{a} + # \sqrtab -> sqrt{a}b + final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer) + final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer) + final_answer = final_answer.replace("$", "") + + # Normalize 100,000 -> 100000 + if final_answer.replace(",", "").isdigit(): + final_answer = final_answer.replace(",", "") + + return final_answer + + +def repeatness(s: str): + def ranks(seq): + index = {v: i for i, v in enumerate(sorted(set(seq)))} + return [index[v] for v in seq] + + def suffixArray(s): + line = ranks(s) + n, k, ans, sa = len(s), 1, line, [0] * len(s) + while k < n - 1: + line = ranks(list(zip_longest(line, islice(line, k, None), fillvalue=-1))) + ans, k = line, k << 1 + for i, k in enumerate(ans): + sa[k] = i + return ans, sa + + def lcp(arr, suffixArr, inv_suff): + n, ans, k = len(arr), [0] * len(arr), 0 + + for i in range(n): + if inv_suff[i] == n - 1: + k = 0 + continue + + j = suffixArr[inv_suff[i] + 1] + while i + k < n and j + k < n and arr[i + k] == arr[j + k]: + k += 1 + + ans[inv_suff[i]] = k + if k > 0: + k -= 1 + + return ans + + arr = [ord(i) for i in s] + n = len(arr) + if n <= 1: + return 0 + c, sa = suffixArray(arr) + cnt = sum(lcp(arr, sa, c)) + + return (cnt * 2 / (n * (n + 1))) > 0.2 + + +class timeout: + def __init__(self, seconds=1, error_message="Timeout"): + self.seconds = seconds + self.error_message = error_message + + def handle_timeout(self, signum, frame): + raise TimeoutError(self.error_message) + + def __enter__(self): + signal.signal(signal.SIGALRM, self.handle_timeout) + signal.alarm(self.seconds) + + def __exit__(self, type, value, traceback): + signal.alarm(0) + + +def latex_eval(latex): + sym = parse_latex(latex) + val = sym.evalf() + return sym, val + + +def numeric_equal(prediction: float, reference: float): + # Note that relative tolerance has significant impact + # on the result of the synthesized GSM-Hard dataset + # if reference.is_integer(): + # return isclose(reference, round(prediction), abs_tol=1e-4) + # else: + # prediction = round(prediction, len(str(reference).split(".")[-1])) + return isclose(reference, prediction, rel_tol=1e-4) + + +@timeout_ours(timeout_seconds=5) +def symbolic_equal(a, b): + def _parse(s): + for f in [parse_latex, parse_expr, latex2sympy]: + try: + return f(s.replace("\\\\", "\\")) + except Exception: + try: + return f(s) + except Exception: + pass + return s + + a = _parse(a) + b = _parse(b) + + # direct equal + try: + if str(a) == str(b) or a == b: + return True + except Exception: + pass + + # simplify equal + try: + if a.equals(b) or simplify(a - b) == 0: + return True + except Exception: + pass + + # equation equal + try: + if (abs(a.lhs - a.rhs)).equals(abs(b.lhs - b.rhs)): + return True + except Exception: + pass + + try: + if numeric_equal(float(N(a)), float(N(b))): + return True + except Exception: + pass + + # matrix + try: + # if a and b are matrix + if a.shape == b.shape: + _a = a.applyfunc(lambda x: round(x, 3)) + _b = b.applyfunc(lambda x: round(x, 3)) + if _a.equals(_b): + return True + except Exception: + pass + + return False + + +def _is_latex_equal(str1, str2): + try: + sym1, val1 = latex_eval(str1) + sym2, val2 = latex_eval(str2) + if sym1 == sym2 or val1 == val2: + return True + else: + raise ValueError + except Exception: # noqa + try: + norm1, norm2 = normalize_final_answer(str1), normalize_final_answer(str2) + sym1, val1 = latex_eval(norm1) + sym2, val2 = latex_eval(norm2) + if sym1 == sym2 or val1 == val2: + return True + except Exception: # noqa + return norm1 == norm2 + return False + + +def is_latex_equal(given_answer: str, ground_truth: str) -> bool: + try: + with timeout(1): + try: + if (len(given_answer) > 128 and repeatness(given_answer)) or ( + len(ground_truth) > 128 and repeatness(ground_truth) + ): + return False + # First conduct normalized string matching. + ground_truth_normalized = _normalize(ground_truth) + given_normalized = _normalize(given_answer) + if ground_truth_normalized is None: + return False + if ground_truth_normalized == given_normalized: + return True + + # Next call math verify. + given_answer.replace("\n", "") + ground_truth.replace("\n", "") + if "$" not in given_answer: + given_answer = f"${given_answer}$" + if "$" not in ground_truth: + ground_truth = f"${ground_truth}$" + return verify( + parse( + ground_truth, + extraction_config=( + LatexExtractionConfig(boxed_match_priority=0), + ExprExtractionConfig(), + ), + fallback_mode="no_fallback", + extraction_mode=["first_match"], + parsing_timeout=1, + ), + parse( + given_answer, + extraction_config=( + LatexExtractionConfig(boxed_match_priority=0), + ExprExtractionConfig(), + ), + fallback_mode="no_fallback", + extraction_mode=["first_match"], + parsing_timeout=1, + ), + timeout_seconds=1, + ) + # or symbolic_equal(ground_truth, given_answer) + except Exception: + return False + except TimeoutError: + return False + + +def is_value_equal(given_answer: str, ground_truth: str) -> bool: + assert ground_truth is not None + ground_truth_normalized_mathd = mathd_normalize_answer(ground_truth) + given_answer_normalized_mathd = mathd_normalize_answer(given_answer) + + str_equal = ground_truth_normalized_mathd == given_answer_normalized_mathd + try: + number_equal = float(ground_truth_normalized_mathd) == float(given_answer_normalized_mathd) + return str_equal or number_equal + except Exception: + return str_equal + + +# sympy might hang -- we don't care about trying to be lenient in these cases +BAD_SUBSTRINGS = ["^{", "^("] +BAD_REGEXES = ["\^[0-9]+\^", "\^[0-9][0-9]+"] +TUPLE_CHARS = "()[]" + + +def _sympy_parse(expr: str): + """Parses an expression with sympy.""" + py_expr = expr.replace("^", "**") + return sympy_parser.parse_expr( + py_expr, + transformations=(sympy_parser.standard_transformations + (sympy_parser.implicit_multiplication_application,)), + ) + + +def _parse_latex(expr: str) -> str: + """Attempts to parse latex to an expression sympy can read.""" + expr = expr.replace("\\tfrac", "\\frac") + expr = expr.replace("\\dfrac", "\\frac") + expr = expr.replace("\\frac", " \\frac") # Play nice with mixed numbers. + expr = latex2text.LatexNodes2Text().latex_to_text(expr) + + # Replace the specific characters that this parser uses. + expr = expr.replace("√", "sqrt") + expr = expr.replace("π", "pi") + expr = expr.replace("∞", "inf") + expr = expr.replace("∪", "U") + expr = expr.replace("·", "*") + expr = expr.replace("×", "*") + + return expr.strip() + + +def _is_float(num: str) -> bool: + try: + float(num) + return True + except ValueError: + return False + + +def _is_int(x: float) -> bool: + try: + return abs(x - int(round(x))) <= 1e-7 + except Exception: + return False + + +def _is_frac(expr: str) -> bool: + return bool(re.search(r"^-?[0-9]+.?/0*[1-9][0-9]*.?$", expr)) + + +def _str_is_int(x: str) -> bool: + try: + x = _strip_properly_formatted_commas(x) + x = float(x) + return abs(x - int(round(x))) <= 1e-7 + except Exception: + return False + + +def _str_to_int(x: str) -> bool: + x = x.replace(",", "") + x = float(x) + return int(x) + + +def _inject_implicit_mixed_number(step: str): + """ + Automatically make a mixed number evalable + e.g. 7 3/4 => 7+3/4 + """ + p1 = re.compile("([0-9]) +([0-9])") + step = p1.sub("\\1+\\2", step) ## implicit mults + return step + + +def _strip_properly_formatted_commas(expr: str): + # We want to be careful because we don't want to strip tuple commas + p1 = re.compile("(\d)(,)(\d\d\d)($|\D)") + while True: + next_expr = p1.sub("\\1\\3\\4", expr) + if next_expr == expr: + break + expr = next_expr + return next_expr + + +def _normalize(expr: str) -> str: + """Normalize answer expressions.""" + if expr is None: + return None + + # Remove enclosing `\text{}`. + m = re.search("^\\\\text\{(?P.+?)\}$", expr) + if m is not None: + expr = m.group("text") + + expr = expr.replace("\\%", "%") + expr = expr.replace("\\$", "$") + expr = expr.replace("$", "") + expr = expr.replace("%", "") + expr = expr.replace(" or ", " , ") + expr = expr.replace(" and ", " , ") + + expr = expr.replace("million", "*10^6") + expr = expr.replace("billion", "*10^9") + expr = expr.replace("trillion", "*10^12") + + for unit in [ + "degree", + "cm", + "centimeter", + "meter", + "mile", + "second", + "minute", + "hour", + "day", + "week", + "month", + "year", + "foot", + "feet", + "inch", + "yard", + ]: + expr = re.sub(f"{unit}(es)?(s)? *(\^[0-9]+)?", "", expr) + expr = re.sub("\^ *\\\\circ", "", expr) + + if len(expr) > 0 and expr[0] == "{" and expr[-1] == "}": + expr = expr[1:-1] + + expr = re.sub(",\\\\! *", "", expr) + if _is_float(expr) and _is_int(float(expr)): + expr = str(int(round(float(expr)))) + if "\\" in expr: + try: + expr = _parse_latex(expr) + except Exception: + pass + + # edge case with mixed numbers and negative signs + expr = re.sub("- *", "-", expr) + + expr = _inject_implicit_mixed_number(expr) + expr = expr.replace(" ", "") + + # if we somehow still have latex braces here, just drop them + expr = expr.replace("{", "") + expr = expr.replace("}", "") + + # don't be case sensitive for text answers + expr = expr.lower() + + if _str_is_int(expr): + expr = str(_str_to_int(expr)) + + return expr + + +def count_unknown_letters_in_expr(expr: str): + expr = expr.replace("sqrt", "") + expr = expr.replace("frac", "") + letters_in_expr = set([x for x in expr if x.isalpha()]) + return len(letters_in_expr) + + +def should_allow_eval(expr: str): + # we don't want to try parsing unknown text or functions of more than two variables + if count_unknown_letters_in_expr(expr) > 2: + return False + + for bad_string in BAD_SUBSTRINGS: + if bad_string in expr: + return False + + for bad_regex in BAD_REGEXES: + if re.search(bad_regex, expr) is not None: + return False + + return True + + +@timeout_ours(timeout_seconds=5) +def are_equal_under_sympy(ground_truth_normalized: str, given_normalized: str): + are_equal = False + try: + expr = f"({ground_truth_normalized})-({given_normalized})" + if should_allow_eval(expr): + sympy_diff = _sympy_parse(expr) + simplified = sympy.simplify(sympy_diff) + if simplified == 0: + are_equal = True + except Exception: + pass + return are_equal + + +def split_tuple(expr: str): + """ + Split the elements in a tuple/interval, while handling well-formatted commas in large numbers + """ + expr = _strip_properly_formatted_commas(expr) + if len(expr) == 0: + return [] + if ( + len(expr) > 2 + and expr[0] in TUPLE_CHARS + and expr[-1] in TUPLE_CHARS + and all([ch not in expr[1:-1] for ch in TUPLE_CHARS]) + ): + elems = [elem.strip() for elem in expr[1:-1].split(",")] + else: + elems = [expr] + return elems + + +def last_boxed_only_string(string): + idx = string.rfind("\\boxed") + if idx < 0: + idx = string.rfind("\\fbox") + if idx < 0: + return None + + i = idx + right_brace_idx = None + num_left_braces_open = 0 + while i < len(string): + if string[i] == "{": + num_left_braces_open += 1 + if string[i] == "}": + num_left_braces_open -= 1 + if num_left_braces_open == 0: + right_brace_idx = i + break + i += 1 + if right_brace_idx is None: + retval = None + else: + retval = string[idx : right_brace_idx + 1] + + return retval + + +def remove_boxed(s): + left = "\\boxed{" + try: + assert s[: len(left)] == left + assert s[-1] == "}" + return s[len(left) : -1] + except Exception: + return None + + +def extract_boxed_answer(solution: str) -> str: + """Extract the answer from inside a LaTeX \\boxed{} command""" + solution = last_boxed_only_string(solution) + solution = remove_boxed(solution) + return solution + + +def grade_answer_sympy(given_answer: str, ground_truth: str) -> bool: + ground_truth_normalized = _normalize(ground_truth) + given_normalized = _normalize(given_answer) + + if ground_truth_normalized is None: + return False + + if ground_truth_normalized == given_normalized: + return True + + if len(given_normalized) == 0: + return False + + ground_truth_elems = split_tuple(ground_truth_normalized) + given_elems = split_tuple(given_normalized) + + if len(ground_truth_elems) > 1 and ( + ground_truth_normalized[0] != given_normalized[0] or ground_truth_normalized[-1] != given_normalized[-1] + ): + is_correct = False + elif len(ground_truth_elems) != len(given_elems): + is_correct = False + else: + for ground_truth_elem, given_elem in zip(ground_truth_elems, given_elems, strict=True): + if _is_frac(ground_truth_elem) and _is_frac(given_elem): + # if fractions aren't reduced, then shouldn't be marked as correct + # so, we don't want to allow sympy.simplify in this case + is_correct = ground_truth_elem == given_elem + elif _str_is_int(ground_truth_elem) != _str_is_int(given_elem): + # if the ground truth answer is an integer, we require the given answer to be a strict match + # (no sympy.simplify) + is_correct = False + else: + is_correct = are_equal_under_sympy(ground_truth_elem, given_elem) + if not is_correct: + break + + return is_correct + + +def grade_answer_mathd(given_answer: str, ground_truth: str) -> bool: + ground_truth_normalized_mathd = mathd_normalize_answer(ground_truth) + given_answer_normalized_mathd = mathd_normalize_answer(given_answer) + + # be at least as lenient as mathd + if ground_truth_normalized_mathd == given_answer_normalized_mathd: + return True + return False + + +def extract_answer(passage: str) -> str: + if "\\boxed" in passage: + return extract_boxed_answer(passage) + return None + + +def grade(model_answer: str, gt_answer: str, fast: bool = True): + if "\\boxed" in gt_answer: + gt_answer = extract_answer(gt_answer) + correct = grade_answer_mathd(model_answer, gt_answer) or grade_answer_sympy(model_answer, gt_answer) + if not fast: + # This mode further uses math_verify to recall originally false positives. + # Will be a bit slower, and sensitive to bad inputs. + correct = correct or is_latex_equal( + model_answer, + gt_answer, + ) + return correct + + +def compute_score(model_response, gt_answer, fast=False): + model_answer = extract_answer(model_response) + if model_answer is None: + return { + "score": 0.0, + "format_score": 0.0, + "acc": False, + "extracted_gt": gt_answer, + # "extracted_pred": None, + } + # return 0.0, 0.0 # Cannot even parse anything. + is_correct = False + if isinstance(gt_answer, float) or isinstance(gt_answer, int): + gt_answer = str(gt_answer) + if isinstance(gt_answer, str): + is_correct = grade(model_answer, gt_answer, fast) + elif isinstance(gt_answer, list): + is_correct = False + for gt in gt_answer: + is_correct |= grade(model_answer, gt, fast) + if is_correct: + return { + "score": 1.0, + "format_score": 1.0, + "acc": True, + "extracted_gt": gt_answer, + # "extracted_pred": None, + } + else: + return { + "score": 0.0, + "format_score": 1.0, + "acc": False, + "extracted_gt": gt_answer, + # "extracted_pred": None, + } diff --git a/recipe/entropy/reward_score/entropy_math/grader.py b/recipe/entropy/reward_score/entropy_math/grader.py new file mode 100644 index 000000000..02507e359 --- /dev/null +++ b/recipe/entropy/reward_score/entropy_math/grader.py @@ -0,0 +1,384 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright (c) Microsoft Corporation. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE + +# Copyright (c) 2023 OpenAI +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# Copyright (c) 2021 Dan Hendrycks +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# Copyright 2024 PRIME team and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This logic is largely copied from the Hendrycks' MATH release (math_equivalence), and borrowed from: +- https://github.com/microsoft/ToRA/blob/main/src/eval/grader.py +- https://github.com/microsoft/ProphetNet/tree/master/CRITIC +- https://github.com/openai/prm800k +""" + +import contextlib +import math +import re +from math import isclose + +# sympy related +from sympy import N, simplify +from sympy.parsing.latex import parse_latex +from sympy.parsing.sympy_parser import parse_expr + +# verl related +from verl.utils.py_functional import timeout_limit + + +def is_digit(s): + try: + if "{,}" in str(s): + num = float(str(s).replace("{,}", "")) + return True, num + + num = float(str(s).replace(",", "")) + return True, num + except ValueError: + return False, None + + +def normalize(answer, pi) -> str: + # checking if answer is $ and removing $ in that case to compare + if isinstance(answer, str) and bool(re.match(r"\$\d+(\.\d+)?", answer)): + return answer[1:] + + # checking if answer is % or \\% and removing % + if isinstance(answer, str) and ( + bool(re.match(r"^\d+(\.\d+)?%$", answer)) or bool(re.match(r"^\d+(\.\d+)?\\%$", answer)) + ): + return answer.replace("\\%", "").replace("%", "") + + # handle base + answer = handle_base(answer) + + # handle pi + answer = handle_pi(answer, pi) + + return answer + + +def handle_base(x) -> str: + if isinstance(x, str) and "_" in x: + # Due to base + x = x.split("_")[0] + x = float(x) + return int(x) + return x + + +def handle_pi(string, pi): + if isinstance(string, str) and "\pi" in string: + # Find the first occurrence of "\pi" + idx = string.find("\pi") + + # Iterate over the string and find all occurrences of "\pi" with a valid previous character + while idx != -1: + if idx > 0 and string[idx - 1].isdigit(): + # Replace "\pi" with "*math.pi" if the previous character is a digit + string = string[:idx] + f"*{pi}" + string[idx + 3 :] + else: + # Replace "\pi" with "1*math.pi" if the previous character is not a digit + string = string[:idx] + f"1*{pi}" + string[idx + 3 :] + + # Find the next occurrence of "\pi" + idx = string.find("\pi", idx + 1) + + # Evaluate the expression using eval() function + with contextlib.suppress(Exception): + string = eval(string) + + return string + + +def math_equal( + prediction: bool | float | str, + reference: float | str, + include_percentage: bool = True, + tolerance: float = 1e-4, + timeout: float = 10.0, + pi: float = math.pi, +) -> bool: + """ + Exact match of math if and only if: + 1. numerical equal: both can convert to float and are equal + 2. symbolic equal: both can convert to sympy expression and are equal + """ + + prediction = normalize(prediction, pi) + reference = normalize(reference, pi) + + if isinstance(prediction, str) and len(prediction) > 1000: # handling weird corner-cases + prediction = prediction[:1000] + + # 0. string comparison + if isinstance(prediction, str) and isinstance(reference, str): + if prediction.strip().lower() == reference.strip().lower(): + return True + if prediction.replace(" ", "") == reference.replace(" ", ""): + return True + + try: # 1. numerical equal + if is_digit(prediction)[0] and is_digit(reference)[0]: + prediction = is_digit(prediction)[1] + reference = is_digit(reference)[1] + # number questions + gt_result = [reference / 100, reference, reference * 100] if include_percentage else [reference] + for item in gt_result: + try: + if isclose(item, prediction, rel_tol=tolerance): + return True + except Exception: + continue + return False + except Exception: + pass + + if not prediction and prediction not in [0, False]: + return False + + # 2. symbolic equal + reference = str(reference).strip() + prediction = str(prediction).strip() + + ## deal with [], (), {} + prediction = format_intervals(prediction) + + pred_str, ref_str = prediction, reference + if (prediction.startswith("[") and prediction.endswith("]") and not reference.startswith("(")) or ( + prediction.startswith("(") and prediction.endswith(")") and not reference.startswith("[") + ): + pred_str = pred_str.strip("[]()") + ref_str = ref_str.strip("[]()") + for s in ["{", "}", "(", ")"]: + ref_str = ref_str.replace(s, "") + pred_str = pred_str.replace(s, "") + if pred_str == ref_str: + return True + + ## [a, b] vs. [c, d], return a==c and b==d + if ( + prediction + and reference + and prediction[0] in "([" + and prediction[-1] in ")]" + and prediction[0] == reference[0] + and prediction[-1] == reference[-1] + ): + pred_parts = prediction[1:-1].split(",") + ref_parts = reference[1:-1].split(",") + if len(pred_parts) == len(ref_parts) and all( + [ + math_equal(pred_pt, ref_pt, include_percentage, tolerance) + for pred_pt, ref_pt in zip(pred_parts, ref_parts, strict=True) + ] + ): + return True + + if "," in prediction and "," in reference: + pred_parts = [item.strip() for item in prediction.split(",")] + ref_parts = [item.strip() for item in reference.split(",")] + + if len(pred_parts) == len(ref_parts): + return bool( + all( + [ + math_equal(pred_parts[i], ref_parts[i], include_percentage, tolerance) + for i in range(len(pred_parts)) + ] + ) + ) + + # if we have point == tuple of values + if prediction.startswith("Point") and reference[0] == "(" and reference[-1] == ")": + pred_parts = prediction[prediction.find("(") + 1 : -1].split(",") + ref_parts = reference[1:-1].split(",") + if len(pred_parts) == len(ref_parts) and all( + [ + math_equal(pred_pt, ref_pt, include_percentage, tolerance) + for pred_pt, ref_pt in zip(pred_parts, ref_parts, strict=True) + ] + ): + return True + + # if reference is a matrix + if "\begin{pmatrix}" in reference and prediction.startswith("Matrix"): + try: + pred_matrix = parse_expr(prediction) + ref_matrix_items = reference.split()[1:-1:2] + if len(pred_matrix) == len(ref_matrix_items) and all( + [ + math_equal(pred, ref, include_percentage, tolerance) + for ref, pred in zip(ref_matrix_items, pred_matrix, strict=True) + ] + ): + return True + except Exception: + pass + elif "\begin{pmatrix}" in reference and prediction.startswith("[") and prediction.endswith("]"): + if isinstance(eval(prediction), list): + try: + pred_matrix = eval(prediction) + # ref_matrix_items = reference.split()[1:-1:2] + ref_matrix_items = ( + reference.lstrip("\\begin{pmatrix}") # noqa: B005 + .lstrip("\begin{pmatrix}") + .rstrip("\\end{pmatrix}") + .rstrip("\end{pmatrix}") + ) # noqa: B005 + ref_matrix_items = ref_matrix_items.split("\\") + ref_matrix_items = [row.split("&") if "&" in row else row for row in ref_matrix_items] + if len(pred_matrix) == len(ref_matrix_items) and all( + [ + math_equal(pred, ref, include_percentage, tolerance) + for ref, pred in zip(ref_matrix_items, pred_matrix, strict=True) + ] + ): + return True + except Exception: + pass + + return symbolic_equal(prediction, reference, tolerance, timeout) + + +def symbolic_equal(a, b, tolerance, timeout=10.0): + def _parse(s): + for f in [parse_expr, parse_latex]: + try: + with timeout_limit(seconds=timeout): + return f(s) + except TimeoutError: + print(f"Parsing timed out for {s}") + continue + except Exception: + continue + return s + + a = _parse(a) + b = _parse(b) + + try: + with timeout_limit(seconds=timeout): + if simplify(a - b) == 0: + return True + except TimeoutError: + print(f"Simplification timed out for {a} - {b}") + pass + except Exception: + pass + + try: + with timeout_limit(seconds=timeout): + if isclose(N(a), N(b), rel_tol=tolerance): + return True + except TimeoutError: + print(f"Numerical evaluation timed out for {a}, {b}") + pass + except Exception: + pass + return False + + +def format_intervals(prediction): + patterns = { + "Interval(": r"^Interval\((.*)\)$", + "Interval.Ropen(": r"^Interval\.Ropen\((.*)\)$", + "Interval.Lopen(": r"^Interval\.Lopen\((.*)\)$", + "Interval.open(": r"^Interval\.open\((.*)\)$", + } + + for key, pattern in patterns.items(): + match = re.match(pattern, prediction) + if match: + inner_content = match.group(1) + + if key == "Interval(": # Intarval(a, b) == [a, b] + return f"[{inner_content}]" + elif key == "Interval.Ropen(": # Intarval.Ropen(a, b) == [a, b) + return f"[{inner_content})" + elif key == "Interval.Lopen(": # Intarval.Lopen(a, b) == (a, b] + return f"({inner_content}]" + elif key == "Interval.open(": # Intarval.open(a, b) == (a, b) + return f"({inner_content})" + + return prediction diff --git a/recipe/entropy/reward_score/entropy_math/math_normalize.py b/recipe/entropy/reward_score/entropy_math/math_normalize.py new file mode 100644 index 000000000..74d94cc41 --- /dev/null +++ b/recipe/entropy/reward_score/entropy_math/math_normalize.py @@ -0,0 +1,192 @@ +# Copyright 2024 PRIME team and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright (c) 2021 Dan Hendrycks +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +""" +This logic is largely copied from the Hendrycks' MATH release (math_equivalence). + +From: https://github.com/openai/prm800k/blob/main/prm800k/grading/math_normalize.py +""" + +import re +from typing import Optional + + +def normalize_answer(answer: Optional[str]) -> Optional[str]: + if answer is None: + return None + answer = answer.strip() + try: + # Remove enclosing `\text{}`. + m = re.search("^\\\\text\{(?P.+?)\}$", answer) + if m is not None: + answer = m.group("text").strip() + return _strip_string(answer) + except: # noqa: E722 + return answer + + +def _fix_fracs(string): + substrs = string.split("\\frac") + new_str = substrs[0] + if len(substrs) > 1: + substrs = substrs[1:] + for substr in substrs: + new_str += "\\frac" + if substr[0] == "{": + new_str += substr + else: + try: + assert len(substr) >= 2 + except: # noqa: E722 + return string + a = substr[0] + b = substr[1] + if b != "{": + if len(substr) > 2: + post_substr = substr[2:] + new_str += "{" + a + "}{" + b + "}" + post_substr + else: + new_str += "{" + a + "}{" + b + "}" + else: + if len(substr) > 2: + post_substr = substr[2:] + new_str += "{" + a + "}" + b + post_substr + else: + new_str += "{" + a + "}" + b + string = new_str + return string + + +def _fix_a_slash_b(string): + if len(string.split("/")) != 2: + return string + a = string.split("/")[0] + b = string.split("/")[1] + try: + a = int(a) + b = int(b) + assert string == "{}/{}".format(a, b) + new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" + return new_string + except: # noqa: E722 + return string + + +def _remove_right_units(string): + # "\\text{ " only ever occurs (at least in the val set) when describing units + if "\\text{ " in string: + splits = string.split("\\text{ ") + assert len(splits) == 2 + return splits[0] + else: + return string + + +def _fix_sqrt(string): + if "\\sqrt" not in string: + return string + splits = string.split("\\sqrt") + new_string = splits[0] + for split in splits[1:]: + if split[0] != "{": + a = split[0] + new_substr = "\\sqrt{" + a + "}" + split[1:] + else: + new_substr = "\\sqrt" + split + new_string += new_substr + return new_string + + +def _strip_string(string): + # linebreaks + string = string.replace("\n", "") + + # remove inverse spaces + string = string.replace("\\!", "") + + # replace \\ with \ + string = string.replace("\\\\", "\\") + + # replace tfrac and dfrac with frac + string = string.replace("tfrac", "frac") + string = string.replace("dfrac", "frac") + + # remove \left and \right + string = string.replace("\\left", "") + string = string.replace("\\right", "") + + # Remove circ (degrees) + string = string.replace("^{\\circ}", "") + string = string.replace("^\\circ", "") + + # remove dollar signs + string = string.replace("\\$", "") + + # remove units (on the right) + string = _remove_right_units(string) + + # remove percentage + string = string.replace("\\%", "") + string = string.replace("\%", "") + + # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string + string = string.replace(" .", " 0.") + string = string.replace("{.", "{0.") + # if empty, return empty string + if len(string) == 0: + return string + if string[0] == ".": + string = "0" + string + + # to consider: get rid of e.g. "k = " or "q = " at beginning + if len(string.split("=")) == 2 and len(string.split("=")[0]) <= 2: + string = string.split("=")[1] + + # fix sqrt3 --> sqrt{3} + string = _fix_sqrt(string) + + # remove spaces + string = string.replace(" ", "") + + # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). + # Also does a/b --> \\frac{a}{b} + string = _fix_fracs(string) + + # manually change 0.5 --> \frac{1}{2} + if string == "0.5": + string = "\\frac{1}{2}" + + # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y + string = _fix_a_slash_b(string) + + return string diff --git a/recipe/genrm_remote/README.md b/recipe/genrm_remote/README.md new file mode 100644 index 000000000..1a800fd88 --- /dev/null +++ b/recipe/genrm_remote/README.md @@ -0,0 +1,39 @@ +# Generative Reward Model + +## Scripts + +### Step 1: Launch a vLLM Server (Optional) + +Deploy the pretrained GenRM model using vLLM. Skip this step if you want to use an external api service. + +```bash +vllm serve verl-team/GenRM-CI-Test-1.5B --served-model-name genrm-demo +``` + +### Step 2: Perform RL using GenRM + +```bash +bash recipe/api-genrm/run_genrm_remote.sh +``` + +The implementation works by passing a customized reward function (see `reward_function.py`) + +For convenience, we run both the RL training and server on the same machine. To use an external server, configure the `BASE_URL` and `API_KEY` in `reward_function.py` first. + +## Advanced: Customizing Your GenRM + +You can use sglang server with data parallel for faster inference: + +```bash +CUDA_VISIBLE_DEVICES=0,1,2,3 python -m sglang_router.launch_server --model-path verl-team/GenRM-CI-Test-1.5B --dp-size 4 +``` + +Note that you should modify the `BASE_URL` in `reward_function.py` to match your SGLang Server address. + +You can also create your own customized GenRM by implementing a custom reward function. Here are some tips for customizing your own GenRM based on `reward_function.py`: + +- Design appropriate prompts for your GenRM +- Convert GenRM responses into RL rewards +- ... + +Since these aspects are highly flexible, we only provide a demo implementation. The actual design and implementation of GenRM is left to the user's discretion. diff --git a/recipe/genrm_remote/reward_function.py b/recipe/genrm_remote/reward_function.py new file mode 100644 index 000000000..b2d3fbc2f --- /dev/null +++ b/recipe/genrm_remote/reward_function.py @@ -0,0 +1,110 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from concurrent.futures import ThreadPoolExecutor +from time import sleep + +import requests + +from verl.utils.reward_score.math import last_boxed_only_string, remove_boxed + +BASE_URL = "http://localhost:30000" +API_KEY = "EMPTY" +MAX_RETRIES = 3 +BASE_DELAY = 2 +MAX_WORKERS = 32 +MODEL_NAME = "genrm-demo" +GENRM_PROMPT_TEMPLATE = """ +The following is a math problem and an AI solution: + +[Math Problem] + +{problem} + +[AI Solution] + +{solution} + +Your task is to review and critique the solution step by step, and output whether the AI solution is correct. + +Please put your final answer (i.e., 'True' or 'False') in \\boxed{{}}. +""".strip() + + +def get_response(problem, solution_str, ground_truth): + prompt = GENRM_PROMPT_TEMPLATE.format(problem=problem, solution=solution_str) + messages = [{"role": "user", "content": prompt}] + for attempt in range(MAX_RETRIES): + try: + headers = {"Content-Type": "application/json"} + chat_url = f"{BASE_URL}/v1/chat/completions" + data = {"model": MODEL_NAME, "messages": messages} + output = requests.post(chat_url, headers=headers, json=data, timeout=30) + response = output.json()["choices"][0]["message"]["content"] + return response + except Exception as e: + if attempt < MAX_RETRIES - 1: + print("Exception: ", repr(e)) + delay = BASE_DELAY * (2**attempt) + print(f"Retrying in {delay} seconds...") + sleep(delay) + else: + print(f"Failed after {MAX_RETRIES} attempts. Error: {e}") + + raise ConnectionRefusedError(f"Failed to run the model for {prompt}!") + + +def compute_reward(response): + reward_score = 0.0 + try: + boxed_result = last_boxed_only_string(response) + if boxed_result is not None: + result = remove_boxed(boxed_result) + reward_score = float(result == "True") + except Exception as e: + print(e) + return reward_score + + +def compute_score(data_source, solution_str, ground_truth, extra_info): + split = extra_info["split"] + from verl.utils.reward_score import default_compute_score + + func_rm_score = default_compute_score(data_source, solution_str, ground_truth, extra_info) + + if split == "test": + return func_rm_score + else: + problem = extra_info["question"] + response = get_response(problem, solution_str, ground_truth) + if response is not None: + reward_score = compute_reward(response) + else: + reward_score = 0.0 + + return reward_score + + +def compute_score_batch(data_sources, solution_strs, ground_truths, extra_infos): + with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor: + futures = [] + for data_source, solution_str, ground_truth, extra_info in zip( + data_sources, solution_strs, ground_truths, extra_infos, strict=True + ): + future = executor.submit(compute_score, data_source, solution_str, ground_truth, extra_info) + futures.append(future) + + results = [future.result() for future in futures] + + return results diff --git a/recipe/genrm_remote/run_genrm_remote.sh b/recipe/genrm_remote/run_genrm_remote.sh new file mode 100644 index 000000000..6656dc8a7 --- /dev/null +++ b/recipe/genrm_remote/run_genrm_remote.sh @@ -0,0 +1,45 @@ +# vllm server +# CUDA_VISIBLE_DEVICES=0,1,2,3 vllm serve verl-team/GenRM-CI-Test-1.5B --served_model_name genrm-demo + +# sglang server +# CUDA_VISIBLE_DEVICES=0,1,2,3 python -m sglang_router.launch_server --model-path verl-team/GenRM-CI-Test-1.5B --dp-size 4 + +set -x + +CUDA_VISIBLE_DEVICES=4,5,6,7 python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=${HOME}/data/gsm8k/train.parquet \ + data.val_files=${HOME}/data/gsm8k/test.parquet \ + data.train_batch_size=1024 \ + data.max_prompt_length=1024 \ + data.max_response_length=2048 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=256 \ + actor_rollout_ref.actor.use_dynamic_bsz=True \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \ + actor_rollout_ref.rollout.n=8 \ + algorithm.use_kl_in_reward=False \ + reward_model.reward_manager=batch \ + custom_reward_function.path=recipe/genrm_remote/reward_function.py \ + custom_reward_function.name=compute_score_batch \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='verl_func_rm_example_gsm8k' \ + trainer.experiment_name='qwen2_5_3b_gen_rm' \ + trainer.n_gpus_per_node=4 \ + trainer.val_before_train=True \ + trainer.nnodes=1 \ + trainer.save_freq=20 \ + trainer.test_freq=5 \ + trainer.total_epochs=10 \ + trainer.resume_mode='disable' diff --git a/tests/e2e/__init__.py b/recipe/langgraph_agent/__init__.py similarity index 100% rename from tests/e2e/__init__.py rename to recipe/langgraph_agent/__init__.py diff --git a/recipe/langgraph_agent/chat_model.py b/recipe/langgraph_agent/chat_model.py new file mode 100644 index 000000000..f41f6ac37 --- /dev/null +++ b/recipe/langgraph_agent/chat_model.py @@ -0,0 +1,357 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Ref: https://python.langchain.com/docs/how_to/custom_chat_model/ +""" + +import asyncio +import json +import logging +import os +import uuid +from typing import Any, Optional + +from langchain_core.language_models import BaseChatModel +from langchain_core.language_models.base import LanguageModelInput +from langchain_core.messages import ( + AIMessage, + BaseMessage, + convert_to_openai_messages, +) +from langchain_core.messages.tool import InvalidToolCall, ToolCall +from langchain_core.outputs import ChatGeneration, ChatResult +from langchain_core.runnables import Runnable, RunnableConfig +from langchain_core.tools import StructuredTool +from langchain_core.utils.function_calling import convert_to_openai_tool +from pydantic import Field + +from verl.experimental.agent_loop.agent_loop import AgentLoopOutput, AsyncLLMServerManager +from verl.experimental.agent_loop.tool_parser import ToolParser + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +class MaxTokenExceededError(Exception): + """Indicate that history chat messages + tool message exceeds LLM max_tokens.""" + + pass + + +class ChatModel(BaseChatModel): + model_name: str = Field(alias="model") + """The name of the model""" + + client: AsyncLLMServerManager + """AsyncLLM server manager""" + + tokenizer: Any + """Tokenizer for the model""" + + max_tokens: int + """Max tokens to generate""" + + tool_parser: str = "hermes" + """Tool parser for the model""" + + max_parallel_calls: int = 1 + """Max parallel tool calls""" + + temperature: float = 1.0 + """Temperature for sampling""" + + top_p: float = 1.0 + """Top p for sampling""" + + repetition_penalty: float = 1.0 + """Repetition penalty for sampling""" + + def bind_tools(self, tools, **kwargs) -> Runnable[LanguageModelInput, BaseMessage]: + """Bind tools to the model. + + Args: + tools: Sequence of tools to bind to the model. + + Returns: + A Runnable that returns a message. + """ + formatted_tools: list = [convert_to_openai_tool(tool) for tool in tools] + + # used to remove system prompt prefix when encoding tool response + system_prompt = self.tokenizer.apply_chat_template([{}], add_generation_prompt=False, tokenize=True) + kwargs["system_prompt"] = system_prompt + + return self.bind(tools=formatted_tools, **kwargs) + + def with_structured_output( + self, + schema: dict | type, + *, + include_raw: bool = False, + **kwargs: Any, + ) -> Runnable[LanguageModelInput, dict | BaseChatModel]: + """Ref: https://langchain-ai.github.io/langgraph/how-tos/react-agent-structured-output/""" + raise NotImplementedError + + def _generate( + self, + messages: list[BaseMessage], + stop: Optional[list[str]] = None, + **kwargs: Any, + ) -> ChatResult: + raise NotImplementedError + + async def _agenerate( + self, + messages: list[BaseMessage], + stop: Optional[list[str]] = None, + **kwargs: Any, + ) -> ChatResult: + """Asynchronously generate chat completion message. + + Args: + messages (list[BaseMessage]): List of list of messages. + stop (Optional[list[str]], optional): Stop words to use when generating. Model output is cut off at the + first occurrence of any of these substrings. Defaults to None. + + Returns: + ChatResult: Chat result. + """ + request_id, prompt_ids, response_mask = await self._preprocess(messages, **kwargs) + + sampling_params = { + "temperature": self.temperature, + "top_p": self.top_p, + "repetition_penalty": self.repetition_penalty, + } + if "sampling_params" in kwargs: + sampling_params.update(kwargs["sampling_params"]) + + response_ids = await self.client.generate( + request_id=request_id, prompt_ids=prompt_ids, sampling_params=sampling_params + ) + + message = await self._postprocess(request_id, prompt_ids, response_mask, response_ids, **kwargs) + generation = ChatGeneration(message=message) + return ChatResult(generations=[generation]) + + @property + def _llm_type(self) -> str: + """Get the type of language model used by this chat model.""" + return self.model_name + + async def _preprocess(self, messages: list[BaseMessage], **kwargs: Any) -> tuple[str, list[int], list[int]]: + """Preprocess messages for chat completion. + + To ensure strong consistency with policy model, AsyncLLM server generate response with token in token out + instead of messages list. + + But all agent frameworks use messages list to represent chat history. To mitigate the gap, we store trajectory + (prompt_ids, response_mask) in lastest AIMessage.response_metadata. + + 1. Encode ToolMessage to token ids. + 2. Retrieve trajectory (prompt_ids, response_mask) from lastest AIMessage.response_metadata. + 3. Append ToolMessage token ids to prompt_ids, and append 0 to response_mask. + + Ref: https://python.langchain.com/docs/concepts/chat_history/ + + Args: + messages (list[BaseMessage]): List of messages. + + Returns: + tuple[str, list[int], list[int]]: Request id, prompt ids, response mask. + """ + # messages: [system], human, ai, human|tool, ai, human|tool, ... + assert messages[-1].type in ["human", "tool"], ( + f"Last message must be human or tool, but got {messages[-1].type}" + ) + loop = asyncio.get_running_loop() + + # Case 1: initial chat completion: [system], human + if messages[-1].type == "human" and (len(messages) == 1 or messages[-2].type != "ai"): + prompt_ids = await loop.run_in_executor( + None, + lambda: self.tokenizer.apply_chat_template( + convert_to_openai_messages(messages), + tools=kwargs.get("tools"), + add_generation_prompt=True, + tokenize=True, + ), + ) + return str(uuid.uuid4()), prompt_ids, [] + + # Case 2: follow up chat completion with tool/human response: [system], human, ai, human|tool, ... + for i in range(len(messages) - 1, -1, -1): + if messages[i].type == "ai": + break + assert "prompt_ids" in messages[i].response_metadata, "Last message must have prompt_ids in response_metadata" + assert "response_mask" in messages[i].response_metadata, ( + "Last message must have response_mask in response_metadata" + ) + + # encode tool response + tool_responses = convert_to_openai_messages(messages[i + 1 :]) + tool_response_ids = await loop.run_in_executor( + None, + lambda messages=tool_responses: self.tokenizer.apply_chat_template( + messages, add_generation_prompt=True, tokenize=True + ), + ) + tool_response_ids = tool_response_ids[len(kwargs["system_prompt"]) :] + + # stop generation if response length exceeds max response length + if len(messages[i].response_metadata["response_mask"]) + len(tool_response_ids) >= self.max_tokens: + raise MaxTokenExceededError(f"Max response length {self.max_tokens} exceeded") + + # append tool response to prompt + request_id = messages[i].response_metadata.pop("request_id") + prompt_ids = messages[i].response_metadata.pop("prompt_ids") + response_mask = messages[i].response_metadata.pop("response_mask") + prompt_ids += tool_response_ids + response_mask += [0] * len(tool_response_ids) + + return request_id, prompt_ids, response_mask + + async def _postprocess( + self, request_id: str, prompt_ids: list[int], response_mask: list[int], response_ids: list[int], **kwargs: Any + ) -> AIMessage: + """Postprocess response_ids when chat completion is done. + + 1. Decode response_ids, parse tool calls to AIMessage. + 2. Append response_ids to prompt_ids, and append 1 to response_mask. + 3. Store trajectory (prompt_ids, response_mask) in AIMessage.response_metadata. + + Args: + request_id (str): Unique request id. + prompt_ids (list[int]): Input prompt token ids in this chat completion. + response_mask (list[int]): Response mask before this chat completion. + response_ids (list[int]): LLM generated token ids in this chat completion. + + Returns: + AIMessage: Postprocessed message. + """ + prompt_ids += response_ids + response_mask += [1] * len(response_ids) + + tool_parser = ToolParser.get_tool_parser(self.tool_parser, self.tokenizer) + content, function_calls = await tool_parser.extract_tool_calls(response_ids) + + tool_calls, invalid_tool_calls = [], [] + for function_call in function_calls: + try: + args = json.loads(function_call.arguments) + if not isinstance(args, dict): + raise json.JSONDecodeError(f"Invalid json tool arguments: {args}") + tool_call = ToolCall( + args=args, + name=function_call.name, + id=str(uuid.uuid4()), + ) + tool_calls.append(tool_call) + except json.JSONDecodeError as e: + logger.warning(f"Invalid json tool arguments: {e}") + tool_call = InvalidToolCall( + args=function_call.arguments, + name=function_call.name, + error=f"Invalid json tool arguments: {e}", + ) + invalid_tool_calls.append(tool_call) + + message = AIMessage( + content=content, + tool_calls=tool_calls[: self.max_parallel_calls], + invalid_tool_calls=invalid_tool_calls[: self.max_parallel_calls], + response_metadata={ + "request_id": request_id, + "prompt_ids": prompt_ids, + "response_mask": response_mask, + }, + ) + return message + + +class TruncateStructuredTool(StructuredTool): + """Structured tool with response truncation.""" + + tool_response_truncate_side: str + """truncate side of tool response: left, middle, right""" + + max_tool_response_length: int + """max length of tool response""" + + async def _arun( + self, + *args: Any, + config: RunnableConfig, + **kwargs: Any, + ) -> Any: + tool_response = await super()._arun(*args, config=config, **kwargs) + tool_response = str(tool_response) + + if len(tool_response) > self.max_tool_response_length: + if self.tool_response_truncate_side == "left": + tool_response = tool_response[: self.max_tool_response_length] + "...(truncated)" + elif self.tool_response_truncate_side == "right": + tool_response = "(truncated)..." + tool_response[-self.max_tool_response_length :] + else: + length = self.max_tool_response_length // 2 + tool_response = tool_response[:length] + "...(truncated)..." + tool_response[-length:] + + return tool_response + + +def convert_to_agent_output(messages: list[BaseMessage], response_length: int) -> AgentLoopOutput: + """Convert messages to AgentLoopOutput. + + Args: + messages (List[BaseMessage]): List of messages, last message must be assistant + with response_metadata containing `prompt_ids` and `response_mask`. + response_length (int): Max length of response. + + Returns: + AgentLoopOutput: agent loop output trajectory used for training. + """ + # skip last tool calls + for i in range(len(messages) - 1, -1, -1): + if messages[i].type != "tool": + break + last_message = messages[i] + assert last_message.type == "ai", f"Last message must be assistant, but got {last_message.type}" + assert "prompt_ids" in last_message.response_metadata, "Last message must have prompt_ids in response_metadata" + assert "response_mask" in last_message.response_metadata, ( + "Last message must have response_mask in response_metadata" + ) + + num_turns = 0 + for i in range(len(messages)): + if messages[i].type == "system": + continue + # parallel tool calls are in single turn + if i == 0 or messages[i].type != messages[i - 1].type: + num_turns += 1 + + prompt_ids = last_message.response_metadata["prompt_ids"] + response_mask = last_message.response_metadata["response_mask"] + + response_ids = prompt_ids[-len(response_mask) :] + prompt_ids = prompt_ids[: len(prompt_ids) - len(response_mask)] + + output = AgentLoopOutput( + prompt_ids=prompt_ids, + response_ids=response_ids[:response_length], + response_mask=response_mask[:response_length], + num_turns=num_turns, + metrics={}, + ) + return output diff --git a/recipe/langgraph_agent/example/README.md b/recipe/langgraph_agent/example/README.md new file mode 100644 index 000000000..021e875bc --- /dev/null +++ b/recipe/langgraph_agent/example/README.md @@ -0,0 +1,111 @@ +# MathExpression: LangGraph Agent Example + +MathExpression is a tiny example to demonstrate multi-turn rollout with [LangGraph ReactAgent](https://langchain-ai.github.io/langgraph/agents/overview/). + +### Define react agent with tool +Firstly, to force ReactAgent to evaluate math expression by tool, we define a special operand `@`: +```python +@tool(parse_docstring=True) +def calculate(a: int, b: int, operand: str) -> int: + """ + Compute the results using operand with two integers + + Args: + a: the first operand + b: the second operand + operand: '+' or '-' or '*' or '@' + """ + assert operand in ["+", "-", "*", "@"], f"unknown operand {operand}" + if operand == "@": + return 3 * a - 2 * b + return eval(f"{a} {operand} {b}") +``` + +Without calling `calculate`, ReactAgent is impossible to evaluate math expression correctly. + +Then, we can equip ReactAgent with `calculate` tool: +```python +class MathExpressionReactAgentLoop(ReactAgentLoop): + @classmethod + def init_class(cls, config, tokenizer): + cls.tools = [calculate] + super().init_class(config, tokenizer) +``` + +We can define agent loop config in yaml file, which will be used by AgentLoopWorker to dynamic load custom AgentLoop class. +```yaml +- name: math_expression + _target_: recipe.langgraph_agent.example.math_expression.MathExpressionReactAgentLoop +``` + +### Prepare dataset +Now, let's prepare two small datasets for training and evaluation: +```bash +python recipe/langgraph_agent/example/create_dataset.py +``` + +Note that dataset should contain a column `agent_name` with `math_expression`, which is used by `AgentLoopWorker` to select the +agent loop class. +| prompt | reward_model | agent_name | +|--------------------------------------|------------------------------|-----------------| +| [{'role': 'user', 'content': '...'}] | {'ground_truth': '-10', ...} | math_expression | +| [{'role': 'user', 'content': '...'}] | {'ground_truth': '-10', ...} | math_expression | + +Generated math expressions are like below, requiring model to call `calculate` multiple times to solve sub expressions. +``` +(2 @ (8 @ 8 @ 5 @ 5 @ 3) @ 6 @ (1 @ 4 @ 4 @ 4) @ 2) @ 6 +(4.6 @ (9.05 @ 4.0) @ 8.3 @ 1.21) @ 8.6 +9 @ 4 +((2 @ 2) @ (3 @ 3)) @ 4 +``` + +### Training +Hook all these up and start training: +```bash +bash recipe/langgraph_agent/example/run_qwen2.5_3b.sh 2>&1 | tee train.log +``` + +After total 39 steps, model should achieve 100% accuray on test dataset: +- val-aux/lighteval/MATH/reward: 1.0 +- val-aux/num_turns/mean: 9.0, average number of messages include assistant and tool turns. + +We can see model can solve math expression by calling `calculate` tool: +``` +system +You are Qwen, created by Alibaba Cloud. You are a helpful assistant. +user +We define a new math operator @, where you can only call an external tool to compute. Please put your final answer inside \boxed{} only in the last turn. Now answer the following questions: +Compute 5 @ (3 @ 6 @ 8) +assistant +Sure, let's break this down step by step. We need to compute the expression \(5 @ (3 @ 6 @ 8)\). First, we need to compute the inner expression \(3 @ 6 @ 8\). + +Let's compute \(3 @ 6\) first using the external tool. + +{"name": "calculate", "arguments": {"a": 3, "b": 6, "operand": "@"}} +user + +-3 + +assistant +We have computed that \(3 @ 6 = -3\). Now we need to compute \(-3 @ 8\). + +Let's use the external tool to compute this. + +{"name": "calculate", "arguments": {"a": -3, "b": 8, "operand": "@"}} +user + +-25 + +assistant +We have computed that \(-3 @ 8 = -25\). Now we can put this result into the original expression to get the final result of \(5 @ (3 @ 6 @ 8)\), which is \(5 @ (-25)\). + +Let's compute \(5 @ (-25)\) using the external tool. + +{"name": "calculate", "arguments": {"a": 5, "b": -25, "operand": "@"}} +user + +65 + +assistant +The final result of \(5 @ (3 @ 6 @ 8)\) is \(\boxed{65}\). +``` diff --git a/recipe/langgraph_agent/example/agent.yaml b/recipe/langgraph_agent/example/agent.yaml new file mode 100644 index 000000000..cbd8fb9eb --- /dev/null +++ b/recipe/langgraph_agent/example/agent.yaml @@ -0,0 +1,2 @@ +- name: math_expression + _target_: recipe.langgraph_agent.example.math_expression.MathExpressionReactAgentLoop diff --git a/recipe/langgraph_agent/example/create_dataset.py b/recipe/langgraph_agent/example/create_dataset.py new file mode 100644 index 000000000..fb14e755d --- /dev/null +++ b/recipe/langgraph_agent/example/create_dataset.py @@ -0,0 +1,277 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Create dataset for calculator +""" + +import random + +import pandas as pd + + +def generate_math_expression(min_terms=2, max_terms=5, min_number=1, max_number=10, allow_decimals=False, max_depth=2): + """ + Generate a random mathematical expression with operators +, -, *, /, and parentheses. + + Args: + min_terms (int): Minimum number of terms in the expression. + max_terms (int): Maximum number of terms in the expression. + max_number (int): Maximum value for numbers in the expression. + allow_decimals (bool): Whether to allow decimal numbers. + max_depth (int): Maximum nesting depth for parentheses. + + Returns: + str: A valid mathematical expression as a string. + """ + + def generate_number(): + """Generate a random number (integer or float).""" + assert min_number < max_number + num = random.uniform(min_number, max_number) + if not allow_decimals: + num = int(num) + else: + num = round(num, random.randint(0, 2)) # Round to 0-2 decimal places + return str(num) + + def generate_term(depth=0): + """Generate a term (number or parenthesized expression).""" + if depth < max_depth and random.random() < 0.5: # 50% chance to add parentheses + expr = generate_expression(depth + 1) + return f"({expr})" + else: + return generate_number() + + def generate_expression(depth=0): + """Generate a full expression with multiple terms and operators.""" + num_terms = random.randint(min_terms, max_terms) + terms = [generate_term(depth) for _ in range(num_terms)] + + # Randomly select operators + operators = ["+", "-", "*", "/", "@"] + expr = terms[0] + + for i in range(1, num_terms): + # Bias towards + and - for readability + op = random.choices( + operators, + weights=[0, 0, 0, 0, 1], # + and - are 1.5x more likely than * and / + )[0] + expr += f" {op} " + terms[i] + + return expr + + return generate_expression() + + +def test(): + # Example 1: Basic integer expression + print(generate_math_expression()) + # Output: (3 + 7) * 2 - 5 + + # Example 2: Expression with decimals + print(generate_math_expression(allow_decimals=True)) + # Output: 4.5 / (2.1 + 3.7) - 1.2 + + # Example 3: More complex expression with higher depth + print(generate_math_expression(max_terms=6, max_depth=3)) + # Output: ((5 * 2) - (3 + 1)) / (7 - 2) + 4 + + # Example 4: Simplified expression + print(generate_math_expression(min_terms=2, max_terms=3, max_number=5)) + # Output: 4 - 2 * 3 + + +def calculate(expression: str) -> float: + """ + Evaluate a mathematical expression with +, -, *, /, @, and parentheses. + The @ operator is defined as: a @ b = 3a - 2b. + + Args: + expression (str): Input mathematical expression (e.g., "3@2+4"). + + Returns: + float: Result of the evaluated expression. + + Raises: + ValueError: For invalid expressions (e.g., mismatched parentheses, division by zero). + """ + + def tokenize(s: str) -> list: + """Convert the input string into tokens (numbers, operators, parentheses).""" + tokens = [] + i = 0 + while i < len(s): + if s[i].isdigit() or s[i] == ".": + # Parse number (integer or float) + j = i + while j < len(s) and (s[j].isdigit() or s[j] == "."): + j += 1 + tokens.append(s[i:j]) + i = j + elif s[i] in "+-*/@()": + # Operator or parenthesis + tokens.append(s[i]) + i += 1 + elif s[i].isspace(): + # Skip whitespace + i += 1 + else: + raise ValueError(f"Invalid character: {s[i]}") + return tokens + + def infix_to_postfix(tokens: list) -> list: + """Convert infix notation to postfix notation (Reverse Polish Notation).""" + output = [] + stack = [] + # Higher precedence for @ (between * and +) + precedence = {"@": 3, "*": 2, "/": 2, "+": 1, "-": 1} + + for token in tokens: + if token.isdigit() or "." in token: + output.append(token) + elif token == "(": + stack.append(token) + elif token == ")": + while stack and stack[-1] != "(": + output.append(stack.pop()) + if not stack or stack[-1] != "(": + raise ValueError("Mismatched parentheses") + stack.pop() # Discard '(' + else: # Operator + while stack and stack[-1] != "(" and precedence.get(stack[-1], 0) >= precedence.get(token, 0): + output.append(stack.pop()) + stack.append(token) + + # Pop remaining operators + while stack: + if stack[-1] in "()": + raise ValueError("Mismatched parentheses") + output.append(stack.pop()) + + return output + + def evaluate_postfix(postfix: list) -> float: + """Evaluate postfix expression using a stack.""" + stack = [] + for token in postfix: + if token.isdigit() or "." in token: + stack.append(float(token)) + else: + if len(stack) < 2: + raise ValueError("Invalid expression") + b = stack.pop() + a = stack.pop() + if token == "+": + res = a + b + elif token == "-": + res = a - b + elif token == "*": + res = a * b + elif token == "/": + if b == 0: + raise ValueError("Division by zero") + res = a / b + elif token == "@": + res = 3 * a - 2 * b # Custom @ operator implementation + else: + raise ValueError(f"Invalid operator: {token}") + stack.append(res) + + if len(stack) != 1: + raise ValueError("Invalid expression") + return stack[0] + + # Remove spaces and validate parentheses + expression = expression.replace(" ", "") + if expression.count("(") != expression.count(")"): + raise ValueError("Mismatched parentheses") + + tokens = tokenize(expression) + postfix = infix_to_postfix(tokens) + result = evaluate_postfix(postfix) + + # Convert integers to integer representation + if result.is_integer(): + return int(result) + return result + + +def generate_data(total_num_dataset, split): + rl_dataset = { + "prompt": [], + "data_source": [], + "ability": [], + "reward_model": [], + "extra_info": [], + "agent_name": [], + } + + for idx in range(total_num_dataset): + while True: + try: + expression: str = generate_math_expression( + min_terms=2, max_terms=3, min_number=1, max_number=10, allow_decimals=False, max_depth=1 + ) + + num_plus = expression.count("+") + num_minus = expression.count("-") + num_mul = expression.count("*") + num_star = expression.count("@") + + answer = str(calculate(expression)) + # answer = str(eval(expression)) + break + except Exception as e: + print(e) + continue + + num_tool_calls = num_plus + num_minus + num_mul + num_star + + prompt = ( + f"We define a new math operator @, where you can only call an external tool to compute. " + f"Please put your final answer inside \\boxed{{}} only in the last turn. Now answer the " + f"following questions:\nCompute {expression}" + ) + prompt_with_template = [ + { + "role": "user", + "content": prompt, + } + ] + + rl_dataset["prompt"].append(prompt_with_template) + rl_dataset["data_source"].append("lighteval/MATH") + rl_dataset["ability"].append("math") + rl_dataset["reward_model"].append({"style": "lighteval/MATH", "ground_truth": answer}) + rl_dataset["extra_info"].append( + {"index": idx, "expression": expression, "split": split, "expected_tool_calls": num_tool_calls} + ) + rl_dataset["agent_name"].append("math_expression") + + rl_dataset = pd.DataFrame(data=rl_dataset) + return rl_dataset + + +if __name__ == "__main__": + # print(calculate("3@2")) # Output: 5 (3*3 - 2*2) + # print(calculate("3@2+4")) # Output: 9 (5 + 4) + # print(calculate("3*(4@2)")) # Output: 24 (3 * 8) + # print(calculate("(5@3)*2")) # Output: 18 (9 * 2) + + train_dataset = generate_data(total_num_dataset=5000, split="train") + test_dataset = generate_data(total_num_dataset=500, split="test") + + train_dataset.to_parquet("train.parquet") + test_dataset.to_parquet("test.parquet") diff --git a/recipe/langgraph_agent/example/math_expression.py b/recipe/langgraph_agent/example/math_expression.py new file mode 100644 index 000000000..4532c8af3 --- /dev/null +++ b/recipe/langgraph_agent/example/math_expression.py @@ -0,0 +1,39 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from langchain_core.tools import tool + +from recipe.langgraph_agent.react_agent_loop import ReactAgentLoop + + +@tool(parse_docstring=True) +def calculate(a: int, b: int, operand: str) -> int: + """ + Compute the results using operand with two integers + + Args: + a: the first operand + b: the second operand + operand: '+' or '-' or '*' or '@' + """ + assert operand in ["+", "-", "*", "@"], f"unknown operand {operand}" + if operand == "@": + return 3 * a - 2 * b + return eval(f"{a} {operand} {b}") + + +class MathExpressionReactAgentLoop(ReactAgentLoop): + @classmethod + def init_class(cls, config, tokenizer, **kwargs): + cls.tools = [calculate] + super().init_class(config, tokenizer) diff --git a/recipe/langgraph_agent/example/run_qwen2.5_3b.sh b/recipe/langgraph_agent/example/run_qwen2.5_3b.sh new file mode 100644 index 000000000..4a398bb6a --- /dev/null +++ b/recipe/langgraph_agent/example/run_qwen2.5_3b.sh @@ -0,0 +1,99 @@ +set -x + +# ================= data/model/tool ================= +HDFS_ROOT=${HDFS_ROOT:-$PWD} +DATA_ROOT=${DATA_ROOT:-$PWD} + +model_path=$DATA_ROOT/model/Qwen2.5-3B-Instruct + +train_files=$DATA_ROOT/dataset/math_expression_tool/train.parquet +test_files=$DATA_ROOT/dataset/math_expression_tool/test.parquet + +# agent +agent_loop_config_path=recipe/langgraph_agent/example/agent.yaml + +# wandb +project_name=math_expression_tool +experiment_name=qwen2.5-3b +default_local_dir=$DATA_ROOT/checkpoint/$experiment_name + +# ================= algorithm ================= +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_turns=8 +max_prompt_length=1024 +max_response_length=2048 +actor_lr=1e-6 + +train_batch_size=128 +ppo_mini_batch_size=16 +n_resp_per_prompt=8 +n_resp_per_prompt_val=1 + +# ================= perfomance ================= +infer_tp=2 # vllm +train_sp=4 # train +offload=True + +actor_max_token_len_per_gpu=$(( (max_prompt_length + max_response_length) * 4 )) +log_prob_max_token_len_per_gpu=$(( actor_max_token_len_per_gpu * 2 )) + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=$adv_estimator \ + algorithm.use_kl_in_reward=$use_kl_in_reward \ + algorithm.kl_ctrl.kl_coef=$kl_coef \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.return_raw_chat=True \ + data.train_batch_size=$train_batch_size \ + data.max_prompt_length=$max_prompt_length \ + data.max_response_length=$max_response_length \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.model.path=$model_path \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.use_kl_loss=$use_kl_loss \ + actor_rollout_ref.actor.kl_loss_coef=$kl_loss_coef \ + actor_rollout_ref.actor.clip_ratio_low=$clip_ratio_low \ + actor_rollout_ref.actor.clip_ratio_high=$clip_ratio_high \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.optim.lr=$actor_lr \ + actor_rollout_ref.actor.use_dynamic_bsz=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=$ppo_mini_batch_size \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=$actor_max_token_len_per_gpu \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=$train_sp \ + actor_rollout_ref.actor.fsdp_config.param_offload=$offload \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=$offload \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=$log_prob_max_token_len_per_gpu \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.mode=async \ + actor_rollout_ref.rollout.tensor_model_parallel_size=$infer_tp \ + actor_rollout_ref.rollout.multi_turn.max_user_turns=$max_turns \ + actor_rollout_ref.rollout.multi_turn.max_assistant_turns=$max_turns \ + actor_rollout_ref.rollout.multi_turn.format=hermes \ + actor_rollout_ref.rollout.agent.agent_loop_config_path=$agent_loop_config_path \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.9 \ + actor_rollout_ref.rollout.n=$n_resp_per_prompt \ + actor_rollout_ref.rollout.val_kwargs.top_p=0.6 \ + actor_rollout_ref.rollout.val_kwargs.temperature=1.0 \ + actor_rollout_ref.rollout.val_kwargs.n=$n_resp_per_prompt_val \ + trainer.logger=['console','wandb'] \ + trainer.project_name=$project_name \ + trainer.experiment_name=$experiment_name \ + trainer.n_gpus_per_node=$ARNOLD_WORKER_GPU \ + trainer.val_before_train=True \ + trainer.log_val_generations=50 \ + trainer.nnodes=$ARNOLD_WORKER_NUM \ + trainer.save_freq=-1 \ + trainer.default_local_dir=$default_local_dir \ + trainer.test_freq=5 \ + trainer.total_epochs=1 $@ diff --git a/recipe/langgraph_agent/react_agent_loop.py b/recipe/langgraph_agent/react_agent_loop.py new file mode 100644 index 000000000..578968a92 --- /dev/null +++ b/recipe/langgraph_agent/react_agent_loop.py @@ -0,0 +1,133 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +LangGraph React Agent Loop. + +This implementation is exact same as `ToolAgentLoop`. + +Ref: https://langchain-ai.github.io/langgraph/tutorials/workflows/ +""" + +from typing import Any, Literal + +from langchain_core.runnables import RunnableConfig +from langgraph.graph import END, MessagesState, StateGraph +from langgraph.prebuilt import ToolNode + +from recipe.langgraph_agent.chat_model import ( + ChatModel, + MaxTokenExceededError, + convert_to_agent_output, +) +from verl.experimental.agent_loop.agent_loop import AgentLoopBase, AgentLoopOutput + + +async def call_model(state: MessagesState, config: RunnableConfig): + model = config["configurable"]["model"] + sampling_params = config["configurable"]["sampling_params"] + try: + message = await model.ainvoke(state["messages"], sampling_params=sampling_params) + return {"messages": [message]} + except MaxTokenExceededError: + # last message is ToolMessage + return {"messages": []} + + +def should_continue(state: MessagesState, config: RunnableConfig) -> Literal["tools", END]: + max_assistant_turns = config["configurable"]["max_assistant_turns"] + num_assistant_turns = 0 + for message in state["messages"]: + if message.type == "ai": + num_assistant_turns += 1 + + last_message = state["messages"][-1] + + # LLM call failed, e.g: max response length exceeded + if last_message.type == "tool": + return END + + # max assistant turns exceeded + if max_assistant_turns and num_assistant_turns >= max_assistant_turns: + return END + + # no tool calls + if not last_message.tool_calls: + return END + + return "tools" + + +class ReactAgentLoop(AgentLoopBase): + @classmethod + def init_class(cls, config, tokenizer, **kwargs): + if cls._class_initialized: + return + cls._class_initialized = True + print("Performing class-level ReactAgentLoop initialization") + + # build graph + cls.graph = cls.build_graph() + + @classmethod + def build_graph(cls) -> StateGraph: + workflow = StateGraph(MessagesState) + + workflow.add_node("agent", call_model) + workflow.add_node("tools", ToolNode(cls.tools)) + workflow.set_entry_point("agent") + workflow.add_conditional_edges( + "agent", + should_continue, + { + "tools": "tools", + END: END, + }, + ) + + workflow.add_edge("tools", "agent") + graph = workflow.compile() + return graph + + async def run(self, messages: list[dict[str, Any]], sampling_params: dict[str, Any]) -> AgentLoopOutput: + model_path = self.config.actor_rollout_ref.model.path + model_name = "/".join(model_path.split("/")[-2:]) + + rollout = self.config.actor_rollout_ref.rollout + model = ChatModel( + model=model_name, + client=self.server_manager, + tokenizer=self.tokenizer, + max_tokens=rollout.response_length, + max_parallel_calls=rollout.multi_turn.max_parallel_calls, + tool_parser=rollout.multi_turn.format, + ) + + model = model.bind_tools(self.tools, tool_choice="any") + + config = { + "configurable": { + "model": model, + "sampling_params": sampling_params, + "max_user_turns": rollout.multi_turn.max_user_turns, + "max_assistant_turns": rollout.multi_turn.max_assistant_turns, + } + } + + # TODO: how to handle multiple trajectories in an graph invocation? + # Each graph node may has its own LLM calls and state, e.g: + # https://github.com/google-gemini/gemini-fullstack-langgraph-quickstart + state = await self.graph.ainvoke(input={"messages": messages}, config=config) + + output = convert_to_agent_output(state["messages"], rollout.response_length) + return output diff --git a/recipe/langgraph_agent/test_react_agent_loop.py b/recipe/langgraph_agent/test_react_agent_loop.py new file mode 100644 index 000000000..0cdc91959 --- /dev/null +++ b/recipe/langgraph_agent/test_react_agent_loop.py @@ -0,0 +1,199 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import os + +import numpy as np +import pytest +import ray +from langchain_core.tools import tool +from omegaconf import DictConfig + +from recipe.langgraph_agent.react_agent_loop import ReactAgentLoop +from tests.experimental.agent_loop.agent_utils import init_agent_loop_manager +from verl.protocol import DataProto +from verl.utils import hf_tokenizer + + +@pytest.fixture +def init_config() -> DictConfig: + from hydra import compose, initialize_config_dir + + with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")): + config = compose(config_name="ppo_trainer") + model_path = "Qwen/Qwen2.5-1.5B-Instruct" + config.actor_rollout_ref.model.path = model_path + config.actor_rollout_ref.rollout.name = os.getenv("ROLLOUT_NAME", "vllm") + config.actor_rollout_ref.rollout.mode = "async" + config.actor_rollout_ref.rollout.prompt_length = 4096 + config.actor_rollout_ref.rollout.response_length = 4096 + config.actor_rollout_ref.rollout.n = 4 + config.actor_rollout_ref.rollout.agent.num_workers = 2 + + # test sleep/wake_up with fsdp offload + config.actor_rollout_ref.actor.fsdp_config.param_offload = True + config.actor_rollout_ref.actor.fsdp_config.optimizer_offload = True + + return config + + +@tool(parse_docstring=True) +def get_current_temperature(location: str, unit: str = "celsius"): + """Get current temperature at a location. + + Args: + location: The location to get the temperature for, in the format "City, State, Country". + unit: The unit to return the temperature in. Defaults to "celsius". (choices: ["celsius", "fahrenheit"]) + + Returns: + the temperature, the location, and the unit in a dict + """ + print(f"[DEBUG] get_current_temperature: {location}, {unit}") + return { + "temperature": 26.1, + "location": location, + "unit": unit, + } + + +@tool(parse_docstring=True) +def get_temperature_date(location: str, date: str, unit: str = "celsius"): + """Get temperature at a location and date. + + Args: + location: The location to get the temperature for, in the format "City, State, Country". + date: The date to get the temperature for, in the format "Year-Month-Day". + unit: The unit to return the temperature in. Defaults to "celsius". (choices: ["celsius", "fahrenheit"]) + + Returns: + the temperature, the location, the date and the unit in a dict + """ + print(f"[DEBUG] get_temperature_date: {location}, {date}, {unit}") + return { + "temperature": 25.9, + "location": location, + "date": date, + "unit": unit, + } + + +class TestReactAgentLoop(ReactAgentLoop): + @classmethod + def init_class(cls, config, tokenizer, **kwargs): + # TODO: find better way to configure tools + cls.tools = [get_current_temperature, get_temperature_date] + super().init_class(config, tokenizer, **kwargs) + + +def test_react_agent(init_config): + ray.init( + runtime_env={ + "env_vars": { + "TOKENIZERS_PARALLELISM": "true", + "NCCL_DEBUG": "WARN", + "VLLM_LOGGING_LEVEL": "INFO", + "VLLM_USE_V1": "1", + } + } + ) + + # =========================== 1. Init rollout manager =========================== + agent_loop_config = [ + { + "_target_": "recipe.langgraph_agent.test_react_agent_loop.TestReactAgentLoop", + "name": "react_agent", + }, + ] + agent_loop_config_path = "/tmp/agent_loop_config.json" + with open(agent_loop_config_path, "w") as f: + json.dump(agent_loop_config, f) + + n = 2 + init_config.actor_rollout_ref.rollout.n = n + # init_config.actor_rollout_ref.rollout.multi_turn.tool_config_path = tool_config_path + init_config.actor_rollout_ref.rollout.multi_turn.max_parallel_calls = 2 + init_config.actor_rollout_ref.rollout.agent.agent_loop_config_path = agent_loop_config_path + agent_loop_manager = init_agent_loop_manager(init_config) + + # =========================== 2. Generate sequences =========================== + raw_prompts = [ + [ + {"role": "user", "content": "How are you?"}, + ], + [ + {"role": "user", "content": "What's the temperature in Los Angeles now?"}, + ], + [ + {"role": "user", "content": "What's the temperature in New York now?"}, + ], + [ + { + "role": "system", + "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant.\n\n" + "Current Date: 2024-09-30", + }, + {"role": "user", "content": "What's the temperature in San Francisco now? How about tomorrow?"}, + ], + ] + batch = DataProto( + non_tensor_batch={ + "raw_prompt": np.array([np.array(prompt) for prompt in raw_prompts], dtype=object), + "agent_name": np.array(["react_agent"] * len(raw_prompts)), + }, + ) + batch = batch.repeat(n) + result = agent_loop_manager.generate_sequences(prompts=batch) + assert len(result) == len(raw_prompts) * n + + # Check turns + num_turns = result.non_tensor_batch["__num_turns__"] + print(f"num_turns: {num_turns}") + for i in range(len(num_turns)): + if i // n == 0: + # [user, assistant] + assert num_turns[i] == 2 + else: + # [user, assistant, tool, assistant] + assert num_turns[i] == 4 + + # Check response_mask + tokenizer = hf_tokenizer(init_config.actor_rollout_ref.model.path) + responses = result.batch["responses"] + response_mask = result.batch["response_mask"] + attention_mask = result.batch["attention_mask"] + assert responses.size() == response_mask.size(), f"{responses.size()} != {response_mask.size()}" + response_length = response_mask.size(1) + + for i in range(len(responses)): + # response with tool response + valid_tokens = responses[i][attention_mask[i][-response_length:].bool()] + response_with_obs = tokenizer.decode(valid_tokens) + + # response without tool response + valid_tokens = responses[i][response_mask[i].bool()] + response_without_obs = tokenizer.decode(valid_tokens) + + assert "" not in response_without_obs, ( + f"found in response: {response_without_obs}" + ) + assert "" not in response_without_obs, ( + f"found in response: {response_without_obs}" + ) + print("=========================") + print(response_with_obs) + print("---") + print(response_without_obs) + + print("Test passed!") + ray.shutdown() diff --git a/recipe/minicpmo/rl_dataset.py b/recipe/minicpmo/rl_dataset.py new file mode 100644 index 000000000..5ce15fb12 --- /dev/null +++ b/recipe/minicpmo/rl_dataset.py @@ -0,0 +1,553 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import logging +import math +import os +import re +from typing import Optional + +import datasets +import torch +from omegaconf import DictConfig, ListConfig +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms +from transformers import PreTrainedTokenizer, ProcessorMixin + +import verl.utils.torch_functional as verl_F +from verl.utils.dataset.vision_utils import process_image +from verl.utils.model import compute_position_id_with_mask + +logger = logging.getLogger(__name__) + + +def build_transform(): + IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) # timm.data.IMAGENET_INCEPTION_MEAN + IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) # timm.data.IMAGENET_INCEPTION_STD + return transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize(mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + ] + ) + + +def build_image_bound(input_ids, tokenizer, new_schema=True, logger=None): + if new_schema: + start_cond = (input_ids == tokenizer.im_start_id) | (input_ids == tokenizer.slice_start_id) + end_cond = (input_ids == tokenizer.im_end_id) | (input_ids == tokenizer.slice_end_id) + else: + start_cond = input_ids == tokenizer.im_start_id + end_cond = input_ids == tokenizer.im_end_id + image_start_tokens = torch.where(start_cond)[0] + image_start_tokens += 1 + image_end_tokens = torch.where(end_cond)[0] + if len(image_start_tokens) != len(image_end_tokens): + logger.error("image start token != image end tokens") + raise Exception("image start token != image end tokens") + if len(image_start_tokens) > 0: + image_bound = torch.hstack([image_start_tokens.unsqueeze(-1), image_end_tokens.unsqueeze(-1)]) + else: + image_bound = [] + return image_bound + + +def preprocess( + images_dict, + conversations, + tokenizer, + transform, + query_nums=64, + slice_config=None, + llm_type=None, + patch_size=14, + batch_vision=False, + max_length=2048, + truncation="error", + logger=None, +): + """ + single(multi) image(s) preprocess, the image(s) will be placed at the top of the conversation + """ + conversations = copy.deepcopy(conversations) + assert conversations[0]["role"] == "user", "the first role must be user" + + if slice_config is not None: + assert isinstance(slice_config, dict) + assert "patch_size" in slice_config + assert "max_slice_nums" in slice_config + assert "scale_resolution" in slice_config + default_image_placeholder = tokenizer.im_start + tokenizer.unk_token * query_nums + tokenizer.im_end + new_schema = False + use_image_id = False + if llm_type == "qwen": + new_schema = True + use_image_id = True + image_placeholder_dict = {} + images = [] + image_id_cnt = 0 + for img_name, image in images_dict.items(): + if slice_config: + source_image, patches, best_grid = slice_image( + image, + slice_config["max_slice_nums"], + slice_config["scale_resolution"], + slice_config["patch_size"], + ) + images.append(source_image) + image_placeholder = default_image_placeholder + if len(patches) > 0: + for i in range(len(patches)): + for j in range(len(patches[0])): + images.append(patches[i][j]) + if use_image_id: + image_placeholder = ( + f"{tokenizer.im_id_start}{image_id_cnt}{tokenizer.im_id_end}" + image_placeholder + ) + image_id_cnt += 1 + image_placeholder += get_grid_placeholder(tokenizer, best_grid, query_nums, new_schema=new_schema) + image_placeholder_dict[img_name] = image_placeholder + else: + images.append(image) + if use_image_id: + image_placeholder = f"{tokenizer.im_id_start}{image_id_cnt}{tokenizer.im_id_end}" + image_placeholder + image_id_cnt += 1 + else: + image_placeholder = default_image_placeholder + image_placeholder_dict[img_name] = image_placeholder + + images = [transform(i) for i in images] + + if len(images_dict) == 1 and "" in images_dict: + if "" in conversations[0]["content"]: + conversations[0]["content"] = conversations[0]["content"].replace("", image_placeholder) + else: + conversations[0]["content"] = image_placeholder + "\n" + conversations[0]["content"] + else: + pattern = r"" + new_conversations = [] + for conversation in conversations: + content = conversation["content"] + parts = re.split(f"({pattern})", content) + for i, part in enumerate(parts): + if not part.strip(): + continue + if re.match(pattern, part): + if part in image_placeholder_dict: + parts[i] = image_placeholder_dict[part] + else: + raise Exception(f"not found {part} in image dict") + conversation["content"] = "\n".join(parts) + new_conversations.append(conversation) + conversations = new_conversations + + # TODO change role in conversation for different llm + prompt_with_chat_template = tokenizer.apply_chat_template(conversations, add_generation_prompt=True, tokenize=False) + + input_ids, attention_mask = verl_F.tokenize_and_postprocess_data( + prompt=prompt_with_chat_template, + tokenizer=tokenizer, + max_length=max_length, + pad_token_id=tokenizer.pad_token_id, + left_pad=True, + truncation=truncation, + ) + position_ids = compute_position_id_with_mask(attention_mask) + image_bound = build_image_bound(input_ids[0], tokenizer, new_schema, logger) + + input_dict = { + "input_ids": input_ids[0], + "attention_mask": attention_mask[0], + "position_ids": position_ids[0], + "image_bound": image_bound, + } + + if batch_vision: + tgt_sizes = [] + reshape_images = [] + for image in images: + H, W = image.shape[1:] + reshape_image = reshape_by_patch(image, patch_size) + reshape_images.append(reshape_image) + tgt_sizes.append([H // patch_size, W // patch_size]) + if tgt_sizes: + tgt_sizes = torch.Tensor(tgt_sizes).type(torch.int32) + + input_dict["pixel_values"] = reshape_images + input_dict["tgt_sizes"] = tgt_sizes + + else: + input_dict["pixel_values"] = images + input_dict["tgt_sizes"] = [] + + return input_dict + + +def slice_image(image, max_slice_nums=9, scale_resolution=448, patch_size=14, never_split=False): + original_size = image.size + original_width, original_height = original_size + log_ratio = math.log(original_width / original_height) + ratio = original_width * original_height / (scale_resolution * scale_resolution) + multiple = min(math.ceil(ratio), max_slice_nums) + + source_image = None + best_grid = None + patches = [] + + if multiple <= 1 or never_split: + # dont need to slice, upsample + best_size = find_best_resize(original_size, scale_resolution, patch_size, allow_upscale=True) + source_image = image.resize(best_size, Image.Resampling.BICUBIC) + else: + candidate_split_grids_nums = [] + for i in [multiple - 1, multiple, multiple + 1]: + if i == 1 or i > max_slice_nums: + continue + candidate_split_grids_nums.append(i) + + # source image, down-sampling and ensure divided by patch_size + best_resize = find_best_resize(original_size, scale_resolution, patch_size) + source_image = image.copy().resize(best_resize, Image.Resampling.BICUBIC) + candidate_grids = [] + + # find best grid + for split_grids_nums in candidate_split_grids_nums: + m = 1 + while m <= split_grids_nums: + if split_grids_nums % m == 0: + candidate_grids.append([m, split_grids_nums // m]) + m += 1 + + best_grid = [1, 1] + min_error = float("inf") + for grid in candidate_grids: + error = abs(log_ratio - math.log(grid[0] / grid[1])) + if error < min_error: + best_grid = grid + min_error = error + + refine_size = get_refine_size(original_size, best_grid, scale_resolution, patch_size, allow_upscale=True) + + refine_image = image.resize(refine_size, Image.Resampling.BICUBIC) + patches = split_to_patches(refine_image, best_grid) + + return source_image, patches, best_grid + + +def ensure_divide(length, patch_size): + return max(round(length / patch_size) * patch_size, patch_size) + + +def find_best_resize(original_size, scale_resolution, patch_size, allow_upscale=False): + width, height = original_size + if (width * height > scale_resolution * scale_resolution) or allow_upscale: + r = width / height + height = int(scale_resolution / math.sqrt(r)) + width = int(height * r) + best_width = ensure_divide(width, patch_size) + best_height = ensure_divide(height, patch_size) + return (best_width, best_height) + + +def get_refine_size(original_size, grid, scale_resolution, patch_size, allow_upscale=False): + width, height = original_size + grid_x, grid_y = grid + + refine_width = ensure_divide(width, grid_x) + refine_height = ensure_divide(height, grid_y) + + grid_width = refine_width / grid_x + grid_height = refine_height / grid_y + + best_grid_size = find_best_resize( + (grid_width, grid_height), + scale_resolution, + patch_size, + allow_upscale=allow_upscale, + ) + + refine_size = (best_grid_size[0] * grid_x, best_grid_size[1] * grid_y) + + return refine_size + + +def split_to_patches(image, grid): + patches = [] + width, height = image.size + grid_x = int(width / grid[0]) + grid_y = int(height / grid[1]) + + for i in range(0, height, grid_y): + images = [] + for j in range(0, width, grid_x): + box = (j, i, j + grid_x, i + grid_y) + patch = image.crop(box) + images.append(patch) + patches.append(images) + + return patches + + +def get_grid_placeholder(tokenizer, grid, query_num, new_schema=False): + if new_schema: + image_placeholder = tokenizer.slice_start + tokenizer.unk_token * query_num + tokenizer.slice_end + else: + image_placeholder = tokenizer.im_start + tokenizer.unk_token * query_num + tokenizer.im_end + + cols = grid[0] + rows = grid[1] + slices = [] + for i in range(rows): + lines = [] + for j in range(cols): + lines.append(image_placeholder) + slices.append("".join(lines)) + if new_schema: + slice_placeholder = "\n".join(slices) + else: + slice_placeholder = tokenizer.slice_start + "\n".join(slices) + tokenizer.slice_end + return slice_placeholder + + +def reshape_by_patch(image_tensor, patch_size): + """ + :param image_tensor: shape [3, H, W] + :param patch_size: + :return: [3, patch_size, HW/patch_size] + """ + patches = torch.nn.functional.unfold(image_tensor, (patch_size, patch_size), stride=(patch_size, patch_size)) + + patches = patches.reshape(image_tensor.size(0), patch_size, patch_size, -1) + patches = patches.permute(0, 1, 3, 2).reshape(image_tensor.size(0), patch_size, -1) + return patches + + +def init_minicpmo_config(processor, config): + """Initialize MiniCPM-o specific configuration""" + minicpmo_config = { + "transform": build_transform(), + "patch_size": config.get("patch_size", 14), + "query_nums": config.get("query_nums", 64), + "slice_config": config.get( + "slice_config", {"max_slice_nums": 9, "patch_size": config.get("patch_size", 14), "scale_resolution": 448} + ), + "llm_type": config.get("llm_type", "qwen"), + "batch_vision": config.get("batch_vision", True), + } + return minicpmo_config + + +def process_minicpmo_data( + row_dict, messages, tokenizer, minicpmo_config, image_key, max_prompt_length, truncation, logger +): + """Process data for MiniCPM-o model""" + if len(row_dict[image_key]) == 1: + multi_modal_data = {} + image = process_image(row_dict.pop(image_key)[0]) + multi_modal_data["image"] = [image] + images_dict = {"": image} + else: + raise NotImplementedError + + model_inputs = preprocess( + images_dict, + messages, + tokenizer, + minicpmo_config["transform"], + query_nums=minicpmo_config["query_nums"], + slice_config=minicpmo_config["slice_config"], + llm_type=minicpmo_config["llm_type"], + patch_size=minicpmo_config["patch_size"], + batch_vision=minicpmo_config["batch_vision"], + max_length=max_prompt_length, + truncation=truncation, + logger=logger, + ) + + raw_prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) + raw_prompt = raw_prompt.replace("", "(./)") + + return model_inputs, multi_modal_data, raw_prompt + + +class RLHFDataset(Dataset): + """ + Load and preprocess RLHF data from Parquet files. + + - Caches files locally. + - Reads into a HuggingFace Dataset and tokenizes prompts. + - Optionally handles images/videos via a ProcessorMixin. + - Filters prompts over a max length. + - Supports resuming from checkpoints. + + Args: + data_files (str or list): Path(s) to Parquet file(s). + tokenizer (PreTrainedTokenizer): For the tokenization of text to token IDs. + config (DictConfig): Options like cache_dir, prompt_key, max_prompt_length, truncation, etc. + processor (ProcessorMixin, optional): Multimodal preprocessor for images/videos. + """ + + def __init__( + self, + data_files: str | list[str], + tokenizer: PreTrainedTokenizer, + config: DictConfig, + processor: Optional[ProcessorMixin] = None, + ): + if not isinstance(data_files, list | ListConfig): + data_files = [data_files] + + self.data_files = copy.deepcopy(data_files) + self.original_data_files = copy.deepcopy(data_files) # use for resume + self.tokenizer = tokenizer + self.processor = processor + self.config = config + + self.cache_dir = os.path.expanduser(config.get("cache_dir", "~/.cache/verl/rlhf")) + self.prompt_key = config.get("prompt_key", "prompt") + self.image_key = config.get("image_key", "images") + self.video_key = config.get("video_key", "videos") + self.max_prompt_length = config.get("max_prompt_length", 1024) + self.return_raw_chat = config.get("return_raw_chat", False) + self.return_full_prompt = config.get("return_full_prompt", False) + self.truncation = config.get("truncation", "error") + self.filter_overlong_prompts = config.get("filter_overlong_prompts", True) + + self.num_workers = config.get("filter_overlong_prompts_workers", max(1, os.cpu_count() // 4)) + self.num_workers = min(self.num_workers, os.cpu_count()) + self.use_shm = config.get("use_shm", False) + self.chat_template_func = config.get("chat_template_func", None) + self.need_tools_kwargs = config.get("need_tools_kwargs", False) + self.filter_prompts = config.get("filter_prompts", True) + self.serialize_dataset = False + self.minicpmo_config = init_minicpmo_config(self.processor, config) + self._download() + self._read_files_and_tokenize() + + def _download(self, use_origin_parquet=False): + from verl.utils.fs import copy_to_local + + data_files = self.data_files if not use_origin_parquet else self.original_data_files + for i, parquet_file in enumerate(data_files): + self.data_files[i] = copy_to_local(src=parquet_file, cache_dir=self.cache_dir, use_shm=self.use_shm) + + def _read_files_and_tokenize(self): + dataframes = [] + for parquet_file in self.data_files: + # read parquet files and cache + dataframe = datasets.load_dataset("parquet", data_files=parquet_file)["train"] + dataframes.append(dataframe) + self.dataframe: datasets.Dataset = datasets.concatenate_datasets(dataframes) + + print(f"dataset len: {len(self.dataframe)}") + + def resume_dataset_state(self): + self.serialize_dataset = not hasattr(self, "original_data_files") + # resume dataframe if not it's serialized in data.pt + if not self.serialize_dataset: + self._download(use_origin_parquet=True) # download and resume from original parquet files + self._read_files_and_tokenize() + else: + print(r"old dataloader ckpt file is used, please train from scratch for better ckpt performance") + + def __len__(self): + return len(self.dataframe) + + def _build_messages(self, example: dict): + return example.pop(self.prompt_key) + + def __getitem__(self, item): + """ + Note that we also return the raw_input_ids so that it can be combined with other chat template + """ + row_dict: dict = self.dataframe[item] + messages = self._build_messages(row_dict) + model_inputs = {} + + if self.processor is not None: + model_inputs, multi_modal_data, raw_prompt = process_minicpmo_data( + row_dict, + messages, + self.tokenizer, + self.minicpmo_config, + self.image_key, + self.max_prompt_length, + self.truncation, + logger, + ) + input_ids = model_inputs.pop("input_ids") + attention_mask = model_inputs.pop("attention_mask") + position_ids = model_inputs.pop("position_ids") + + # There's a trap here, multi_modal_inputs has to be a dict, not BatchFeature + row_dict["multi_modal_data"] = multi_modal_data + row_dict["multi_modal_inputs"] = dict(model_inputs) + else: + raw_prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) + model_inputs = self.tokenizer(raw_prompt, return_tensors="pt", add_special_tokens=False) + input_ids = model_inputs.pop("input_ids") + attention_mask = model_inputs.pop("attention_mask") + position_ids = compute_position_id_with_mask(attention_mask) + + row_dict["input_ids"] = input_ids + row_dict["attention_mask"] = attention_mask + row_dict["position_ids"] = position_ids + + raw_prompt_ids = self.tokenizer.encode(raw_prompt, add_special_tokens=False) + if len(raw_prompt_ids) > self.max_prompt_length: + if self.truncation == "left": + raw_prompt_ids = raw_prompt_ids[-self.max_prompt_length :] + elif self.truncation == "right": + raw_prompt_ids = raw_prompt_ids[: self.max_prompt_length] + elif self.truncation == "middle": + left_half = self.max_prompt_length // 2 + right_half = self.max_prompt_length - left_half + raw_prompt_ids = raw_prompt_ids[:left_half] + raw_prompt_ids[-right_half:] + elif self.truncation == "error": + raise RuntimeError(f"Prompt length {len(raw_prompt_ids)} is longer than {self.max_prompt_length}.") + + row_dict["raw_prompt_ids"] = raw_prompt_ids + # encode prompts without chat template + if self.return_raw_chat: + row_dict["raw_prompt"] = messages + + # get prompts with chat template + if self.return_full_prompt: + row_dict["full_prompts"] = raw_prompt # array of strings + + # add index for each prompt + index = row_dict.get("extra_info", {}).get("index", 0) + tools_kwargs = row_dict.get("extra_info", {}).get("tools_kwargs", {}) + interaction_kwargs = row_dict.get("extra_info", {}).get("interaction_kwargs", {}) + need_tools_kwargs = row_dict.get("extra_info", {}).get("need_tools_kwargs", self.need_tools_kwargs) + if need_tools_kwargs and not tools_kwargs: + logger.warning("tools_kwargs is empty for index {}, data source: {}", index, row_dict["data_source"]) + row_dict["index"] = index + row_dict["tools_kwargs"] = tools_kwargs + row_dict["interaction_kwargs"] = interaction_kwargs + return row_dict + + def __getstate__(self): + if not self.serialize_dataset: + state = self.__dict__.copy() + + if "dataframe" in state: + del state["dataframe"] + return state + + return self.__dict__.copy() diff --git a/recipe/prime/config/prime_trainer.yaml b/recipe/prime/config/prime_trainer.yaml index 56989bf93..23a3c4403 100644 --- a/recipe/prime/config/prime_trainer.yaml +++ b/recipe/prime/config/prime_trainer.yaml @@ -33,6 +33,8 @@ reward_model: ref_path: ${reward_model.model.path} use_remove_padding: True use_fused_kernels: ${actor_rollout_ref.model.use_fused_kernels} + fused_kernel_options: + impl_backend: torch # triton, torch tokenizer_path: ${actor_rollout_ref.model.path} enable_gradient_checkpointing: ${actor_rollout_ref.model.enable_gradient_checkpointing} ref_type: freeze diff --git a/recipe/prime/main_prime.py b/recipe/prime/main_prime.py index 5f912e374..6bf7f5e45 100644 --- a/recipe/prime/main_prime.py +++ b/recipe/prime/main_prime.py @@ -72,8 +72,8 @@ def main_task(config, compute_score=None): tokenizer = hf_tokenizer(local_path) # define worker classes - if config.actor_rollout_ref.actor.strategy == "fsdp": - assert config.actor_rollout_ref.actor.strategy == config.critic.strategy + if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}: + assert config.critic.strategy in {"fsdp", "fsdp2"} from verl.single_controller.ray import RayWorkerGroup from verl.workers.fsdp_workers import ActorRolloutRefWorker @@ -140,6 +140,7 @@ def main_task(config, compute_score=None): ray_worker_group_cls=ray_worker_group_cls, reward_fn=reward_fn, val_reward_fn=val_reward_fn, + device_name=config.trainer.device, ) trainer.init_workers() trainer.fit() diff --git a/recipe/prime/prime_core_algos.py b/recipe/prime/prime_core_algos.py index c17c668a7..825671216 100644 --- a/recipe/prime/prime_core_algos.py +++ b/recipe/prime/prime_core_algos.py @@ -25,12 +25,19 @@ def masked_rloo(reward_tensor_original, mask_tensor): reward_tensor[~mask_tensor] = 0 for start_pos in range(0, reward_tensor.shape[0], n_samples): cur_rewards_mean = torch.cat( - [reward_tensor[pos : pos + 1][mask_tensor[pos : pos + 1]].mean(dim=0, keepdim=True) for pos in range(start_pos, start_pos + n_samples)], + [ + reward_tensor[pos : pos + 1][mask_tensor[pos : pos + 1]].mean(dim=0, keepdim=True) + for pos in range(start_pos, start_pos + n_samples) + ], dim=0, ) cur_rewards_sum = cur_rewards_mean.sum() cur_reward_baseline = cur_rewards_sum / (n_samples - 1) - reward_tensor[start_pos : start_pos + n_samples][mask_tensor[start_pos : start_pos + n_samples]] = reward_tensor[start_pos : start_pos + n_samples][mask_tensor[start_pos : start_pos + n_samples]] * (n_samples / (n_samples - 1)) - cur_reward_baseline + reward_tensor[start_pos : start_pos + n_samples][mask_tensor[start_pos : start_pos + n_samples]] = ( + reward_tensor[start_pos : start_pos + n_samples][mask_tensor[start_pos : start_pos + n_samples]] + * (n_samples / (n_samples - 1)) + - cur_reward_baseline + ) return reward_tensor @@ -112,7 +119,9 @@ def compute_detach_dpo_loss_rm(token_level_scores, acc, Q_bc, acc_bc, response_m def compute_dpo_accuracy(token_level_scores, acc, response_mask, n_samples): dpo_acc = [] for start_id in range(0, token_level_scores.shape[0], n_samples): - cur_scores = (token_level_scores[start_id : start_id + n_samples] * response_mask[start_id : start_id + n_samples]).sum(dim=1) + cur_scores = ( + token_level_scores[start_id : start_id + n_samples] * response_mask[start_id : start_id + n_samples] + ).sum(dim=1) def get_upper_triangle(tensor_x): diff_matrix = tensor_x.unsqueeze(1) - tensor_x.unsqueeze(0) @@ -125,7 +134,9 @@ def get_upper_triangle(tensor_x): if cur_acc_diff.abs().sum() == 0: cur_acc = torch.zeros_like(cur_score_prediction[0]) + 0.5 else: - cur_acc = (((cur_score_diff > 0) == (cur_acc_diff > 0)).float() * cur_acc_diff.abs()).sum() / cur_acc_diff.abs().sum() + cur_acc = ( + ((cur_score_diff > 0) == (cur_acc_diff > 0)).float() * cur_acc_diff.abs() + ).sum() / cur_acc_diff.abs().sum() dpo_acc.append(cur_acc.unsqueeze(0)) diff --git a/recipe/prime/prime_dp_rm.py b/recipe/prime/prime_dp_rm.py index 9999a3444..d15d772f0 100644 --- a/recipe/prime/prime_dp_rm.py +++ b/recipe/prime/prime_dp_rm.py @@ -25,11 +25,10 @@ import verl.utils.torch_functional as verl_F from verl import DataProto -from verl.trainer.ppo import core_algos -from verl.workers.critic import BasePPOCritic +from verl.utils.device import get_device_name from verl.utils.py_functional import append_to_dict from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches -from verl.utils.ulysses import gather_outpus_and_unpad, ulysses_pad_and_slice_inputs +from verl.utils.ulysses import gather_outputs_and_unpad, ulysses_pad_and_slice_inputs from .prime_core_algos import compute_ce_dpo_loss_rm, compute_detach_dpo_loss_rm @@ -59,19 +58,27 @@ def _forward_micro_batch(self, micro_batch, prompt_length): max_positions = micro_batch["attention_mask"][:, prompt_length:].sum(-1) if self.use_remove_padding: - input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), attention_mask) # input_ids_rmpad (total_nnz, ...) + input_ids_rmpad, indices, *_ = unpad_input( + input_ids.unsqueeze(-1), attention_mask + ) # input_ids_rmpad (total_nnz, ...) input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) # unpad the position_ids to align the rotary - position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices).transpose(0, 1) + position_ids_rmpad = index_first_axis( + rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices + ).transpose(0, 1) # for compute the log_prob input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=1) # (1, total_nnz) # pad and slice the inputs if sp > 1 if self.ulysses_sequence_parallel_size > 1: - input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(input_ids_rmpad, position_ids_rmpad, sp_size=self.ulysses_sequence_parallel_size) - input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs(input_ids_rmpad_rolled, None, self.ulysses_sequence_parallel_size) + input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs( + input_ids_rmpad, position_ids_rmpad, sp_size=self.ulysses_sequence_parallel_size + ) + input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs( + input_ids_rmpad_rolled, None, self.ulysses_sequence_parallel_size + ) input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(0) output = self.reward_module( @@ -79,6 +86,7 @@ def _forward_micro_batch(self, micro_batch, prompt_length): attention_mask=None, position_ids=position_ids_rmpad, use_cache=False, + return_dict=self.use_fused_kernels, ) if self.use_fused_kernels: @@ -93,8 +101,12 @@ def _forward_micro_batch(self, micro_batch, prompt_length): ) if self.ulysses_sequence_parallel_size > 1: - rm_log_labels = gather_outpus_and_unpad(rm_log_labels, gather_dim=0, unpad_dim=0, padding_size=pad_size) - rm_log_labels = pad_input(hidden_states=rm_log_labels.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen).squeeze(-1)[:, -num_actions - 1 : -1] + rm_log_labels = gather_outputs_and_unpad( + rm_log_labels, gather_dim=0, unpad_dim=0, padding_size=pad_size + ) + rm_log_labels = pad_input( + hidden_states=rm_log_labels.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen + ).squeeze(-1)[:, -num_actions - 1 : -1] else: output = self.reward_module( @@ -102,6 +114,7 @@ def _forward_micro_batch(self, micro_batch, prompt_length): attention_mask=micro_batch["attention_mask"], position_ids=micro_batch["position_ids"], use_cache=False, + return_dict=self.use_fused_kernels, ) if self.use_fused_kernels: @@ -110,12 +123,16 @@ def _forward_micro_batch(self, micro_batch, prompt_length): else: rm_output_logits = output.logits - rm_log_prob = torch.nn.functional.log_softmax(rm_output_logits[:, :-1, :], dim=-1) # (batch_size, seq_length, vocab_size) - rm_log_labels = rm_log_prob.gather(dim=-1, index=micro_batch["input_ids"][:, 1:].unsqueeze(-1)).squeeze(-1) # (batch, seq_length) + rm_log_prob = torch.nn.functional.log_softmax( + rm_output_logits[:, :-1, :], dim=-1 + ) # (batch_size, seq_length, vocab_size) + rm_log_labels = rm_log_prob.gather(dim=-1, index=micro_batch["input_ids"][:, 1:].unsqueeze(-1)).squeeze( + -1 + ) # (batch, seq_length) if self.ref_module is not None: # do not have to pad again - with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + with torch.no_grad(), torch.autocast(device_type=get_device_name(), dtype=torch.bfloat16): if self.ulysses_sequence_parallel_size > 1 and self.use_remove_padding: ref_output = self.ref_module( input_ids=input_ids_rmpad, @@ -130,10 +147,16 @@ def _forward_micro_batch(self, micro_batch, prompt_length): else: ref_output_logits = ref_output.logits.squeeze(0) - ref_log_labels = verl_F.logprobs_from_logits(logits=ref_output_logits, labels=input_ids_rmpad_rolled) + ref_log_labels = verl_F.logprobs_from_logits( + logits=ref_output_logits, labels=input_ids_rmpad_rolled + ) - ref_log_labels = gather_outpus_and_unpad(ref_log_labels, gather_dim=0, unpad_dim=0, padding_size=pad_size) - ref_log_labels = pad_input(hidden_states=ref_log_labels.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen).squeeze(-1)[:, -num_actions - 1 : -1] + ref_log_labels = gather_outputs_and_unpad( + ref_log_labels, gather_dim=0, unpad_dim=0, padding_size=pad_size + ) + ref_log_labels = pad_input( + hidden_states=ref_log_labels.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen + ).squeeze(-1)[:, -num_actions - 1 : -1] else: ref_output = self.ref_module( input_ids=micro_batch["input_ids"], @@ -148,8 +171,12 @@ def _forward_micro_batch(self, micro_batch, prompt_length): else: ref_output_logits = ref_output.logits - ref_log_prob = torch.nn.functional.log_softmax(ref_output_logits[:, :-1, :], dim=-1) # (batch_size, seq_length, vocab_size) - ref_log_labels = ref_log_prob.gather(dim=-1, index=micro_batch["input_ids"][:, 1:].unsqueeze(-1)).squeeze(-1) # (batch, seq_length) + ref_log_prob = torch.nn.functional.log_softmax( + ref_output_logits[:, :-1, :], dim=-1 + ) # (batch_size, seq_length, vocab_size) + ref_log_labels = ref_log_prob.gather( + dim=-1, index=micro_batch["input_ids"][:, 1:].unsqueeze(-1) + ).squeeze(-1) # (batch, seq_length) else: ref_log_labels = micro_batch["old_log_probs"] @@ -163,7 +190,8 @@ def _forward_micro_batch(self, micro_batch, prompt_length): # reward computation does not need gradient. only q needs with torch.no_grad(): - # generalized estimation of r should go before the reward filling. r means process reward for policy model, or the advantage of reward model. + # generalized estimation of r should go before the reward filling. r means process reward for policy + # model, or the advantage of reward model. lam = self.config.get("lambda", 0.0) beta = self.config.model.get("beta_train", 0.05) if lam == 0.0: @@ -174,7 +202,8 @@ def _forward_micro_batch(self, micro_batch, prompt_length): q_ = q * beta r = torch.zeros_like(q) lastgaelam = 0 - # change the last token and mask out all paddings to make this process easier if we rely on outcome reward to calculate V + # change the last token and mask out all paddings to make this process easier if we rely on + # outcome reward to calculate V for i in range(q.shape[0]): if self.config.prime_use_gt: q_[i, max_positions[i] - 1] = acc[i] - q_[i, : max_positions[i] - 1].sum() @@ -204,7 +233,9 @@ def _optimizer_step(self): if isinstance(self.reward_module, FSDP): grad_norm = self.reward_module.clip_grad_norm_(self.config.model.optim.grad_clip) else: - grad_norm = torch.nn.utils.clip_grad_norm_(self.reward_module.parameters(), max_norm=self.config.model.optim.grad_clip) + grad_norm = torch.nn.utils.clip_grad_norm_( + self.reward_module.parameters(), max_norm=self.config.model.optim.grad_clip + ) self.reward_optimizer.step() return grad_norm @@ -291,7 +322,7 @@ def update_rm(self, data: DataProto): self.reward_optimizer.zero_grad() for data in micro_batches: - data = data.cuda() + data = data.to(get_device_name()) attention_mask = data["attention_mask"] acc = data["acc"] @@ -308,8 +339,11 @@ def update_rm(self, data: DataProto): if self.config.model.loss_type == "ce": dpo_loss = compute_ce_dpo_loss_rm(q, acc, response_mask=response_mask, beta=beta) elif self.config.model.loss_type == "dpo": - # the implementation of dpo is actually detached, which means we have to know the average value of w/l reward before the update. - dpo_loss = compute_detach_dpo_loss_rm(q, acc, Q_bc=data["Q_bc"], acc_bc=data["acc_bc"], response_mask=response_mask, beta=beta) + # the implementation of dpo is actually detached, which means we have to know the average + # value of w/l reward before the update. + dpo_loss = compute_detach_dpo_loss_rm( + q, acc, Q_bc=data["Q_bc"], acc_bc=data["acc_bc"], response_mask=response_mask, beta=beta + ) elif self.config.model.loss_type == "bon_acc": # change the original distribution of each sample to BoN distribution, then update reward model dpo_loss = compute_detach_dpo_loss_rm( diff --git a/recipe/prime/prime_fsdp_workers.py b/recipe/prime/prime_fsdp_workers.py index 5b9cf4b8f..e35340464 100644 --- a/recipe/prime/prime_fsdp_workers.py +++ b/recipe/prime/prime_fsdp_workers.py @@ -25,7 +25,7 @@ from verl.single_controller.base.decorator import Dispatch, register from verl.utils import hf_tokenizer from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager -from verl.utils.debug import log_gpu_memory_usage +from verl.utils.device import get_device_id, get_device_name, get_nccl_backend from verl.utils.flops_counter import FlopsCounter from verl.utils.fs import copy_local_path_from_hdfs from verl.utils.fsdp_utils import ( @@ -38,6 +38,7 @@ offload_fsdp_optimizer, ) from verl.utils.import_utils import import_external_libs +from verl.utils.profiler import log_gpu_memory_usage from verl.workers.fsdp_workers import create_device_mesh, get_sharding_strategy from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager @@ -53,7 +54,7 @@ def __init__(self, config): import torch.distributed if not torch.distributed.is_initialized(): - torch.distributed.init_process_group(backend="nccl") + torch.distributed.init_process_group(backend=get_nccl_backend()) self.config = config # build device mesh for Ulysses Sequence Parallel @@ -66,7 +67,9 @@ def __init__(self, config): self.ulysses_sequence_parallel_size = self.config.get("ulysses_sequence_parallel_size", 1) dp = world_size // self.ulysses_sequence_parallel_size if self.ulysses_sequence_parallel_size > 1: - self.ulysses_device_mesh = init_device_mesh("cuda", mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"]) + self.ulysses_device_mesh = init_device_mesh( + get_device_name(), mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"] + ) self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) @@ -129,11 +132,17 @@ def _build_reward_ref_model_optimizer(self, config): trust_remote_code=trust_remote_code, ) + fused_kernel_options = config.model.get("fused_kernel_options", None) + fused_kernels_backend = ( + fused_kernel_options.get("impl_backend", None) if fused_kernel_options is not None else None + ) + apply_monkey_patch( model=reward_module, ulysses_sp_size=self.ulysses_sequence_parallel_size, use_remove_padding=config.model.get("use_remove_padding", False), use_fused_kernels=config.model.get("use_fused_kernels", False), + fused_kernels_backend=fused_kernels_backend, ) # some parameters may not in torch_dtype @@ -186,7 +195,7 @@ def _build_reward_ref_model_optimizer(self, config): param_init_fn=init_fn, use_orig_params=False, auto_wrap_policy=auto_wrap_policy, - device_id=torch.cuda.current_device(), + device_id=get_device_id(), sharding_strategy=sharding_strategy, mixed_precision=mixed_precision, sync_module_states=True, @@ -202,7 +211,7 @@ def _build_reward_ref_model_optimizer(self, config): param_init_fn=init_fn, use_orig_params=False, auto_wrap_policy=auto_wrap_policy, - device_id=torch.cuda.current_device(), + device_id=get_device_id(), sharding_strategy=sharding_strategy, mixed_precision=mixed_precision, sync_module_states=True, @@ -228,7 +237,9 @@ def _build_reward_ref_model_optimizer(self, config): from verl.utils.torch_functional import get_constant_schedule_with_warmup - reward_lr_scheduler = get_constant_schedule_with_warmup(optimizer=reward_optimizer, num_warmup_steps=num_warmup_steps) + reward_lr_scheduler = get_constant_schedule_with_warmup( + optimizer=reward_optimizer, num_warmup_steps=num_warmup_steps + ) return reward_module, ref_module, reward_optimizer, reward_lr_scheduler @@ -239,7 +250,9 @@ def init_model(self): from .prime_dp_rm import DataParallelPRIMERewardModel - self.reward_module, self.ref_module, self.reward_optimizer, self.reward_lr_scheduler = self._build_reward_ref_model_optimizer(config=self.config) + self.reward_module, self.ref_module, self.reward_optimizer, self.reward_lr_scheduler = ( + self._build_reward_ref_model_optimizer(config=self.config) + ) if self._is_offload_param: offload_fsdp_model_to_cpu(self.reward_module) @@ -264,7 +277,7 @@ def init_model(self): @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def compute_rm_score(self, data: DataProto): - data = data.to("cuda") + data = data.to(get_device_name()) if self._is_offload_param: load_fsdp_model_to_gpu(self.reward_module) @@ -299,12 +312,12 @@ def compute_rm_score(self, data: DataProto): @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def update_rm(self, data: DataProto): - data = data.to("cuda") + data = data.to(get_device_name()) if self._is_offload_param: load_fsdp_model_to_gpu(self.ref_module) load_fsdp_model_to_gpu(self.reward_module) if self._is_offload_optimizer: - load_fsdp_optimizer(optimizer=self.reward_optimizer, device_id=torch.cuda.current_device()) + load_fsdp_optimizer(optimizer=self.reward_optimizer, device_id=get_device_id()) # perform forward computation with self.ulysses_sharding_manager: @@ -320,7 +333,9 @@ def update_rm(self, data: DataProto): response_mask = data.batch["attention_mask"][:, prompt_length:] acc = data.batch["acc"] - dpo_acc_before = compute_dpo_accuracy(rm_scores, acc, response_mask=response_mask, n_samples=data.meta_info["n"]) + dpo_acc_before = compute_dpo_accuracy( + rm_scores, acc, response_mask=response_mask, n_samples=data.meta_info["n"] + ) dpo_acc_abs = compute_dpo_abs_accuracy(rm_scores, acc, response_mask, n_samples=data.meta_info["n"]) metrics["reward_model/dpo_acc_before"] = dpo_acc_before.detach().item() @@ -344,7 +359,9 @@ def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to if self._is_offload_param: load_fsdp_model_to_gpu(self.reward_module) - self.checkpoint_manager.save_checkpoint(local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep) + self.checkpoint_manager.save_checkpoint( + local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep + ) torch.distributed.barrier() if self._is_offload_param: diff --git a/recipe/prime/prime_ray_trainer.py b/recipe/prime/prime_ray_trainer.py index 4b981c21f..a5ad96431 100644 --- a/recipe/prime/prime_ray_trainer.py +++ b/recipe/prime/prime_ray_trainer.py @@ -30,10 +30,11 @@ from verl.single_controller.ray import RayWorkerGroup from verl.trainer.ppo.core_algos import agg_loss from verl.trainer.ppo.metric_utils import _compute_response_info -from verl.trainer.ppo.ray_trainer import RayPPOTrainer, ResourcePoolManager, Role, WorkerType, _timer +from verl.trainer.ppo.ray_trainer import RayPPOTrainer, ResourcePoolManager, Role, WorkerType from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn from verl.utils.metric import reduce_metrics +from verl.utils.profiler.performance import simple_timer from . import prime_core_algos @@ -44,7 +45,9 @@ def compute_advantage(data: DataProto, adv_estimator, config): response_length = responses.size(-1) attention_mask = data.batch["attention_mask"] response_mask = attention_mask[:, -response_length:] - advantages, returns = prime_core_algos.compute_rloo_advantage_return(data, response_mask, config.actor_rollout_ref.rollout.n, config) + advantages, returns = prime_core_algos.compute_rloo_advantage_return( + data, response_mask, config.actor_rollout_ref.rollout.n, config + ) data.batch["advantages"] = advantages data.batch["returns"] = returns else: @@ -101,7 +104,9 @@ def compute_data_metrics(batch, use_critic=True): "response_length/mean": torch.mean(response_length).detach().item(), "response_length/max": torch.max(response_length).detach().item(), "response_length/min": torch.min(response_length).detach().item(), - "response_length/clip_ratio": torch.mean(torch.eq(response_length, max_response_length).float()).detach().item(), + "response_length/clip_ratio": torch.mean(torch.eq(response_length, max_response_length).float()) + .detach() + .item(), # prompt length "prompt_length/mean": torch.mean(prompt_length).detach().item(), "prompt_length/max": torch.max(prompt_length).detach().item(), @@ -131,7 +136,10 @@ def compute_timing_metrics(batch, timing_raw): return { **{f"timing_s/{name}": value for name, value in timing_raw.items()}, - **{f"timing_per_token_ms/{name}": timing_raw[name] * 1000 / num_tokens_of_section[name] for name in set(num_tokens_of_section.keys()) & set(timing_raw.keys())}, + **{ + f"timing_per_token_ms/{name}": timing_raw[name] * 1000 / num_tokens_of_section[name] + for name in set(num_tokens_of_section.keys()) & set(timing_raw.keys()) + }, } @@ -151,8 +159,9 @@ def __init__( ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup, reward_fn=None, val_reward_fn=None, + device_name="cuda", ): - # assert torch.cuda.is_available(), 'cuda must be available on driver' + # assert get_torch_device().is_available(), 'cuda must be available on driver' super().__init__( config, @@ -160,8 +169,9 @@ def __init__( role_worker_mapping, resource_pool_manager, ray_worker_group_cls, - reward_fn, - val_reward_fn, + reward_fn=reward_fn, + val_reward_fn=val_reward_fn, + device_name=device_name, ) self.use_critic = False @@ -174,7 +184,9 @@ def _create_dataloader(self, *args, **kwargs): from torch.utils.data import DataLoader, RandomSampler, SequentialSampler # TODO: we have to make sure the batch size is divisible by the dp size - self.train_dataset = RLHFDataset(data_files=self.config.data.train_files, tokenizer=self.tokenizer, config=self.config.data) + self.train_dataset = RLHFDataset( + data_files=self.config.data.train_files, tokenizer=self.tokenizer, config=self.config.data + ) # use sampler for better ckpt resume if self.config.data.shuffle: train_dataloader_generator = torch.Generator() @@ -191,7 +203,9 @@ def _create_dataloader(self, *args, **kwargs): sampler=sampler, ) - self.val_dataset = RLHFDataset(data_files=self.config.data.val_files, tokenizer=self.tokenizer, config=self.config.data) + self.val_dataset = RLHFDataset( + data_files=self.config.data.val_files, tokenizer=self.tokenizer, config=self.config.data + ) self.val_dataloader = DataLoader( dataset=self.val_dataset, batch_size=len(self.val_dataset), @@ -222,11 +236,17 @@ def _create_dataloader(self, *args, **kwargs): def _save_checkpoint(self): # path: given_path + `/global_step_{global_steps}` + `/actor` - local_global_step_folder = os.path.join(self.config.trainer.default_local_dir, f"global_step_{self.global_steps}") + local_global_step_folder = os.path.join( + self.config.trainer.default_local_dir, f"global_step_{self.global_steps}" + ) print(f"local_global_step_folder: {local_global_step_folder}") actor_local_path = os.path.join(local_global_step_folder, "actor") - actor_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "actor") + actor_remote_path = ( + None + if self.config.trainer.default_hdfs_dir is None + else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "actor") + ) self.actor_rollout_wg.save_checkpoint( actor_local_path, actor_remote_path, @@ -235,7 +255,11 @@ def _save_checkpoint(self): if self.use_rm: reward_local_path = os.path.join(local_global_step_folder, "reward") - reward_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "reward") + reward_remote_path = ( + None + if self.config.trainer.default_hdfs_dir is None + else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "reward") + ) self.rm_wg.save_checkpoint( reward_local_path, reward_remote_path, @@ -249,7 +273,9 @@ def _save_checkpoint(self): torch.save(self.train_dataloader, dataloader_local_path, pickle_module=dill) # latest checkpointed iteration tracker (for atomic usage) - local_latest_checkpointed_iteration = os.path.join(self.config.trainer.default_local_dir, "latest_checkpointed_iteration.txt") + local_latest_checkpointed_iteration = os.path.join( + self.config.trainer.default_local_dir, "latest_checkpointed_iteration.txt" + ) with open(local_latest_checkpointed_iteration, "w") as f: f.write(str(self.global_steps)) @@ -275,7 +301,9 @@ def _load_checkpoint(self): else: if self.config.trainer.resume_mode == "resume_path": assert isinstance(self.config.trainer.resume_from_path, str), "resume ckpt must be str type" - assert "global_step_" in self.config.trainer.resume_from_path, "resume ckpt must specify the global_steps" + assert "global_step_" in self.config.trainer.resume_from_path, ( + "resume ckpt must specify the global_steps" + ) global_step_folder = self.config.trainer.resume_from_path if not os.path.isabs(global_step_folder): working_dir = os.getcwd() @@ -290,7 +318,9 @@ def _load_checkpoint(self): actor_path = os.path.join(global_step_folder, "actor") reward_path = os.path.join(global_step_folder, "reward") # load actor - self.actor_rollout_wg.load_checkpoint(actor_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load) + self.actor_rollout_wg.load_checkpoint( + actor_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load + ) # load rm if self.use_rm: self.rm_wg.load_checkpoint(reward_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load) @@ -346,14 +376,17 @@ def fit(self): # pop those keys for generation gen_batch = batch.pop(batch_keys=["input_ids", "attention_mask", "position_ids"]) + gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) - with _timer("step", timing_raw): + with simple_timer("step", timing_raw): # generate a batch - with _timer("gen", timing_raw): + with simple_timer("gen", timing_raw): gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) + timing_raw.update(gen_batch_output.meta_info["timing"]) + gen_batch_output.meta_info.pop("timing", None) if self.config.algorithm.adv_estimator == "remax": - with _timer("gen_max", timing_raw): + with simple_timer("gen_max", timing_raw): gen_baseline_batch = deepcopy(gen_batch) gen_baseline_batch.meta_info["do_sample"] = False gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch) @@ -368,7 +401,9 @@ def fit(self): del gen_baseline_batch, gen_baseline_output - batch.non_tensor_batch["uid"] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object) + batch.non_tensor_batch["uid"] = np.array( + [str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object + ) # repeat to align with repeated responses in rollout batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) batch = batch.union(gen_batch_output) @@ -385,7 +420,7 @@ def fit(self): batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() # verify - with _timer("verify", timing_raw): + with simple_timer("verify", timing_raw): scores = self.reward_fn.verify(batch) metrics["acc"] = statistics.mean(scores) @@ -397,24 +432,24 @@ def fit(self): n_samples = self.config.actor_rollout_ref.rollout.n # recompute old_log_probs - with _timer("old_log_prob", timing_raw): + with simple_timer("old_log_prob", timing_raw): old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) entropys = old_log_prob.batch["entropys"] response_masks = compute_response_mask(batch) loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode - entropy_loss = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode) - old_log_prob_metrics = {"actor/entropy_loss": entropy_loss.detach().item()} + entropy_agg = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode) + old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()} metrics.update(old_log_prob_metrics) old_log_prob.batch.pop("entropys") batch = batch.union(old_log_prob) if self.use_reference_policy: # compute reference log_prob - with _timer("ref", timing_raw): + with simple_timer("ref", timing_raw): ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) batch = batch.union(ref_log_prob) - with _timer("adv", timing_raw): + with simple_timer("adv", timing_raw): if self.use_rm: update_style = self.config.reward_model.model.get("update", "none") if update_style == "none": # only run forward @@ -428,13 +463,24 @@ def fit(self): metrics.update(reward_output_metrics) reward_output = self.rm_wg.compute_rm_score(batch) - elif update_style == "reverse": # run forward to calculate statistics, then update reward model + elif ( + update_style == "reverse" + ): # run forward to calculate statistics, then update reward model reward_output = self.rm_wg.compute_rm_score(batch) # broadcast q and acc tensor to each result bc_td = DataProto.from_dict( tensors={ - "Q_bc": reward_output.batch["q"].sum(dim=-1).view(-1, n_samples).unsqueeze(1).expand(-1, n_samples, -1).reshape(-1, n_samples), - "acc_bc": batch.batch["acc"].view(-1, n_samples).unsqueeze(1).expand(-1, n_samples, -1).reshape(-1, n_samples), + "Q_bc": reward_output.batch["q"] + .sum(dim=-1) + .view(-1, n_samples) + .unsqueeze(1) + .expand(-1, n_samples, -1) + .reshape(-1, n_samples), + "acc_bc": batch.batch["acc"] + .view(-1, n_samples) + .unsqueeze(1) + .expand(-1, n_samples, -1) + .reshape(-1, n_samples), } ) batch = batch.union(bc_td) @@ -447,22 +493,28 @@ def fit(self): metrics.update(reward_output_metrics) # compute advantages, executed on the driver process - batch = compute_advantage(batch, adv_estimator=self.config.algorithm.adv_estimator, config=self.config) + batch = compute_advantage( + batch, adv_estimator=self.config.algorithm.adv_estimator, config=self.config + ) # update actor - with _timer("update_actor", timing_raw): + with simple_timer("update_actor", timing_raw): actor_output = self.actor_rollout_wg.update_actor(batch) actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) metrics.update(actor_output_metrics) # validate - if self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and self.global_steps % self.config.trainer.test_freq == 0: - with _timer("testing", timing_raw): + if ( + self.val_reward_fn is not None + and self.config.trainer.test_freq > 0 + and self.global_steps % self.config.trainer.test_freq == 0 + ): + with simple_timer("testing", timing_raw): val_metrics: dict = self._validate() metrics.update(val_metrics) if self.config.trainer.save_freq > 0 and self.global_steps % self.config.trainer.save_freq == 0: - with _timer("save_checkpoint", timing_raw): + with simple_timer("save_checkpoint", timing_raw): self._save_checkpoint() # collect metrics @@ -480,8 +532,11 @@ def fit(self): val_metrics = self._validate() pprint(f"Final validation metrics: {val_metrics}") logger.log(data=val_metrics, step=self.global_steps) - if self.config.trainer.save_freq > 0 and (self.global_steps - 1) % self.config.trainer.save_freq != 0: - with _timer("save_checkpoint", timing_raw): + if ( + self.config.trainer.save_freq > 0 + and (self.global_steps - 1) % self.config.trainer.save_freq != 0 + ): + with simple_timer("save_checkpoint", timing_raw): self._save_checkpoint() return @@ -497,15 +552,24 @@ def filter_and_downsample(self, scores, batch: DataProto): if self.config.data.filter_accuracy: acc_tensor = torch.mean(reward_matrix, dim=-1) - filter_mask[(acc_tensor > self.config.data.accuracy_upper_bound) | (acc_tensor < self.config.data.accuracy_lower_bound)] = False + filter_mask[ + (acc_tensor > self.config.data.accuracy_upper_bound) + | (acc_tensor < self.config.data.accuracy_lower_bound) + ] = False if self.config.data.filter_truncate: - length_matrix = batch.batch["attention_mask"][:, -batch.batch["responses"].shape[-1] :].sum(dim=-1).reshape(-1, n_samples) + length_matrix = ( + batch.batch["attention_mask"][:, -batch.batch["responses"].shape[-1] :] + .sum(dim=-1) + .reshape(-1, n_samples) + ) length_tensor = torch.max(length_matrix, dim=-1)[0] filter_mask[length_tensor >= self.config.data.max_response_length - 1] = False reorder_index = torch.argsort(filter_mask, descending=True) reorder_index = (reorder_index.unsqueeze(-1) * n_samples + torch.arange(0, n_samples).unsqueeze(0)).view(-1) - batch.reorder(reorder_index[: int(len(batch) // self.config.data.oversample_factor)]) # this operation is inplace + batch.reorder( + reorder_index[: int(len(batch) // self.config.data.oversample_factor)] + ) # this operation is inplace return batch diff --git a/recipe/prime/run_prime_qwen.sh b/recipe/prime/run_prime_qwen.sh index 1623a3826..145f31b7b 100644 --- a/recipe/prime/run_prime_qwen.sh +++ b/recipe/prime/run_prime_qwen.sh @@ -1,7 +1,5 @@ set -x -# If you are using vllm<=0.6.3, you might need to set the following environment variable to avoid bugs: -# export VLLM_ATTENTION_BACKEND=XFORMERS gsm8k_train_path=$HOME/data/gsm8k/train.parquet gsm8k_test_path=$HOME/data/gsm8k/test.parquet @@ -56,7 +54,7 @@ python3 -m recipe.prime.main_prime \ reward_model.model.input_tokenizer=null \ reward_model.mini_batch_size=64 \ trainer.val_before_train=False \ - trainer.logger=['console','wandb'] \ + trainer.logger='["console","wandb"]' \ trainer.project_name='prime_example' \ trainer.experiment_name='Eurus-2-7B-SFT-gsm8k' \ trainer.n_gpus_per_node=8 \ diff --git a/recipe/prime/run_prime_qwen_code.sh b/recipe/prime/run_prime_qwen_code.sh index c0932a8ee..e179c0858 100644 --- a/recipe/prime/run_prime_qwen_code.sh +++ b/recipe/prime/run_prime_qwen_code.sh @@ -1,7 +1,5 @@ set -x -# If you are using vllm<=0.6.3, you might need to set the following environment variable to avoid bugs: -# export VLLM_ATTENTION_BACKEND=XFORMERS # download from https://huggingface.co/datasets/PRIME-RL/Eurus-2-RL-Data code_train_path=$HOME/data/code/train.parquet @@ -53,7 +51,7 @@ python3 -m recipe.prime.main_prime \ reward_model.model.input_tokenizer=null \ reward_model.mini_batch_size=64 \ trainer.val_before_train=False \ - trainer.logger=['console','wandb'] \ + trainer.logger='["console","wandb"]' \ trainer.project_name='prime_example' \ trainer.experiment_name='Eurus-2-7B-SFT-code' \ trainer.n_gpus_per_node=8 \ diff --git a/recipe/r1/data_process.py b/recipe/r1/data_process.py index ab85af1aa..fb41c8143 100644 --- a/recipe/r1/data_process.py +++ b/recipe/r1/data_process.py @@ -43,7 +43,9 @@ def process_aime2024(example): data_source = "Maxwell-Jia/AIME_2024" print(f"Loading the {data_source} dataset from huggingface...", flush=True) dataset = load_dataset(data_source, split="train") - map_fn = partial(example_map_fn, process_fn=process_aime2024, data_source=data_source, ability="English", split="test") + map_fn = partial( + example_map_fn, process_fn=process_aime2024, data_source=data_source, ability="English", split="test" + ) dataset = dataset.map(map_fn, with_indices=True, remove_columns=dataset.column_names) return dataset @@ -51,14 +53,20 @@ def process_aime2024(example): def build_gpqa_dimond_dataset(): import random - GPQA_QUERY_TEMPLATE = "Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering.\n\n{Question}\n\nA) {A}\nB) {B}\nC) {C}\nD) {D}" + GPQA_QUERY_TEMPLATE = ( + "Answer the following multiple choice question. The last line of your response should be of the following " + "format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before " + "answering.\n\n{Question}\n\nA) {A}\nB) {B}\nC) {C}\nD) {D}" + ) def process_gpqa_diamond(example): choices = [example["Incorrect Answer 1"], example["Incorrect Answer 2"], example["Incorrect Answer 3"]] random.shuffle(choices) gold_index = random.randint(0, 3) choices.insert(gold_index, example["Correct Answer"]) - query_prompt = GPQA_QUERY_TEMPLATE.format(A=choices[0], B=choices[1], C=choices[2], D=choices[3], Question=example["Question"]) + query_prompt = GPQA_QUERY_TEMPLATE.format( + A=choices[0], B=choices[1], C=choices[2], D=choices[3], Question=example["Question"] + ) gold_choice = "ABCD"[gold_index] return query_prompt, gold_choice @@ -66,7 +74,9 @@ def process_gpqa_diamond(example): print(f"Loading the {data_source} dataset from huggingface...", flush=True) dataset = load_dataset(data_source, "gpqa_diamond", split="train") - map_fn = partial(example_map_fn, process_fn=process_gpqa_diamond, data_source=data_source, ability="Math", split="test") + map_fn = partial( + example_map_fn, process_fn=process_gpqa_diamond, data_source=data_source, ability="Math", split="test" + ) dataset = dataset.map(map_fn, with_indices=True, remove_columns=dataset.column_names) return dataset @@ -79,11 +89,15 @@ def process_cnmo2024(example): print(f"Loading the {data_source} dataset from huggingface...", flush=True) dataset_en = load_dataset(data_source, "v202412_CNMO_en", split="test") - map_fn_en = partial(example_map_fn, process_fn=process_cnmo2024, data_source="opencompass/cnmo2024_en", ability="Math", split="test") + map_fn_en = partial( + example_map_fn, process_fn=process_cnmo2024, data_source="opencompass/cnmo2024_en", ability="Math", split="test" + ) dataset_en = dataset_en.map(map_fn_en, with_indices=True, remove_columns=dataset_en.column_names) dataset_zh = load_dataset(data_source, "v202412_CNMO_cn", split="test") - map_fn_zh = partial(example_map_fn, process_fn=process_cnmo2024, data_source="opencompass/cnmo2024_zh", ability="Math", split="test") + map_fn_zh = partial( + example_map_fn, process_fn=process_cnmo2024, data_source="opencompass/cnmo2024_zh", ability="Math", split="test" + ) dataset_zh = dataset_zh.map(map_fn_zh, with_indices=True, remove_columns=dataset_zh.column_names) dataset = concatenate_datasets([dataset_en, dataset_zh]) @@ -99,12 +113,20 @@ def build_livecodebench_dataset(): def process_livecodebench(example): # Construct Query Prompt # From https://github.com/LiveCodeBench/LiveCodeBench/blob/998c52d394b836f15fff3b9a29866191108ff81b/lcb_runner/prompts/code_generation.py#L140 - query_prompt = f"You will be given a question (problem specification) and will generate a correct Python program that matches the specification and passes all tests.\n\nQuestion: {example['question_content']}\n\n" + query_prompt = ( + f"You will be given a question (problem specification) and will generate a correct Python program " + f"that matches the specification and passes all tests.\n\nQuestion: {example['question_content']}\n\n" + ) if example["starter_code"]: - query_prompt += f"You will use the following starter code to write the solution to the problem and enclose your code within delimiters.\n```python\n{example['starter_code']}\n```" + query_prompt += ( + f"You will use the following starter code to write the solution to the problem and enclose your " + f"code within delimiters.\n```python\n{example['starter_code']}\n```" + ) else: query_prompt += ( - "Read the inputs from stdin solve the problem and write the answer to stdout (do not directly test on the sample inputs). Enclose your code within delimiters as follows. Ensure that when the python program runs, it reads the inputs, runs the algorithm and writes output to STDOUT." + "Read the inputs from stdin solve the problem and write the answer to stdout (do not directly test " + "on the sample inputs). Enclose your code within delimiters as follows. Ensure that when the python " + "program runs, it reads the inputs, runs the algorithm and writes output to STDOUT." "```python\n# YOUR CODE HERE\n```" ) @@ -114,7 +136,9 @@ def process_livecodebench(example): private_test_cases = json.loads(example["private_test_cases"]) except Exception as e: print(f"Error loading private test cases: {e}") - private_test_cases = json.loads(pickle.loads(zlib.decompress(base64.b64decode(example["private_test_cases"].encode("utf-8"))))) + private_test_cases = json.loads( + pickle.loads(zlib.decompress(base64.b64decode(example["private_test_cases"].encode("utf-8")))) + ) full_test_cases = public_test_cases + private_test_cases metadata = json.loads(example["metadata"]) @@ -131,7 +155,9 @@ def process_livecodebench(example): dataset = load_dataset(data_source, split="test") # R1 Evaluation use LiveCodeBench 24.08-25.01 dataset = dataset.filter(lambda line: "2024-08-00T00:00:00" <= line["contest_date"] < "2025-01-00T00:00:00") - map_fn = partial(example_map_fn, process_fn=process_livecodebench, data_source=data_source, ability="Code", split="test") + map_fn = partial( + example_map_fn, process_fn=process_livecodebench, data_source=data_source, ability="Code", split="test" + ) dataset = dataset.map(map_fn, with_indices=True, remove_columns=dataset.column_names, num_proc=8) return dataset diff --git a/recipe/r1/main_eval.py b/recipe/r1/main_eval.py index 2120b5541..b9c03791b 100644 --- a/recipe/r1/main_eval.py +++ b/recipe/r1/main_eval.py @@ -55,7 +55,9 @@ def main(config): data_source_reward = defaultdict(list) # Create remote tasks - remote_tasks = [process_item.remote(config, data_sources[i], responses[i], reward_model_data[i]) for i in range(total)] + remote_tasks = [ + process_item.remote(config, data_sources[i], responses[i], reward_model_data[i]) for i in range(total) + ] # Process results as they come in with tqdm(total=total) as pbar: diff --git a/recipe/retool/retool.py b/recipe/retool/retool.py new file mode 100644 index 000000000..b4d6028ff --- /dev/null +++ b/recipe/retool/retool.py @@ -0,0 +1,120 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import re +from typing import Any + +import datasets + +from verl.tools.base_tool import OpenAIFunctionToolSchema +from verl.tools.sandbox_fusion_tools import SandboxFusionTool +from verl.utils.dataset import RLHFDataset +from verl.utils.reward_score import math_dapo +from verl.utils.rollout_trace import rollout_trace_op + +logger = logging.getLogger(__name__) + + +class CustomSandboxFusionTool(SandboxFusionTool): + def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema): + super().__init__(config, tool_schema) + self.code_pattern = re.compile(r"```python(.*?)```", re.DOTALL) + + @rollout_trace_op + async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]: + code = parameters["code"] + matches = self.code_pattern.findall(code) + if matches: + code = matches[0].strip() + + # NOTE: some script may not explicitly print result, we need to add a print statement to the end of the script + lines = code.split("\n") + for i, line in reversed(list(enumerate(lines))): + if line == "": + continue + if not lines[i].startswith("print"): + lines[i] = f"print({line})" + break + code = "\n".join(lines) + + timeout = parameters.get("timeout", self.default_timeout) + language = parameters.get("language", self.default_language) + if not isinstance(code, str): + code = str(code) + + result = await self.execution_pool.execute.remote(self.execute_code, instance_id, code, timeout, language) + # sandbox has no score or metrics, use Nones + return result, None, None + + +answer_format = """\nThe answer format must be: \\boxed{'The final answer goes here.'}""" + + +class CustomRLHFDataset(RLHFDataset): + """Custom dataset class to process Maxwell-Jia/AIME_2024, yentinglin/aime_2025 datasets.""" + + def _read_files_and_tokenize(self): + dataframes = [] + for parquet_file in self.data_files: + # read parquet files and cache + dataframe = datasets.load_dataset(parquet_file)["train"] + data_source = "/".join(parquet_file.split("/")[-2:]) + if data_source in ["Maxwell-Jia/AIME_2024", "yentinglin/aime_2025"]: + dataframe = dataframe.map( + self.map_fn, fn_kwargs={"data_source": data_source}, remove_columns=dataframe.column_names + ) + else: + dataframe = dataframe.map(self.map_fn2, num_proc=16) + dataframes.append(dataframe) + self.dataframe: datasets.Dataset = datasets.concatenate_datasets(dataframes) + + print(f"dataset len: {len(self.dataframe)}") + + def map_fn(self, row: dict, *, data_source: str = None): + if data_source == "Maxwell-Jia/AIME_2024": + problem, answer = row["Problem"], row["Answer"] + elif data_source == "yentinglin/aime_2025": + problem, answer = row["problem"], row["answer"] + + prompt = problem + answer_format + data = { + "data_source": data_source.split("/")[1].lower(), # aime_2024, aime_2025 + "prompt": [{"role": "user", "content": prompt}], + "ability": "MATH", + "reward_model": {"ground_truth": str(answer)}, + "agent_name": "tool_agent", + } + return data + + def map_fn2(self, row: dict): + content = row["prompt"][0]["content"] + row["prompt"][0]["content"] = content + answer_format + row["agent_name"] = "tool_agent" + return row + + +def compute_score(data_source, solution_str, ground_truth, extra_info): + # use \\boxed{...} answer + result = math_dapo.compute_score(solution_str, ground_truth, strict_box_verify=True) + + # encourage model to call tools + num_turns = extra_info["num_turns"] + if result["score"] < 0: + tool_call_reward = (num_turns - 2) / 2 * 0.1 + result["score"] = min(0, result["score"] + tool_call_reward) + + if result["pred"] is None: + result["pred"] = "" + + return result diff --git a/recipe/retool/retool_multi_turn_sft_preprocess.py b/recipe/retool/retool_multi_turn_sft_preprocess.py new file mode 100644 index 000000000..201ee6892 --- /dev/null +++ b/recipe/retool/retool_multi_turn_sft_preprocess.py @@ -0,0 +1,92 @@ +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Preprocess the Retool dataset to parquet format +""" + +import argparse +import os + +import datasets + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--local_dir", default="~/data/retool_multiturn") + parser.add_argument("--hdfs_dir", default=None) + parser.add_argument("--train_ratio", default=0.9, type=float) + parser.add_argument("--seed", default=42, type=int) + args = parser.parse_args() + + data_source = "swordfaith/ReTool-SFT-multi-turn" + dataset = datasets.load_dataset(data_source, "default") + + train_dataset = dataset["train"] + shuffled_train_dataset = train_dataset.shuffle(seed=args.seed) + split_idx = int(len(shuffled_train_dataset) * args.train_ratio) + train_dataset = shuffled_train_dataset.select(range(split_idx)) + test_dataset = shuffled_train_dataset.select(range(split_idx, len(shuffled_train_dataset))) + + # add a row to each data item that represents a unique id + def make_map_fn(split): + def process_fn(example, idx): + messages = example.pop("messages") + tools = example.pop("tools") + data = { + "data_source": data_source, + "messages": messages, + "tools": tools, + "enable_thinking": False, + "extra_info": { + "split": split, + "index": idx, + }, + } + return data + + return process_fn + + train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True) + test_dataset = test_dataset.map(function=make_map_fn("test"), with_indices=True) + + # Create output directory + local_dir = os.path.expanduser(args.local_dir) + os.makedirs(local_dir, exist_ok=True) + + # Save to parquet files + local_dir = args.local_dir + hdfs_dir = args.hdfs_dir + + train_dataset.to_parquet(os.path.join(local_dir, "train.parquet")) + test_dataset.to_parquet(os.path.join(local_dir, "test.parquet")) + + # Handle HDFS if specified + if hdfs_dir is not None: + try: + from verl.utils.hdfs_io import copy, makedirs + + makedirs(hdfs_dir) + copy(src=local_dir, dst=hdfs_dir) + except ImportError: + print("Warning: HDFS support not available. Skipping HDFS copy.") + + # Print statistics + print(f"Train dataset size: {len(train_dataset)}") + print(f"Test dataset size: {len(test_dataset)}") + print(f"Data saved to {local_dir}") + + +if __name__ == "__main__": + main() diff --git a/recipe/retool/retool_sft_preprocess.py b/recipe/retool/retool_sft_preprocess.py new file mode 100644 index 000000000..0a46c1522 --- /dev/null +++ b/recipe/retool/retool_sft_preprocess.py @@ -0,0 +1,133 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Convert JoeYing/ReTool-SFT to standard multi-turn tool calling messages. +""" + +import json +import re +from typing import Any + +import datasets +from omegaconf import OmegaConf + +code_pattern = re.compile(r"```python(.*?)```", re.DOTALL) + + +def extract_code_message(content: str) -> tuple[dict[str, Any], str]: + start, stop = "", "" + i = content.find(start) + if i == -1: + return None, content + j = content.find(stop) + assert j > i + + code = content[i + len(start) : j] + matches = code_pattern.findall(code) + if matches: + code = matches[0].strip() + + message = { + "role": "assistant", + "content": content[:i].strip(), + "tool_calls": [ + { + "type": "function", + "function": { + "name": "code_interpreter", + "arguments": {"code": code}, + }, + }, + ], + } + return message, content[j + len(stop) :] + + +def extract_answer_message(content: str) -> tuple[dict[str, Any], str]: + start, stop = "", "" + i = content.find(start) + if i == -1: + return None, content + j = content.find(stop) + assert j > i + + answer = content[:i] + content[i + len(start) : j] + message = { + "role": "assistant", + "content": answer.strip(), + } + return message, content[j + len(stop) :] + + +def extract_interpreter_message(content: str) -> tuple[dict[str, Any], str]: + start, stop = "", "" + i = content.find(start) + if i == -1: + return None, content + j = content.find(stop) + assert j > i + + interpreter = content[i + len(start) : j] + message = { + "role": "tool", + "content": interpreter.strip(), + } + return message, content[j + len(stop) :] + + +def process(row: dict, *, tools: str): + messages = [] + + # extract problem + content = row["messages"][0]["content"] + start = "*user question:*" + i = content.find(start) + assert i != -1 + prompt = content[i + len(start) :].replace("", "").replace("", "").strip() + messages.append( + { + "role": "user", + "content": prompt, + } + ) + + # extract multi turns + content = row["messages"][1]["content"] + role = "assistant" + while len(content) > 0: + if role == "assistant": + message, content = extract_code_message(content) + if message is None: + message, content = extract_answer_message(content) + assert message is not None + messages.append(message) + role = "tool" + else: + message, content = extract_interpreter_message(content) + assert message is not None + messages.append(message) + role = "assistant" + + return {"messages": messages, "tools": tools} + + +if __name__ == "__main__": + tools_config_file = "recipe/retool/sandbox_fusion_tool_config.yaml" + tools_config = OmegaConf.load(tools_config_file) + tool_schema = OmegaConf.to_container(tools_config["tools"][0]["tool_schema"]) + tools = json.dumps([tool_schema]) + + data = datasets.load_dataset("JoeYing/ReTool-SFT")["train"] + data = data.map(process, fn_kwargs={"tools": tools}) + data.to_parquet("wuxibin/ReTool-SFT/data/train-00000-of-00001.parquet") diff --git a/recipe/retool/run_qwen2-32b_sft.sh b/recipe/retool/run_qwen2-32b_sft.sh new file mode 100644 index 000000000..137698138 --- /dev/null +++ b/recipe/retool/run_qwen2-32b_sft.sh @@ -0,0 +1,59 @@ +#!/bin/bash +set -x + +# set dist args +nproc_per_node=${ARNOLD_WORKER_GPU} +if [ ! -z "$SINGLE" ] && [ "$SINGLE" != "0" ]; then + echo "[single node alone] SINGLE=$SINGLE" + MASTER_NODE_ID=${ARNOLD_ID} + nnodes=1 + node_rank=0 +else + MASTER_NODE_ID=0 + nnodes=${ARNOLD_WORKER_NUM} + node_rank=${ARNOLD_ID} +fi +master_addr="METIS_WORKER_${MASTER_NODE_ID}_HOST" +master_addr=${!master_addr} +master_port="METIS_WORKER_${MASTER_NODE_ID}_PORT" +master_port=${!master_port} +ports=(`echo $master_port | tr ',' ' '`) +master_port=${ports[0]} +echo "[nproc_per_node: ${nproc_per_node}]" +echo "[nnodes: ${nnodes}]" +echo "[node_rank: ${node_rank}]" +echo "[master_addr: ${master_addr}]" +echo "[master_port: ${master_port}]" + +experiment_name=multiturn-sft-qwen-2.5-32b-instruct +HDFS_ROOT=${HDFS_ROOT:-$PWD} +DATA_ROOT=${DATA_ROOT:-$PWD} + +TRAIN_DATA=$DATA_ROOT/dataset/wuxibin/ReTool-SFT/data/train-00000-of-00001.parquet +EVAL_DATA=$DATA_ROOT/dataset/wuxibin/ReTool-SFT/data/train-00000-of-00001.parquet +MODEL_PATH=$HDFS_ROOT/model/Qwen2.5-32B-Instruct +SAVE_PATH=$DATA_ROOT/checkpoint/$experiment_name + +torchrun --nnodes=$ARNOLD_WORKER_NUM \ + --nproc_per_node=$ARNOLD_WORKER_GPU \ + --master-addr=$master_addr \ + --master-port=$master_port \ + --node-rank=$node_rank \ + -m verl.trainer.fsdp_sft_trainer \ + data.train_files=$TRAIN_DATA \ + data.val_files=$EVAL_DATA \ + data.max_length=16384 \ + data.train_batch_size=32 \ + data.multiturn.enable=true \ + data.multiturn.messages_key=messages \ + data.multiturn.tools_key=tools \ + data.micro_batch_size_per_gpu=4 \ + model.partial_pretrain=$MODEL_PATH \ + model.strategy=fsdp \ + trainer.default_local_dir=$SAVE_PATH \ + trainer.project_name=wuxibin-multiturn-sft \ + trainer.experiment_name=$experiment_name \ + trainer.logger='["console","wandb"]' \ + trainer.total_epochs=6 \ + ulysses_sequence_parallel_size=4 \ + use_remove_padding=true \ No newline at end of file diff --git a/recipe/retool/run_qwen2.5_32b_sp8.sh b/recipe/retool/run_qwen2.5_32b_sp8.sh new file mode 100644 index 000000000..4d6daa1dd --- /dev/null +++ b/recipe/retool/run_qwen2.5_32b_sp8.sh @@ -0,0 +1,33 @@ +#!/bin/bash +set -x + +export PYTHONUNBUFFERED=1 +export RUST_BACKTRACE=1 +export HYDRA_FULL_ERROR=1 + +ulimit -n 65535 + +EXPERIMENT_NAME=retool-multiturn-sft-qwen2.5-32b-sp8 + +torchrun --nnodes=1 --nproc_per_node=8 \ + -m verl.trainer.fsdp_sft_trainer \ + data.max_length=16384 \ + data.train_batch_size=128 \ + data.micro_batch_size_per_gpu=4 \ + data.train_files=$HOME/data/retool_multi_turn_sft_preprocessed/train.parquet \ + data.val_files=$HOME/data/retool_multi_turn_sft_preprocessed/test.parquet \ + data.multiturn.enable=true \ + data.multiturn.messages_key=messages \ + data.multiturn.tools_key=tools \ + model.partial_pretrain=$HOME/models/Qwen/Qwen2.5-32B-Instruct \ + model.trust_remote_code=true \ + model.fsdp_config.cpu_offload=true \ + model.fsdp_config.offload_params=true \ + optim.lr=1e-6 \ + trainer.default_local_dir=$HOME/checkpoints/retool-multiturn-sft/$EXPERIMENT_NAME \ + trainer.project_name=retool-multiturn-sft \ + trainer.experiment_name=$EXPERIMENT_NAME \ + trainer.logger='["console","wandb"]' \ + trainer.total_epochs=12 $@ \ + ulysses_sequence_parallel_size=8 \ + use_remove_padding=true diff --git a/recipe/retool/run_qwen2.5_7b_sp4.sh b/recipe/retool/run_qwen2.5_7b_sp4.sh new file mode 100644 index 000000000..9265dbbac --- /dev/null +++ b/recipe/retool/run_qwen2.5_7b_sp4.sh @@ -0,0 +1,33 @@ +#!/bin/bash +set -x + +export PYTHONUNBUFFERED=1 +export RUST_BACKTRACE=1 +export HYDRA_FULL_ERROR=1 + +ulimit -n 65535 + +EXPERIMENT_NAME=retool-multiturn-sft-qwen2.5-7b-sp4 + +torchrun --nnodes=1 --nproc_per_node=8 \ + -m verl.trainer.fsdp_sft_trainer \ + data.max_length=16384 \ + data.train_batch_size=128 \ + data.micro_batch_size_per_gpu=16 \ + data.train_files=$HOME/data/retool_multi_turn_sft_preprocessed/train.parquet \ + data.val_files=$HOME/data/retool_multi_turn_sft_preprocessed/test.parquet \ + data.multiturn.enable=true \ + data.multiturn.messages_key=messages \ + data.multiturn.tools_key=tools \ + model.partial_pretrain=$HOME/models/Qwen/Qwen2.5-7B-Instruct \ + model.trust_remote_code=true \ + model.fsdp_config.cpu_offload=false \ + model.fsdp_config.offload_params=false \ + optim.lr=1e-6 \ + trainer.default_local_dir=$HOME/checkpoints/retool-multiturn-sft/$EXPERIMENT_NAME \ + trainer.project_name=retool-multiturn-sft \ + trainer.experiment_name=$EXPERIMENT_NAME \ + trainer.logger='["console","wandb"]' \ + trainer.total_epochs=8 $@ \ + ulysses_sequence_parallel_size=4 \ + use_remove_padding=true diff --git a/recipe/retool/run_qwen3_4b_sp4.sh b/recipe/retool/run_qwen3_4b_sp4.sh new file mode 100644 index 000000000..23ec986e3 --- /dev/null +++ b/recipe/retool/run_qwen3_4b_sp4.sh @@ -0,0 +1,31 @@ +#!/bin/bash +set -x + +export PYTHONUNBUFFERED=1 +export RUST_BACKTRACE=1 +export HYDRA_FULL_ERROR=1 + +ulimit -n 65535 + +EXPERIMENT_NAME=retool-multiturn-sft-qwen3-4b-sp4 + +torchrun --nnodes=1 --nproc_per_node=8 \ + -m verl.trainer.fsdp_sft_trainer \ + data.max_length=16384 \ + data.train_batch_size=128 \ + data.micro_batch_size_per_gpu=16 \ + data.train_files=$HOME/data/retool_multi_turn_sft_preprocessed/train.parquet \ + data.val_files=$HOME/data/retool_multi_turn_sft_preprocessed/test.parquet \ + data.multiturn.enable=true \ + data.multiturn.messages_key=messages \ + data.multiturn.tools_key=tools \ + model.partial_pretrain=$HOME/models/Qwen/Qwen3-4B \ + model.trust_remote_code=true \ + optim.lr=1e-6 \ + trainer.default_local_dir=$HOME/checkpoints/retool-multiturn-sft/$EXPERIMENT_NAME \ + trainer.project_name=retool-multiturn-sft \ + trainer.experiment_name=$EXPERIMENT_NAME \ + trainer.logger='["console","wandb"]' \ + trainer.total_epochs=12 $@ \ + ulysses_sequence_parallel_size=4 \ + use_remove_padding=true diff --git a/recipe/retool/sandbox_fusion_tool_config.yaml b/recipe/retool/sandbox_fusion_tool_config.yaml new file mode 100644 index 000000000..203457155 --- /dev/null +++ b/recipe/retool/sandbox_fusion_tool_config.yaml @@ -0,0 +1,24 @@ +tools: + - class_name: "recipe.retool.retool.CustomSandboxFusionTool" + config: + sandbox_fusion_url: "https://***.apigateway-cn-beijing.volceapi.com/run_code" + num_workers: 128 + enable_global_rate_limit: true + rate_limit: 128 + default_timeout: 30 + default_language: "python" + memory_limit_mb: 1024 + type: native + + tool_schema: + type: "function" + function: + name: "code_interpreter" + description: "A tool for executing code." + parameters: + type: "object" + properties: + code: + type: "string" + description: "The code to execute." + required: ["code"] diff --git a/recipe/spin/core_algos.py b/recipe/spin/core_algos.py index 83dae09b5..c48027e54 100644 --- a/recipe/spin/core_algos.py +++ b/recipe/spin/core_algos.py @@ -14,7 +14,6 @@ # limitations under the License. - import numpy as np import torch @@ -48,10 +47,10 @@ def update(self, current_kl, n_steps): def get_kl_controller(kl_ctrl): - if kl_ctrl.type == 'fixed': + if kl_ctrl.type == "fixed": return FixedKLController(kl_coef=kl_ctrl.kl_coef) - elif kl_ctrl.type == 'adaptive': - assert kl_ctrl.horizon > 0, f'horizon must be larger than 0. Got {kl_ctrl.horizon}' + elif kl_ctrl.type == "adaptive": + assert kl_ctrl.horizon > 0, f"horizon must be larger than 0. Got {kl_ctrl.horizon}" return AdaptiveKLController(init_kl_coef=kl_ctrl.kl_coef, target_kl=kl_ctrl.target_kl, horizon=kl_ctrl.horizon) else: raise NotImplementedError @@ -79,10 +78,12 @@ def compute_onlinedpo_pref( """ # print(f"---- [DEBUG] Inside compute_onlinedpo_pref ----") if token_level_rewards.shape[0] % 2 != 0 or response_mask.shape[0] % 2 != 0: - raise ValueError(f"Input tensor batch dimension must be even for pair comparison, " - f"got shapes: {token_level_rewards.shape}, {response_mask.shape}") + raise ValueError( + f"Input tensor batch dimension must be even for pair comparison, got shapes: " + f"{token_level_rewards.shape}, {response_mask.shape}" + ) if token_level_rewards.shape != response_mask.shape: - raise ValueError(f"Shape mismatch between rewards {token_level_rewards.shape} and mask {response_mask.shape}") + raise ValueError(f"Shape mismatch between rewards {token_level_rewards.shape} and mask {response_mask.shape}") # 1. Calculate Sequence Scores scores = (token_level_rewards * response_mask).sum(dim=-1) @@ -94,11 +95,11 @@ def compute_onlinedpo_pref( except RuntimeError as e: print(f"ERROR reshaping scores (shape {scores.shape}) into pairs: {e}") raise e - print(f" Reshaped score pairs shape: {score_pairs.shape}") # [batch_size, 2] + print(f" Reshaped score pairs shape: {score_pairs.shape}") # [batch_size, 2] # 3. Compare scores to find which index (0 or 1) is the winner within each pair # winner_indices[i] = 0 if score_pairs[i, 0] >= score_pairs[i, 1] else 1 - winner_indices = torch.argmax(score_pairs, dim=1) # 0 if first is max, 1 if second is max + winner_indices = torch.argmax(score_pairs, dim=1) # 0 if first is max, 1 if second is max # Handle ties explicitly if argmax behavior isn't guaranteed (usually picks first max) # Alternatively: winner_mask_original = score_pairs[:, 0] >= score_pairs[:, 1] # print(f" Winner indices shape: {winner_indices.shape}") # [batch_size] @@ -108,7 +109,7 @@ def compute_onlinedpo_pref( num_pairs = score_pairs.shape[0] full_batch_size = num_pairs * 2 # Create indices for the full batch [0, 1, 2, 3, ..., N*2-1] - full_indices = torch.arange(full_batch_size, device=scores.device) + # full_indices = torch.arange(full_batch_size, device=scores.device) # Create indices corresponding to the winner within each pair's original index # E.g., if winner_indices is [0, 1, 0], pair_indices is [0, 1, 2] # winner_global_indices = (pair_indices * 2) + winner_indices -> [ (0*2)+0, (1*2)+1, (2*2)+0 ] -> [0, 3, 4] @@ -125,7 +126,6 @@ def compute_onlinedpo_pref( # print(f"---- [DEBUG] Exiting compute_onlinedpo_pref ----") return output_preference_mask - def compute_online_dpo_loss( @@ -138,8 +138,8 @@ def compute_online_dpo_loss( loss_type: str = "sigmoid", reference_free: bool = False, ) -> torch.Tensor: - import torch.nn.functional as F + pi_logratios = policy_chosen_logps - policy_rejected_logps ref_logratios = reference_chosen_logps - reference_rejected_logps @@ -157,12 +157,16 @@ def compute_online_dpo_loss( return losses.mean() -def get_batch_logps(logits: torch.FloatTensor, labels: torch.LongTensor, average_log_prob: bool = False) -> torch.FloatTensor: + +def get_batch_logps( + logits: torch.FloatTensor, labels: torch.LongTensor, average_log_prob: bool = False +) -> torch.FloatTensor: """ Compute the log probabilities of the given labels under the given logits. Args: - logits: Logits of the model (e.g., huggingface CausalLMOutputs `logits`). Shape: (batch_size, sequence_length, vocab_size) + logits: Logits of the model (e.g., huggingface CausalLMOutputs `logits`). + Shape: (batch_size, sequence_length, vocab_size) labels: Labels for computing the sequence log probabilities. Shape: (batch_size, sequence_length) average_log_prob: If True, return the average log probability per sequence. Otherwise, return the sum. @@ -177,24 +181,26 @@ def get_batch_logps(logits: torch.FloatTensor, labels: torch.LongTensor, average # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() - - # Calculate per token log probability - loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100, reduction='none') + + # Calculate per token log probability + loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100, reduction="none") per_token_logps = -loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) - per_token_logps = per_token_logps.view(shift_logits.size(0), shift_logits.size(1)) # Reshape back to (batch_size, seq_len-1) - + per_token_logps = per_token_logps.view( + shift_logits.size(0), shift_logits.size(1) + ) # Reshape back to (batch_size, seq_len-1) + # Create a mask for the labels that are not -100 - loss_mask = (shift_labels != -100) - + loss_mask = shift_labels != -100 + # Apply the mask to the per token log probabilities masked_logps = per_token_logps * loss_mask - + # Calculate the sum or average log probability per sequence sequence_logps = masked_logps.sum(dim=-1) - + if average_log_prob: # Avoid division by zero for sequences with no valid tokens num_valid_tokens = loss_mask.sum(dim=-1) return sequence_logps / torch.clamp(num_valid_tokens, min=1) else: - return sequence_logps \ No newline at end of file + return sequence_logps diff --git a/recipe/spin/dp_actor.py b/recipe/spin/dp_actor.py index a0112d76d..35caa29c7 100644 --- a/recipe/spin/dp_actor.py +++ b/recipe/spin/dp_actor.py @@ -23,13 +23,14 @@ from recipe.spin.core_algos import compute_online_dpo_loss, get_batch_logps from verl import DataProto +from verl.utils.device import get_device_name from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches from verl.workers.actor import DataParallelPPOActor -__all__ = ['DataParallelPPOActor'] +__all__ = ["DataParallelPPOActor"] + class SPINDataParallelPPOActor(DataParallelPPOActor): - def compute_log_prob(self, data: DataProto) -> torch.Tensor: """Compute the log probability of the responses given input_ids, attention_mask and position_ids @@ -51,21 +52,21 @@ def compute_log_prob(self, data: DataProto) -> torch.Tensor: # set to eval self.actor_module.eval() - micro_batch_size = data.meta_info['micro_batch_size'] - temperature = data.meta_info['temperature'] # temperature must be in the data.meta_info to avoid slient error - use_dynamic_bsz = data.meta_info['use_dynamic_bsz'] + micro_batch_size = data.meta_info["micro_batch_size"] + temperature = data.meta_info["temperature"] # temperature must be in the data.meta_info to avoid silent error + use_dynamic_bsz = data.meta_info["use_dynamic_bsz"] - select_keys = ['responses', 'input_ids', 'attention_mask', 'position_ids'] + select_keys = ["responses", "input_ids", "attention_mask", "position_ids"] batch = data.select(batch_keys=select_keys).batch - has_multi_modal_inputs = 'multi_modal_inputs' in data.non_tensor_batch.keys() + has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() if has_multi_modal_inputs: num_micro_batches = data.batch.batch_size[0] // micro_batch_size - non_tensor_select_keys = ['multi_modal_inputs'] + non_tensor_select_keys = ["multi_modal_inputs"] micro_batches = data.select(select_keys, non_tensor_select_keys).chunk(num_micro_batches) elif use_dynamic_bsz: # split using dynamic bsz - max_token_len = data.meta_info['max_token_len'] * self.ulysses_sequence_parallel_size + max_token_len = data.meta_info["max_token_len"] * self.ulysses_sequence_parallel_size micro_batches, indices = rearrange_micro_batches(batch=batch, max_token_len=max_token_len) else: micro_batches = batch.split(micro_batch_size) @@ -93,54 +94,54 @@ def update_policy_dpo_with_ref(self, data: DataProto): Performs the DPO update step using pre-calculated reference log probs from an external, periodically updated reference model. """ - self.actor_module.train() # Ensure training mode + self.actor_module.train() # Ensure training mode # --- Retrieve necessary data --- try: # Expects batch prepared by fit_dpo loop, including reference log probs batch_td = data.batch - chosen_labels = batch_td['chosen_labels'] - rejected_labels = batch_td['rejected_labels'] + chosen_labels = batch_td["chosen_labels"] + rejected_labels = batch_td["rejected_labels"] # ... other needed tensors like chosen/rejected input_ids, attention_mask, position_ids ... # === Get PRE-CALCULATED reference log probs from input data === - reference_chosen_logps = batch_td['reference_chosen_logps'] # Should be sequence-level logps - reference_rejected_logps = batch_td['reference_rejected_logps'] # Should be sequence-level logps + reference_chosen_logps = batch_td["reference_chosen_logps"] # Should be sequence-level logps + reference_rejected_logps = batch_td["reference_rejected_logps"] # Should be sequence-level logps # ============================================================ # Get DPO params from meta_info # beta = data.meta_info.get('dpo_beta', 0.1) # Default beta - beta = self.config.get('dpo_beta', 0.1) # Default beta - loss_type = data.meta_info.get('dpo_loss_type', 'sigmoid') - label_smoothing = data.meta_info.get('dpo_label_smoothing', 0.0) + beta = self.config.get("dpo_beta", 0.1) # Default beta + loss_type = data.meta_info.get("dpo_loss_type", "sigmoid") + label_smoothing = data.meta_info.get("dpo_label_smoothing", 0.0) # reference_free should now be False as we provide ref logps - reference_free = data.meta_info.get('reference_free', False) # Default False + reference_free = data.meta_info.get("reference_free", False) # Default False except KeyError as e: print(f"ERROR: Missing required key for DPO update (in update_policy_dpo): {e}") - print(f"Available keys in data.batch: {list(batch_td.keys())}") # Debug print - return {} # Return empty metrics on error + print(f"Available keys in data.batch: {list(batch_td.keys())}") # Debug print + return {} # Return empty metrics on error except Exception as e_data: print(f"ERROR accessing data for DPO update (in update_policy_dpo): {e_data}") return {} # --- Micro-batching Setup --- - micro_batch_size = self.config.get('ppo_micro_batch_size_per_gpu') + micro_batch_size = self.config.get("ppo_micro_batch_size_per_gpu") if micro_batch_size is None: # Fallback or default if not set, or raise error - micro_batch_size = 1 # Example fallback, adjust as needed + micro_batch_size = 1 # Example fallback, adjust as needed print(f"Warning: 'ppo_micro_batch_size_per_gpu' not set, defaulting to {micro_batch_size}") # raise ValueError("Config 'ppo_micro_batch_size_per_gpu' must be set.") # Ensure chosen_input_ids exists before getting shape - if 'chosen_input_ids' not in batch_td: - print("ERROR: 'chosen_input_ids' not found in batch_td for DPO update.") - return {} - bsz = batch_td['chosen_input_ids'].shape[0] + if "chosen_input_ids" not in batch_td: + print("ERROR: 'chosen_input_ids' not found in batch_td for DPO update.") + return {} + bsz = batch_td["chosen_input_ids"].shape[0] if bsz == 0: print("Warning: DPO batch size is 0 in update_policy_dpo. Skipping update.") - return {'actor/dpo_loss': 0.0, 'actor/grad_norm': 0.0} # Return zero metrics if batch is empty + return {"actor/dpo_loss": 0.0, "actor/grad_norm": 0.0} # Return zero metrics if batch is empty num_micro_batches = math.ceil(bsz / micro_batch_size) gradient_accumulation_steps = num_micro_batches @@ -148,7 +149,7 @@ def update_policy_dpo_with_ref(self, data: DataProto): # --- Metrics Accumulation --- total_loss = 0.0 accumulated_metrics = defaultdict(list) - metrics = {} # Final metrics dict + metrics = {} # Final metrics dict # --- Zero Gradients --- self.actor_optimizer.zero_grad(set_to_none=True) @@ -157,31 +158,31 @@ def update_policy_dpo_with_ref(self, data: DataProto): for i in range(num_micro_batches): start_idx = i * micro_batch_size end_idx = min(start_idx + micro_batch_size, bsz) - if start_idx >= end_idx: continue + if start_idx >= end_idx: + continue # Slice the full DPO batch into micro-batches # Important: Slice ALL required tensors, including labels and inputs micro_batch_chosen_labels = chosen_labels[start_idx:end_idx] micro_batch_rejected_labels = rejected_labels[start_idx:end_idx] micro_batch_chosen_inputs = { - 'input_ids': batch_td['chosen_input_ids'][start_idx:end_idx], - 'attention_mask': batch_td['chosen_attention_mask'][start_idx:end_idx] + "input_ids": batch_td["chosen_input_ids"][start_idx:end_idx], + "attention_mask": batch_td["chosen_attention_mask"][start_idx:end_idx], } - if 'chosen_position_ids' in batch_td: - micro_batch_chosen_inputs['position_ids'] = batch_td['chosen_position_ids'][start_idx:end_idx] + if "chosen_position_ids" in batch_td: + micro_batch_chosen_inputs["position_ids"] = batch_td["chosen_position_ids"][start_idx:end_idx] micro_batch_rejected_inputs = { - 'input_ids': batch_td['rejected_input_ids'][start_idx:end_idx], - 'attention_mask': batch_td['rejected_attention_mask'][start_idx:end_idx] + "input_ids": batch_td["rejected_input_ids"][start_idx:end_idx], + "attention_mask": batch_td["rejected_attention_mask"][start_idx:end_idx], } - if 'rejected_position_ids' in batch_td: - micro_batch_rejected_inputs['position_ids'] = batch_td['rejected_position_ids'][start_idx:end_idx] - + if "rejected_position_ids" in batch_td: + micro_batch_rejected_inputs["position_ids"] = batch_td["rejected_position_ids"][start_idx:end_idx] # Determine autocast dtype - autocast_dtype = torch.bfloat16 # Or get dynamically from config/FSDP settings + autocast_dtype = torch.bfloat16 # Or get dynamically from config/FSDP settings # --- Autocast Forward Pass --- - with torch.autocast(device_type='cuda', dtype=autocast_dtype): + with torch.autocast(device_type=get_device_name(), dtype=autocast_dtype): # --- Step 1: Forward pass for CURRENT policy log probs (with grad) --- policy_chosen_outputs = self.actor_module(**micro_batch_chosen_inputs, use_cache=False) policy_rejected_outputs = self.actor_module(**micro_batch_rejected_inputs, use_cache=False) @@ -202,32 +203,34 @@ def update_policy_dpo_with_ref(self, data: DataProto): # --- Step 4: Calculate DPO Logits and Loss --- pi_logratios = policy_chosen_logps - policy_rejected_logps - ref_logratios = micro_ref_chosen_logps - micro_ref_rejected_logps # Uses pre-calculated values - logits = pi_logratios - ref_logratios # DPO logits + ref_logratios = micro_ref_chosen_logps - micro_ref_rejected_logps # Uses pre-calculated values + logits = pi_logratios - ref_logratios # DPO logits loss = compute_online_dpo_loss( - policy_chosen_logps=policy_chosen_logps, # Has grad - policy_rejected_logps=policy_rejected_logps, # Has grad - reference_chosen_logps=micro_ref_chosen_logps, # No grad (from input) - reference_rejected_logps=micro_ref_rejected_logps, # No grad (from input) + policy_chosen_logps=policy_chosen_logps, # Has grad + policy_rejected_logps=policy_rejected_logps, # Has grad + reference_chosen_logps=micro_ref_chosen_logps, # No grad (from input) + reference_rejected_logps=micro_ref_rejected_logps, # No grad (from input) beta=beta, label_smoothing=label_smoothing, loss_type=loss_type, - reference_free=reference_free # Should be False now + reference_free=reference_free, # Should be False now ) # --- Scale loss for gradient accumulation --- scaled_loss = loss / gradient_accumulation_steps # --- Accumulate Metrics --- - total_loss += loss.item() # Unscaled loss - accumulated_metrics['actor/dpo_loss_batch'].append(loss.item()) - accumulated_metrics['actor/dpo_logits_batch'].append(logits.mean().item()) + total_loss += loss.item() # Unscaled loss + accumulated_metrics["actor/dpo_loss_batch"].append(loss.item()) + accumulated_metrics["actor/dpo_logits_batch"].append(logits.mean().item()) # Accumulate policy and reference log probs/ratios if needed for debugging - accumulated_metrics['actor/policy_chosen_logps_batch'].append(policy_chosen_logps.mean().item()) - accumulated_metrics['actor/policy_rejected_logps_batch'].append(policy_rejected_logps.mean().item()) - accumulated_metrics['actor/reference_chosen_logps_batch'].append(micro_ref_chosen_logps.mean().item()) - accumulated_metrics['actor/reference_rejected_logps_batch'].append(micro_ref_rejected_logps.mean().item()) + accumulated_metrics["actor/policy_chosen_logps_batch"].append(policy_chosen_logps.mean().item()) + accumulated_metrics["actor/policy_rejected_logps_batch"].append(policy_rejected_logps.mean().item()) + accumulated_metrics["actor/reference_chosen_logps_batch"].append(micro_ref_chosen_logps.mean().item()) + accumulated_metrics["actor/reference_rejected_logps_batch"].append( + micro_ref_rejected_logps.mean().item() + ) # --- Backward Pass (outside autocast) --- # Check if loss requires grad before backward @@ -236,41 +239,50 @@ def update_policy_dpo_with_ref(self, data: DataProto): else: print(f"Warning: Scaled loss at micro-batch {i} does not require grad. Skipping backward.") - # --- End Micro-batch Loop --- # --- Optimizer Step (after accumulating gradients for all micro-batches) --- grad_norm = self._optimizer_step() # --- Populate Final Metrics --- - if num_micro_batches > 0 and bsz > 0: # Check if any processing happened - metrics['actor/dpo_loss'] = total_loss / num_micro_batches - metrics['actor/grad_norm'] = grad_norm.item() if torch.is_tensor(grad_norm) and torch.isfinite(grad_norm) else float('inf') + if num_micro_batches > 0 and bsz > 0: # Check if any processing happened + metrics["actor/dpo_loss"] = total_loss / num_micro_batches + metrics["actor/grad_norm"] = ( + grad_norm.item() if torch.is_tensor(grad_norm) and torch.isfinite(grad_norm) else float("inf") + ) # Average other accumulated metrics for key, val_list in accumulated_metrics.items(): - if val_list: metrics[key.replace('_batch','')] = np.mean(val_list) + if val_list: + metrics[key.replace("_batch", "")] = np.mean(val_list) # Calculate accuracy / rewards / margins based on averaged logprobs if desired - if 'actor/policy_chosen_logps' in metrics and 'actor/policy_rejected_logps' in metrics and \ - 'actor/reference_chosen_logps' in metrics and 'actor/reference_rejected_logps' in metrics: - policy_ratio_mean = metrics['actor/policy_chosen_logps'] - metrics['actor/policy_rejected_logps'] - ref_ratio_mean = metrics['actor/reference_chosen_logps'] - metrics['actor/reference_rejected_logps'] + if ( + "actor/policy_chosen_logps" in metrics + and "actor/policy_rejected_logps" in metrics + and "actor/reference_chosen_logps" in metrics + and "actor/reference_rejected_logps" in metrics + ): + policy_ratio_mean = metrics["actor/policy_chosen_logps"] - metrics["actor/policy_rejected_logps"] + ref_ratio_mean = metrics["actor/reference_chosen_logps"] - metrics["actor/reference_rejected_logps"] logits_mean = policy_ratio_mean - ref_ratio_mean - metrics['actor/rewards_chosen'] = beta * (metrics['actor/policy_chosen_logps'] - metrics['actor/reference_chosen_logps']) - metrics['actor/rewards_rejected'] = beta * (metrics['actor/policy_rejected_logps'] - metrics['actor/reference_rejected_logps']) - metrics['actor/rewards_accuracies'] = float(logits_mean > 0) # Mean accuracy proxy - metrics['actor/rewards_margins'] = metrics['actor/rewards_chosen'] - metrics['actor/rewards_rejected'] - - else: # Handle case where no micro-batches were run (e.g., bsz=0) - metrics['actor/dpo_loss'] = 0.0 - metrics['actor/grad_norm'] = 0.0 - # Initialize other metrics to 0 or NaN as appropriate - for key in accumulated_metrics.keys(): - metrics[key.replace('_batch','')] = 0.0 - metrics['actor/rewards_chosen'] = 0.0 - metrics['actor/rewards_rejected'] = 0.0 - metrics['actor/rewards_accuracies'] = 0.0 - metrics['actor/rewards_margins'] = 0.0 - - - return metrics # Return aggregated metrics + metrics["actor/rewards_chosen"] = beta * ( + metrics["actor/policy_chosen_logps"] - metrics["actor/reference_chosen_logps"] + ) + metrics["actor/rewards_rejected"] = beta * ( + metrics["actor/policy_rejected_logps"] - metrics["actor/reference_rejected_logps"] + ) + metrics["actor/rewards_accuracies"] = float(logits_mean > 0) # Mean accuracy proxy + metrics["actor/rewards_margins"] = metrics["actor/rewards_chosen"] - metrics["actor/rewards_rejected"] + + else: # Handle case where no micro-batches were run (e.g., bsz=0) + metrics["actor/dpo_loss"] = 0.0 + metrics["actor/grad_norm"] = 0.0 + # Initialize other metrics to 0 or NaN as appropriate + for key in accumulated_metrics.keys(): + metrics[key.replace("_batch", "")] = 0.0 + metrics["actor/rewards_chosen"] = 0.0 + metrics["actor/rewards_rejected"] = 0.0 + metrics["actor/rewards_accuracies"] = 0.0 + metrics["actor/rewards_margins"] = 0.0 + + return metrics # Return aggregated metrics diff --git a/recipe/spin/fsdp_workers.py b/recipe/spin/fsdp_workers.py index 3a26377a9..bbbfa0ed0 100644 --- a/recipe/spin/fsdp_workers.py +++ b/recipe/spin/fsdp_workers.py @@ -31,31 +31,41 @@ from verl.single_controller.base.decorator import Dispatch, register from verl.utils import hf_tokenizer from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager -from verl.utils.debug import log_gpu_memory_usage +from verl.utils.device import get_device_id, get_device_name, get_nccl_backend, get_torch_device from verl.utils.flops_counter import FlopsCounter from verl.utils.fs import copy_to_local -from verl.utils.fsdp_utils import get_fsdp_wrap_policy, get_init_weight_context_manager, init_fn, load_fsdp_model_to_gpu, load_fsdp_optimizer, offload_fsdp_model_to_cpu, offload_fsdp_optimizer +from verl.utils.fsdp_utils import ( + get_fsdp_wrap_policy, + get_init_weight_context_manager, + init_fn, + load_fsdp_model_to_gpu, + load_fsdp_optimizer, + offload_fsdp_model_to_cpu, + offload_fsdp_optimizer, +) from verl.utils.import_utils import import_external_libs from verl.utils.model import compute_position_id_with_mask +from verl.utils.profiler import log_gpu_memory_usage from verl.workers.fsdp_workers import ActorRolloutRefWorker from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager logger = logging.getLogger(__file__) -logger.setLevel(os.getenv('VERL_PPO_LOGGING_LEVEL', 'WARN')) +logger.setLevel(os.getenv("VERL_PPO_LOGGING_LEVEL", "WARN")) def create_device_mesh(world_size, fsdp_size): if fsdp_size < 0 or fsdp_size >= world_size: - device_mesh = init_device_mesh('cuda', mesh_shape=(world_size,), mesh_dim_names=['fsdp']) + device_mesh = init_device_mesh(get_device_name(), mesh_shape=(world_size,), mesh_dim_names=["fsdp"]) else: - device_mesh = init_device_mesh('cuda', - mesh_shape=(world_size // fsdp_size, fsdp_size), - mesh_dim_names=['ddp', 'fsdp']) + device_mesh = init_device_mesh( + get_device_name(), mesh_shape=(world_size // fsdp_size, fsdp_size), mesh_dim_names=["ddp", "fsdp"] + ) return device_mesh def get_sharding_strategy(device_mesh): from torch.distributed.fsdp import ShardingStrategy + if device_mesh.ndim == 1: sharding_strategy = ShardingStrategy.FULL_SHARD elif device_mesh.ndim == 2: @@ -64,18 +74,21 @@ def get_sharding_strategy(device_mesh): raise NotImplementedError(f"Get device mesh ndim={device_mesh.ndim}, but only support 1 or 2") return sharding_strategy + class SPINRolloutRefWorker(ActorRolloutRefWorker): - @register(dispatch_mode=Dispatch.ONE_TO_ALL) def init_model(self): from recipe.spin.dp_actor import SPINDataParallelPPOActor as DataParallelPPOActor + # This is used to import external_lib into the huggingface systems - import_external_libs(self.config.model.get('external_lib', None)) + import_external_libs(self.config.model.get("external_lib", None)) from omegaconf import OmegaConf - override_model_config = OmegaConf.to_container(self.config.model.get('override_config', OmegaConf.create())) - use_remove_padding = self.config.model.get('use_remove_padding', False) + override_model_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create())) + + use_remove_padding = self.config.model.get("use_remove_padding", False) + use_fused_kernels = self.config.model.get("use_fused_kernels", False) if self._is_actor or self._is_rollout or self._is_ref: # we need the model for actor and rollout @@ -85,57 +98,66 @@ def init_model(self): else: optim_config = None fsdp_config = OmegaConf.create() - self.actor_module_fsdp, self.actor_optimizer, self.actor_lr_scheduler, self.actor_model_config = self._build_model_optimizer( - model_path=self.config.model.path, - fsdp_config=fsdp_config, - optim_config=optim_config, - override_model_config=override_model_config, - use_remove_padding=use_remove_padding, - enable_gradient_checkpointing=self.config.model.get('enable_gradient_checkpointing', False), - trust_remote_code=self.config.model.get('trust_remote_code', False), - use_liger=self.config.model.get('use_liger', False), - role='actor') + self.actor_module_fsdp, self.actor_optimizer, self.actor_lr_scheduler, self.actor_model_config = ( + self._build_model_optimizer( + model_path=self.config.model.path, + fsdp_config=fsdp_config, + optim_config=optim_config, + override_model_config=override_model_config, + use_remove_padding=use_remove_padding, + use_fused_kernels=use_fused_kernels, + enable_gradient_checkpointing=self.config.model.get("enable_gradient_checkpointing", False), + trust_remote_code=self.config.model.get("trust_remote_code", False), + use_liger=self.config.model.get("use_liger", False), + role="actor", + ) + ) # get the original unwrapped module self.actor_module = self.actor_module_fsdp._fsdp_wrapped_module if self._is_offload_optimizer: offload_fsdp_optimizer(optimizer=self.actor_optimizer) - log_gpu_memory_usage('After offload actor optimizer during init', logger=logger) + log_gpu_memory_usage("After offload actor optimizer during init", logger=logger) # load from checkpoint if self._is_actor or self._is_ref: OmegaConf.set_struct(self.config.actor, True) with open_dict(self.config.actor): self.config.actor.use_remove_padding = use_remove_padding - self.actor = DataParallelPPOActor(config=self.config.actor, - actor_module=self.actor_module_fsdp, - actor_optimizer=self.actor_optimizer) + self.config.actor.use_fused_kernels = use_fused_kernels + self.actor = DataParallelPPOActor( + config=self.config.actor, actor_module=self.actor_module_fsdp, actor_optimizer=self.actor_optimizer + ) if self._is_rollout: self.rollout, self.rollout_sharding_manager = self._build_rollout( - trust_remote_code=self.config.model.get('trust_remote_code', False)) + trust_remote_code=self.config.model.get("trust_remote_code", False) + ) if self._is_ref: - self.ref_module_fsdp = self._build_model_optimizer(model_path=self.config.model.path, - fsdp_config=self.config.ref.fsdp_config, - optim_config=None, - override_model_config=override_model_config, - use_remove_padding=use_remove_padding, - trust_remote_code=self.config.model.get( - 'trust_remote_code', False), - use_liger=self.config.model.get('use_liger', False), - role='ref')[0] + self.ref_module_fsdp = self._build_model_optimizer( + model_path=self.config.model.path, + fsdp_config=self.config.ref.fsdp_config, + optim_config=None, + override_model_config=override_model_config, + use_remove_padding=use_remove_padding, + use_fused_kernels=use_fused_kernels, + trust_remote_code=self.config.model.get("trust_remote_code", False), + use_liger=self.config.model.get("use_liger", False), + role="ref", + )[0] OmegaConf.set_struct(self.config.ref, True) with open_dict(self.config.ref): self.config.ref.use_remove_padding = use_remove_padding + self.config.ref.use_fused_kernels = use_fused_kernels self.ref_policy = DataParallelPPOActor(config=self.config.ref, actor_module=self.ref_module_fsdp) self.checkpoint_manager = FSDPCheckpointManager( model=self.actor_module_fsdp, optimizer=self.actor.actor_optimizer, lr_scheduler=self.actor_lr_scheduler, processing_class=self.processor if self.processor is not None else self.tokenizer, - checkpoint_contents=self.config.actor.checkpoint.contents) - + checkpoint_config=self.config.actor.checkpoint, + ) if self._is_actor: self.flops_counter = FlopsCounter(self.actor_model_config) @@ -144,27 +166,28 @@ def init_model(self): optimizer=self.actor.actor_optimizer, lr_scheduler=self.actor_lr_scheduler, processing_class=self.processor if self.processor is not None else self.tokenizer, - checkpoint_contents=self.config.actor.checkpoint.contents) - + checkpoint_config=self.config.actor.checkpoint, + ) + @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def compute_ref_log_prob(self, data: DataProto): assert self._is_ref # Support all hardwares - data = data.to(torch.cuda.current_device()) + data = data.to(get_device_id()) micro_batch_size = self.config.ref.log_prob_micro_batch_size_per_gpu - data.meta_info['micro_batch_size'] = micro_batch_size - data.meta_info['temperature'] = self.config.rollout.temperature - data.meta_info['max_token_len'] = self.config.ref.log_prob_max_token_len_per_gpu - data.meta_info['use_dynamic_bsz'] = self.config.ref.log_prob_use_dynamic_bsz + data.meta_info["micro_batch_size"] = micro_batch_size + data.meta_info["temperature"] = self.config.rollout.temperature + data.meta_info["max_token_len"] = self.config.ref.log_prob_max_token_len_per_gpu + data.meta_info["use_dynamic_bsz"] = self.config.ref.log_prob_use_dynamic_bsz with self.ulysses_sharding_manager: data = self.ulysses_sharding_manager.preprocess_data(data) output = self.ref_policy.compute_log_prob(data=data) - output = DataProto.from_dict(tensors={'ref_log_prob': output}) + output = DataProto.from_dict(tensors={"ref_log_prob": output}) output = self.ulysses_sharding_manager.postprocess_data(output) - output = output.to('cpu') + output = output.to("cpu") # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes # unshard the root FSDP module @@ -180,21 +203,22 @@ def compute_log_prob(self, data: DataProto): load_fsdp_model_to_gpu(self.actor_module_fsdp) # Support all hardwares - data = data.to(torch.cuda.current_device()) + data = data.to(get_device_id()) # we should always recompute old_log_probs when it is HybridEngine - data.meta_info['micro_batch_size'] = self.config.rollout.log_prob_micro_batch_size_per_gpu - data.meta_info['max_token_len'] = self.config.rollout.log_prob_max_token_len_per_gpu - data.meta_info['use_dynamic_bsz'] = self.config.rollout.log_prob_use_dynamic_bsz - data.meta_info['temperature'] = self.config.rollout.temperature + data.meta_info["micro_batch_size"] = self.config.rollout.log_prob_micro_batch_size_per_gpu + data.meta_info["max_token_len"] = self.config.rollout.log_prob_max_token_len_per_gpu + data.meta_info["use_dynamic_bsz"] = self.config.rollout.log_prob_use_dynamic_bsz + data.meta_info["temperature"] = self.config.rollout.temperature # perform recompute log_prob with self.ulysses_sharding_manager: data = self.ulysses_sharding_manager.preprocess_data(data) output = self.actor.compute_log_prob(data=data) - output = DataProto.from_dict(tensors={'old_log_probs': output}, - meta_info={'temperature': self.config.rollout.temperature}) + output = DataProto.from_dict( + tensors={"old_log_probs": output}, meta_info={"temperature": self.config.rollout.temperature} + ) output = self.ulysses_sharding_manager.postprocess_data(output) - output = output.to('cpu') + output = output.to("cpu") # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes # unshard the root FSDP module @@ -204,7 +228,7 @@ def compute_log_prob(self, data: DataProto): if self._is_offload_param: offload_fsdp_model_to_cpu(self.actor_module_fsdp) - log_gpu_memory_usage('After compute_log_prob', logger=logger) + log_gpu_memory_usage("After compute_log_prob", logger=logger) return output @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) @@ -215,48 +239,53 @@ def update_actor_dpo(self, data: DataProto): on pre-calculated log probabilities. """ # Support all hardwares - data = data.to(torch.cuda.current_device()) + data = data.to(get_device_id()) - assert self._is_actor # Make sure this worker has the actor role + assert self._is_actor # Make sure this worker has the actor role if self.actor is None: - raise RuntimeError("Actor instance (self.actor) not initialized in worker.") + raise RuntimeError("Actor instance (self.actor) not initialized in worker.") # --- FSDP State Management --- if self._is_offload_param: load_fsdp_model_to_gpu(self.actor_module_fsdp) if self._is_offload_optimizer: - load_fsdp_optimizer(optimizer=self.actor_optimizer, device_id=torch.cuda.current_device()) + load_fsdp_optimizer(optimizer=self.actor_optimizer, device_id=get_device_id()) - log_gpu_memory_usage('Before update policy (DPO via PPO path)', logger=logger) + log_gpu_memory_usage("Before update policy (DPO via PPO path)", logger=logger) # --- Ulysses Sharding (if used) --- with self.ulysses_sharding_manager: data = self.ulysses_sharding_manager.preprocess_data(data=data) # --- Call the core update method (now containing DPO logic) --- - with Timer(name='update_policy_dpo_via_ppo', logger=None) as timer: # Use a distinct timer name + with Timer(name="update_policy_dpo_via_ppo", logger=None) as timer: # Use a distinct timer name # Calls the modified update_policy method - metrics = self.actor.update_policy_dpo_with_ref(data=data) # <-- THIS CALLS THE MODIFIED FUNCTION + metrics = self.actor.update_policy_dpo_with_ref(data=data) # <-- THIS CALLS THE MODIFIED FUNCTION delta_time = timer.last # --- Add Performance Metrics --- # MFU calculation might be less accurate/meaningful here for DPO - metrics['perf/approx_tokens_processed'] = torch.sum(data.batch.get('attention_mask', torch.tensor(0))).item() # Approx tokens - metrics['perf/max_memory_allocated_gb'] = torch.cuda.max_memory_allocated() / (1024**3) - metrics['perf/max_memory_reserved_gb'] = torch.cuda.max_memory_reserved() / (1024**3) - metrics['perf/cpu_memory_used_gb'] = psutil.virtual_memory().used / (1024**3) + metrics["perf/approx_tokens_processed"] = torch.sum( + data.batch.get("attention_mask", torch.tensor(0)) + ).item() # Approx tokens + metrics["perf/max_memory_allocated_gb"] = get_torch_device().max_memory_allocated() / (1024**3) + metrics["perf/max_memory_reserved_gb"] = get_torch_device().max_memory_reserved() / (1024**3) + metrics["perf/cpu_memory_used_gb"] = psutil.virtual_memory().used / (1024**3) + global_num_tokens = data.meta_info["global_token_num"] + estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time) + metrics["perf/mfu/actor"] = estimated_flops * self.config.ppo_epochs / promised_flops / self.world_size # --- LR Scheduler Step --- - self.actor_lr_scheduler.step() lr = self.actor_lr_scheduler.get_last_lr()[0] - metrics['actor/lr'] = lr + metrics["actor/lr"] = lr + self.actor_lr_scheduler.step() - log_gpu_memory_usage('After update policy (DPO via PPO path)', logger=logger) + log_gpu_memory_usage("After update policy (DPO via PPO path)", logger=logger) # --- Prepare Output --- - output = DataProto(meta_info={'metrics': metrics}) + output = DataProto(meta_info={"metrics": metrics}) output = self.ulysses_sharding_manager.postprocess_data(data=output) - output = output.to('cpu') + output = output.to("cpu") # --- FSDP State Management (Offload) --- if self._is_offload_param: @@ -265,7 +294,6 @@ def update_actor_dpo(self, data: DataProto): offload_fsdp_optimizer(optimizer=self.actor_optimizer) return output - # TODO(sgm): we may need to extract it to dp_reward_model.py @@ -277,8 +305,9 @@ class RewardModelWorker(Worker): def __init__(self, config): super().__init__() import torch.distributed + if not torch.distributed.is_initialized(): - torch.distributed.init_process_group(backend="nccl") + torch.distributed.init_process_group(backend=get_nccl_backend()) self.config = config # build device mesh for Ulysses Sequence Parallel @@ -289,16 +318,16 @@ def __init__(self, config): self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=fsdp_size) self.ulysses_device_mesh = None - self.ulysses_sequence_parallel_size = self.config.get('ulysses_sequence_parallel_size', 1) + self.ulysses_sequence_parallel_size = self.config.get("ulysses_sequence_parallel_size", 1) dp = world_size // self.ulysses_sequence_parallel_size if self.ulysses_sequence_parallel_size > 1: - self.ulysses_device_mesh = init_device_mesh('cuda', - mesh_shape=(dp, self.ulysses_sequence_parallel_size), - mesh_dim_names=['dp', 'sp']) + self.ulysses_device_mesh = init_device_mesh( + get_device_name(), mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"] + ) self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) - self.use_remove_padding = self.config.model.get('use_remove_padding', False) + self.use_remove_padding = self.config.model.get("use_remove_padding", False) # normalize config if self.config.micro_batch_size is not None: @@ -319,29 +348,34 @@ def _build_model(self, config): else: self._do_switch_chat_template = True input_tokenizer_local_path = copy_to_local(config.model.input_tokenizer) - self.input_tokenizer = hf_tokenizer(input_tokenizer_local_path, - trust_remote_code=config.model.get('trust_remote_code', False)) - self.tokenizer = hf_tokenizer(local_path, trust_remote_code=config.model.get('trust_remote_code', False)) + self.input_tokenizer = hf_tokenizer( + input_tokenizer_local_path, trust_remote_code=config.model.get("trust_remote_code", False) + ) + self.tokenizer = hf_tokenizer(local_path, trust_remote_code=config.model.get("trust_remote_code", False)) - trust_remote_code = config.model.get('trust_remote_code', False) + trust_remote_code = config.model.get("trust_remote_code", False) model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code) model_config.num_labels = 1 # note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect - init_context = get_init_weight_context_manager(use_meta_tensor=not model_config.tie_word_embeddings, - mesh=self.device_mesh) + init_context = get_init_weight_context_manager( + use_meta_tensor=not model_config.tie_word_embeddings, mesh=self.device_mesh + ) with init_context(), warnings.catch_warnings(): warnings.simplefilter("ignore") model_config.classifier_dropout = 0.0 - reward_module = AutoModelForTokenClassification.from_pretrained(pretrained_model_name_or_path=local_path, - config=model_config, - torch_dtype=torch.bfloat16, - attn_implementation='flash_attention_2', - trust_remote_code=trust_remote_code) - - if config.model.get('use_remove_padding', False) or self.ulysses_sequence_parallel_size > 1: + reward_module = AutoModelForTokenClassification.from_pretrained( + pretrained_model_name_or_path=local_path, + config=model_config, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2", + trust_remote_code=trust_remote_code, + ) + + if config.model.get("use_remove_padding", False) or self.ulysses_sequence_parallel_size > 1: from verl.models.transformers.monkey_patch import apply_monkey_patch + apply_monkey_patch(model=reward_module, ulysses_sp_size=self.ulysses_sequence_parallel_size) reward_module.to(torch.bfloat16) @@ -356,69 +390,69 @@ def _build_model(self, config): param_init_fn=init_fn, use_orig_params=False, auto_wrap_policy=auto_wrap_policy, - device_id=torch.cuda.current_device(), + device_id=get_device_id(), sharding_strategy=sharding_strategy, # zero3 sync_module_states=True, cpu_offload=CPUOffload(offload_params=True), forward_prefetch=False, - device_mesh=self.device_mesh) + device_mesh=self.device_mesh, + ) return reward_module @register(dispatch_mode=Dispatch.ONE_TO_ALL) def init_model(self): # This is used to import external_lib into the huggingface systems - import_external_libs(self.config.model.get('external_lib', None)) + import_external_libs(self.config.model.get("external_lib", None)) self.reward_module = self._build_model(config=self.config) def _forward_micro_batch(self, micro_batch): from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input - from verl.utils.ulysses import gather_outpus_and_unpad, ulysses_pad_and_slice_inputs + from verl.utils.ulysses import gather_outputs_and_unpad, ulysses_pad_and_slice_inputs - with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.bfloat16): - input_ids = micro_batch['input_ids'] + with torch.no_grad(), torch.autocast(device_type=get_device_name(), dtype=torch.bfloat16): + input_ids = micro_batch["input_ids"] batch_size, seqlen = input_ids.shape - attention_mask = micro_batch['attention_mask'] - position_ids = micro_batch['position_ids'] + attention_mask = micro_batch["attention_mask"] + position_ids = micro_batch["position_ids"] if self.use_remove_padding: - input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), - attention_mask) # input_ids_rmpad (total_nnz, ...) + input_ids_rmpad, indices, *_ = unpad_input( + input_ids.unsqueeze(-1), attention_mask + ) # input_ids_rmpad (total_nnz, ...) input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) # unpad the position_ids to align the rotary - position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), - indices).transpose(0, 1) + position_ids_rmpad = index_first_axis( + rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices + ).transpose(0, 1) # pad and slice the inputs if sp > 1 if self.ulysses_sequence_parallel_size > 1: - input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(input_ids_rmpad, \ - position_ids_rmpad, \ - sp_size=self.ulysses_sequence_parallel_size) + input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs( + input_ids_rmpad, position_ids_rmpad, sp_size=self.ulysses_sequence_parallel_size + ) # only pass input_ids and position_ids to enable flash_attn_varlen - output = self.reward_module(input_ids=input_ids_rmpad, - attention_mask=None, - position_ids=position_ids_rmpad, - use_cache=False) # prevent model thinks we are generating + output = self.reward_module( + input_ids=input_ids_rmpad, attention_mask=None, position_ids=position_ids_rmpad, use_cache=False + ) # prevent model thinks we are generating reward_rmpad = output.logits reward_rmpad = reward_rmpad.squeeze(0) # (total_nnz) # gather output if sp > 1 if self.ulysses_sequence_parallel_size > 1: - reward_rmpad = gather_outpus_and_unpad(reward_rmpad, - gather_dim=0, - unpad_dim=0, - padding_size=pad_size) + reward_rmpad = gather_outputs_and_unpad( + reward_rmpad, gather_dim=0, unpad_dim=0, padding_size=pad_size + ) # pad it back rm_score = pad_input(reward_rmpad, indices=indices, batch=batch_size, seqlen=seqlen).squeeze(-1) else: - output = self.reward_module(input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - use_cache=False) + output = self.reward_module( + input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False + ) rm_score = output.logits # (batch_size, seq_len, 1) rm_score = rm_score.squeeze(-1) @@ -430,9 +464,9 @@ def _forward_micro_batch(self, micro_batch): def _expand_to_token_level(self, data: DataProto, scores: torch.Tensor): batch_size = data.batch.batch_size[0] # expand as token_level_reward - attention_mask = data.batch['attention_mask'] - position_ids = data.batch['position_ids'] - response_length = data.batch['responses'].shape[-1] + attention_mask = data.batch["attention_mask"] + position_ids = data.batch["position_ids"] + response_length = data.batch["responses"].shape[-1] eos_mask_idx = torch.argmax(position_ids * attention_mask, dim=-1) # (bsz,) token_level_scores = torch.zeros_like(attention_mask, dtype=scores.dtype) # (bsz, seqlen) token_level_scores[torch.arange(batch_size), eos_mask_idx] = scores @@ -443,7 +477,7 @@ def _expand_to_token_level(self, data: DataProto, scores: torch.Tensor): return token_level_scores def _switch_chat_template(self, data: DataProto): - src_max_length = data.batch['attention_mask'].shape[-1] + src_max_length = data.batch["attention_mask"].shape[-1] src_tokenizer = self.input_tokenizer target_tokenizer = self.tokenizer @@ -453,44 +487,45 @@ def _switch_chat_template(self, data: DataProto): for i in range(data.batch.batch_size[0]): # extract raw prompt - if isinstance(data.non_tensor_batch['raw_prompt'][i], list): - chat: list = data.non_tensor_batch['raw_prompt'][i] + if isinstance(data.non_tensor_batch["raw_prompt"][i], list): + chat: list = data.non_tensor_batch["raw_prompt"][i] else: - chat: list = data.non_tensor_batch['raw_prompt'][i].tolist() + chat: list = data.non_tensor_batch["raw_prompt"][i].tolist() # extract response - response_ids = data.batch['responses'][i] + response_ids = data.batch["responses"][i] response_length = response_ids.shape[-1] - valid_response_length = data.batch['attention_mask'][i][-response_length:].sum() + valid_response_length = data.batch["attention_mask"][i][-response_length:].sum() valid_response_ids = response_ids[:valid_response_length] # decode response = src_tokenizer.decode(valid_response_ids) # remove bos and eos - response = response.replace(src_tokenizer.eos_token, '') + response = response.replace(src_tokenizer.eos_token, "") - chat.append({'role': 'assistant', 'content': response}) + chat.append({"role": "assistant", "content": response}) - prompt_with_chat_template = target_tokenizer.apply_chat_template(chat, - add_generation_prompt=False, - tokenize=False) + prompt_with_chat_template = target_tokenizer.apply_chat_template( + chat, add_generation_prompt=False, tokenize=False + ) if self.rank == 0 and i == 0: # for debugging purpose - print(f'Switch template. chat: {prompt_with_chat_template}') + print(f"Switch template. chat: {prompt_with_chat_template}") # the maximum length is actually determined by the reward model itself - max_length = self.config.get('max_length', src_max_length) + max_length = self.config.get("max_length", src_max_length) if max_length is None: max_length = src_max_length - model_inputs = target_tokenizer(prompt_with_chat_template, return_tensors='pt', add_special_tokens=False) + model_inputs = target_tokenizer(prompt_with_chat_template, return_tensors="pt", add_special_tokens=False) input_ids, attention_mask = verl_F.postprocess_data( - input_ids=model_inputs['input_ids'], - attention_mask=model_inputs['attention_mask'], + input_ids=model_inputs["input_ids"], + attention_mask=model_inputs["attention_mask"], max_length=max_length, pad_token_id=target_tokenizer.pad_token_id, left_pad=False, # right padding - truncation=self.config.get('truncation', 'right')) # truncate from the right + truncation=self.config.get("truncation", "right"), + ) # truncate from the right rm_input_ids.append(input_ids) rm_attention_mask.append(attention_mask) @@ -500,7 +535,7 @@ def _switch_chat_template(self, data: DataProto): rm_position_ids = compute_position_id_with_mask(rm_attention_mask) - rm_inputs = {'input_ids': rm_input_ids, 'attention_mask': rm_attention_mask, 'position_ids': rm_position_ids} + rm_inputs = {"input_ids": rm_input_ids, "attention_mask": rm_attention_mask, "position_ids": rm_position_ids} return DataProto.from_dict(rm_inputs) @@ -509,23 +544,24 @@ def compute_rm_score(self, data: DataProto): import itertools from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches + # Support all hardwares - data = data.to(torch.cuda.current_device()) + data = data.to(get_device_id()) if self._do_switch_chat_template: rm_data = self._switch_chat_template(data) else: - rm_input_ids = data.batch['input_ids'] - rm_attention_mask = data.batch['attention_mask'] - rm_position_ids = data.batch['position_ids'] + rm_input_ids = data.batch["input_ids"] + rm_attention_mask = data.batch["attention_mask"] + rm_position_ids = data.batch["position_ids"] rm_inputs = { - 'input_ids': rm_input_ids, - 'attention_mask': rm_attention_mask, - 'position_ids': rm_position_ids + "input_ids": rm_input_ids, + "attention_mask": rm_attention_mask, + "position_ids": rm_position_ids, } rm_data = DataProto.from_dict(rm_inputs) # Support all hardwares - rm_data.batch = rm_data.batch.to(torch.cuda.current_device()) + rm_data.batch = rm_data.batch.to(get_device_id()) # perform forward computation with self.ulysses_sharding_manager: @@ -552,12 +588,12 @@ def compute_rm_score(self, data: DataProto): token_level_scores = self._expand_to_token_level(data, scores) # Note that this is only the scores, may not be the final rewards used to train RL - output = DataProto.from_dict(tensors={'rm_scores': token_level_scores}) + output = DataProto.from_dict(tensors={"rm_scores": token_level_scores}) output = self.ulysses_sharding_manager.postprocess_data(data=output) # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes # unshard the root FSDP module self.reward_module._handle.reshard(True) - output = output.to('cpu') + output = output.to("cpu") return output diff --git a/recipe/spin/main_spin.py b/recipe/spin/main_spin.py index 679b78866..9a879ee77 100644 --- a/recipe/spin/main_spin.py +++ b/recipe/spin/main_spin.py @@ -33,7 +33,11 @@ def run_ppo(config) -> None: os.environ["ENSURE_CUDA_VISIBLE_DEVICES"] = os.environ.get("CUDA_VISIBLE_DEVICES", "") if not ray.is_initialized(): # this is for local ray cluster - ray.init(runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN", "VLLM_LOGGING_LEVEL": "WARN"}}) + ray.init( + runtime_env={ + "env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN", "VLLM_LOGGING_LEVEL": "WARN"} + } + ) runner = TaskRunner.remote() ray.get(runner.run.remote(config)) @@ -63,8 +67,8 @@ def run(self, config): processor = hf_processor(local_path, use_fast=True) # used for multimodal LLM, could be none # define worker classes - if config.actor_rollout_ref.actor.strategy == "fsdp": - assert config.actor_rollout_ref.actor.strategy == config.critic.strategy + if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}: + assert config.critic.strategy in {"fsdp", "fsdp2"} # from recipe.spin.fsdp_workers import ActorRolloutRefWorker from recipe.spin.fsdp_workers import SPINRolloutRefWorker from verl.single_controller.ray import RayWorkerGroup @@ -98,7 +102,7 @@ def run(self, config): } if config.reward_model.enable: - if config.reward_model.strategy == "fsdp": + if config.reward_model.strategy in {"fsdp", "fsdp2"}: from recipe.spin.fsdp_workers import RewardModelWorker elif config.reward_model.strategy == "megatron": from verl.workers.megatron_workers import RewardModelWorker @@ -113,35 +117,40 @@ def run(self, config): role_worker_mapping[Role.RefPolicy] = ray.remote(SPINRolloutRefWorker) mapping[Role.RefPolicy] = global_pool_id - reward_manager_name = config.reward_model.get("reward_manager", "naive") - if reward_manager_name == "naive": - from verl.workers.reward_manager import NaiveRewardManager - - reward_manager_cls = NaiveRewardManager - elif reward_manager_name == "prime": - from verl.workers.reward_manager import PrimeRewardManager - - reward_manager_cls = PrimeRewardManager - elif reward_manager_name == "batch": - from verl.workers.reward_manager import BatchRewardManager + from verl.workers.reward_manager import get_reward_manager_cls - reward_manager_cls = BatchRewardManager - elif reward_manager_name == "dapo": - from verl.workers.reward_manager import DAPORewardManager - - reward_manager_cls = DAPORewardManager - else: - raise NotImplementedError + # Note(haibin.lin): please make sure custom reward managers are imported and + # registered via `verl.workers.reward_manager.register` + reward_manager_name = config.reward_model.get("reward_manager", "naive") + reward_manager_cls = get_reward_manager_cls(reward_manager_name) compute_score = get_custom_reward_fn(config) reward_kwargs = dict(config.reward_model.get("reward_kwargs", {})) - reward_fn = reward_manager_cls(tokenizer=tokenizer, num_examine=0, compute_score=compute_score, reward_fn_key=config.data.reward_fn_key, **reward_kwargs) + reward_fn = reward_manager_cls( + tokenizer=tokenizer, + num_examine=0, + compute_score=compute_score, + reward_fn_key=config.data.reward_fn_key, + **reward_kwargs, + ) # Note that we always use function-based RM for validation - val_reward_fn = reward_manager_cls(tokenizer=tokenizer, num_examine=1, compute_score=compute_score, reward_fn_key=config.data.reward_fn_key) + val_reward_fn = reward_manager_cls( + tokenizer=tokenizer, num_examine=1, compute_score=compute_score, reward_fn_key=config.data.reward_fn_key + ) resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) - trainer = RaySPINTrainer(config=config, tokenizer=tokenizer, processor=processor, role_worker_mapping=role_worker_mapping, resource_pool_manager=resource_pool_manager, ray_worker_group_cls=ray_worker_group_cls, reward_fn=reward_fn, val_reward_fn=val_reward_fn) + trainer = RaySPINTrainer( + config=config, + tokenizer=tokenizer, + processor=processor, + role_worker_mapping=role_worker_mapping, + resource_pool_manager=resource_pool_manager, + ray_worker_group_cls=ray_worker_group_cls, + reward_fn=reward_fn, + val_reward_fn=val_reward_fn, + device_name=config.trainer.device, + ) trainer.init_workers() trainer.fit_dpo() diff --git a/recipe/spin/run_spin.sh b/recipe/spin/run_spin.sh index 41fcadf0e..798dedabe 100644 --- a/recipe/spin/run_spin.sh +++ b/recipe/spin/run_spin.sh @@ -18,9 +18,8 @@ CUDA_VISIBLE_DEVICES=${VISIBLE_DEVICES} python3 -m recipe.spin.main_spin \ actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ actor_rollout_ref.ref.log_prob_micro_batch_size=64 \ algorithm.kl_ctrl.kl_coef=0.001 \ - trainer.logger=['console'] \ + trainer.logger=console \ trainer.val_before_train=True \ - trainer.default_hdfs_dir=null \ trainer.n_gpus_per_node=4 \ trainer.nnodes=1 \ trainer.save_freq=-1 \ diff --git a/recipe/spin/spin_trainer.py b/recipe/spin/spin_trainer.py index db1ea58d3..fa435dbdd 100644 --- a/recipe/spin/spin_trainer.py +++ b/recipe/spin/spin_trainer.py @@ -21,10 +21,11 @@ from dataclasses import dataclass, field from enum import Enum from pprint import pprint -from typing import Dict, Optional, Type +from typing import Any, Optional import numpy as np import ray +import torch from codetiming import Timer from omegaconf import OmegaConf, open_dict from torch.utils.data import Dataset, Sampler @@ -37,38 +38,32 @@ from verl.single_controller.base import Worker from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup from verl.single_controller.ray.base import create_colocated_worker_cls -from verl.trainer.ppo.metric_utils import compute_throughout_metrics, compute_timing_metrics, process_validation_metrics, reduce_metrics +from verl.trainer.ppo.metric_utils import ( + compute_throughout_metrics, + compute_timing_metrics, + process_validation_metrics, + reduce_metrics, +) +from verl.trainer.ppo.ray_trainer import Role from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance +from verl.utils.torch_functional import masked_mean from verl.utils.tracking import ValidationGenerationsLogger -WorkerType = Type[Worker] - - -class Role(Enum): - """ - To create more roles dynamically, you can subclass Role and add new members - """ - Actor = 0 - Rollout = 1 - ActorRollout = 2 - Critic = 3 - RefPolicy = 4 - RewardModel = 5 - ActorRolloutRef = 6 +WorkerType = type[Worker] class AdvantageEstimator(str, Enum): """ Using an enumeration class to avoid spelling errors in adv_estimator """ - GAE = 'gae' - GRPO = 'grpo' - REINFORCE_PLUS_PLUS = 'reinforce_plus_plus' - REINFORCE_PLUS_PLUS_BASELINE = 'reinforce_plus_plus_baseline' - REMAX = 'remax' - RLOO = 'rloo' + GAE = "gae" + GRPO = "grpo" + REINFORCE_PLUS_PLUS = "reinforce_plus_plus" + REINFORCE_PLUS_PLUS_BASELINE = "reinforce_plus_plus_baseline" + REMAX = "remax" + RLOO = "rloo" @dataclass @@ -77,6 +72,7 @@ class ResourcePoolManager: Define a resource pool specification. Resource pool will be initialized first. Mapping """ + resource_pool_spec: dict[str, list[int]] mapping: dict[Role, str] resource_pool_dict: dict[str, RayResourcePool] = field(default_factory=dict) @@ -85,11 +81,11 @@ def create_resource_pool(self): for resource_pool_name, process_on_nodes in self.resource_pool_spec.items(): # max_colocate_count means the number of WorkerGroups (i.e. processes) in each RayResourcePool # For FSDP backend, we recommend using max_colocate_count=1 that merge all WorkerGroups into one. - # For Megatron backend, we recommend using max_colocate_count>1 that can utilize different WorkerGroup for differnt models - resource_pool = RayResourcePool(process_on_nodes=process_on_nodes, - use_gpu=True, - max_colocate_count=1, - name_prefix=resource_pool_name) + # For Megatron backend, we recommend using max_colocate_count>1 that can utilize different + # WorkerGroup for different models + resource_pool = RayResourcePool( + process_on_nodes=process_on_nodes, use_gpu=True, max_colocate_count=1, name_prefix=resource_pool_name + ) self.resource_pool_dict[resource_pool_name] = resource_pool self._check_resource_available() @@ -105,15 +101,17 @@ def get_n_gpus(self) -> int: def _check_resource_available(self): """Check if the resource pool can be satisfied in this ray cluster.""" node_available_resources = ray.state.available_resources_per_node() - node_available_gpus = {node: node_info.get('GPU', 0) for node, node_info in node_available_resources.items()} + node_available_gpus = {node: node_info.get("GPU", 0) for node, node_info in node_available_resources.items()} # check total required gpus can be satisfied total_available_gpus = sum(node_available_gpus.values()) total_required_gpus = sum( - [n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes]) + [n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes] + ) if total_available_gpus < total_required_gpus: raise ValueError( - f"Total available GPUs {total_available_gpus} is less than total desired GPUs {total_required_gpus}") + f"Total available GPUs {total_available_gpus} is less than total desired GPUs {total_required_gpus}" + ) # check each resource pool can be satisfied, O(#resource_pools * #nodes) for resource_pool_name, process_on_nodes in self.resource_pool_spec.items(): @@ -126,65 +124,59 @@ def _check_resource_available(self): break if num_nodes > 0: raise ValueError( - f"Resource pool {resource_pool_name}: {num_gpus}*{num_nodes} cannot be satisfied in this ray cluster" + f"Resource pool {resource_pool_name}: {num_gpus}*{num_nodes} cannot be satisfied in this " + f"ray cluster" ) -from typing import Any - -import torch - -from verl.utils.torch_functional import masked_mean - - -def _compute_response_info(batch: DataProto) -> Dict[str, Any]: +def _compute_response_info(batch: DataProto) -> dict[str, Any]: """Placeholder: Computes prompt and response lengths.""" try: # Assuming 'prompts' and 'responses' keys exist after generation/union - prompt_len = batch.batch['prompts'].shape[1] - resp_len = batch.batch['responses'].shape[1] + prompt_len = batch.batch["prompts"].shape[1] + resp_len = batch.batch["responses"].shape[1] # This is simplified - real implementation might use attention masks # to get actual lengths per sample. batch_size = batch.batch.batch_size[0] - prompt_lengths_tensor = torch.full((batch_size,), prompt_len, - dtype=torch.float32, device=batch.batch.device) - response_lengths_tensor = torch.full((batch_size,), resp_len, - dtype=torch.float32, device=batch.batch.device) + prompt_lengths_tensor = torch.full((batch_size,), prompt_len, dtype=torch.float32, device=batch.batch.device) + response_lengths_tensor = torch.full((batch_size,), resp_len, dtype=torch.float32, device=batch.batch.device) # Try getting actual lengths from attention mask if possible (more accurate) - if 'response_mask' in batch.batch: - response_lengths_tensor = batch.batch['response_mask'].sum(dim=1).float() - if 'attention_mask' in batch.batch and 'response_mask' in batch.batch: - full_mask = batch.batch['attention_mask'] - resp_mask = batch.batch['response_mask'] - # Infer prompt mask length based on where response mask starts or total length - # This logic depends heavily on how your masks are constructed. - # Example: prompt_lengths_tensor = full_mask.sum(dim=1).float() - response_lengths_tensor - # Fallback to using prompt shape if mask logic is complex: - prompt_lengths_tensor = torch.tensor([batch.batch['prompts'].shape[1]] * batch_size, - dtype=torch.float32, device=batch.batch.device) - + if "response_mask" in batch.batch: + response_lengths_tensor = batch.batch["response_mask"].sum(dim=1).float() + # if "attention_mask" in batch.batch and "response_mask" in batch.batch: + # full_mask = batch.batch["attention_mask"] + # resp_mask = batch.batch["response_mask"] + # Infer prompt mask length based on where response mask starts or total length + # This logic depends heavily on how your masks are constructed. + # Example: prompt_lengths_tensor = full_mask.sum(dim=1).float() - response_lengths_tensor + # Fallback to using prompt shape if mask logic is complex: + prompt_lengths_tensor = torch.tensor( + [batch.batch["prompts"].shape[1]] * batch_size, dtype=torch.float32, device=batch.batch.device + ) return { - 'prompt_length': prompt_lengths_tensor, - 'response_length': response_lengths_tensor, - 'max_response_length': resp_len, - 'max_prompt_length': prompt_len # Or from config if fixed padding + "prompt_length": prompt_lengths_tensor, + "response_length": response_lengths_tensor, + "max_response_length": resp_len, + "max_prompt_length": prompt_len, # Or from config if fixed padding } except KeyError as e: - print(f"Warning: Missing key in _compute_response_info: {e}. Returning defaults.") - # Return default/dummy values if keys are missing - b_size = batch.batch.batch_size[0] if batch.batch.batch_size else 1 - max_resp = batch.batch.get('responses').shape[1] if batch.batch.get('responses') is not None else 0 - max_prompt = batch.batch.get('prompts').shape[1] if batch.batch.get('prompts') is not None else 0 - return { - 'prompt_length': torch.zeros(b_size), 'response_length': torch.zeros(b_size), - 'max_response_length': max_resp, 'max_prompt_length': max_prompt - } + print(f"Warning: Missing key in _compute_response_info: {e}. Returning defaults.") + # Return default/dummy values if keys are missing + b_size = batch.batch.batch_size[0] if batch.batch.batch_size else 1 + max_resp = batch.batch.get("responses").shape[1] if batch.batch.get("responses") is not None else 0 + max_prompt = batch.batch.get("prompts").shape[1] if batch.batch.get("prompts") is not None else 0 + return { + "prompt_length": torch.zeros(b_size), + "response_length": torch.zeros(b_size), + "max_response_length": max_resp, + "max_prompt_length": max_prompt, + } # --- Modified Metric Function --- -def compute_dpo_data_metrics(batch: DataProto) -> Dict[str, Any]: +def compute_dpo_data_metrics(batch: DataProto) -> dict[str, Any]: """ Computes and returns metrics relevant for the DPO-like process. Assumes 'batch' contains results after generation and preference marking, @@ -195,62 +187,73 @@ def compute_dpo_data_metrics(batch: DataProto) -> Dict[str, Any]: metrics = {} try: # --- Scores and Rewards (from reward_fn) --- - if 'token_level_scores' in batch.batch and batch.batch['token_level_scores'] is not None: - sequence_score = batch.batch['token_level_scores'].sum(-1) - metrics.update({ - 'reward/score/mean': torch.mean(sequence_score).item(), - 'reward/score/max': torch.max(sequence_score).item(), - 'reward/score/min': torch.min(sequence_score).item(), - }) - else: print("DEBUG compute_dpo_data_metrics: 'token_level_scores' not found.") - - if 'token_level_rewards' in batch.batch and batch.batch['token_level_rewards'] is not None: - sequence_reward = batch.batch['token_level_rewards'].sum(-1) - metrics.update({ - 'reward/rewards/mean': torch.mean(sequence_reward).item(), - 'reward/rewards/max': torch.max(sequence_reward).item(), - 'reward/rewards/min': torch.min(sequence_reward).item(), - }) - else: print("DEBUG compute_dpo_data_metrics: 'token_level_rewards' not found.") + if "token_level_scores" in batch.batch and batch.batch["token_level_scores"] is not None: + sequence_score = batch.batch["token_level_scores"].sum(-1) + metrics.update( + { + "reward/score/mean": torch.mean(sequence_score).item(), + "reward/score/max": torch.max(sequence_score).item(), + "reward/score/min": torch.min(sequence_score).item(), + } + ) + else: + print("DEBUG compute_dpo_data_metrics: 'token_level_scores' not found.") + + if "token_level_rewards" in batch.batch and batch.batch["token_level_rewards"] is not None: + sequence_reward = batch.batch["token_level_rewards"].sum(-1) + metrics.update( + { + "reward/rewards/mean": torch.mean(sequence_reward).item(), + "reward/rewards/max": torch.max(sequence_reward).item(), + "reward/rewards/min": torch.min(sequence_reward).item(), + } + ) + else: + print("DEBUG compute_dpo_data_metrics: 'token_level_rewards' not found.") # --- DPO Specific Metrics (if stored previously) --- - if 'dpo_logits' in batch.batch and batch.batch['dpo_logits'] is not None: - metrics['actor/dpo_logits'] = batch.batch['dpo_logits'].mean().item() - else: print("DEBUG compute_dpo_data_metrics: 'dpo_logits' not found.") + if "dpo_logits" in batch.batch and batch.batch["dpo_logits"] is not None: + metrics["actor/dpo_logits"] = batch.batch["dpo_logits"].mean().item() + else: + print("DEBUG compute_dpo_data_metrics: 'dpo_logits' not found.") - if 'chosen_logps' in batch.batch and batch.batch['chosen_logps'] is not None: - metrics['actor/chosen_logps'] = batch.batch['chosen_logps'].mean().item() - else: print("DEBUG compute_dpo_data_metrics: 'chosen_logps' not found.") + if "chosen_logps" in batch.batch and batch.batch["chosen_logps"] is not None: + metrics["actor/chosen_logps"] = batch.batch["chosen_logps"].mean().item() + else: + print("DEBUG compute_dpo_data_metrics: 'chosen_logps' not found.") - if 'rejected_logps' in batch.batch and batch.batch['rejected_logps'] is not None: - metrics['actor/rejected_logps'] = batch.batch['rejected_logps'].mean().item() - else: print("DEBUG compute_dpo_data_metrics: 'rejected_logps' not found.") + if "rejected_logps" in batch.batch and batch.batch["rejected_logps"] is not None: + metrics["actor/rejected_logps"] = batch.batch["rejected_logps"].mean().item() + else: + print("DEBUG compute_dpo_data_metrics: 'rejected_logps' not found.") # Add metrics based on the 'preferences' mask if available - if 'preferences' in batch.batch and batch.batch['preferences'] is not None: - prefs_mask = batch.batch['preferences'] # Shape [batch_size * n] - # Calculate accuracy based on RM scores (assuming higher score -> True in mask) - # Requires chosen/rejected scores to be available or recalculated - # This is complex here, better calculated in the main loop or update function + # if "preferences" in batch.batch and batch.batch["preferences"] is not None: + # prefs_mask = batch.batch["preferences"] # Shape [batch_size * n] + # Calculate accuracy based on RM scores (assuming higher score -> True in mask) + # Requires chosen/rejected scores to be available or recalculated + # This is complex here, better calculated in the main loop or update function # --- Length Metrics --- response_info = _compute_response_info(batch) - prompt_length = response_info['prompt_length'] - response_length = response_info['response_length'] - max_response_length = response_info['max_response_length'] - max_prompt_length = response_info['max_prompt_length'] # Use calculated or from config - - metrics.update({ - 'response_length/mean': torch.mean(response_length).item(), - 'response_length/max': torch.max(response_length).item(), - 'response_length/min': torch.min(response_length).item(), - 'response_length/clip_ratio': torch.mean(torch.eq(response_length, max_response_length).float()).item(), - 'prompt_length/mean': torch.mean(prompt_length).item(), - 'prompt_length/max': torch.max(prompt_length).item(), - 'prompt_length/min': torch.min(prompt_length).item(), - # Prompt clip ratio might need adjustment based on how max_prompt_length is defined - 'prompt_length/clip_ratio': torch.mean(torch.eq(prompt_length, max_prompt_length).float()).item(), - }) + prompt_length = response_info["prompt_length"] + response_length = response_info["response_length"] + max_response_length = response_info["max_response_length"] + max_prompt_length = response_info["max_prompt_length"] # Use calculated or from config + + metrics.update( + { + "response_length/mean": torch.mean(response_length).item(), + "response_length/max": torch.max(response_length).item(), + "response_length/min": torch.min(response_length).item(), + "response_length/clip_ratio": torch.mean(torch.eq(response_length, max_response_length).float()).item(), + "prompt_length/mean": torch.mean(prompt_length).item(), + "prompt_length/max": torch.max(prompt_length).item(), + "prompt_length/min": torch.min(prompt_length).item(), + # Prompt clip ratio might need adjustment based on how max_prompt_length is defined + "prompt_length/clip_ratio": torch.mean(torch.eq(prompt_length, max_prompt_length).float()).item(), + } + ) except KeyError as e: print(f"ERROR in compute_dpo_data_metrics: Missing key {e}") @@ -262,18 +265,19 @@ def compute_dpo_data_metrics(batch: DataProto) -> Dict[str, Any]: return metrics -def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, kl_penalty='kl'): - responses = data.batch['responses'] +def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, kl_penalty="kl"): + responses = data.batch["responses"] response_length = responses.size(1) - token_level_scores = data.batch['token_level_scores'] + token_level_scores = data.batch["token_level_scores"] batch_size = data.batch.batch_size[0] - attention_mask = data.batch['attention_mask'] + attention_mask = data.batch["attention_mask"] response_mask = attention_mask[:, -response_length:] # compute kl between ref_policy and current policy # When apply_kl_penalty, algorithm.use_kl_in_reward=True, so the reference model has been enabled. - kld = core_algos.kl_penalty(data.batch['old_log_probs'], data.batch['ref_log_prob'], - kl_penalty=kl_penalty) # (batch_size, response_length) + kld = core_algos.kl_penalty( + data.batch["old_log_probs"], data.batch["ref_log_prob"], kl_penalty=kl_penalty + ) # (batch_size, response_length) kld = kld * response_mask beta = kl_ctrl.value @@ -284,17 +288,17 @@ def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, # according to https://github.com/huggingface/trl/blob/951ca1841f29114b969b57b26c7d3e80a39f75a0/trl/trainer/ppo_trainer.py#L837 kl_ctrl.update(current_kl=current_kl, n_steps=batch_size) - data.batch['token_level_rewards'] = token_level_rewards + data.batch["token_level_rewards"] = token_level_rewards - metrics = {'actor/reward_kl_penalty': current_kl, 'actor/reward_kl_penalty_coeff': beta} + metrics = {"actor/reward_kl_penalty": current_kl, "actor/reward_kl_penalty_coeff": beta} return data, metrics def compute_response_mask(data: DataProto): - responses = data.batch['responses'] + responses = data.batch["responses"] response_length = responses.size(1) - attention_mask = data.batch['attention_mask'] + attention_mask = data.batch["attention_mask"] return attention_mask[:, -response_length:] @@ -307,8 +311,8 @@ def compute_onlineDPO_pref(data: DataProto): # print(f" Input batch keys: {list(data.batch.keys())}") # Check inputs - rewards_tensor = data.batch.get('token_level_rewards') - mask_tensor = data.batch.get('response_mask') + rewards_tensor = data.batch.get("token_level_rewards") + mask_tensor = data.batch.get("response_mask") if rewards_tensor is None or mask_tensor is None: print(" ERROR: Missing 'token_level_rewards' or 'response_mask' in input data!") @@ -317,30 +321,27 @@ def compute_onlineDPO_pref(data: DataProto): return data try: - preferences = core_algos.compute_onlinedpo_pref( - token_level_rewards=rewards_tensor, - response_mask=mask_tensor - ) + preferences = core_algos.compute_onlinedpo_pref(token_level_rewards=rewards_tensor, response_mask=mask_tensor) # Store the result - data.batch['preferences'] = preferences + data.batch["preferences"] = preferences except AttributeError: - print("ERROR: Function 'compute_online_dpo_preference' not found in core_algos.py!") - # Assign dummy value or raise error - data.batch['preferences'] = None # Indicate failure + print("ERROR: Function 'compute_online_dpo_preference' not found in core_algos.py!") + # Assign dummy value or raise error + data.batch["preferences"] = None # Indicate failure except Exception as e_pref: - print(f"ERROR during core_algos.compute_online_dpo_preference: {e_pref}") - import traceback - traceback.print_exc() - data.batch['preferences'] = None # Indicate failure + print(f"ERROR during core_algos.compute_online_dpo_preference: {e_pref}") + import traceback + + traceback.print_exc() + data.batch["preferences"] = None # Indicate failure # print(f"---- [DEBUG] Exiting compute_onlineDPO_pref ----") return data - @contextmanager -def _timer(name: str, timing_raw: Dict[str, float]): +def _timer(name: str, timing_raw: dict[str, float]): with Timer(name=name, logger=None) as timer: yield timing_raw[name] = timer.last @@ -353,22 +354,23 @@ class RaySPINTrainer: # TODO: support each role have individual ray_worker_group_cls, # i.e., support different backend of different role - def __init__(self, - config, - tokenizer, - role_worker_mapping: dict[Role, WorkerType], - resource_pool_manager: ResourcePoolManager, - ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup, - processor=None, - reward_fn=None, - val_reward_fn=None, - train_dataset: Optional[Dataset] = None, - val_dataset: Optional[Dataset] = None, - collate_fn=None, - train_sampler: Optional[Sampler] = None, - ): - - # assert torch.cuda.is_available(), 'cuda must be available on driver' + def __init__( + self, + config, + tokenizer, + role_worker_mapping: dict[Role, WorkerType], + resource_pool_manager: ResourcePoolManager, + ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup, + processor=None, + reward_fn=None, + val_reward_fn=None, + train_dataset: Optional[Dataset] = None, + val_dataset: Optional[Dataset] = None, + collate_fn=None, + train_sampler: Optional[Sampler] = None, + device_name="cuda", + ): + # assert get_torch_device().is_available(), 'cuda must be available on driver' self.tokenizer = tokenizer self.processor = processor @@ -377,10 +379,10 @@ def __init__(self, self.val_reward_fn = val_reward_fn self.hybrid_engine = config.actor_rollout_ref.hybrid_engine - assert self.hybrid_engine, 'Currently, only support hybrid engine' + assert self.hybrid_engine, "Currently, only support hybrid engine" if self.hybrid_engine: - assert Role.ActorRollout in role_worker_mapping, f'{role_worker_mapping.keys()=}' + assert Role.ActorRollout in role_worker_mapping, f"{role_worker_mapping.keys()=}" self.role_worker_mapping = role_worker_mapping self.resource_pool_manager = resource_pool_manager @@ -389,21 +391,13 @@ def __init__(self, self.ray_worker_group_cls = ray_worker_group_cls self.validation_generations_logger = ValidationGenerationsLogger() self.async_rollout_mode = False + self.device_name = device_name # define in-reward KL control # kl loss control currently not suppoorted if config.algorithm.use_kl_in_reward: self.kl_ctrl_in_reward = core_algos.get_kl_controller(config.algorithm.kl_ctrl) - # if self.config.algorithm.adv_estimator == AdvantageEstimator.GAE: - # self.use_critic = True - # elif self.config.algorithm.adv_estimator in [ - # AdvantageEstimator.GRPO, AdvantageEstimator.REINFORCE_PLUS_PLUS, AdvantageEstimator.REMAX, - # AdvantageEstimator.RLOO, AdvantageEstimator.REINFORCE_PLUS_PLUS_BASELINE - # ]: - # self.use_critic = False - # else: - # raise NotImplementedError self.use_critic = False self._validate_config() self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler) @@ -415,8 +409,9 @@ def _validate_config(self): # 1. Check total batch size for data correctness real_train_batch_size = config.data.train_batch_size * config.actor_rollout_ref.rollout.n - assert real_train_batch_size % n_gpus == 0, \ + assert real_train_batch_size % n_gpus == 0, ( f"real_train_batch_size ({real_train_batch_size}) must be divisible by total n_gpus ({n_gpus})." + ) # A helper function to check "micro_batch_size" vs "micro_batch_size_per_gpu" # We throw an error if the user sets both. The new convention is "..._micro_batch_size_per_gpu". @@ -435,40 +430,50 @@ def check_mutually_exclusive(mbs, mbs_per_gpu, name: str): if mbs is None and mbs_per_gpu is None: raise ValueError( - f"[{name}] Please set at least one of '{name}.{param}' or '{name}.{param_per_gpu}'.") + f"[{name}] Please set at least one of '{name}.{param}' or '{name}.{param_per_gpu}'." + ) if mbs is not None and mbs_per_gpu is not None: raise ValueError( f"[{name}] You have set both '{name}.{param}' AND '{name}.{param_per_gpu}'. " - f"Please remove '{name}.{param}' because only '*_{param_per_gpu}' is supported (the former is deprecated)." + f"Please remove '{name}.{param}' because only '*_{param_per_gpu}' is supported " + f"(the former is deprecated)." ) if not config.actor_rollout_ref.actor.use_dynamic_bsz: # actor: ppo_micro_batch_size vs. ppo_micro_batch_size_per_gpu - check_mutually_exclusive(config.actor_rollout_ref.actor.ppo_micro_batch_size, - config.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu, - "actor_rollout_ref.actor") + check_mutually_exclusive( + config.actor_rollout_ref.actor.ppo_micro_batch_size, + config.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu, + "actor_rollout_ref.actor", + ) if self.use_reference_policy: # reference: log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu - check_mutually_exclusive(config.actor_rollout_ref.ref.log_prob_micro_batch_size, - config.actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu, - "actor_rollout_ref.ref") + check_mutually_exclusive( + config.actor_rollout_ref.ref.log_prob_micro_batch_size, + config.actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu, + "actor_rollout_ref.ref", + ) # The rollout section also has log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu - check_mutually_exclusive(config.actor_rollout_ref.rollout.log_prob_micro_batch_size, - config.actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu, - "actor_rollout_ref.rollout") + check_mutually_exclusive( + config.actor_rollout_ref.rollout.log_prob_micro_batch_size, + config.actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu, + "actor_rollout_ref.rollout", + ) if self.use_critic and not config.critic.use_dynamic_bsz: # Check for critic micro-batch size conflicts - check_mutually_exclusive(config.critic.ppo_micro_batch_size, config.critic.ppo_micro_batch_size_per_gpu, - "critic") + check_mutually_exclusive( + config.critic.ppo_micro_batch_size, config.critic.ppo_micro_batch_size_per_gpu, "critic" + ) # Check for reward model micro-batch size conflicts if config.reward_model.enable and not config.reward_model.use_dynamic_bsz: - check_mutually_exclusive(config.reward_model.micro_batch_size, config.reward_model.micro_batch_size_per_gpu, - "reward_model") + check_mutually_exclusive( + config.reward_model.micro_batch_size, config.reward_model.micro_batch_size_per_gpu, "reward_model" + ) # Actor # check if train_batch_size is larger than ppo_mini_batch_size @@ -477,13 +482,19 @@ def check_mutually_exclusive(mbs, mbs_per_gpu, name: str): # ppo_micro_batch_size * sequence_parallel_size >= n_gpus if not config.actor_rollout_ref.actor.use_dynamic_bsz: assert config.data.train_batch_size >= config.actor_rollout_ref.actor.ppo_mini_batch_size - sp_size = config.actor_rollout_ref.actor.get('ulysses_sequence_parallel_size', 1) + sp_size = config.actor_rollout_ref.actor.get("ulysses_sequence_parallel_size", 1) if config.actor_rollout_ref.actor.ppo_micro_batch_size is not None: - assert config.actor_rollout_ref.actor.ppo_mini_batch_size % config.actor_rollout_ref.actor.ppo_micro_batch_size == 0 + assert ( + config.actor_rollout_ref.actor.ppo_mini_batch_size + % config.actor_rollout_ref.actor.ppo_micro_batch_size + == 0 + ) assert config.actor_rollout_ref.actor.ppo_micro_batch_size * sp_size >= n_gpus assert config.actor_rollout_ref.actor.loss_agg_mode in [ - "token-mean", "seq-mean-token-sum", "seq-mean-token-mean" + "token-mean", + "seq-mean-token-sum", + "seq-mean-token-mean", ], f"Invalid loss_agg_mode: {config.actor_rollout_ref.actor.loss_agg_mode}" if config.algorithm.use_kl_in_reward and config.actor_rollout_ref.actor.use_kl_loss: @@ -492,32 +503,38 @@ def check_mutually_exclusive(mbs, mbs_per_gpu, name: str): # critic if self.use_critic and not config.critic.use_dynamic_bsz: assert config.data.train_batch_size >= config.critic.ppo_mini_batch_size - sp_size = config.critic.get('ulysses_sequence_parallel_size', 1) + sp_size = config.critic.get("ulysses_sequence_parallel_size", 1) if config.critic.ppo_micro_batch_size is not None: assert config.critic.ppo_mini_batch_size % config.critic.ppo_micro_batch_size == 0 assert config.critic.ppo_micro_batch_size * sp_size >= n_gpus # Check if use_remove_padding is enabled when using sequence parallelism for fsdp - if config.actor_rollout_ref.actor.strategy == 'fsdp': - if config.actor_rollout_ref.actor.get('ulysses_sequence_parallel_size', 1) > 1 or \ - config.actor_rollout_ref.ref.get('ulysses_sequence_parallel_size', 1) > 1: - assert config.actor_rollout_ref.model.use_remove_padding, \ + if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}: + if ( + config.actor_rollout_ref.actor.get("ulysses_sequence_parallel_size", 1) > 1 + or config.actor_rollout_ref.ref.get("ulysses_sequence_parallel_size", 1) > 1 + ): + assert config.actor_rollout_ref.model.use_remove_padding, ( "When using sequence parallelism for actor/ref policy, you must enable `use_remove_padding`." + ) - if self.use_critic and config.critic.strategy == 'fsdp': - if config.critic.get('ulysses_sequence_parallel_size', 1) > 1: - assert config.critic.model.use_remove_padding, \ + if self.use_critic and config.critic.strategy in {"fsdp", "fsdp2"}: + if config.critic.get("ulysses_sequence_parallel_size", 1) > 1: + assert config.critic.model.use_remove_padding, ( "When using sequence parallelism for critic, you must enable `use_remove_padding`." + ) - if config.data.get('val_batch_size', None) is not None: + if config.data.get("val_batch_size", None) is not None: print( - "WARNING: val_batch_size is deprecated. Validation datasets are sent to inference engines as a whole batch, which will schedule the memory themselves." + "WARNING: val_batch_size is deprecated. Validation datasets are sent to inference engines " + "as a whole batch, which will schedule the memory themselves." ) # check eval config if config.actor_rollout_ref.rollout.val_kwargs.do_sample: - assert config.actor_rollout_ref.rollout.temperature > 0, \ + assert config.actor_rollout_ref.rollout.temperature > 0, ( "validation gen temperature should be greater than 0 when enabling do_sample" + ) print("[validate_config] All configuration checks passed successfully!") @@ -529,9 +546,13 @@ def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampl from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler if train_dataset is None: - train_dataset = create_rl_dataset(self.config.data.train_files, self.config.data, self.tokenizer, self.processor) + train_dataset = create_rl_dataset( + self.config.data.train_files, self.config.data, self.tokenizer, self.processor + ) if val_dataset is None: - val_dataset = create_rl_dataset(self.config.data.val_files, self.config.data, self.tokenizer, self.processor) + val_dataset = create_rl_dataset( + self.config.data.val_files, self.config.data, self.tokenizer, self.processor + ) self.train_dataset, self.val_dataset = train_dataset, val_dataset if train_sampler is None: @@ -566,7 +587,10 @@ def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampl assert len(self.train_dataloader) >= 1, "Train dataloader is empty!" assert len(self.val_dataloader) >= 1, "Validation dataloader is empty!" - print(f"Size of train dataloader: {len(self.train_dataloader)}, Size of val dataloader: {len(self.val_dataloader)}") + print( + f"Size of train dataloader: {len(self.train_dataloader)}, " + f"Size of val dataloader: {len(self.val_dataloader)}" + ) total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs @@ -597,7 +621,7 @@ def _maybe_log_val_generations(self, inputs, outputs, scores): import numpy as np # Create tuples of (input, output, score) and sort by input text - samples = list(zip(inputs, outputs, scores)) + samples = list(zip(inputs, outputs, scores, strict=True)) samples.sort(key=lambda x: x[0]) # Sort by input text # Use fixed random seed for deterministic shuffling @@ -623,7 +647,9 @@ def _validate(self): test_batch = DataProto.from_single_dict(test_data) # repeat test batch - test_batch = test_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.val_kwargs.n, interleave=True) + test_batch = test_batch.repeat( + repeat_times=self.config.actor_rollout_ref.rollout.val_kwargs.n, interleave=True + ) # we only do validation on rule-based rm if self.config.reward_model.enable and test_batch[0].non_tensor_batch["reward_model"]["style"] == "model": @@ -662,9 +688,7 @@ def _validate(self): if not self.async_rollout_mode: test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(test_gen_batch_padded) else: - self.async_rollout_manager.wake_up() test_output_gen_batch_padded = self.async_rollout_manager.generate_sequences(test_gen_batch_padded) - self.async_rollout_manager.sleep() # unpad test_output_gen_batch = unpad_dataproto(test_output_gen_batch_padded, pad_size=pad_size) @@ -707,18 +731,24 @@ def _validate(self): assert len(lst) == 0 or len(lst) == len(sample_scores), f"{key_info}: {len(lst)=}, {len(sample_scores)=}" data_sources = np.concatenate(data_source_lst, axis=0) - print(f"DEBUG: Data sources shape: {data_sources.shape}") # Added Print - print(f"DEBUG: reward_extra_infos_dict keys before processing: {reward_extra_infos_dict.keys()}") # Added Print + print(f"DEBUG: Data sources shape: {data_sources.shape}") # Added Print + print(f"DEBUG: reward_extra_infos_dict keys before processing: {reward_extra_infos_dict.keys()}") # Added Print data_src2var2metric2val = process_validation_metrics(data_sources, sample_inputs, reward_extra_infos_dict) - print(f"DEBUG: Output of process_validation_metrics (data_src2var2metric2val): {data_src2var2metric2val}") # Added Print + print( + f"DEBUG: Output of process_validation_metrics (data_src2var2metric2val): {data_src2var2metric2val}" + ) # Added Print metric_dict = {} for data_source, var2metric2val in data_src2var2metric2val.items(): core_var = "acc" if "acc" in var2metric2val else "reward" for var_name, metric2val in var2metric2val.items(): n_max = max([int(name.split("@")[-1].split("/")[0]) for name in metric2val.keys()]) for metric_name, metric_val in metric2val.items(): - if (var_name == core_var) and any(metric_name.startswith(pfx) for pfx in ["mean", "maj", "best"]) and (f"@{n_max}" in metric_name): + if ( + (var_name == core_var) + and any(metric_name.startswith(pfx) for pfx in ["mean", "maj", "best"]) + and (f"@{n_max}" in metric_name) + ): metric_sec = "val-core" else: metric_sec = "val-aux" @@ -726,7 +756,7 @@ def _validate(self): metric_dict[pfx] = metric_val return metric_dict - + def init_workers(self): """Init resource pool and worker group""" self.resource_pool_manager.create_resource_pool() @@ -736,10 +766,12 @@ def init_workers(self): # create actor and rollout if self.hybrid_engine: resource_pool = self.resource_pool_manager.get_resource_pool(Role.ActorRollout) - actor_rollout_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.ActorRollout], - config=self.config.actor_rollout_ref, - role='actor_rollout') - self.resource_pool_to_cls[resource_pool]['actor_rollout'] = actor_rollout_cls + actor_rollout_cls = RayClassWithInitArgs( + cls=self.role_worker_mapping[Role.ActorRollout], + config=self.config.actor_rollout_ref, + role="actor_rollout", + ) + self.resource_pool_to_cls[resource_pool]["actor_rollout"] = actor_rollout_cls else: raise NotImplementedError @@ -747,26 +779,28 @@ def init_workers(self): if self.use_critic: resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic) critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=self.config.critic) - self.resource_pool_to_cls[resource_pool]['critic'] = critic_cls + self.resource_pool_to_cls[resource_pool]["critic"] = critic_cls # create reference policy if needed if self.use_reference_policy: resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy) - ref_policy_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RefPolicy], - config=self.config.actor_rollout_ref, - role='ref') - self.resource_pool_to_cls[resource_pool]['ref'] = ref_policy_cls + ref_policy_cls = RayClassWithInitArgs( + self.role_worker_mapping[Role.RefPolicy], config=self.config.actor_rollout_ref, role="ref" + ) + self.resource_pool_to_cls[resource_pool]["ref"] = ref_policy_cls # create a reward model if reward_fn is None if self.use_rm: # we create a RM here resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel) rm_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model) - self.resource_pool_to_cls[resource_pool]['rm'] = rm_cls + self.resource_pool_to_cls[resource_pool]["rm"] = rm_cls # initialize WorkerGroup - # NOTE: if you want to use a different resource pool for each role, which can support different parallel size, - # you should not use `create_colocated_worker_cls`. Instead, directly pass different resource pool to different worker groups. + # NOTE: if you want to use a different resource pool for each role, which can support different + # parallel size, + # you should not use `create_colocated_worker_cls`. Instead, directly pass different resource pool to + # different worker groups. # See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information. all_wg = {} self.wg_dicts = [] @@ -776,83 +810,95 @@ def init_workers(self): for resource_pool, class_dict in self.resource_pool_to_cls.items(): worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) - wg_dict = self.ray_worker_group_cls(resource_pool=resource_pool, - ray_cls_with_init=worker_dict_cls, - **wg_kwargs) + wg_dict = self.ray_worker_group_cls( + resource_pool=resource_pool, + ray_cls_with_init=worker_dict_cls, + device_name=self.device_name, + **wg_kwargs, + ) spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys()) all_wg.update(spawn_wg) # keep the referece of WorkerDict to support ray >= 2.31. Ref: https://github.com/ray-project/ray/pull/45699 self.wg_dicts.append(wg_dict) if self.use_critic: - self.critic_wg = all_wg['critic'] + self.critic_wg = all_wg["critic"] self.critic_wg.init_model() if self.use_reference_policy: - self.ref_policy_wg = all_wg['ref'] + self.ref_policy_wg = all_wg["ref"] self.ref_policy_wg.init_model() if self.use_rm: - self.rm_wg = all_wg['rm'] + self.rm_wg = all_wg["rm"] self.rm_wg.init_model() # we should create rollout at the end so that vllm can have a better estimation of kv cache memory - self.actor_rollout_wg = all_wg['actor_rollout'] + self.actor_rollout_wg = all_wg["actor_rollout"] self.actor_rollout_wg.init_model() def _save_checkpoint(self): # path: given_path + `/global_step_{global_steps}` + `/actor` - local_global_step_folder = os.path.join(self.config.trainer.default_local_dir, - f'global_step_{self.global_steps}') + local_global_step_folder = os.path.join( + self.config.trainer.default_local_dir, f"global_step_{self.global_steps}" + ) - print(f'local_global_step_folder: {local_global_step_folder}') - actor_local_path = os.path.join(local_global_step_folder, 'actor') + print(f"local_global_step_folder: {local_global_step_folder}") + actor_local_path = os.path.join(local_global_step_folder, "actor") - actor_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join( - self.config.trainer.default_hdfs_dir, f'global_step_{self.global_steps}', 'actor') + actor_remote_path = ( + None + if self.config.trainer.default_hdfs_dir is None + else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "actor") + ) - remove_previous_ckpt_in_save = self.config.trainer.get('remove_previous_ckpt_in_save', False) + remove_previous_ckpt_in_save = self.config.trainer.get("remove_previous_ckpt_in_save", False) if remove_previous_ckpt_in_save: print( - 'Warning: remove_previous_ckpt_in_save is deprecated, set max_actor_ckpt_to_keep=1 and max_critic_ckpt_to_keep=1 instead' + "Warning: remove_previous_ckpt_in_save is deprecated, set max_actor_ckpt_to_keep=1 and " + "max_critic_ckpt_to_keep=1 instead" ) - max_actor_ckpt_to_keep = self.config.trainer.get('max_actor_ckpt_to_keep', - None) if not remove_previous_ckpt_in_save else 1 - max_critic_ckpt_to_keep = self.config.trainer.get('max_critic_ckpt_to_keep', - None) if not remove_previous_ckpt_in_save else 1 + max_actor_ckpt_to_keep = ( + self.config.trainer.get("max_actor_ckpt_to_keep", None) if not remove_previous_ckpt_in_save else 1 + ) + max_critic_ckpt_to_keep = ( + self.config.trainer.get("max_critic_ckpt_to_keep", None) if not remove_previous_ckpt_in_save else 1 + ) - self.actor_rollout_wg.save_checkpoint(actor_local_path, - actor_remote_path, - self.global_steps, - max_ckpt_to_keep=max_actor_ckpt_to_keep) + self.actor_rollout_wg.save_checkpoint( + actor_local_path, actor_remote_path, self.global_steps, max_ckpt_to_keep=max_actor_ckpt_to_keep + ) if self.use_critic: - critic_local_path = os.path.join(local_global_step_folder, 'critic') - critic_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join( - self.config.trainer.default_hdfs_dir, f'global_step_{self.global_steps}', 'critic') - self.critic_wg.save_checkpoint(critic_local_path, - critic_remote_path, - self.global_steps, - max_ckpt_to_keep=max_critic_ckpt_to_keep) + critic_local_path = os.path.join(local_global_step_folder, "critic") + critic_remote_path = ( + None + if self.config.trainer.default_hdfs_dir is None + else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "critic") + ) + self.critic_wg.save_checkpoint( + critic_local_path, critic_remote_path, self.global_steps, max_ckpt_to_keep=max_critic_ckpt_to_keep + ) # save dataloader - dataloader_local_path = os.path.join(local_global_step_folder, 'data.pt') + dataloader_local_path = os.path.join(local_global_step_folder, "data.pt") dataloader_state_dict = self.train_dataloader.state_dict() torch.save(dataloader_state_dict, dataloader_local_path) # latest checkpointed iteration tracker (for atomic usage) - local_latest_checkpointed_iteration = os.path.join(self.config.trainer.default_local_dir, - 'latest_checkpointed_iteration.txt') - with open(local_latest_checkpointed_iteration, 'w') as f: + local_latest_checkpointed_iteration = os.path.join( + self.config.trainer.default_local_dir, "latest_checkpointed_iteration.txt" + ) + with open(local_latest_checkpointed_iteration, "w") as f: f.write(str(self.global_steps)) def _load_checkpoint(self): - if self.config.trainer.resume_mode == 'disable': + if self.config.trainer.resume_mode == "disable": return 0 # load from hdfs if self.config.trainer.default_hdfs_dir is not None: - raise NotImplementedError('load from hdfs is not implemented yet') + raise NotImplementedError("load from hdfs is not implemented yet") else: checkpoint_folder = self.config.trainer.default_local_dir # TODO: check path if not os.path.isabs(checkpoint_folder): @@ -861,63 +907,66 @@ def _load_checkpoint(self): global_step_folder = find_latest_ckpt_path(checkpoint_folder) # None if no latest # find global_step_folder - if self.config.trainer.resume_mode == 'auto': + if self.config.trainer.resume_mode == "auto": if global_step_folder is None: - print('Training from scratch') + print("Training from scratch") return 0 else: if self.config.trainer.resume_mode == "resume_path": assert isinstance(self.config.trainer.resume_from_path, str), "resume ckpt must be str type" - assert 'global_step_' in self.config.trainer.resume_from_path, "resume ckpt must specify the global_steps" + assert "global_step_" in self.config.trainer.resume_from_path, ( + "resume ckpt must specify the global_steps" + ) global_step_folder = self.config.trainer.resume_from_path if not os.path.isabs(global_step_folder): working_dir = os.getcwd() global_step_folder = os.path.join(working_dir, global_step_folder) - print(f'Load from checkpoint folder: {global_step_folder}') + print(f"Load from checkpoint folder: {global_step_folder}") # set global step - self.global_steps = int(global_step_folder.split('global_step_')[-1]) + self.global_steps = int(global_step_folder.split("global_step_")[-1]) - print(f'Setting global step to {self.global_steps}') - print(f'Resuming from {global_step_folder}') + print(f"Setting global step to {self.global_steps}") + print(f"Resuming from {global_step_folder}") - actor_path = os.path.join(global_step_folder, 'actor') - critic_path = os.path.join(global_step_folder, 'critic') + actor_path = os.path.join(global_step_folder, "actor") + critic_path = os.path.join(global_step_folder, "critic") # load actor - self.actor_rollout_wg.load_checkpoint(actor_path, - del_local_after_load=self.config.trainer.del_local_ckpt_after_load) + self.actor_rollout_wg.load_checkpoint( + actor_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load + ) # load critic if self.use_critic: - self.critic_wg.load_checkpoint(critic_path, - del_local_after_load=self.config.trainer.del_local_ckpt_after_load) + self.critic_wg.load_checkpoint( + critic_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load + ) # load dataloader, # TODO: from remote not implemented yet - dataloader_local_path = os.path.join(global_step_folder, 'data.pt') + dataloader_local_path = os.path.join(global_step_folder, "data.pt") if os.path.exists(dataloader_local_path): dataloader_state_dict = torch.load(dataloader_local_path, weights_only=False) self.train_dataloader.load_state_dict(dataloader_state_dict) else: print(f"Warning: No dataloader state found at {dataloader_local_path}, will start from scratch") - def _balance_batch(self, batch: DataProto, metrics, logging_prefix='global_seqlen'): + def _balance_batch(self, batch: DataProto, metrics, logging_prefix="global_seqlen"): """Reorder the data on single controller such that each dp rank gets similar total tokens""" - attention_mask = batch.batch['attention_mask'] + attention_mask = batch.batch["attention_mask"] batch_size = attention_mask.shape[0] - global_seqlen_lst = batch.batch['attention_mask'].view(batch_size, -1).sum(-1).tolist() # (train_batch_size,) + global_seqlen_lst = batch.batch["attention_mask"].view(batch_size, -1).sum(-1).tolist() # (train_batch_size,) world_size = self.actor_rollout_wg.world_size - global_partition_lst = get_seqlen_balanced_partitions(global_seqlen_lst, - k_partitions=world_size, - equal_size=True) + global_partition_lst = get_seqlen_balanced_partitions( + global_seqlen_lst, k_partitions=world_size, equal_size=True + ) # reorder based on index. The data will be automatically equally partitioned by dispatch function global_idx = torch.tensor([j for partition in global_partition_lst for j in partition]) batch.reorder(global_idx) - global_balance_stats = log_seqlen_unbalance(seqlen_list=global_seqlen_lst, - partitions=global_partition_lst, - prefix=logging_prefix) + global_balance_stats = log_seqlen_unbalance( + seqlen_list=global_seqlen_lst, partitions=global_partition_lst, prefix=logging_prefix + ) metrics.update(global_balance_stats) - - def fit_dpo(self): # Renamed for clarity as standard PPO loop + def fit_dpo(self): # Renamed for clarity as standard PPO loop """ The training loop of Online DPO using a periodically updated reference model. The driver process calls worker groups for computation. @@ -932,10 +981,12 @@ def fit_dpo(self): # Renamed for clarity as standard PPO loop # Initialize logger logger = None try: - logger = Tracking(project_name=self.config.trainer.project_name, - experiment_name=self.config.trainer.experiment_name, - default_backend=self.config.trainer.logger, - config=OmegaConf.to_container(self.config, resolve=True, throw_on_missing=False)) + logger = Tracking( + project_name=self.config.trainer.project_name, + experiment_name=self.config.trainer.experiment_name, + default_backend=self.config.trainer.logger, + config=OmegaConf.to_container(self.config, resolve=True, throw_on_missing=False), + ) except Exception as e: print(f"Warning: Failed to initialize logger: {e}") @@ -943,61 +994,84 @@ def fit_dpo(self): # Renamed for clarity as standard PPO loop # Load checkpoint before doing anything loaded_step = self._load_checkpoint() self.global_steps = loaded_step + 1 if loaded_step is not None and loaded_step > 0 else 1 - print(f"Starting Online DPO training from global step {self.global_steps}. Total steps: {self.total_training_steps}") + print( + f"Starting Online DPO training from global step {self.global_steps}. " + f"Total steps: {self.total_training_steps}" + ) print(f"Reference model update frequency: {self.config.trainer.get('ref_update_freq', 'Not Set')}") # Check if reference policy is configured correctly for this mode if not self.use_reference_policy: - print("WARNING: 'use_reference_policy' is False. Periodic reference model update requires a reference policy worker. DPO updates might fail or use incorrect logic.") - # Consider raising an error if strict adherence is required: - # raise ValueError("Periodic reference model update requires 'use_reference_policy' to be True and a configured reference worker.") - + print( + "WARNING: 'use_reference_policy' is False. Periodic reference model update requires a " + "reference policy worker. DPO updates might fail or use incorrect logic." + ) + # Consider raising an error if strict adherence is required: + # raise ValueError("Periodic reference model update requires 'use_reference_policy' to be True " + # "and a configured reference worker.") # Perform validation before training - if self.val_reward_fn is not None and self.config.trainer.get('val_before_train', True): + if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True): print("Running validation before Online DPO training...") val_metrics = self._validate() - pprint(f'Initial validation metrics: {val_metrics}') - if logger and val_metrics: logger.log(data=val_metrics, step=max(0, self.global_steps - 1)) - if self.config.trainer.get('val_only', False): + pprint(f"Initial validation metrics: {val_metrics}") + if logger and val_metrics: + logger.log(data=val_metrics, step=max(0, self.global_steps - 1)) + if self.config.trainer.get("val_only", False): print("Validation only mode enabled. Exiting training.") - if logger and hasattr(logger, 'finish'): logger.finish() + if logger and hasattr(logger, "finish"): + logger.finish() return # Add tqdm progress bar - progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Online DPO Training Progress", position=0, leave=True) + progress_bar = tqdm( + total=self.total_training_steps, + initial=self.global_steps, + desc="Online DPO Training Progress", + position=0, + leave=True, + ) last_val_metrics = None should_stop = False for epoch in range(self.config.trainer.total_epochs): - if should_stop: break + if should_stop: + break print(f"--- Starting Online DPO Epoch {epoch} ---") try: train_iterator = iter(self.train_dataloader) except TypeError: print("Warning: Dataloader is not iterable.") - train_iterator = self.train_dataloader # Fallback attempt + train_iterator = self.train_dataloader # Fallback attempt for batch_idx, batch_dict in enumerate(train_iterator): if self.global_steps > self.total_training_steps: - should_stop = True; break + should_stop = True + break metrics = {} timing_raw = {} step_timer = Timer(logger=None) - ref_log_prob_computed = False # Flag to track if ref log probs were computed + ref_log_prob_computed = False # Flag to track if ref log probs were computed - try: # Outer try-except for the whole step + try: # Outer try-except for the whole step step_timer.start() - with _timer('step', timing_raw): + with _timer("step", timing_raw): batch: DataProto = DataProto.from_single_dict(batch_dict) current_batch_size = batch.batch.batch_size[0] - print(f"\n[Step {self.global_steps}, Batch {batch_idx}] Processing batch size: {current_batch_size}") + print( + f"\n[Step {self.global_steps}, Batch {batch_idx}] Processing batch size: " + f"{current_batch_size}" + ) # --- Reference Model Update --- - ref_update_freq = self.config.trainer.get('ref_update_freq', -1) - if self.use_reference_policy and ref_update_freq > 0 and self.global_steps % ref_update_freq == 0: + ref_update_freq = self.config.trainer.get("ref_update_freq", -1) + if ( + self.use_reference_policy + and ref_update_freq > 0 + and self.global_steps % ref_update_freq == 0 + ): print(f"\n[Step {self.global_steps}] Updating Reference Model Weights from Actor...") try: # --- This requires careful implementation with FSDP --- @@ -1005,12 +1079,12 @@ def fit_dpo(self): # Renamed for clarity as standard PPO loop # This needs to be done collectively across actor worker ranks. # The checkpoint_manager might be adaptable, or use FSDP APIs directly. # Example placeholder using a conceptual save/load mechanism: - actor_state_path = "/tmp/actor_state_mid" # Temporary path - self.actor_rollout_wg.save_checkpoint(actor_state_path) # Adapt save logic + actor_state_path = "/tmp/actor_state_mid" # Temporary path + self.actor_rollout_wg.save_checkpoint(actor_state_path) # Adapt save logic # 2. Load the state dict onto the reference model worker group # This also needs collective loading on the ref worker ranks. - self.ref_policy_wg.load_checkpoint(actor_state_path,None, True) # Adapt load logic + self.ref_policy_wg.load_checkpoint(actor_state_path, None, True) # Adapt load logic print(f"[Step {self.global_steps}] Reference Model Weights Updated.") # Optionally remove the temporary state file @@ -1021,180 +1095,225 @@ def fit_dpo(self): # Renamed for clarity as standard PPO loop traceback.print_exc() # Pop keys for generation - pop_batch_keys=['input_ids', 'attention_mask'] - if 'position_ids' in batch.batch: pop_batch_keys.append('position_ids') - pop_non_tensor_keys = ['raw_prompt_ids'] if 'raw_prompt_ids' in batch.non_tensor_batch else [] - if 'multi_modal_inputs' in batch.non_tensor_batch.keys(): - pop_non_tensor_keys.extend(['multi_modal_data', 'multi_modal_inputs']) + pop_batch_keys = ["input_ids", "attention_mask"] + if "position_ids" in batch.batch: + pop_batch_keys.append("position_ids") + pop_non_tensor_keys = ["raw_prompt_ids"] if "raw_prompt_ids" in batch.non_tensor_batch else [] + if "multi_modal_inputs" in batch.non_tensor_batch.keys(): + pop_non_tensor_keys.extend(["multi_modal_data", "multi_modal_inputs"]) original_non_tensor_data = batch.non_tensor_batch gen_batch = batch.pop( batch_keys=pop_batch_keys, non_tensor_batch_keys=pop_non_tensor_keys, ) + gen_batch = gen_batch.repeat( + repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True + ) # (Add Debug prints for gen_batch if needed) # Generate sequences (chosen/rejected pairs) - with _timer('gen', timing_raw): + with _timer("gen", timing_raw): try: gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) # (Add Debug prints for gen_batch_output if needed) except Exception as gen_e: print(f"\n!!!!!!!! ERROR DURING GENERATION (Step {self.global_steps}) !!!!!!!!") - print(gen_e); traceback.print_exc() + print(gen_e) + traceback.print_exc() print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") - step_timer.stop(); continue + step_timer.stop() + continue # Combine original prompts with generated sequences - batch.non_tensor_batch = original_non_tensor_data # Restore non-tensor data - batch.non_tensor_batch['uid'] = np.array([str(uuid.uuid4()) for _ in range(current_batch_size)], dtype=object) + batch.non_tensor_batch = original_non_tensor_data # Restore non-tensor data + batch.non_tensor_batch["uid"] = np.array( + [str(uuid.uuid4()) for _ in range(current_batch_size)], dtype=object + ) batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) batch = batch.union(gen_batch_output) # (Add Debug prints after union if needed) # Compute response mask (needed for ref logprob calc and DPO prep) - batch.batch['response_mask'] = compute_response_mask(batch) + batch.batch["response_mask"] = compute_response_mask(batch) if self.config.trainer.balance_batch: self._balance_batch(batch, metrics=metrics) - batch.meta_info['global_token_num'] = torch.sum(batch.batch['attention_mask'], dim=-1).tolist() + batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() - # --- Compute Log Probs for the CURRENT policy (used for KL if enabled, or ActorAsRef fallback) --- + # --- Compute Log Probs for the CURRENT policy (used for KL if enabled, or ActorAsRef + # fallback) --- # Note: For pure DPO with external ref, this 'old_log_probs' might not be strictly needed # unless used for other metrics or a fallback. Keep it for now. - with _timer('policy_log_prob', timing_raw): - policy_log_prob_output = self.actor_rollout_wg.compute_log_prob(batch) - batch = batch.union(policy_log_prob_output) # Adds 'old_log_probs' - # (Debug prints for old_log_probs) + with _timer("policy_log_prob", timing_raw): + policy_log_prob_output = self.actor_rollout_wg.compute_log_prob(batch) + batch = batch.union(policy_log_prob_output) # Adds 'old_log_probs' + # (Debug prints for old_log_probs) # --- Compute Log Probs using the EXTERNAL Reference Model --- if self.use_reference_policy: - with _timer('ref_log_prob_dpo', timing_raw): + with _timer("ref_log_prob_dpo", timing_raw): # print(f"---- [Step {self.global_steps}] DEBUG DPO: Calling compute_ref_log_prob ----") try: # 'batch' contains interleaved chosen/rejected sequences - ref_log_prob_output = self.ref_policy_wg.compute_ref_log_prob(batch) # Returns DataProto with 'ref_log_prob' - batch = batch.union(ref_log_prob_output) # Adds 'ref_log_prob' key [batch_size * n, seq_len] - ref_log_prob_computed = True # Mark success - # print(f"---- [Step {self.global_steps}] DEBUG DPO: ref_log_prob tensor shape: {batch.batch['ref_log_prob'].shape} ----") + ref_log_prob_output = self.ref_policy_wg.compute_ref_log_prob( + batch + ) # Returns DataProto with 'ref_log_prob' + batch = batch.union( + ref_log_prob_output + ) # Adds 'ref_log_prob' key [batch_size * n, seq_len] + ref_log_prob_computed = True # Mark success + # print(f"---- [Step {self.global_steps}] DEBUG DPO: ref_log_prob tensor shape: " + # f"{batch.batch['ref_log_prob'].shape} ----") except Exception as ref_e: - print(f"ERROR computing reference log probs at step {self.global_steps}: {ref_e}") - traceback.print_exc() - batch.batch['ref_log_prob'] = None # Mark as failed - ref_log_prob_computed = False + print(f"ERROR computing reference log probs at step {self.global_steps}: {ref_e}") + traceback.print_exc() + batch.batch["ref_log_prob"] = None # Mark as failed + ref_log_prob_computed = False else: - print("Warning: Skipping external reference log prob calculation as use_reference_policy is False.") + print( + "Warning: Skipping external reference log prob calculation as use_reference_policy " + "is False." + ) # DPO update will likely fail unless ActorAsRef logic is re-enabled in dp_actor - # --- Compute Rewards/Scores (used to determine preference) --- - with _timer('reward_calc', timing_raw): - # (Reward calculation logic using RM or reward_fn as before) - # ... Ensure this calculates 'token_level_rewards' or similar ... + with _timer("reward_calc", timing_raw): + # (Reward calculation logic using RM or reward_fn as before) + # ... Ensure this calculates 'token_level_rewards' or similar ... if self.use_rm: reward_tensor_rm = self.rm_wg.compute_rm_score(batch) - batch = batch.union(reward_tensor_rm) # Adds 'rm_scores' + batch = batch.union(reward_tensor_rm) # Adds 'rm_scores' reward_extra_infos_dict = {} try: if self.reward_fn is None: - # print(f"---- [DEBUG Step {self.global_steps}] ERROR: self.reward_fn is None! Using dummy rewards. ----") - # Use rm_scores if available, otherwise zeros - reward_tensor = batch.batch.get('rm_scores', torch.zeros_like(batch.batch['response_mask'], dtype=torch.float32)) + # print(f"---- [DEBUG Step {self.global_steps}] ERROR: self.reward_fn is None! " + # f"Using dummy rewards. ----") + # Use rm_scores if available, otherwise zeros + reward_tensor = batch.batch.get( + "rm_scores", torch.zeros_like(batch.batch["response_mask"], dtype=torch.float32) + ) else: - reward_result = self.reward_fn(batch, return_dict=True) - reward_tensor = reward_result['reward_tensor'] # Final combined reward - reward_extra_infos_dict = reward_result.get('reward_extra_info', {}) + reward_result = self.reward_fn(batch, return_dict=True) + reward_tensor = reward_result["reward_tensor"] # Final combined reward + reward_extra_infos_dict = reward_result.get("reward_extra_info", {}) except Exception: - # print(f'---- [DEBUG Step {self.global_steps}] Error in reward_fn call: {e}. Using dummy rewards. ----') + # print(f'---- [DEBUG Step {self.global_steps}] Error in reward_fn call: {e}. ' + # f'Using dummy rewards. ----') traceback.print_exc() - reward_tensor = torch.zeros_like(batch.batch['response_mask'], dtype=torch.float32) + reward_tensor = torch.zeros_like(batch.batch["response_mask"], dtype=torch.float32) reward_extra_infos_dict = {} # Use 'token_level_rewards' as the key for preference calculation - batch.batch['token_level_rewards'] = reward_tensor - if reward_extra_infos_dict: batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()}) - + batch.batch["token_level_rewards"] = reward_tensor + if reward_extra_infos_dict: + batch.non_tensor_batch.update( + {k: np.array(v) for k, v in reward_extra_infos_dict.items()} + ) # --- Determine Preferences --- # Uses 'token_level_rewards' to determine chosen/rejected based on score - batch = compute_onlineDPO_pref(batch) # Adds 'preferences' key + batch = compute_onlineDPO_pref(batch) # Adds 'preferences' key # --- Prepare DPO Batch --- - dpo_update_batch_proto = None # Initialize - with _timer('prepare_dpo_batch', timing_raw): + dpo_update_batch_proto = None # Initialize + with _timer("prepare_dpo_batch", timing_raw): try: - if 'preferences' not in batch.batch or batch.batch['preferences'] is None: + if "preferences" not in batch.batch or batch.batch["preferences"] is None: raise ValueError("'preferences' key missing or None after compute_onlineDPO_pref.") # Check if reference log probs were computed successfully (if needed) if self.use_reference_policy and not ref_log_prob_computed: - raise ValueError("Reference log probs required but failed to compute.") + raise ValueError("Reference log probs required but failed to compute.") # Check required base keys - required_keys = ['input_ids', 'attention_mask', 'response_mask'] + required_keys = ["input_ids", "attention_mask", "response_mask"] for rk in required_keys: if rk not in batch.batch or batch.batch[rk] is None: raise KeyError(f"Required key '{rk}' missing from batch for DPO prep.") - preferences_mask = batch.batch['preferences'] # Shape [batch_size * n] + preferences_mask = batch.batch["preferences"] # Shape [batch_size * n] not_preferences_mask = ~preferences_mask # Gather Chosen/Rejected Base Tensors - chosen_input_ids = batch.batch['input_ids'][preferences_mask] - chosen_attention_mask = batch.batch['attention_mask'][preferences_mask] - rejected_input_ids = batch.batch['input_ids'][not_preferences_mask] - rejected_attention_mask = batch.batch['attention_mask'][not_preferences_mask] - chosen_position_ids = batch.batch.get('position_ids')[preferences_mask] if 'position_ids' in batch.batch else None - rejected_position_ids = batch.batch.get('position_ids')[not_preferences_mask] if 'position_ids' in batch.batch else None + chosen_input_ids = batch.batch["input_ids"][preferences_mask] + chosen_attention_mask = batch.batch["attention_mask"][preferences_mask] + rejected_input_ids = batch.batch["input_ids"][not_preferences_mask] + rejected_attention_mask = batch.batch["attention_mask"][not_preferences_mask] + chosen_position_ids = ( + batch.batch.get("position_ids")[preferences_mask] + if "position_ids" in batch.batch + else None + ) + rejected_position_ids = ( + batch.batch.get("position_ids")[not_preferences_mask] + if "position_ids" in batch.batch + else None + ) # Create Labels print("WARNING: Creating DPO labels using configured max_prompt_length...") prompt_len = self.config.data.max_prompt_length - chosen_labels = chosen_input_ids.clone(); chosen_labels[:, :prompt_len] = -100 - rejected_labels = rejected_input_ids.clone(); rejected_labels[:, :prompt_len] = -100 + chosen_labels = chosen_input_ids.clone() + chosen_labels[:, :prompt_len] = -100 + rejected_labels = rejected_input_ids.clone() + rejected_labels[:, :prompt_len] = -100 # Calculate and Gather Reference Log Probs (Sequence Level) if self.use_reference_policy: - ref_log_prob_tensor = batch.batch['ref_log_prob'] # Token level [bsz * n, seq_len] - response_mask_full = batch.batch['response_mask'] # Response mask [bsz * n, seq_len] - ref_sequence_logps = (ref_log_prob_tensor * response_mask_full).sum(dim=-1) # Sequence level [bsz * n] + ref_log_prob_tensor = batch.batch["ref_log_prob"] # Token level [bsz * n, seq_len] + response_mask_full = batch.batch[ + "response_mask" + ] # Response mask [bsz * n, seq_len] + ref_sequence_logps = (ref_log_prob_tensor * response_mask_full).sum( + dim=-1 + ) # Sequence level [bsz * n] reference_chosen_logps = ref_sequence_logps[preferences_mask] reference_rejected_logps = ref_sequence_logps[not_preferences_mask] else: - # If not using external ref, DPO needs ActorAsRef logic in dp_actor - # We won't add the keys here, dp_actor will handle it (or fail if not modified) - print("Info: Not adding explicit reference logps to DPO batch (use_reference_policy=False).") - reference_chosen_logps = None # Explicitly None - reference_rejected_logps = None + # If not using external ref, DPO needs ActorAsRef logic in dp_actor + # We won't add the keys here, dp_actor will handle it (or fail if not modified) + print( + "Info: Not adding explicit reference logps to DPO batch " + "(use_reference_policy=False)." + ) + reference_chosen_logps = None # Explicitly None + reference_rejected_logps = None # Package Tensors dpo_tensors = { - 'chosen_input_ids': chosen_input_ids, - 'chosen_attention_mask': chosen_attention_mask, - 'chosen_labels': chosen_labels, - 'rejected_input_ids': rejected_input_ids, - 'rejected_attention_mask': rejected_attention_mask, - 'rejected_labels': rejected_labels, + "chosen_input_ids": chosen_input_ids, + "chosen_attention_mask": chosen_attention_mask, + "chosen_labels": chosen_labels, + "rejected_input_ids": rejected_input_ids, + "rejected_attention_mask": rejected_attention_mask, + "rejected_labels": rejected_labels, } # Conditionally add reference logps if computed if reference_chosen_logps is not None: - dpo_tensors['reference_chosen_logps'] = reference_chosen_logps + dpo_tensors["reference_chosen_logps"] = reference_chosen_logps if reference_rejected_logps is not None: - dpo_tensors['reference_rejected_logps'] = reference_rejected_logps + dpo_tensors["reference_rejected_logps"] = reference_rejected_logps # Add position ids if they exist - if chosen_position_ids is not None: dpo_tensors['chosen_position_ids'] = chosen_position_ids - if rejected_position_ids is not None: dpo_tensors['rejected_position_ids'] = rejected_position_ids + if chosen_position_ids is not None: + dpo_tensors["chosen_position_ids"] = chosen_position_ids + if rejected_position_ids is not None: + dpo_tensors["rejected_position_ids"] = rejected_position_ids # Prepare Meta Info dpo_meta = { - 'dpo_beta': OmegaConf.select(self.config.algorithm, "dpo_beta", default=0.1), - 'dpo_loss_type': OmegaConf.select(self.config.algorithm, "dpo_loss_type", default='sigmoid'), - 'dpo_label_smoothing': OmegaConf.select(self.config.algorithm, "dpo_label_smoothing", default=0.0), - 'use_reference_policy': self.use_reference_policy, - 'reference_free': not self.use_reference_policy, # False if using external ref - 'global_step': self.global_steps + "dpo_beta": OmegaConf.select(self.config.algorithm, "dpo_beta", default=0.1), + "dpo_loss_type": OmegaConf.select( + self.config.algorithm, "dpo_loss_type", default="sigmoid" + ), + "dpo_label_smoothing": OmegaConf.select( + self.config.algorithm, "dpo_label_smoothing", default=0.0 + ), + "use_reference_policy": self.use_reference_policy, + "reference_free": not self.use_reference_policy, # False if using external ref + "global_step": self.global_steps, } dpo_update_batch_proto = DataProto.from_dict(tensors=dpo_tensors, meta_info=dpo_meta) @@ -1205,70 +1324,86 @@ def fit_dpo(self): # Renamed for clarity as standard PPO loop except Exception as e_prep: print(f"ERROR preparing DPO batch at step {self.global_steps}: {e_prep}") traceback.print_exc() - dpo_update_batch_proto = None # Skip update on error - + dpo_update_batch_proto = None # Skip update on error # --- Actor Update Step --- actor_output = None if self.config.trainer.critic_warmup <= self.global_steps and dpo_update_batch_proto: - with _timer('update_actor', timing_raw): + with _timer("update_actor", timing_raw): # Pass the batch containing reference log probs (if computed) # The modified update_actor_dpo expects them if reference_free=False actor_output = self.actor_rollout_wg.update_actor_dpo(dpo_update_batch_proto) - if actor_output and 'metrics' in actor_output.meta_info: - metrics.update(reduce_metrics(actor_output.meta_info['metrics'])) + if actor_output and "metrics" in actor_output.meta_info: + metrics.update(reduce_metrics(actor_output.meta_info["metrics"])) elif dpo_update_batch_proto is None: - print(f"Skipping actor update at step {self.global_steps} due to DPO batch preparation error.") - + print( + f"Skipping actor update at step {self.global_steps} due to DPO batch preparation error." + ) # --- Validation and Saving --- - test_freq = OmegaConf.select(self.config.trainer, "test_freq", default = -1) + test_freq = OmegaConf.select(self.config.trainer, "test_freq", default=-1) is_last_step = self.global_steps >= self.total_training_steps - if self.val_reward_fn is not None and test_freq > 0 and (is_last_step or self.global_steps % test_freq == 0): + if ( + self.val_reward_fn is not None + and test_freq > 0 + and (is_last_step or self.global_steps % test_freq == 0) + ): print(f"\nRunning DPO validation at step {self.global_steps}...") val_timing_raw = {} - with _timer('testing', val_timing_raw): + with _timer("testing", val_timing_raw): val_metrics: dict = self._validate() - if is_last_step: last_val_metrics = val_metrics + if is_last_step: + last_val_metrics = val_metrics if val_metrics: - metrics['time/validation_run'] = val_timing_raw.get('testing', 0) + metrics["time/validation_run"] = val_timing_raw.get("testing", 0) metrics.update(val_metrics) - else: print("Validation skipped or returned no metrics.") + else: + print("Validation skipped or returned no metrics.") - save_freq = OmegaConf.select(self.config.trainer, "save_freq", default = -1) - if save_freq > 0 and ( is_last_step or self.global_steps % save_freq == 0): + save_freq = OmegaConf.select(self.config.trainer, "save_freq", default=-1) + if save_freq > 0 and (is_last_step or self.global_steps % save_freq == 0): print(f"\nSaving DPO checkpoint at step {self.global_steps}...") - with _timer('save_checkpoint', timing_raw): - self._save_checkpoint() # Saves actor (and potentially critic if used elsewhere) - metrics['time/save_checkpoint'] = timing_raw.get('save_checkpoint', 0) + with _timer("save_checkpoint", timing_raw): + self._save_checkpoint() # Saves actor (and potentially critic if used elsewhere) + metrics["time/save_checkpoint"] = timing_raw.get("save_checkpoint", 0) # --- End main step timer context --- # --- Metrics calculation AFTER the 'step' timer block --- - metrics.update(compute_dpo_data_metrics(batch=batch)) # Use DPO-specific metrics + metrics.update(compute_dpo_data_metrics(batch=batch)) # Use DPO-specific metrics metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) n_gpus = self.resource_pool_manager.get_n_gpus() - if 'step' in timing_raw: - metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus)) + if "step" in timing_raw: + metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus)) else: - print(f"Warning: 'step' key missing from timing_raw at step {self.global_steps}. Skipping throughput.") + print( + f"Warning: 'step' key missing from timing_raw at step {self.global_steps}. " + f"Skipping throughput." + ) step_timer.stop() - metrics['time/step'] = step_timer.last + metrics["time/step"] = step_timer.last # Log metrics - log_freq = OmegaConf.select(self.config.trainer, "log_freq", default = 1) + log_freq = OmegaConf.select(self.config.trainer, "log_freq", default=1) if logger and self.global_steps % log_freq == 0: log_payload = metrics.copy() # Add learning rate to log payload - if actor_output and 'actor/lr' in metrics: log_payload['actor/lr'] = metrics['actor/lr'] + if actor_output and "actor/lr" in metrics: + log_payload["actor/lr"] = metrics["actor/lr"] print(f"[Step {self.global_steps} DPO] Logging Step Payload Keys: {list(log_payload.keys())}") - try: logger.log(data=log_payload, step=self.global_steps) - except Exception as e: print(f"Logging failed at step {self.global_steps}: {e}") + try: + logger.log(data=log_payload, step=self.global_steps) + except Exception as e: + print(f"Logging failed at step {self.global_steps}: {e}") # Update progress bar - postfix_metrics = {k: f"{v:.3f}" if isinstance(v, float) else v for k, v in metrics.items() if isinstance(v, (int, float))} + postfix_metrics = { + k: f"{v:.3f}" if isinstance(v, float) else v + for k, v in metrics.items() + if isinstance(v, int | float) + } progress_bar.set_postfix(postfix_metrics) except Exception as step_e: @@ -1276,41 +1411,48 @@ def fit_dpo(self): # Renamed for clarity as standard PPO loop print(f"Caught Exception: {step_e}") traceback.print_exc() print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") - step_timer.stop(); should_stop = True; break + step_timer.stop() + should_stop = True + break if is_last_step or should_stop: - print(f'Stopping DPO training at step {self.global_steps}.') + print(f"Stopping DPO training at step {self.global_steps}.") break self.global_steps += 1 progress_bar.update(1) # End of epoch handling - if hasattr(self.train_dataloader, 'reset'): - try: self.train_dataloader.reset() - except Exception as e: print(f"Warning: Failed to reset train dataloader state: {e}") - if should_stop: break + if hasattr(self.train_dataloader, "reset"): + try: + self.train_dataloader.reset() + except Exception as e: + print(f"Warning: Failed to reset train dataloader state: {e}") + if should_stop: + break # --- Final cleanup and logging --- progress_bar.close() final_step = max(0, self.global_steps - 1) print(f"Online DPO Training finished at step {final_step}.") # Save final checkpoint - save_freq = OmegaConf.select(self.config.trainer, "save_freq", default = -1) - if not self.config.trainer.get('val_only', False) and (save_freq <= 0 or final_step % save_freq != 0) : + save_freq = OmegaConf.select(self.config.trainer, "save_freq", default=-1) + if not self.config.trainer.get("val_only", False) and (save_freq <= 0 or final_step % save_freq != 0): print(f"Saving final DPO checkpoint at step {final_step}...") self._save_checkpoint() # Final validation run - if self.val_reward_fn and last_val_metrics is None and not self.config.trainer.get('val_only', False): - print("Running final validation...") - last_val_metrics = self._validate() - if last_val_metrics and logger: - last_val_metrics['final_validation'] = True - try: logger.log(data=last_val_metrics, step=final_step) - except Exception as e: print(f"[Final Val Metrics Log Error]: {e}") - - pprint(f'Final validation metrics: {last_val_metrics}') - if logger and hasattr(logger, 'finish'): logger.finish() + if self.val_reward_fn and last_val_metrics is None and not self.config.trainer.get("val_only", False): + print("Running final validation...") + last_val_metrics = self._validate() + if last_val_metrics and logger: + last_val_metrics["final_validation"] = True + try: + logger.log(data=last_val_metrics, step=final_step) + except Exception as e: + print(f"[Final Val Metrics Log Error]: {e}") + + pprint(f"Final validation metrics: {last_val_metrics}") + if logger and hasattr(logger, "finish"): + logger.finish() print("Online DPO Training Run Complete.") - diff --git a/recipe/sppo/dp_actor.py b/recipe/sppo/dp_actor.py index a58d1c31f..df14c0b4e 100644 --- a/recipe/sppo/dp_actor.py +++ b/recipe/sppo/dp_actor.py @@ -21,7 +21,8 @@ import verl.utils.torch_functional as verl_F from verl import DataProto from verl.trainer.ppo.core_algos import agg_loss, kl_penalty -from verl.utils.debug import GPUMemoryLogger +from verl.utils.device import get_device_id +from verl.utils.profiler import GPUMemoryLogger from verl.utils.py_functional import append_to_dict from verl.utils.seqlen_balancing import rearrange_micro_batches from verl.workers.actor.dp_actor import DataParallelPPOActor @@ -88,14 +89,18 @@ def update_policy(self, data: DataProto): # split batch into micro_batches mini_batch = data if has_multi_modal_inputs: - self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu + self.gradient_accumulation = ( + self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu + ) num_micro_batches = mini_batch.batch.batch_size[0] // self.config.ppo_micro_batch_size_per_gpu micro_batches = data.select(select_keys, non_tensor_select_keys).chunk(num_micro_batches) elif self.config.use_dynamic_bsz: max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len) else: - self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu + self.gradient_accumulation = ( + self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu + ) # split batch into micro_batches micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu) @@ -104,9 +109,9 @@ def update_policy(self, data: DataProto): for data in micro_batches: # Support all hardwares if isinstance(data, DataProto): - data = {**data.batch.to(torch.cuda.current_device()), **data.non_tensor_batch} + data = {**data.batch.to(get_device_id()), **data.non_tensor_batch} else: - data = data.to(torch.cuda.current_device()) # actor device is cpu when using offload + data = data.to(get_device_id()) # actor device is cpu when using offload responses = data["responses"] response_length = responses.size(1) attention_mask = data["attention_mask"] @@ -126,7 +131,9 @@ def update_policy(self, data: DataProto): calculate_entropy = False if entropy_coeff != 0: calculate_entropy = True - entropy, log_prob = self._forward_micro_batch(micro_batch=data, temperature=temperature, calculate_entropy=calculate_entropy) + entropy, log_prob = self._forward_micro_batch( + micro_batch=data, temperature=temperature, calculate_entropy=calculate_entropy + ) pg_loss, log_ratios, preference = compute_sppo_loss( old_log_prob=old_log_prob, @@ -148,8 +155,12 @@ def update_policy(self, data: DataProto): if self.config.use_kl_loss: ref_log_prob = data["ref_log_prob"] # compute kl loss - kld = kl_penalty(logprob=log_prob, ref_logprob=ref_log_prob, kl_penalty=self.config.kl_loss_type) - kl_loss = agg_loss(loss_mat=kld, loss_mask=response_mask, loss_agg_mode=self.config.loss_agg_mode) + kld = kl_penalty( + logprob=log_prob, ref_logprob=ref_log_prob, kl_penalty=self.config.kl_loss_type + ) + kl_loss = agg_loss( + loss_mat=kld, loss_mask=response_mask, loss_agg_mode=self.config.loss_agg_mode + ) policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef metrics["actor/kl_loss"] = kl_loss.detach().item() diff --git a/recipe/sppo/main_sppo.py b/recipe/sppo/main_sppo.py index 315aa6f0a..d99f4f2dc 100644 --- a/recipe/sppo/main_sppo.py +++ b/recipe/sppo/main_sppo.py @@ -23,7 +23,6 @@ import ray from verl.trainer.ppo.reward import load_reward_manager -from verl.utils.device import is_cuda_available from .sppo_ray_trainer import RaySPPOTrainer @@ -40,7 +39,9 @@ def run_ppo(config) -> None: if not ray.is_initialized(): # this is for local ray cluster ray.init( - runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN", "VLLM_LOGGING_LEVEL": "WARN"}}, + runtime_env={ + "env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN", "VLLM_LOGGING_LEVEL": "WARN"} + }, num_cpus=config.ray_init.num_cpus, ) @@ -72,8 +73,8 @@ def run(self, config): processor = hf_processor(local_path, use_fast=True) # used for multimodal LLM, could be none # define worker classes - if config.actor_rollout_ref.actor.strategy == "fsdp": - assert config.actor_rollout_ref.actor.strategy == config.critic.strategy + if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}: + assert config.critic.strategy in {"fsdp", "fsdp2"} from verl.single_controller.ray import RayWorkerGroup from .sppo_worker import SPPOActorRolloutRefWorker # , CriticWorker @@ -114,7 +115,7 @@ def run(self, config): # - finally, we combine all the rewards together # - The reward type depends on the tag of the data if config.reward_model.enable: - if config.reward_model.strategy == "fsdp": + if config.reward_model.strategy in {"fsdp", "fsdp2"}: from verl.workers.fsdp_workers import RewardModelWorker elif config.reward_model.strategy == "megatron": from verl.workers.megatron_workers import RewardModelWorker @@ -128,7 +129,9 @@ def run(self, config): role_worker_mapping[Role.RefPolicy] = ray.remote(SPPOActorRolloutRefWorker) mapping[Role.RefPolicy] = global_pool_id - reward_fn = load_reward_manager(config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {})) + reward_fn = load_reward_manager( + config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {}) + ) val_reward_fn = load_reward_manager(config, tokenizer, num_examine=1) resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) @@ -141,7 +144,7 @@ def run(self, config): ray_worker_group_cls=ray_worker_group_cls, reward_fn=reward_fn, val_reward_fn=val_reward_fn, - device_name="cuda" if is_cuda_available else "npu", + device_name=config.trainer.device, ) trainer.init_workers() trainer.fit() diff --git a/recipe/sppo/run_qwen2.5-7b_rm.sh b/recipe/sppo/run_qwen2.5-7b_rm.sh index 46c459b86..1a4c02686 100755 --- a/recipe/sppo/run_qwen2.5-7b_rm.sh +++ b/recipe/sppo/run_qwen2.5-7b_rm.sh @@ -43,7 +43,7 @@ python3 -m recipe.sppo.main_sppo \ actor_rollout_ref.rollout.gpu_memory_utilization=0.3 \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ - trainer.logger=['console','wandb'] \ + trainer.logger='["console","wandb"]' \ trainer.project_name='sppo-sglang' \ trainer.val_before_train=True \ trainer.experiment_name='Qwen2-7B-Instruct_hybrid_rm' \ diff --git a/recipe/sppo/sppo_ray_trainer.py b/recipe/sppo/sppo_ray_trainer.py index 4a401308e..15e2f9c40 100644 --- a/recipe/sppo/sppo_ray_trainer.py +++ b/recipe/sppo/sppo_ray_trainer.py @@ -34,8 +34,17 @@ from verl.trainer.ppo import core_algos from verl.trainer.ppo.core_algos import agg_loss from verl.trainer.ppo.metric_utils import reduce_metrics -from verl.trainer.ppo.ray_trainer import AdvantageEstimator, RayPPOTrainer, ResourcePoolManager, Role, WorkerType, _timer, apply_kl_penalty, compute_response_mask +from verl.trainer.ppo.ray_trainer import ( + AdvantageEstimator, + RayPPOTrainer, + ResourcePoolManager, + Role, + WorkerType, + apply_kl_penalty, + compute_response_mask, +) from verl.trainer.ppo.reward import compute_reward, compute_reward_async +from verl.utils.profiler.performance import simple_timer from verl.utils.tracking import ValidationGenerationsLogger @@ -176,21 +185,22 @@ def fit(self): batch_keys=batch_keys_to_pop, non_tensor_batch_keys=non_tensor_batch_keys_to_pop, ) + gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) is_last_step = self.global_steps >= self.total_training_steps - with _timer("step", timing_raw): + with simple_timer("step", timing_raw): # generate a batch - with _timer("gen", timing_raw): + with simple_timer("gen", timing_raw): if not self.async_rollout_mode: gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) else: - self.async_rollout_manager.wake_up() gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch) - self.async_rollout_manager.sleep() + timing_raw.update(gen_batch_output.meta_info["timing"]) + gen_batch_output.meta_info.pop("timing", None) if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX: - with _timer("gen_max", timing_raw): + with simple_timer("gen_max", timing_raw): gen_baseline_batch = deepcopy(gen_batch) gen_baseline_batch.meta_info["do_sample"] = False gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch) @@ -205,7 +215,9 @@ def fit(self): del gen_baseline_batch, gen_baseline_output - batch.non_tensor_batch["uid"] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object) + batch.non_tensor_batch["uid"] = np.array( + [str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object + ) # repeat to align with repeated responses in rollout batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) batch = batch.union(gen_batch_output) @@ -222,7 +234,7 @@ def fit(self): # compute global_valid tokens batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() - with _timer("reward", timing_raw): + with simple_timer("reward", timing_raw): # compute reward model score if self.use_rm: reward_tensor = self.rm_wg.compute_rm_score(batch) @@ -234,43 +246,44 @@ def fit(self): reward_tensor, reward_extra_infos_dict = compute_reward(batch, self.reward_fn) # recompute old_log_probs - with _timer("old_log_prob", timing_raw): + with simple_timer("old_log_prob", timing_raw): old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) entropys = old_log_prob.batch["entropys"] response_masks = batch.batch["response_mask"] loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode - entropy_loss = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode) - old_log_prob_metrics = {"actor/entropy_loss": entropy_loss.detach().item()} + entropy_agg = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode) + old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()} metrics.update(old_log_prob_metrics) old_log_prob.batch.pop("entropys") batch = batch.union(old_log_prob) if self.use_reference_policy: # compute reference log_prob - with _timer("ref", timing_raw): + with simple_timer("ref", timing_raw): ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) batch = batch.union(ref_log_prob) # compute values if self.use_critic: - with _timer("values", timing_raw): + with simple_timer("values", timing_raw): values = self.critic_wg.compute_values(batch) batch = batch.union(values) - with _timer("adv", timing_raw): + with simple_timer("adv", timing_raw): # we combine with rule-based rm reward_extra_infos_dict: dict[str, list] if self.config.reward_model.launch_reward_fn_async: reward_tensor, reward_extra_infos_dict = ray.get(future_reward) batch.batch["token_level_scores"] = reward_tensor - print(f"{list(reward_extra_infos_dict.keys())=}") if reward_extra_infos_dict: batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()}) # compute rewards. apply_kl_penalty if available if self.config.algorithm.use_kl_in_reward: - batch, kl_metrics = apply_kl_penalty(batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty) + batch, kl_metrics = apply_kl_penalty( + batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty + ) metrics.update(kl_metrics) else: batch.batch["token_level_rewards"] = batch.batch["token_level_scores"] @@ -281,7 +294,7 @@ def fit(self): # update critic if self.use_critic: - with _timer("update_critic", timing_raw): + with simple_timer("update_critic", timing_raw): critic_output = self.critic_wg.update_critic(batch) critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"]) metrics.update(critic_output_metrics) @@ -289,7 +302,7 @@ def fit(self): # implement critic warmup if self.config.trainer.critic_warmup <= self.global_steps: # update actor - with _timer("update_actor", timing_raw): + with simple_timer("update_actor", timing_raw): batch.meta_info["multi_turn"] = self.config.actor_rollout_ref.rollout.multi_turn.enable actor_output = self.actor_rollout_wg.update_actor(batch) actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) @@ -298,7 +311,7 @@ def fit(self): # Log rollout generations if enabled rollout_data_dir = self.config.trainer.get("rollout_data_dir", None) if rollout_data_dir: - with _timer("dump_rollout_generations", timing_raw): + with simple_timer("dump_rollout_generations", timing_raw): print(batch.batch.keys()) inputs = self.tokenizer.batch_decode(batch.batch["prompts"], skip_special_tokens=True) outputs = self.tokenizer.batch_decode(batch.batch["responses"], skip_special_tokens=True) @@ -312,15 +325,21 @@ def fit(self): ) # validate - if self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0): - with _timer("testing", timing_raw): + if ( + self.val_reward_fn is not None + and self.config.trainer.test_freq > 0 + and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0) + ): + with simple_timer("testing", timing_raw): val_metrics: dict = self._validate() if is_last_step: last_val_metrics = val_metrics metrics.update(val_metrics) - if self.config.trainer.save_freq > 0 and (is_last_step or self.global_steps % self.config.trainer.save_freq == 0): - with _timer("save_checkpoint", timing_raw): + if self.config.trainer.save_freq > 0 and ( + is_last_step or self.global_steps % self.config.trainer.save_freq == 0 + ): + with simple_timer("save_checkpoint", timing_raw): self._save_checkpoint() # training metrics diff --git a/recipe/sppo/sppo_worker.py b/recipe/sppo/sppo_worker.py index 0d52ae013..fbe3a6e48 100644 --- a/recipe/sppo/sppo_worker.py +++ b/recipe/sppo/sppo_worker.py @@ -20,10 +20,10 @@ from verl.single_controller.base.decorator import Dispatch, register from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager -from verl.utils.debug import log_gpu_memory_usage from verl.utils.flops_counter import FlopsCounter from verl.utils.fsdp_utils import offload_fsdp_model_to_cpu, offload_fsdp_optimizer from verl.utils.import_utils import import_external_libs +from verl.utils.profiler import log_gpu_memory_usage from verl.workers.fsdp_workers import ActorRolloutRefWorker logger = logging.getLogger(__file__) @@ -48,6 +48,7 @@ def init_model(self): override_model_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create())) use_remove_padding = self.config.model.get("use_remove_padding", False) + use_fused_kernels = self.config.model.get("use_fused_kernels", False) if self._is_actor or self._is_rollout: # we need the model for actor and rollout @@ -57,16 +58,19 @@ def init_model(self): else: optim_config = None fsdp_config = OmegaConf.create() - self.actor_module_fsdp, self.actor_optimizer, self.actor_lr_scheduler, self.actor_model_config = self._build_model_optimizer( - model_path=self.config.model.path, - fsdp_config=fsdp_config, - optim_config=optim_config, - override_model_config=override_model_config, - use_remove_padding=use_remove_padding, - enable_gradient_checkpointing=self.config.model.get("enable_gradient_checkpointing", False), - trust_remote_code=self.config.model.get("trust_remote_code", False), - use_liger=self.config.model.get("use_liger", False), - role="actor", + self.actor_module_fsdp, self.actor_optimizer, self.actor_lr_scheduler, self.actor_model_config = ( + self._build_model_optimizer( + model_path=self.config.model.path, + fsdp_config=fsdp_config, + optim_config=optim_config, + override_model_config=override_model_config, + use_remove_padding=use_remove_padding, + use_fused_kernels=use_fused_kernels, + enable_gradient_checkpointing=self.config.model.get("enable_gradient_checkpointing", False), + trust_remote_code=self.config.model.get("trust_remote_code", False), + use_liger=self.config.model.get("use_liger", False), + role="actor", + ) ) # get the original unwrapped module @@ -84,10 +88,15 @@ def init_model(self): OmegaConf.set_struct(self.config.actor, True) with open_dict(self.config.actor): self.config.actor.use_remove_padding = use_remove_padding - self.actor = DataParallelSPPOActor(config=self.config.actor, actor_module=self.actor_module_fsdp, actor_optimizer=self.actor_optimizer) + self.config.actor.use_fused_kernels = use_fused_kernels + self.actor = DataParallelSPPOActor( + config=self.config.actor, actor_module=self.actor_module_fsdp, actor_optimizer=self.actor_optimizer + ) if self._is_rollout: - self.rollout, self.rollout_sharding_manager = self._build_rollout(trust_remote_code=self.config.model.get("trust_remote_code", False)) + self.rollout, self.rollout_sharding_manager = self._build_rollout( + trust_remote_code=self.config.model.get("trust_remote_code", False) + ) if self._is_ref: self.ref_module_fsdp = self._build_model_optimizer( @@ -96,6 +105,7 @@ def init_model(self): optim_config=None, override_model_config=override_model_config, use_remove_padding=use_remove_padding, + use_fused_kernels=use_fused_kernels, trust_remote_code=self.config.model.get("trust_remote_code", False), use_liger=self.config.model.get("use_liger", False), role="ref", @@ -103,6 +113,7 @@ def init_model(self): OmegaConf.set_struct(self.config.ref, True) with open_dict(self.config.ref): self.config.ref.use_remove_padding = use_remove_padding + self.config.ref.use_fused_kernels = use_fused_kernels self.ref_policy = DataParallelSPPOActor(config=self.config.ref, actor_module=self.ref_module_fsdp) if self._is_actor: @@ -112,5 +123,5 @@ def init_model(self): optimizer=self.actor.actor_optimizer, lr_scheduler=self.actor_lr_scheduler, processing_class=self.processor if self.processor is not None else self.tokenizer, - checkpoint_contents=self.config.actor.checkpoint.contents, + checkpoint_config=self.config.actor.checkpoint, ) diff --git a/requirements-npu.txt b/requirements-npu.txt index 601e8f9fa..7d0386937 100644 --- a/requirements-npu.txt +++ b/requirements-npu.txt @@ -4,17 +4,18 @@ codetiming datasets dill hydra-core -numpy +numpy<2.0.0 pandas peft pyarrow>=15.0.0 pybind11 pylatexenc -ray -tensordict<=0.6.2 -transformers>=4.52.0 +tensordict>=0.8.0,<=0.9.1,!=0.9.0 +transformers==4.52.4 +ray==2.46.0 wandb mathruler torchdata einops qwen_vl_utils +torchvision==0.20.1 diff --git a/requirements.txt b/requirements.txt index 92b71bd8d..1a1173827 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,12 @@ +# requirements.txt records the full set of dependencies for development accelerate codetiming datasets dill +flash-attn hydra-core liger-kernel -numpy +numpy<2.0.0 openai pandas peft @@ -13,10 +15,14 @@ pybind11 pylatexenc pre-commit ray[default] -tensordict<=0.6.2 +tensordict>=0.8.0,<=0.9.1,!=0.9.0 torchdata transformers +# vllm==0.8.4 wandb packaging>=20.0 uvicorn fastapi +latex2sympy2_extended +math_verify +reasoning_gym diff --git a/requirements_sglang.txt b/requirements_sglang.txt index 57d5e0bef..ce9e7d536 100644 --- a/requirements_sglang.txt +++ b/requirements_sglang.txt @@ -5,14 +5,14 @@ datasets dill flash-attn hydra-core -numpy +numpy<2.0.0 pandas peft pyarrow>=19.0.0 pybind11 pylatexenc ray[default]>=2.10 -tensordict<=0.6.2 +tensordict>=0.8.0,<=0.9.1,!=0.9.0 torchdata torchvision transformers diff --git a/scripts/tools/clean_ckpt_training_state.sh b/scripts/tools/clean_ckpt_training_state.sh new file mode 100644 index 000000000..121acb946 --- /dev/null +++ b/scripts/tools/clean_ckpt_training_state.sh @@ -0,0 +1,2 @@ +# !/bin/bash + diff --git a/scripts/tools/converter_hf_to_mcore.py b/scripts/tools/converter_hf_to_mcore.py index 897b0f8a7..0183c1591 100644 --- a/scripts/tools/converter_hf_to_mcore.py +++ b/scripts/tools/converter_hf_to_mcore.py @@ -16,20 +16,37 @@ import argparse import os import warnings +from contextlib import contextmanager +from typing import Any, Callable, ContextManager, Optional +import numpy as np import torch +import torch.distributed as dist +from accelerate import init_empty_weights from megatron.core import dist_checkpointing from megatron.core import parallel_state as mpu +from megatron.core.dist_checkpointing.mapping import ShardedTensor from megatron.core.dist_checkpointing.serialization import StrictHandling from megatron.core.models.gpt.gpt_model import ModelType from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from transformers import AutoConfig, AutoModelForCausalLM +from transformers import AutoConfig +from verl.model_merger.megatron_model_merger import get_dynamic_pipeline_shards from verl.models.mcore import hf_to_mcore_config +from verl.utils.device import get_device_name, get_torch_device from verl.utils.megatron_utils import get_model def _init_args(): + """ + Examples: + + 1. single rank conversion for any model: + > python converter_hf_to_mcore.py --hf_model_path %{hf_model} --output_path ${output_path} + 2. distributed conversion for DeepseekV3 671B: + > torchrun --nproc_per_node 1 --nnodes 4 --node_rank ${RANK} converter_hf_to_mcore.py \ + --hf_model_path %{hf_model} --output_path ${output_path} + """ parser = argparse.ArgumentParser() parser.add_argument("--hf_model_path", type=str, required=True, help="The path for the huggingface model") parser.add_argument("--output_path", type=str, required=True, help="The path for the output mcore model") @@ -40,21 +57,6 @@ def _init_args(): return args -class MegatronConfig: - def __init__(self): - self.params_dtype = torch.bfloat16 - - -class ModelConfig: - def __init__(self): - self.path = None - - -class Config: - def __init__(self): - self.model = ModelConfig() - - def test_conversion(megatron_model_provider, tfconfig, output_path, model): ########### test ########### # load model @@ -74,8 +76,12 @@ def test_conversion(megatron_model_provider, tfconfig, output_path, model): continue dut_data = dut_state_dict[name].data if name in ref_state_dict: - ref_data = ref_state_dict[name].data - assert dut_data.shape == ref_state_dict.shape, f"{name=} {dut_data.shape=} {ref_data.shape=}" + ref_data = ref_state_dict[name] + if isinstance(ref_data, ShardedTensor): + ref_data = ref_data.data.view(ref_data.local_shape) + else: + ref_data = ref_data.data + assert dut_data.shape == ref_data.shape, f"{name=} {dut_data.shape=} {ref_data.shape=}" assert (dut_data == ref_data).all(), f"{name} is not equal" print(f"{name} is equal") else: @@ -84,7 +90,11 @@ def test_conversion(megatron_model_provider, tfconfig, output_path, model): if ref_state_dict[name] is None: print(f"[Warning] {name} is none in ref_state_dict") continue - ref_data = ref_state_dict[name].data + ref_data = ref_state_dict[name] + if isinstance(ref_data, ShardedTensor): + ref_data = ref_data.data.view(ref_data.local_shape) + else: + ref_data = ref_data.data if name in dut_state_dict: dut_data = dut_state_dict[name].data assert dut_data.shape == ref_data.shape, f"{name=} {dut_data.shape=} {ref_data.shape=}" @@ -95,7 +105,17 @@ def test_conversion(megatron_model_provider, tfconfig, output_path, model): print("Conversion test passed!") -def convert_checkpoint_from_transformers_to_megatron(hf_model, model, hf_config): +@torch.inference_mode() +def convert_checkpoint_from_transformers_to_megatron( + hf_model, model, hf_config, layer_start_end: Optional[tuple[int, int]] = None +): + if layer_start_end is None: + layer_start_end = (0, len(model.decoder.layers)) + layer_start, layer_end = layer_start_end + pp_rank = mpu.get_pipeline_model_parallel_rank() + pp_size = mpu.get_pipeline_model_parallel_world_size() + numel = 0 + num_attention_heads = hf_config.num_attention_heads num_key_value_heads = hf_config.num_key_value_heads hidden_dim = hf_config.hidden_size @@ -104,111 +124,273 @@ def convert_checkpoint_from_transformers_to_megatron(hf_model, model, hf_config) print("[WARNING] Converting GQA model") has_qkv_bias = getattr(hf_config, "qkv_bias", False) or getattr(hf_config, "attention_bias", False) has_share_expert = getattr(hf_config, "shared_expert_intermediate_size", None) - with torch.no_grad(): - model.embedding.word_embeddings.weight.copy_(hf_model.model.embed_tokens.weight) - for layer, hf_layer in zip(model.decoder.layers, hf_model.model.layers): - layer.self_attention.linear_qkv.layer_norm_weight.copy_(hf_layer.input_layernorm.weight) - - q = hf_layer.self_attn.q_proj.weight.view([num_key_value_heads, head_dim * num_attention_heads // num_key_value_heads, -1]) - k = hf_layer.self_attn.k_proj.weight.view([num_key_value_heads, head_dim, -1]) - v = hf_layer.self_attn.v_proj.weight.view([num_key_value_heads, head_dim, -1]) - qkv = torch.cat([q, k, v], dim=1).view(-1, hidden_dim).contiguous() - layer.self_attention.linear_qkv.weight.copy_(qkv) - - if has_qkv_bias: - q_bias = hf_layer.self_attn.q_proj.bias.view([num_key_value_heads, -1]) - k_bias = hf_layer.self_attn.k_proj.bias.view([num_key_value_heads, -1]) - v_bias = hf_layer.self_attn.v_proj.bias.view([num_key_value_heads, -1]) - qkv_bias = torch.cat([q_bias, k_bias, v_bias], dim=1).view(-1).contiguous() - layer.self_attention.linear_qkv.bias.copy_(qkv_bias) - - if hasattr(hf_layer.self_attn, "q_norm"): - layer.self_attention.q_layernorm.weight.copy_(hf_layer.self_attn.q_norm.weight.data) - layer.self_attention.k_layernorm.weight.copy_(hf_layer.self_attn.k_norm.weight.data) - - layer.self_attention.linear_proj.weight.copy_(hf_layer.self_attn.o_proj.weight) - layer.pre_mlp_layernorm.weight.copy_(hf_layer.post_attention_layernorm.weight) - - layer.mlp.router.weight.copy_(hf_layer.mlp.gate.weight) - - for idx, hf_expert in enumerate(hf_layer.mlp.experts): - fc1_weight = torch.cat([hf_expert.gate_proj.weight, hf_expert.up_proj.weight]) - layer.mlp.experts.linear_fc1._parameters[f"weight{idx}"].copy_(fc1_weight) - layer.mlp.experts.linear_fc2._parameters[f"weight{idx}"].copy_(hf_expert.down_proj.weight) - - if has_share_expert: - layer.mlp.shared_experts.gate_weight.copy_(hf_layer.mlp.shared_expert_gate.weight) - shared_fc1_weight = torch.cat([hf_layer.mlp.shared_expert.gate_proj.weight, hf_layer.mlp.shared_expert.up_proj.weight]) - layer.mlp.shared_experts.linear_fc1.weight.copy_(shared_fc1_weight) - layer.mlp.shared_experts.linear_fc2.weight.copy_(hf_layer.mlp.shared_expert.down_proj.weight) - - model.decoder.final_layernorm.weight.copy_(hf_model.model.norm.weight) - model.output_layer.weight.copy_(hf_model.lm_head.weight) - - -@torch.no_grad() -def convert_checkpoint_from_transformers_to_megatron_dpskv3(hf_model, model, hf_config, tfconfig): - warnings.warn("MTP model is not supported yet", stacklevel=2) + if pp_rank == 0: + numel += safe_copy(hf_model.model.embed_tokens.weight, model.embedding.word_embeddings.weight) - def safe_copy( - src_tensor: torch.Tensor, - dst_tensor: torch.Tensor, - skip_dtype_assert: bool = False, + assert len(model.decoder.layers) == (layer_end - layer_start), ( + f"Expected {len(model.decoder.layers)} layers, but got {layer_end - layer_start}" + ) + for layer_idx, (layer, hf_layer) in enumerate( + zip(model.decoder.layers, hf_model.model.layers[layer_start:layer_end], strict=True) ): - if not skip_dtype_assert: - if src_tensor.dtype != dst_tensor.dtype: - raise ValueError(f"Get source dtype {src_tensor.dtype}, but target dtype {dst_tensor.dtype}") - assert src_tensor.shape == dst_tensor.shape - dst_tensor.data.copy_(src_tensor.data) - return src_tensor.numel() - - model.embedding.word_embeddings.weight.copy_(hf_model.model.embed_tokens.weight) - for layer_idx, (layer, hf_layer) in enumerate(zip(model.decoder.layers, hf_model.model.layers)): - print(layer_idx) - layer.input_layernorm.weight.copy_(hf_layer.input_layernorm.weight) + global_layer_idx = layer_idx + layer_start + numel_cur = numel + numel += safe_copy(hf_layer.input_layernorm.weight, layer.self_attention.linear_qkv.layer_norm_weight) + + q = hf_layer.self_attn.q_proj.weight.view( + [num_key_value_heads, head_dim * num_attention_heads // num_key_value_heads, -1] + ) + k = hf_layer.self_attn.k_proj.weight.view([num_key_value_heads, head_dim, -1]) + v = hf_layer.self_attn.v_proj.weight.view([num_key_value_heads, head_dim, -1]) + qkv = torch.cat([q, k, v], dim=1).view(-1, hidden_dim).contiguous() + numel += safe_copy(qkv, layer.self_attention.linear_qkv.weight) + + if has_qkv_bias: + q_bias = hf_layer.self_attn.q_proj.bias.view([num_key_value_heads, -1]) + k_bias = hf_layer.self_attn.k_proj.bias.view([num_key_value_heads, -1]) + v_bias = hf_layer.self_attn.v_proj.bias.view([num_key_value_heads, -1]) + qkv_bias = torch.cat([q_bias, k_bias, v_bias], dim=1).view(-1).contiguous() + numel += safe_copy(qkv_bias, layer.self_attention.linear_qkv.bias) + + if hasattr(hf_layer.self_attn, "q_norm"): + numel += safe_copy(hf_layer.self_attn.q_norm.weight.data, layer.self_attention.q_layernorm.weight) + numel += safe_copy(hf_layer.self_attn.k_norm.weight.data, layer.self_attention.k_layernorm.weight) + + numel += safe_copy(hf_layer.self_attn.o_proj.weight, layer.self_attention.linear_proj.weight) + numel += safe_copy(hf_layer.post_attention_layernorm.weight, layer.pre_mlp_layernorm.weight) + + numel += safe_copy(hf_layer.mlp.gate.weight, layer.mlp.router.weight) + + for idx, hf_expert in enumerate(hf_layer.mlp.experts): + fc1_weight = torch.cat([hf_expert.gate_proj.weight, hf_expert.up_proj.weight]) + numel += safe_copy(fc1_weight, layer.mlp.experts.linear_fc1._parameters[f"weight{idx}"]) + numel += safe_copy(hf_expert.down_proj.weight, layer.mlp.experts.linear_fc2._parameters[f"weight{idx}"]) + + if has_share_expert: + numel += safe_copy(hf_layer.mlp.shared_expert_gate.weight, layer.mlp.shared_experts.gate_weight) + shared_fc1_weight = torch.cat( + [hf_layer.mlp.shared_expert.gate_proj.weight, hf_layer.mlp.shared_expert.up_proj.weight] + ) + numel += safe_copy(shared_fc1_weight, layer.mlp.shared_experts.linear_fc1.weight) + numel += safe_copy(hf_layer.mlp.shared_expert.down_proj.weight, layer.mlp.shared_experts.linear_fc2.weight) + print(f"{pp_rank=} {global_layer_idx=} {layer_idx=} {numel=} numel this layer={numel - numel_cur}") + + if pp_rank == pp_size - 1: + numel += safe_copy(hf_model.model.norm.weight, model.decoder.final_layernorm.weight) + numel += safe_copy(hf_model.lm_head.weight, model.output_layer.weight) + return numel + + +def safe_copy( + src_tensor: torch.Tensor, + dst_tensor: torch.Tensor, + skip_dtype_assert: bool = False, +): + if not skip_dtype_assert: + if src_tensor.dtype != dst_tensor.dtype: + raise ValueError(f"Get source dtype {src_tensor.dtype}, but target dtype {dst_tensor.dtype}") + assert src_tensor.shape == dst_tensor.shape + dst_tensor.data.copy_(src_tensor.data) + return src_tensor.numel() + + +@torch.inference_mode() +def convert_checkpoint_from_transformers_to_megatron_qwen2_5_vl(hfmodel, mgmodel, hf_config): + mgmodel = mgmodel.bfloat16() + hfmodel = hfmodel.bfloat16() + num_attention_heads = hf_config.num_attention_heads + num_query_groups = hf_config.num_key_value_heads + hidden_size = hf_config.hidden_size + head_dim = hidden_size // num_attention_heads + + # 1. vision model + hfvision = hfmodel.visual + mgvision = mgmodel.vision_model + vision_hidden_size = mgvision.config.hidden_size + vision_num_query_groups = mgvision.config.num_query_groups + vision_head_dim = vision_hidden_size // mgvision.config.num_attention_heads + copied_numel = 0 + safe_copy(hfvision.rotary_pos_emb.inv_freq, mgvision.rotary_pos_emb.inv_freq) + copied_numel += safe_copy(hfvision.patch_embed.proj.weight, mgvision.patch_embed.proj.weight) + for hfblock, mgblock in zip(hfvision.blocks, mgvision.decoder.layers, strict=True): + # norm1 --> linear_qkv.norm + copied_numel += safe_copy(hfblock.norm1.weight, mgblock.self_attention.linear_qkv.layer_norm_weight) + # norm2 --> mlp.linear_fc1.norm + copied_numel += safe_copy(hfblock.norm2.weight, mgblock.mlp.linear_fc1.layer_norm_weight) + # qkv --> self_attention.linear_qkv + converted_weight = ( + hfblock.attn.qkv.weight.view(3, vision_num_query_groups, -1, vision_head_dim, vision_hidden_size) + .transpose(0, 1) + .flatten(1, 2) + .reshape(-1, vision_hidden_size) + .contiguous() + ) + copied_numel += safe_copy(converted_weight, mgblock.self_attention.linear_qkv.weight) + converted_bias = ( + hfblock.attn.qkv.bias.view(3, vision_num_query_groups, -1) + .transpose(0, 1) + .flatten(1, 2) + .view(-1) + .contiguous() + ) + copied_numel += safe_copy(converted_bias, mgblock.self_attention.linear_qkv.bias) + # proj --> self_attention.linear_proj + copied_numel += safe_copy(hfblock.attn.proj.weight, mgblock.self_attention.linear_proj.weight) + copied_numel += safe_copy(hfblock.attn.proj.bias, mgblock.self_attention.linear_proj.bias) + # mlp --> mlp: gate + fc1_weight = torch.cat([hfblock.mlp.gate_proj.weight, hfblock.mlp.up_proj.weight]) + fc1_bias = torch.cat([hfblock.mlp.gate_proj.bias, hfblock.mlp.up_proj.bias]) + copied_numel += safe_copy(fc1_weight, mgblock.mlp.linear_fc1.weight) + copied_numel += safe_copy(fc1_bias, mgblock.mlp.linear_fc1.bias) + copied_numel += safe_copy(hfblock.mlp.down_proj.weight, mgblock.mlp.linear_fc2.weight) + copied_numel += safe_copy(hfblock.mlp.down_proj.bias, mgblock.mlp.linear_fc2.bias) + + # 2. vision projector + hfprojector = hfvision.merger + mgprojector = mgvision.projection + copied_numel += safe_copy(hfprojector.ln_q.weight, mgvision.decoder.final_layernorm.weight) + + copied_numel += safe_copy(hfprojector.mlp[0].weight, mgprojector.encoder.linear_fc1.weight) + copied_numel += safe_copy(hfprojector.mlp[0].bias, mgprojector.encoder.linear_fc1.bias) + copied_numel += safe_copy(hfprojector.mlp[2].weight, mgprojector.encoder.linear_fc2.weight) + copied_numel += safe_copy(hfprojector.mlp[2].bias, mgprojector.encoder.linear_fc2.bias) + n_params = sum([t.numel() for t in hfvision.state_dict().values()]) + assert n_params == copied_numel + # 3. llm [just Qwen2] + hfllm = hfmodel.model + mgllm = mgmodel.language_model + copied_numel = 0 + copied_numel += safe_copy(hfllm.embed_tokens.weight, mgllm.embedding.word_embeddings.weight) + for mglayer, hflayer in zip(mgllm.decoder.layers, hfllm.layers, strict=True): + copied_numel += safe_copy(hflayer.input_layernorm.weight, mglayer.self_attention.linear_qkv.layer_norm_weight) + + q_proj_weight = hflayer.self_attn.q_proj.weight.view(num_query_groups, -1, head_dim, hidden_size) + k_proj_weight = hflayer.self_attn.k_proj.weight.view(num_query_groups, -1, head_dim, hidden_size) + v_proj_weight = hflayer.self_attn.v_proj.weight.view(num_query_groups, -1, head_dim, hidden_size) + qkv_proj = torch.cat([q_proj_weight, k_proj_weight, v_proj_weight], dim=1).view(-1, hidden_size).contiguous() + copied_numel += safe_copy(qkv_proj, mglayer.self_attention.linear_qkv.weight) + + q_proj_bias = hflayer.self_attn.q_proj.bias.view(num_query_groups, -1) + k_proj_bias = hflayer.self_attn.k_proj.bias.view(num_query_groups, -1) + v_proj_bias = hflayer.self_attn.v_proj.bias.view(num_query_groups, -1) + qkv_bias = torch.cat([q_proj_bias, k_proj_bias, v_proj_bias], dim=1).view(-1).contiguous() + copied_numel += safe_copy(qkv_bias, mglayer.self_attention.linear_qkv.bias) + copied_numel += safe_copy(hflayer.self_attn.o_proj.weight, mglayer.self_attention.linear_proj.weight) + + fc1_weight = torch.cat([hflayer.mlp.gate_proj.weight, hflayer.mlp.up_proj.weight]) + copied_numel += safe_copy(fc1_weight, mglayer.mlp.linear_fc1.weight) + + copied_numel += safe_copy(hflayer.mlp.down_proj.weight, mglayer.mlp.linear_fc2.weight) + copied_numel += safe_copy(hflayer.post_attention_layernorm.weight, mglayer.mlp.linear_fc1.layer_norm_weight) + + copied_numel += safe_copy(hfllm.norm.weight, mgllm.decoder.final_layernorm.weight) + if not hf_config.tie_word_embeddings: + safe_copy(hfmodel.lm_head.weight, mgllm.output_layer.weight) + + n_params = sum([t.numel() for t in hfllm.state_dict().values()]) + + assert n_params == copied_numel + + +@torch.inference_mode() +def convert_checkpoint_from_transformers_to_megatron_dpskv3( + hf_model, + model, + hf_config, + tfconfig, + layer_start_end: Optional[tuple[int, int]] = None, +): + warnings.warn("MTP model is not supported yet", stacklevel=2) + if layer_start_end is None: + layer_start_end = (0, len(model.decoder.layers)) + layer_start, layer_end = layer_start_end + numel: int = 0 + pp_rank = mpu.get_pipeline_model_parallel_rank() + pp_size = mpu.get_pipeline_model_parallel_world_size() + if pp_rank == 0: + numel += safe_copy(hf_model.model.embed_tokens.weight, model.embedding.word_embeddings.weight) + + assert len(model.decoder.layers) == (layer_end - layer_start), ( + f"Expected {len(model.decoder.layers)} layers, but got {layer_end - layer_start}" + ) + for layer_idx, (layer, hf_layer) in enumerate( + zip(model.decoder.layers, hf_model.model.layers[layer_start:layer_end], strict=True) + ): + global_layer_idx = layer_idx + layer_start + numel_cur: int = numel + numel += safe_copy(hf_layer.input_layernorm.weight, layer.input_layernorm.weight) if hf_config.q_lora_rank is None: - layer.self_attention.linear_q_proj.weight.copy_(hf_layer.self_attn.q_proj.weight) + numel += safe_copy(hf_layer.self_attn.q_proj.weight, layer.self_attention.linear_q_proj.weight) else: - layer.self_attention.linear_q_down_proj.weight.copy_(hf_layer.self_attn.q_a_proj.weight) - layer.self_attention.linear_q_up_proj.weight.copy_(hf_layer.self_attn.q_b_proj.weight) - layer.self_attention.linear_q_up_proj.layer_norm_weight.copy_(hf_layer.self_attn.q_a_layernorm.weight) - - layer.self_attention.linear_kv_down_proj.weight.copy_(hf_layer.self_attn.kv_a_proj_with_mqa.weight) - layer.self_attention.linear_kv_up_proj.weight.copy_(hf_layer.self_attn.kv_b_proj.weight) - layer.self_attention.linear_kv_up_proj.layer_norm_weight.copy_(hf_layer.self_attn.kv_a_layernorm.weight) - layer.self_attention.linear_proj.weight.copy_(hf_layer.self_attn.o_proj.weight) + numel += safe_copy(hf_layer.self_attn.q_a_proj.weight, layer.self_attention.linear_q_down_proj.weight) + numel += safe_copy(hf_layer.self_attn.q_b_proj.weight, layer.self_attention.linear_q_up_proj.weight) + numel += safe_copy( + hf_layer.self_attn.q_a_layernorm.weight, layer.self_attention.linear_q_up_proj.layer_norm_weight + ) + + numel += safe_copy( + hf_layer.self_attn.kv_a_proj_with_mqa.weight, layer.self_attention.linear_kv_down_proj.weight + ) + numel += safe_copy(hf_layer.self_attn.kv_b_proj.weight, layer.self_attention.linear_kv_up_proj.weight) + numel += safe_copy( + hf_layer.self_attn.kv_a_layernorm.weight, layer.self_attention.linear_kv_up_proj.layer_norm_weight + ) + numel += safe_copy(hf_layer.self_attn.o_proj.weight, layer.self_attention.linear_proj.weight) if not hasattr(layer.mlp, "router"): - layer.mlp.linear_fc1.layer_norm_weight.copy_(hf_layer.post_attention_layernorm.weight) - layer.mlp.linear_fc1.weight.copy_(torch.cat([hf_layer.mlp.gate_proj.weight, hf_layer.mlp.up_proj.weight])) - layer.mlp.linear_fc2.weight.copy_(hf_layer.mlp.down_proj.weight) + numel += safe_copy(hf_layer.post_attention_layernorm.weight, layer.mlp.linear_fc1.layer_norm_weight) + numel += safe_copy( + torch.cat([hf_layer.mlp.gate_proj.weight, hf_layer.mlp.up_proj.weight]), layer.mlp.linear_fc1.weight + ) + numel += safe_copy(hf_layer.mlp.down_proj.weight, layer.mlp.linear_fc2.weight) else: - layer.mlp.router.weight.copy_(hf_layer.mlp.gate.weight) + numel += safe_copy(hf_layer.mlp.gate.weight, layer.mlp.router.weight) # NOTE: the e_score_correction_bias in mcore model will be initialized with bfloat16 and \ # recover to fp32 in the first forward. There is always a diff in the bias between two models (~0.3%) - safe_copy(hf_layer.mlp.gate.e_score_correction_bias, layer.mlp.router.expert_bias, skip_dtype_assert=True) + numel += safe_copy( + hf_layer.mlp.gate.e_score_correction_bias, layer.mlp.router.expert_bias, skip_dtype_assert=True + ) if tfconfig.moe_grouped_gemm: for i, hf_expert in enumerate(hf_layer.mlp.experts): fc1_weight = torch.cat([hf_expert.gate_proj.weight, hf_expert.up_proj.weight]) linear_fc1_weighti = getattr(layer.mlp.experts.linear_fc1, "weight" + str(i)) - linear_fc1_weighti.copy_(fc1_weight) + numel += safe_copy(fc1_weight, linear_fc1_weighti) linear_fc2_weighti = getattr(layer.mlp.experts.linear_fc2, "weight" + str(i)) - linear_fc2_weighti.copy_(hf_expert.down_proj.weight) + numel += safe_copy(hf_expert.down_proj.weight, linear_fc2_weighti) else: for i, hf_expert in enumerate(hf_layer.mlp.experts): expert = layer.mlp.experts.local_experts[i] fc1_weight = torch.cat([hf_expert.gate_proj.weight, hf_expert.up_proj.weight]) - expert.linear_fc1.weight.copy_(fc1_weight) - expert.linear_fc2.weight.copy_(hf_expert.down_proj.weight) - layer.pre_mlp_layernorm.weight.copy_(hf_layer.post_attention_layernorm.weight) - shared_fc1_weight = torch.cat([hf_layer.mlp.shared_experts.gate_proj.weight, hf_layer.mlp.shared_experts.up_proj.weight]) - layer.mlp.shared_experts.linear_fc1.weight.copy_(shared_fc1_weight) - layer.mlp.shared_experts.linear_fc2.weight.copy_(hf_layer.mlp.shared_experts.down_proj.weight) - - model.decoder.final_layernorm.weight.copy_(hf_model.model.norm.weight) + numel += safe_copy(fc1_weight, expert.linear_fc1.weight) + numel += safe_copy(hf_expert.down_proj.weight, expert.linear_fc2.weight) + numel += safe_copy(hf_layer.post_attention_layernorm.weight, layer.pre_mlp_layernorm.weight) + shared_fc1_weight = torch.cat( + [hf_layer.mlp.shared_experts.gate_proj.weight, hf_layer.mlp.shared_experts.up_proj.weight] + ) + numel += safe_copy(shared_fc1_weight, layer.mlp.shared_experts.linear_fc1.weight) + numel += safe_copy(hf_layer.mlp.shared_experts.down_proj.weight, layer.mlp.shared_experts.linear_fc2.weight) + print(f"{pp_rank=} {global_layer_idx=} {layer_idx=} {numel=} numel this layer={numel - numel_cur}") + assert numel - numel_cur == sum([i.numel() for i in hf_layer.state_dict().values()]), "numel mismatch" + + if pp_rank == pp_size - 1: + numel += safe_copy(hf_model.model.norm.weight, model.decoder.final_layernorm.weight) if not hf_config.tie_word_embeddings: - model.output_layer.weight.copy_(hf_model.lm_head.weight) + numel += safe_copy(hf_model.lm_head.weight, model.output_layer.weight) + print(f"{pp_rank=} {numel=}") + return numel + + +@contextmanager +def noop_context() -> Any: + yield + + +def support_distributed_convert(hf_config: AutoConfig) -> bool: + for arch in ["DeepseekV3ForCausalLM", "Qwen3MoeForCausalLM", "Qwen2MoeForCausalLM"]: + if arch in hf_config.architectures: + return True + return False def convert_hf_to_mcore(hf_model_path, output_path, use_cpu_initialization=False, test=False, trust_remote_code=False): @@ -218,13 +400,22 @@ def convert_hf_to_mcore(hf_model_path, output_path, use_cpu_initialization=False return # init torch distributed and mpu - os.environ["RANK"] = "0" - os.environ["WORLD_SIZE"] = "1" - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "12355" + if "WORLD_SIZE" not in os.environ: + os.environ["RANK"] = "0" + os.environ["WORLD_SIZE"] = "1" + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + torch.distributed.init_process_group("nccl") + + rank = dist.get_rank() + local_rank = os.getenv("LOCAL_RANK", 0) + world_size = dist.get_world_size() + get_torch_device().set_device(f"{get_device_name()}:{local_rank}") + mpu.initialize_model_parallel( tensor_model_parallel_size=1, + pipeline_model_parallel_size=world_size, virtual_pipeline_model_parallel_size=None, context_parallel_size=1, expert_model_parallel_size=1, @@ -235,9 +426,18 @@ def convert_hf_to_mcore(hf_model_path, output_path, use_cpu_initialization=False hf_config = AutoConfig.from_pretrained(hf_model_path) print(hf_config, flush=True) - cfg = Config() - cfg.model.path = hf_model_path - tfconfig = hf_to_mcore_config(hf_config, torch.bfloat16) + if world_size > 1 and not support_distributed_convert(hf_config): + raise NotImplementedError(f"distributed conversion is not supported for {hf_config.architectures} yet.") + + pipeline_shards = get_dynamic_pipeline_shards(hf_config.num_hidden_layers, world_size) + print(f"Pipeline shards: {pipeline_shards}", flush=True) + + tfconfig = hf_to_mcore_config( + hf_config, + torch.bfloat16, + num_layers_in_first_pipeline_stage=pipeline_shards[0] if len(pipeline_shards) > 1 else None, + num_layers_in_last_pipeline_stage=pipeline_shards[-1] if len(pipeline_shards) > 2 else None, + ) tfconfig.use_cpu_initialization = use_cpu_initialization tie_word_embeddings = getattr(hf_config, "tie_word_embeddings", False) @@ -255,23 +455,62 @@ def megatron_model_provider(pre_process, post_process): ) return parallel_model - model = get_model( - model_provider_func=megatron_model_provider, - model_type=ModelType.encoder_or_decoder, - wrap_with_ddp=False, - transformer_config=tfconfig, - ) + context: Callable[..., ContextManager] = init_empty_weights if use_cpu_initialization else noop_context + with context(): + model = get_model( + model_provider_func=megatron_model_provider, + model_type=ModelType.encoder_or_decoder, + wrap_with_ddp=False, + transformer_config=tfconfig, + ) + + if use_cpu_initialization: + # convert meta device to empty tensor so it can use `copy_` function + model[0].module = model[0].module.to_empty(device="cpu") with warnings.catch_warnings(): warnings.simplefilter("ignore") + from transformers import AutoModelForCausalLM, AutoModelForImageTextToText # init hf model - hf_model = AutoModelForCausalLM.from_pretrained(hf_model_path, torch_dtype=torch.bfloat16, trust_remote_code=trust_remote_code) + if "Qwen2_5_VLForConditionalGeneration" in hf_config.architectures: + hf_model = AutoModelForImageTextToText.from_pretrained( + hf_model_path, torch_dtype=torch.bfloat16, trust_remote_code=trust_remote_code + ) + else: + hf_model = AutoModelForCausalLM.from_pretrained( + hf_model_path, torch_dtype=torch.bfloat16, trust_remote_code=trust_remote_code + ) hf_state_dict = hf_model.state_dict() + # distributed convert + if world_size > 1 and support_distributed_convert(hf_config): + pipeline_cumsum = np.cumsum(pipeline_shards) + layer_start = 0 if rank == 0 else pipeline_cumsum[rank - 1] + layer_end = pipeline_cumsum[rank] + if "DeepseekV3ForCausalLM" in hf_config.architectures: + numel_partial: int = convert_checkpoint_from_transformers_to_megatron_dpskv3( + hf_model, model[0].module, hf_config, tfconfig=tfconfig, layer_start_end=(layer_start, layer_end) + ) + elif "Qwen3MoeForCausalLM" in hf_config.architectures or "Qwen2MoeForCausalLM" in hf_config.architectures: + numel_partial: int = convert_checkpoint_from_transformers_to_megatron( + hf_model, model[0].module, hf_config, layer_start_end=(layer_start, layer_end) + ) + else: + raise NotImplementedError(f"Distributed conversion is not supported for {hf_config.architectures} yet.") + + numel_tensor = torch.tensor([numel_partial]).to(get_device_name()) + dist.all_reduce(numel_tensor, op=dist.ReduceOp.SUM) + numel = int(numel_tensor.cpu().item()) + print(f"total numel={numel} vs {hf_model.num_parameters()=}") + if numel != hf_model.num_parameters(): + warnings.warn(f"numel mismatch: {numel=} != {hf_model.num_parameters()=}", stacklevel=1) + # load hf state dict to megatron model - if "Qwen2MoeForCausalLM" in hf_config.architectures: + elif "Qwen2MoeForCausalLM" in hf_config.architectures: convert_checkpoint_from_transformers_to_megatron(hf_model, model[0].module, hf_config) + elif "Qwen2_5_VLForConditionalGeneration" in hf_config.architectures: + convert_checkpoint_from_transformers_to_megatron_qwen2_5_vl(hf_model, model[0].module, hf_config) elif "DeepseekV3ForCausalLM" in hf_config.architectures: convert_checkpoint_from_transformers_to_megatron_dpskv3(hf_model, model[0].module, hf_config, tfconfig=tfconfig) elif "Qwen3MoeForCausalLM" in hf_config.architectures: @@ -300,4 +539,6 @@ def megatron_model_provider(pre_process, post_process): if __name__ == "__main__": args = _init_args() - convert_hf_to_mcore(args.hf_model_path, args.output_path, args.use_cpu_initialization, args.test, args.trust_remote_code) + convert_hf_to_mcore( + args.hf_model_path, args.output_path, args.use_cpu_initialization, args.test, args.trust_remote_code + ) diff --git a/scripts/tools/init_random_model.py b/scripts/tools/init_random_model.py index c626fc55b..2804bc2a2 100644 --- a/scripts/tools/init_random_model.py +++ b/scripts/tools/init_random_model.py @@ -14,7 +14,8 @@ # limitations under the License. """ -This script override a model with custom config and random weights, mainly for create small models for debugging purposes. +This script override a model with custom config and random weights, mainly for create small models for +debugging purposes. Usage: python scripts/init_random_model.py \ @@ -28,7 +29,7 @@ import json import os import warnings -from typing import Any, Dict +from typing import Any from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, PretrainedConfig @@ -51,17 +52,21 @@ def check_output_path(output_path: str): print(f"Output path '{output_path}' created.") -def check_configs(original_config: Dict[str, Any], new_config: Dict[str, Any]) -> bool: +def check_configs(original_config: dict[str, Any], new_config: dict[str, Any]) -> bool: """ Check if the original config and new config are compatible. This is a placeholder function; actual implementation may vary based on requirements. """ # Example check: ensure 'model_type' is the same - if new_config.get("model_type", None) is not None and original_config.get("model_type") != new_config.get("model_type"): + if new_config.get("model_type", None) is not None and original_config.get("model_type") != new_config.get( + "model_type" + ): raise RuntimeError("Model types do not match.") for key in new_config: if key not in original_config: - warnings.warn(f"Key '{key}' in new config does not exist in original config, may not take effect.", stacklevel=2) + warnings.warn( + f"Key '{key}' in new config does not exist in original config, may not take effect.", stacklevel=2 + ) def init_random_model(hf_model_path, new_config_path, output_path): @@ -85,4 +90,6 @@ def init_random_model(hf_model_path, new_config_path, output_path): if __name__ == "__main__": args = _init_args() check_output_path(args.output_path) - init_random_model(hf_model_path=args.hf_model_path, new_config_path=args.new_config_path, output_path=args.output_path) + init_random_model( + hf_model_path=args.hf_model_path, new_config_path=args.new_config_path, output_path=args.output_path + ) diff --git a/scripts/tools/install_vllm_sglang_mcore.sh b/scripts/tools/install_vllm_sglang_mcore.sh old mode 100644 new mode 100755 diff --git a/scripts/tools/legacy_model_merger.py b/scripts/tools/legacy_model_merger.py new file mode 100644 index 000000000..7049fc65d --- /dev/null +++ b/scripts/tools/legacy_model_merger.py @@ -0,0 +1,781 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script is used to merge huggingface model and test verl checkpoints from FSDP and Megatron backends. + +To merge FSDP checkpoints: +```sh +python scripts/legacy_model_merger.py merge \ + --backend fsdp \ + --local_dir checkpoints/verl_fsdp_gsm8k_examples/qwen2_5_0b5_fsdp_saveload/global_step_1/actor \ + --target_dir /path/to/merged_hf_model +``` + +To merge Megatron checkpoints: +```sh +python scripts/legacy_model_merger.py merge \ + --backend megatron \ + --tie-word-embedding \ + --local_dir checkpoints/verl_megatron_gsm8k_examples/qwen2_5_0b5_megatron_saveload/global_step_1/actor \ + --target_dir /path/to/merged_hf_model +``` + +For more details, please refer to documentation: +https://verl.readthedocs.io/en/latest/advance/checkpoint.html#convert-fsdp-and-megatron-checkpoints-to-huggingface-format-model +""" + +import argparse +import os +import re +import warnings +from abc import ABC, abstractmethod +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass, field +from pathlib import Path +from typing import Optional + +import numpy as np +import torch +from accelerate import init_empty_weights +from safetensors.torch import load_file +from torch.distributed._tensor import Placement, Shard +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoModelForTokenClassification, + AutoModelForVision2Seq, + GenerationConfig, + PretrainedConfig, +) + +try: + # for torch 2.5+ + from torch.distributed.tensor import DTensor +except ImportError: + from torch.distributed._tensor import DTensor + +from tqdm import tqdm + +from verl.utils import hf_processor, hf_tokenizer + + +@dataclass +class ModelMergerConfig: + operation: str # 'merge' or 'test' + backend: str + local_dir: str + hf_model_config_path: str + target_dir: Optional[str] = "tmp" + hf_upload_path: Optional[str] = None + private: bool = False + test_hf_dir: Optional[str] = None + tie_word_embedding: bool = False + is_value_model: bool = False + hf_model_path: Optional[str] = None + hf_upload: bool = field(init=False) + + def __post_init__(self): + self.hf_upload = self.operation == "merge" and bool(self.hf_upload_path) + if self.operation == "test": + self.target_dir = None + self.hf_upload_path = None + self.private = False + + +class BaseModelMerger(ABC): + def __init__(self, config: ModelMergerConfig): + self.config = config + self.hf_model_config_path = config.hf_model_config_path + + if config.hf_model_path: + print( + "Warning: --hf_model_path is deprecated and will be removed in a future version. Currently verl will save huggingface model configuration files into checkpoint directories. Therefore, there is no need to provide --hf_model_path. " + ) + self.hf_model_config_path = config.hf_model_path + + self.model_config = AutoConfig.from_pretrained(self.hf_model_config_path) + + def get_transformers_auto_model_class(self): + if "ForTokenClassification" in self.model_config.architectures[0]: + return AutoModelForTokenClassification + elif "ForCausalLM" in self.model_config.architectures[0]: + return AutoModelForCausalLM + elif "ForConditionalGeneration" in self.model_config.architectures[0]: + return AutoModelForVision2Seq + + raise NotImplementedError(f"Unknown architecture {self.model_config.architectures}") + + def patch_model_generation_config(self, model): + """ + The generation_config created from model config may be different to the pretrained model, + this may lead to error when generating: https://github.com/volcengine/verl/issues/1246 + + This function patch the generation_config created from model config to the pretrained model. + """ + if model.can_generate(): + try: + model.generation_config = GenerationConfig.from_pretrained(self.hf_model_config_path) + except OSError: + print( + f"Warning: Generation config file not found in {self.hf_model_config_path}, using a generation config created from the model config." + ) + return model + + def save_lora_adapter(self, state_dict: dict[str, torch.Tensor]): + """ + Save lora adapter to safetensors. + + Returns: + lora_path: str, the path to the lora adapter. None if no lora adapter found. + + Note: + This function change the 'state_dict' in place. + """ + lora_params_names = [name for name in state_dict.keys() if "lora_" in name] + + if len(lora_params_names) == 0: + return None + + import json + from typing import OrderedDict + + import peft + from safetensors.torch import save_file + + lora_params = OrderedDict() + target_modules = set() + lora_key = None + + for name in lora_params_names: + lora_key = name.replace(".default.weight", ".weight") + target_modules.add(lora_key.split(".")[-3]) + lora_params[lora_key] = state_dict.pop(name) + + lora_rank = min(lora_params[lora_key].shape[0], lora_params[lora_key].shape[1]) + peft_dict = { + "r": lora_rank, + "lora_alpha": 0, # lora_alpha is not set. An error should be raised to inform the user to set it manually. + "target_modules": list(target_modules), + } + peft_config = peft.LoraConfig(**peft_dict).to_dict() + peft_config["task_type"] = peft_config["task_type"].value if peft_config["task_type"] else None + peft_config["peft_type"] = peft_config["peft_type"].value if peft_config["peft_type"] else None + peft_config["target_modules"] = list(peft_config["target_modules"]) + + lora_path = os.path.join(self.config.target_dir, "lora_adapter") + os.makedirs(lora_path, exist_ok=True) + with open(os.path.join(lora_path, "adapter_config.json"), "w", encoding="utf-8") as f: + json.dump(peft_config, f, ensure_ascii=False, indent=4) + save_file(lora_params, os.path.join(lora_path, "adapter_model.safetensors")) + + for name in list(state_dict.keys()): + key = ( + name.replace("base_model.model.", "") + .replace(".base_layer.weight", ".weight") + .replace(".base_layer.bias", ".bias") + ) + state_dict[key] = state_dict.pop(name) + + return lora_path + + def save_hf_model_and_tokenizer(self, state_dict: dict[str, torch.Tensor]): + auto_model_class = self.get_transformers_auto_model_class() + with init_empty_weights(): + model = auto_model_class.from_config(self.model_config, torch_dtype=torch.bfloat16) + model.to_empty(device="cpu") + model = self.patch_model_generation_config(model) + + lora_path = self.save_lora_adapter(state_dict) + if lora_path: + print(f"Saving lora adapter to {lora_path}") + + print(f"Saving model to {self.config.target_dir}") + model.save_pretrained(self.config.target_dir, state_dict=state_dict) + del state_dict + del model + + processor = hf_processor(self.hf_model_config_path) + tokenizer = hf_tokenizer(self.hf_model_config_path) + if processor is not None: + print(f"Saving processor to {self.config.target_dir}") + processor.save_pretrained(self.config.target_dir) + if tokenizer is not None: + print(f"Saving tokenizer to {self.config.target_dir}") + tokenizer.save_pretrained(self.config.target_dir) + + def upload_to_huggingface(self): + from huggingface_hub import HfApi + + api = HfApi() + api.create_repo(repo_id=self.config.hf_upload_path, private=self.config.private, exist_ok=True) + api.upload_folder(folder_path=self.config.target_dir, repo_id=self.config.hf_upload_path, repo_type="model") + + @abstractmethod + def merge_and_save(self): + raise NotImplementedError("Subclasses should implement this method") + + +class FSDPModelMerger(BaseModelMerger): + def _get_world_size(self) -> int: + """Extracts the FSDP world_size from checkpoint filenames (e.g., 'model_world_size_8_rank_0.pt').""" + for filename in os.listdir(self.config.local_dir): + match = re.match(r"model_world_size_(\d+)_rank_0\.pt", filename) + if match: + return int(match.group(1)) + raise FileNotFoundError( + f"Could not determine world size. No file matching 'model_world_size_(\d+)_rank_0.pt' found in {self.config.local_dir}" + ) + + def _load_rank_zero_state_dict(self, world_size: int) -> dict: + return torch.load( + Path(self.config.local_dir) / f"model_world_size_{world_size}_rank_0.pt", + map_location="cpu", + weights_only=False, + ) + + def _extract_device_mesh_info(self, state_dict: dict, world_size: int) -> tuple[np.ndarray, tuple[str, ...]]: + """ + Retrieves sharding information (device_mesh, mesh_dim_names) from a DTensor in the state_dict. + If no DTensor is found, infers a simple FSDP mesh based on world_size. + """ + pivot_key = sorted(list(state_dict.keys()))[0] + weight = state_dict[pivot_key] + + if isinstance(weight, DTensor): + # get sharding info + device_mesh = weight.device_mesh + mesh = device_mesh.mesh + mesh_dim_names = device_mesh.mesh_dim_names + else: + # for non-DTensor + mesh = np.array([world_size], dtype=np.int64) + mesh_dim_names = ("fsdp",) + + return mesh, mesh_dim_names + + def _calculate_shard_configuration( + self, mesh: np.ndarray, mesh_dim_names: tuple[str, ...] + ) -> tuple[int, tuple[int, ...]]: + """Calculates the total number of shards and the shape of the device mesh.""" + assert mesh_dim_names in (("fsdp",), ("ddp", "fsdp")), f"Unsupported mesh_dim_names {mesh_dim_names}" + + if "tp" in mesh_dim_names: + # TODO: "tp" is not supported yet due to the above assert + total_shards = mesh.shape[-1] * mesh.shape[-2] + mesh_shape = (mesh.shape[-2], mesh.shape[-1]) + else: + total_shards = mesh.shape[-1] + mesh_shape = (mesh.shape[-1],) + + return total_shards, mesh_shape + + def _merge_by_placement(self, tensors: list[torch.Tensor], placement: Placement) -> torch.Tensor: + """Merges a list of tensors based on their DTensor placement""" + if placement.is_replicate(): + return tensors[0] + elif placement.is_partial(): + raise NotImplementedError("Partial placement is not supported yet") + elif placement.is_shard(): + return torch.cat(tensors, dim=placement.dim).contiguous() + + raise NotImplementedError(f"Unsupported placement: {placement}") + + def _load_and_merge_state_dicts( + self, world_size: int, total_shards: int, mesh_shape: tuple[int, ...], mesh_dim_names: tuple[str, ...] + ) -> dict[str, torch.Tensor]: + model_state_dict_lst = [None] * total_shards + + def process_one_shard(rank: int, model_state_dict_lst: list): + model_path = Path(self.config.local_dir) / f"model_world_size_{world_size}_rank_{rank}.pt" + state_dict = torch.load(model_path, map_location="cpu", weights_only=False) + model_state_dict_lst[rank] = state_dict + return state_dict + + with ThreadPoolExecutor(max_workers=min(32, os.cpu_count())) as executor: + futures = [executor.submit(process_one_shard, rank, model_state_dict_lst) for rank in range(total_shards)] + for future in tqdm(futures, desc=f"Loading {total_shards} FSDP shards", total=total_shards): + future.result() + + # Merge state dicts from all shards + state_dict = {} + param_placements: dict[str, list] = {} + + for key in set(model_state_dict_lst[0].keys()): + state_dict[key] = [] + for model_state_shard in model_state_dict_lst: + # add tensor shard in order of rank to state_dict[key] + tensor = model_state_shard.pop(key) + if isinstance(tensor, DTensor): + state_dict[key].append(tensor._local_tensor.bfloat16()) + + placements = tuple(tensor.placements) + # replicated placement at dp dimension can be discarded + if mesh_dim_names[0] in ("dp", "ddp"): + placements = placements[1:] + + if key not in param_placements: + param_placements[key] = placements + else: + assert param_placements[key] == placements + else: + state_dict[key].append(tensor.bfloat16()) + + del model_state_dict_lst + + # Merge tensors + for key in sorted(state_dict): + if not isinstance(state_dict[key], list): + print(f"No need to merge key {key}") + continue + if key in param_placements: + # merge shards + placements: tuple[Shard] = param_placements[key] + if len(mesh_shape) == 1: + # 1-D list, FSDP without TP + assert len(placements) == 1 + shards = state_dict[key] + state_dict[key] = self._merge_by_placement(shards, placements[0]) + else: + # 2-D list, FSDP + TP + raise NotImplementedError("FSDP + TP is not supported yet") + else: + state_dict[key] = torch.cat(state_dict[key], dim=0) + + return state_dict + + def merge_and_save(self): + world_size = self._get_world_size() + rank_zero_state_dict = self._load_rank_zero_state_dict(world_size) + + mesh, mesh_dim_names = self._extract_device_mesh_info(rank_zero_state_dict, world_size) + print(f"Got device mesh {mesh}, mesh_dim_names {mesh_dim_names}") + + total_shards, mesh_shape = self._calculate_shard_configuration(mesh, mesh_dim_names) + print(f"Processing model shards with {total_shards} {mesh_shape} in total") + + merged_state_dict = self._load_and_merge_state_dicts(world_size, total_shards, mesh_shape, mesh_dim_names) + + if self.config.operation == "test": + if not self.config.test_hf_dir: + raise ValueError("test_hf_dir must be provided for test operation") + self._test_state_dict(merged_state_dict) + elif self.config.operation == "merge": + self.save_hf_model_and_tokenizer(merged_state_dict) + if self.config.hf_upload: + self.upload_to_huggingface() + else: + raise ValueError(f"Unknown operation: {self.config.operation}") + + def _test_state_dict(self, state_dict: dict[str, torch.Tensor]): + auto_model_class = self.get_transformers_auto_model_class() + + hf_model = auto_model_class.from_pretrained(self.config.test_hf_dir, torch_dtype=torch.bfloat16) + hf_state_dict = hf_model.state_dict() + del hf_model + + hf_model_keys = set(hf_state_dict.keys()) + collected_keys = set(state_dict.keys()) + + missing_keys = hf_model_keys - collected_keys + assert len(missing_keys) == 0, f"Missing keys in collected state dict: {list(sorted(missing_keys))}" + + extra_keys = collected_keys - hf_model_keys + assert len(extra_keys) == 0, f"Extra keys in collected state dict: {list(sorted(extra_keys))}" + + for key in hf_model_keys: + hf_shape = hf_state_dict[key].shape + collected_shape = state_dict[key].shape + assert hf_shape == collected_shape, ( + f"Shape mismatch for key '{key}': original {hf_shape} vs collected {collected_shape}" + ) + + hf_dtype = hf_state_dict[key].dtype + collected_dtype = state_dict[key].dtype + assert hf_dtype == collected_dtype, ( + f"Dtype mismatch for key '{key}': original {hf_dtype} vs collected {collected_dtype}" + ) + + torch.testing.assert_close(hf_state_dict[key], state_dict[key], atol=1e-6, rtol=1e-6) + + print("FSDP checks passed: The merged state_dict matches the hf model saved by FSDPCheckpointManager.") + + +class MegatronModelMerger(BaseModelMerger): + def __init__(self, config: ModelMergerConfig): + from verl.utils.megatron_utils import get_hf_config_and_tokenizer_checkpoint_path + + config.hf_model_config_path = get_hf_config_and_tokenizer_checkpoint_path(config.local_dir) + super().__init__(config) + + self.params_mapping = { + # megatron core gpt model name, huggingface model name + # NOTICE: It's a little bit tricky, when 2 keys have the same prefix, we need to make sure the longer key within the containing relationship is processed first. + "embedding.word_embeddings": "model.embed_tokens", + # attn + "self_attention.linear_qkv.layer_norm_weight": "input_layernorm.weight", + "self_attention.linear_qkv.layer_norm_bias": "input_layernorm.bias", + "self_attention.linear_qkv": "self_attn.qkv_proj", + "self_attention.q_layernorm": "self_attn.q_norm", + "self_attention.k_layernorm": "self_attn.k_norm", + "self_attention.linear_proj": "self_attn.o_proj", + # mla + "self_attention.linear_q_proj": "self_attn.q_proj", + "self_attention.linear_q_down_proj": "self_attn.q_a_proj", + "self_attention.linear_q_up_proj.layer_norm_weight": "self_attn.q_a_layernorm.weight", + "self_attention.linear_q_up_proj": "self_attn.q_b_proj", + "self_attention.linear_kv_down_proj": "self_attn.kv_a_proj_with_mqa", + "self_attention.linear_kv_up_proj.layer_norm_weight": "self_attn.kv_a_layernorm.weight", + "self_attention.linear_kv_up_proj": "self_attn.kv_b_proj", + # mlp + "pre_mlp_layernorm": "post_attention_layernorm", + "mlp.linear_fc1.layer_norm_weight": "post_attention_layernorm.weight", + "mlp.linear_fc1.layer_norm_bias": "post_attention_layernorm.bias", + "mlp.linear_fc1": "mlp.gate_up_proj", + "mlp.linear_fc2": "mlp.down_proj", + # moe + "mlp.router.expert_bias": "mlp.gate.e_score_correction_bias", + "mlp.router": "mlp.gate", + "mlp.shared_experts.linear_fc1": "mlp.shared_experts.gate_up_proj", + "mlp.shared_experts.linear_fc2": "mlp.shared_experts.down_proj", + "linear_fc1": "gate_up_proj", + "linear_fc2": "down_proj", + # output + "final_layernorm": "norm", + "output_layer": "lm_head", + } + + def _get_tp_pp_rank_from_sharded_dir(self, sharded_dir: str) -> tuple[int, int]: + tp_rank = pp_rank = None + rank_list = sharded_dir.split("_")[2:] + if re.match(r"mp_rank_(\d\d)_(\d\d\d)", sharded_dir): + tp_rank = int(rank_list[0]) + pp_rank = int(rank_list[1]) + elif re.match(r"mp_rank_(\d\d)", sharded_dir): + tp_rank = int(rank_list[0]) + pp_rank = 0 + + assert tp_rank is not None and pp_rank is not None, f"Invalid sharded dir {sharded_dir}" + + return tp_rank, pp_rank + + def _check_megatron_checkpoint_path(self, model_path: str) -> tuple[list[str], int, int]: + """ + Validates the Megatron checkpoint structure (presence of 'model.pt' in sharded directories). + Determines TP and PP sizes from directory names. + """ + tp_size = 0 + pp_size = 0 + sharded_dirs = sorted(os.listdir(model_path)) + for sharded_dir in sharded_dirs: + assert "model.pt" in os.listdir(Path(model_path) / sharded_dir), f"model.pt not found in {sharded_dir}" + tp_rank, pp_rank = self._get_tp_pp_rank_from_sharded_dir(sharded_dir) + tp_size = max(tp_size, tp_rank + 1) + pp_size = max(pp_size, pp_rank + 1) + return sharded_dirs, tp_size, pp_size + + def _merge_across_tp( + self, + key: str, + tp_data: list[torch.Tensor], + config: PretrainedConfig, + tp_size: int, + is_value_model: bool = False, + ) -> torch.Tensor | list[torch.Tensor]: + if "linear_fc1.weight" in key: + # if the tensor is gate and proj + gate_lst = [] + up_lst = [] + for infer_param in tp_data: + gate, up = infer_param.chunk(2) + gate_lst.append(gate) + up_lst.append(up) + gate = torch.cat(gate_lst, dim=0) + up = torch.cat(up_lst, dim=0) + return [gate, up] + elif "self_attention.linear_qkv." in key and "layer_norm" not in key: + # if the tensor is qkv, for each param on tp, split into q, k, v + # concat q, k, v separately. + q_lst = [] + k_lst = [] + v_lst = [] + assert config.num_attention_heads % config.num_key_value_heads == 0 + num_q_per_kv = config.num_attention_heads // config.num_key_value_heads + assert tp_data[0].shape[0] % (num_q_per_kv + 2) == 0 + kv_size_per_tp = tp_data[0].shape[0] // (num_q_per_kv + 2) + split_size = [kv_size_per_tp * num_q_per_kv, kv_size_per_tp, kv_size_per_tp] + + for infer_param in tp_data: + num_query_groups_per_partition = config.num_key_value_heads // tp_size + for chunk in infer_param.chunk(num_query_groups_per_partition): + split_size = [ + kv_size_per_tp * num_q_per_kv // num_query_groups_per_partition, + kv_size_per_tp // num_query_groups_per_partition, + kv_size_per_tp // num_query_groups_per_partition, + ] + q, k, v = chunk.split(split_size) + q_lst.append(q) + k_lst.append(k) + v_lst.append(v) + + q = torch.cat(q_lst, dim=0) + k = torch.cat(k_lst, dim=0) + v = torch.cat(v_lst, dim=0) + return [q, k, v] + elif "layer_norm" in key or "layernorm" in key or "router" in key or ("output_layer" in key and is_value_model): + return tp_data[0] + else: + dim = 0 + if "linear_fc2.weight" in key or "self_attention.linear_proj" in key: + dim = 1 + return torch.cat(tp_data, dim=dim) + + def _load_state_dicts( + self, model_ckpt_path: str, sharded_dirs: list[str], tp_size: int, pp_size: int + ) -> list[list[dict]]: + model_state_dict_lst = [[None for _ in range(tp_size)] for _ in range(pp_size)] + + def _process_one_megatron_shard(sharded_dir: str): + model_file_path = Path(model_ckpt_path) / sharded_dir / "model.pt" + state_dict = torch.load(model_file_path, map_location="cpu", weights_only=False) + tp_rank, pp_rank = self._get_tp_pp_rank_from_sharded_dir(sharded_dir) + model_state_dict_lst[pp_rank][tp_rank] = state_dict + + with ThreadPoolExecutor(max_workers=min(32, os.cpu_count())) as executor: + futures = [executor.submit(_process_one_megatron_shard, sharded_dir) for sharded_dir in sharded_dirs] + for future in tqdm(futures, desc=f"Loading {len(sharded_dirs)} Megatron shards", total=len(sharded_dirs)): + future.result() + + return model_state_dict_lst + + def _check_megatron_state_key(self, key: str) -> bool: + """ + Checks if the key is a valid Megatron state key. + + Now the model merger only supports keys that start with "decoder/embedding/output_layer" in TransformerLayer. + Shall not use key starts with "model." + """ + if key.startswith("model."): + raise ValueError( + f"Invalid key {key} in Megatron state_dict. Expected keys to start with 'decoder/embedding/output_layer' in TransformerLayer." + ) + + skip_checking_keys = ["embedding.word_embeddings", "output_layer"] + for skip_key in skip_checking_keys: + if skip_key in key: + print(f"skip checking key {key}") + return + + # Exclude extra state keys + if not key.startswith("decoder"): + raise ValueError( + f"Invalid key {key} in Megatron state_dict. Expected keys to start with 'decoder' in TransformerLayer." + ) + + def _merge_state_dicts( + self, model_state_dict_lst: list[list[dict]], tp_size: int, pp_size: int + ) -> dict[str, torch.Tensor]: + state_dict = {} + vpp_size = len(model_state_dict_lst[0][0]) + layers_cum = 0 + + for vpp_rank in range(vpp_size): + for pp_rank in range(pp_size): + layers_handled = 0 + keys = model_state_dict_lst[pp_rank][0][vpp_rank].keys() + for key in keys: + if "extra_state" in key: + continue + if self.config.tie_word_embedding and ("output_layer" in key): + print("skip lm_head and reward_head loading because of tie_word_embeddings") + continue + + self._check_megatron_state_key(key) + hf_name = self._replace_name(key, self.params_mapping) + assert hf_name is not None, f"Failed to convert layer name [{key}] from megatron to huggingface." + if "model.layers." in hf_name: + local_layer_no = int(hf_name.split(".")[2]) + layers_handled = max(local_layer_no, layers_handled) + global_layer_no = local_layer_no + layers_cum + new_key_list = hf_name.split(".") + new_key_list[2] = str(global_layer_no) + hf_name = ".".join(new_key_list) + else: + warnings.warn(f"hf_name {hf_name} will not be fixed with layer number", stacklevel=2) + + tp_data = [model_state_dict_lst[pp_rank][tp_rank][vpp_rank][key] for tp_rank in range(tp_size)] + merged = self._merge_across_tp(key, tp_data, self.model_config, tp_size, self.config.is_value_model) + + if not isinstance(merged, list): + state_dict[hf_name] = merged + elif len(merged) == 3: + # split qkv + for n, d in zip(["q", "k", "v"], merged, strict=False): + state_dict[hf_name.replace("qkv", n)] = d + elif len(merged) == 2: + # split gate up + state_dict[hf_name.replace("gate_up", "gate")] = merged[0] + state_dict[hf_name.replace("gate_up", "up")] = merged[1] + print( + f"converted {key} to {hf_name} with shape {merged.shape if isinstance(merged, torch.Tensor) else [t.shape for t in merged]}" + ) + + layers_cum += layers_handled + 1 # zero based + + return state_dict + + def merge_and_save(self): + from verl.utils.megatron_utils import get_model_checkpoint_path + + model_ckpt_path = get_model_checkpoint_path(self.config.local_dir) + sharded_dirs, tp_size, pp_size = self._check_megatron_checkpoint_path(model_ckpt_path) + print(f"sharded_dirs: {sharded_dirs}, tp_size: {tp_size}, pp_size: {pp_size}, mp_size: {len(sharded_dirs)}") + + model_state_dict_lst = self._load_state_dicts(model_ckpt_path, sharded_dirs, tp_size, pp_size) + merged_state_dict = self._merge_state_dicts(model_state_dict_lst, tp_size, pp_size) + del model_state_dict_lst + + if self.config.operation == "test": + if not self.config.test_hf_dir: + raise ValueError("test_hf_dir must be provided for test operation") + self._test_state_dict(merged_state_dict) + elif self.config.operation == "merge": + self.save_hf_model_and_tokenizer(merged_state_dict) + if self.config.hf_upload: + self.upload_to_huggingface() + else: + raise ValueError(f"Unknown operation: {self.config.operation}") + + def _test_state_dict(self, state_dict: dict[str, torch.Tensor]): + """ + Compares the merged Megatron state_dict against a reference safetensors model. + Applies necessary name mappings from Megatron to Hugging Face conventions using _replace_name. + """ + ref_state_dict = load_file(Path(self.config.test_hf_dir) / "model.safetensors") + + for name, loaded_weight in state_dict.items(): + # name = self._replace_name(original_name, self.params_mapping) + if not name or name.endswith(".bias") and name not in ref_state_dict: + continue + if "rotary_emb.inv_freq" in name: + continue + if self.config.tie_word_embedding and "lm_head.weight" in name: + continue + if name not in ref_state_dict: + raise RuntimeError(f"key: {name} not exist in state_dict") + param = ref_state_dict[name] + assert loaded_weight.dtype == param.dtype + torch.testing.assert_close(loaded_weight, param, atol=1e-2, rtol=5e-2) + + def _replace_name(self, megatron_name: str, name_mapping: dict[str, str]) -> str: + for m_name, v_name in name_mapping.items(): + if m_name not in megatron_name: + continue + + megatron_name = megatron_name.replace("decoder", "model") + param_name = megatron_name.replace(m_name, v_name) + return param_name + + return None # Return None if no mapping found + + +def main(): + parser = argparse.ArgumentParser(description="verl model merger") + subparsers = parser.add_subparsers(dest="operation", required=True, help="Specify 'merge' or 'test' operation.") + + base_op_parser = argparse.ArgumentParser(add_help=False) + base_op_parser.add_argument( + "--backend", type=str, required=True, choices=["fsdp", "megatron"], help="The backend of the model" + ) + base_op_parser.add_argument("--local_dir", type=str, required=True, help="Path to the saved model checkpoints") + base_op_parser.add_argument( + "--hf_model_path", + type=str, + default=None, + help="(Deprecated) Path to the original Hugging Face model for config.", + ) + base_op_parser.add_argument( + "--tie-word-embedding", + action="store_true", + help="Whether to tie word embedding weights (currently only Megatron supported)", + ) + base_op_parser.add_argument( + "--is-value-model", + action="store_true", + help="Whether the model is a value model (currently only Megatron supported)", + ) + + merge_parser = subparsers.add_parser("merge", parents=[base_op_parser], help="Merge model checkpoints and save.") + merge_parser.add_argument( + "--target_dir", default="tmp", type=str, help="Directory to save the merged huggingface model" + ) + merge_parser.add_argument( + "--hf_upload_path", default=None, type=str, help="Hugging Face repository ID to upload the model" + ) + merge_parser.add_argument( + "--private", action="store_true", help="Whether to upload the model to a private Hugging Face repository" + ) + + test_parser = subparsers.add_parser( + "test", parents=[base_op_parser], help="Test merged model against a reference Hugging Face model" + ) + test_parser.add_argument( + "--test_hf_dir", type=str, required=True, help="Path to the reference Hugging Face model directory for testing" + ) + + args = parser.parse_args() + + common_config_args = { + "operation": args.operation, + "backend": args.backend, + "tie_word_embedding": args.tie_word_embedding, + "is_value_model": args.is_value_model, + "local_dir": args.local_dir, + "hf_model_path": args.hf_model_path, + "hf_model_config_path": args.local_dir, + } + + if args.operation == "merge": + config = ModelMergerConfig( + **common_config_args, + target_dir=args.target_dir, + hf_upload_path=args.hf_upload_path, + private=args.private, + test_hf_dir=None, + ) + os.makedirs(config.target_dir, exist_ok=True) + elif args.operation == "test": + config = ModelMergerConfig( + **common_config_args, + test_hf_dir=args.test_hf_dir, + # the following args are not used by test operation + target_dir=None, + hf_upload_path=None, + private=False, + ) + else: + raise NotImplementedError(f"Unknown operation: {args.operation}") + + if config.backend == "fsdp": + merger = FSDPModelMerger(config) + elif config.backend == "megatron": + merger = MegatronModelMerger(config) + else: + raise NotImplementedError(f"Unknown backend: {config.backend}") + + merger.merge_and_save() + + +if __name__ == "__main__": + main() diff --git a/scripts/tools/pass_rate.py b/scripts/tools/pass_rate.py new file mode 100644 index 000000000..a2f79981b --- /dev/null +++ b/scripts/tools/pass_rate.py @@ -0,0 +1,470 @@ +#!/usr/bin/env python3 +""" +Compute rewards for all parquet files in a directory. +Optionally compute response lengths using a specified tokenizer (default: DeepSeek-R1-0528). +""" + +import argparse +import json +import os +import time +from datetime import datetime, timedelta +from multiprocessing import cpu_count +from multiprocessing.pool import Pool +from pathlib import Path +from typing import Optional, Dict, List, Tuple + +from tqdm import tqdm +import pandas as pd +import numpy as np +from datasets import Dataset + +from verl.utils.reward_score import default_compute_score + +# Global tokenizer variable for response length computation +_tokenizer = None + +def init_tokenizer(model_name: str = "deepseek-ai/DeepSeek-R1-0528"): + """Initialize the tokenizer for response length computation.""" + global _tokenizer + try: + from transformers import AutoTokenizer + _tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + print(f"Successfully initialized {model_name} tokenizer for length computation") + except Exception as e: + print(f"Warning: Failed to initialize tokenizer '{model_name}': {e}") + _tokenizer = None + +def compute_response_length(response: str) -> int: + """Compute the token length of a response using the initialized tokenizer.""" + global _tokenizer + if _tokenizer is None: + return 0 + try: + return len(_tokenizer.encode(response)) + except Exception as e: + print(f"Warning: Failed to compute response length: {e}") + return 0 + +def compute_length_stats(lengths: List[int]) -> Dict[str, float]: + """Compute min, max, and average length statistics.""" + if not lengths: + return {"min": 0, "max": 0, "avg": 0} + return { + "min": min(lengths), + "max": max(lengths), + "avg": sum(lengths) / len(lengths) + } + +def format_time(seconds: float) -> str: + return str(timedelta(seconds=int(seconds))) + +def sanitize_for_parquet(obj): + """Recursively sanitize an object to make it parquet-compatible.""" + if obj is None: + return None + elif isinstance(obj, bool): + return int(obj) + elif isinstance(obj, np.bool_): + return int(obj) + elif isinstance(obj, (np.integer, np.int64, np.int32, np.int16, np.int8)): + return int(obj) + elif isinstance(obj, (np.floating, np.float64, np.float32, np.float16)): + return float(obj) + elif isinstance(obj, np.ndarray): + return obj.tolist() + elif isinstance(obj, dict): + return {key: sanitize_for_parquet(value) for key, value in obj.items()} + elif isinstance(obj, (list, tuple)): + return [sanitize_for_parquet(item) for item in obj] + elif isinstance(obj, (str, int, float)): + return obj + else: + return str(obj) + +def compute_single_reward(arg_tuple): + """Compute reward for one (response, ground-truth) pair.""" + gid, response, data_source, ground_truth, extra_info, resp_idx, compute_length, raw_resp = arg_tuple + # print(f"Computing reward for {gid} with data_source {data_source}") + # print(f"Response: {response}") + # print(f"Ground truth: {ground_truth}") + + try: + result = default_compute_score( + data_source=data_source, + solution_str=response, + ground_truth=ground_truth, + extra_info=extra_info, + ) + + if isinstance(result, dict): + detailed = result + else: + detailed = {"score": float(result)} + + # Add response length if requested + if compute_length: + detailed["response_length"] = compute_response_length(raw_resp) + + except ValueError as e: + if "embedded null byte" in str(e): + detailed = {"score": 0.0, "error": "null_byte_in_response"} + if compute_length: + detailed["response_length"] = 0 + print(f"Warning: Null byte detected in response {gid}, marking as failed") + else: + # Re-raise other ValueError types + raise + except Exception as e: + # Handle any other unexpected errors gracefully + detailed = {"score": 0.0, "error": f"execution_error: {str(e)[:100]}"} + if compute_length: + detailed["response_length"] = 0 + print(f"Warning: Error computing reward for response {gid}: {str(e)[:100]}") + + detailed = sanitize_for_parquet(detailed) + return gid, detailed, resp_idx + +def read_parquet_file(file_path: str) -> pd.DataFrame: + """Read a parquet file using HuggingFace datasets.""" + if not os.path.exists(file_path) or os.path.getsize(file_path) == 0: + return pd.DataFrame() + + dataset = Dataset.from_parquet(file_path) + return dataset.to_pandas() + +def write_parquet_file(file_path: str, data: pd.DataFrame): + """Write DataFrame to a parquet file using HuggingFace datasets.""" + dataset = Dataset.from_pandas(data) + dataset.to_parquet(file_path) + +def process_parquet_file(file_path: str, output_path: str, args, reward_pool: Optional[Pool]) -> Tuple[int, float]: + """Process a single parquet file and compute rewards.""" + print(f"Processing: {os.path.basename(file_path)}") + + df = read_parquet_file(file_path) + if df.empty: + print(f"No valid data found in {file_path}") + return 0, 0.0 + + start_time = time.time() + tasks, lookup, gid = [], {}, 0 + + # Track skipped items + skipped = {"existing_scores": 0, "no_responses": 0, "no_ground_truth": 0} + processed_rows = 0 + + for i in range(len(df)): + # Skip if already has scores and not recalculating + current_scores = df.iloc[i].get("scores") + if not args.recalculate_rewards and current_scores is not None: + if isinstance(current_scores, (list, tuple)) and len(current_scores) > 0: + skipped["existing_scores"] += 1 + continue + elif isinstance(current_scores, np.ndarray) and current_scores.size > 0: + skipped["existing_scores"] += 1 + continue + + responses = df.iloc[i][args.response_column_name] + reward_model_data = df.iloc[i]["reward_model"] + + # Validate responses + if responses is None or (isinstance(responses, float) and pd.isna(responses)): + skipped["no_responses"] += 1 + continue + + if isinstance(responses, np.ndarray): + responses = responses.tolist() + + if not isinstance(responses, list): + skipped["no_responses"] += 1 + continue + + # Extract ground truth + if not isinstance(reward_model_data, dict): + skipped["no_ground_truth"] += 1 + continue + + ground_truth = reward_model_data.get("ground_truth", "") + + # Handle different ground_truth formats + if ground_truth is None: + skipped["no_ground_truth"] += 1 + continue + elif isinstance(ground_truth, str) and not ground_truth: + skipped["no_ground_truth"] += 1 + continue + elif isinstance(ground_truth, list) or (isinstance(ground_truth, np.ndarray) and ground_truth.ndim > 0): + if len(ground_truth) == 0: + skipped["no_ground_truth"] += 1 + continue + # Convert ndarray to list while keeping original shape + if isinstance(ground_truth, np.ndarray): + ground_truth = ground_truth.tolist() + elif isinstance(ground_truth, np.ndarray) and ground_truth.ndim == 0: + # Handle numpy scalars (0-dimensional arrays) + ground_truth = ground_truth.item() + # For other types (int, float, etc.), use as-is + + # Get data source and extra info + data_source = df.iloc[i].get("data_source", df.iloc[i].get("source", "unknown")) + extra_info = df.iloc[i].get("extra_info", {}) + + processed_rows += 1 + + # Create tasks for each response + for resp_idx, raw_resp in enumerate(responses): + if raw_resp is None: + continue + + # Extract response, removing thinking tags + stripped = raw_resp.split("
", 1)[1] if "
" in str(raw_resp) else str(raw_resp) + tasks.append((gid, stripped, data_source, ground_truth, extra_info, resp_idx, args.compute_response_length, raw_resp)) + lookup[gid] = i + gid += 1 + + if not tasks: + print(f"No tasks to process in {os.path.basename(file_path)}") + return 0, 0.0 + + print(f"Computing rewards for {len(tasks)} responses across {processed_rows} items...") + + # Initialize results + detailed_by_sample: Dict[int, List[Optional[Dict]]] = {} + for i in range(len(df)): + responses = df.iloc[i][args.response_column_name] + if isinstance(responses, (list, np.ndarray)): + detailed_by_sample[i] = [None] * len(responses) + + # Process tasks + if reward_pool: + results = reward_pool.map(compute_single_reward, tasks) + else: + results = [compute_single_reward(task) for task in tqdm(tasks, desc="Computing rewards")] + + # Collect results + for gidx, detailed, resp_idx in results: + row_idx = lookup[gidx] + if row_idx in detailed_by_sample: + detailed_by_sample[row_idx][resp_idx] = detailed + + # Update DataFrame + detailed_scores_list = [None] * len(df) + scores_list = [None] * len(df) + pass_rate_list = [None] * len(df) + response_lengths_list = [None] * len(df) if args.compute_response_length else None + + total_responses = 0 + total_passed = 0 + question_pass_rates = [] + + # Length tracking + all_response_lengths = [] + passed_response_lengths = [] + + for row_idx, detailed_list in detailed_by_sample.items(): + # Fill missing results + for i, d in enumerate(detailed_list): + if d is None: + missing_result = {"score": 0.0, "error": "missing"} + if args.compute_response_length: + missing_result["response_length"] = 0 + detailed_list[i] = missing_result + + scores = [d["score"] for d in detailed_list] + pass_cnt = sum(s >= args.correct_reward_threshold for s in scores) + question_pass_rate = pass_cnt / len(scores) if len(scores) > 0 else 0.0 + + detailed_scores_list[row_idx] = detailed_list + scores_list[row_idx] = scores + pass_rate_list[row_idx] = question_pass_rate + + # Extract response lengths if computed + if args.compute_response_length: + response_lengths = [d.get("response_length", 0) for d in detailed_list] + response_lengths_list[row_idx] = response_lengths + + # Collect length statistics + all_response_lengths.extend(response_lengths) + for i, length in enumerate(response_lengths): + if scores[i] >= args.correct_reward_threshold: + passed_response_lengths.append(length) + + question_pass_rates.append(question_pass_rate) + total_passed += pass_cnt + total_responses += len(scores) + + # Add results to DataFrame + df["detailed_scores"] = detailed_scores_list + df["scores"] = scores_list + df["pass_rate"] = pass_rate_list + + # Add response lengths if computed + if args.compute_response_length: + df["response_lengths"] = response_lengths_list + + # Compute and add length statistics + all_length_stats = compute_length_stats(all_response_lengths) + passed_length_stats = compute_length_stats(passed_response_lengths) + + length_stats = { + "all_min": all_length_stats["min"], + "all_max": all_length_stats["max"], + "all_avg": all_length_stats["avg"], + "passed_min": passed_length_stats["min"], + "passed_max": passed_length_stats["max"], + "passed_avg": passed_length_stats["avg"] + } + df["length_statistics"] = [length_stats for _ in range(len(df))] + + # Add model pass rate + model_pass_rate = sum(question_pass_rates) / len(question_pass_rates) if len(question_pass_rates) > 0 else 0.0 + df["model_pass_rate"] = [{args.model_name: model_pass_rate} for _ in range(len(df))] + + # Save results + write_parquet_file(output_path, df) + + elapsed = time.time() - start_time + + # Print summary + print(f"Results: {len(df)} items, {processed_rows} processed, {sum(skipped.values())} skipped") + print(f"Pass rate: {model_pass_rate:.2%} ({total_passed}/{total_responses})") + + # Print response length statistics if computed + if args.compute_response_length and all_response_lengths: + all_stats = compute_length_stats(all_response_lengths) + passed_stats = compute_length_stats(passed_response_lengths) if passed_response_lengths else {"min": 0, "max": 0, "avg": 0} + print(f"Length stats - All: {all_stats['min']}/{all_stats['max']}/{all_stats['avg']:.1f}, Passed: {passed_stats['min']}/{passed_stats['max']}/{passed_stats['avg']:.1f} tokens") + + print(f"Time: {format_time(elapsed)}") + + return total_responses, elapsed + +def main(): + parser = argparse.ArgumentParser(description="Compute rewards for parquet files") + + parser.add_argument("input_path", help="Directory or parquet file path") + parser.add_argument("--output_dir", help="Output directory") + parser.add_argument("--model_name", default="r1-0528", help="Model name") + parser.add_argument("--output_suffix", default="", help="Output filename suffix") + parser.add_argument("--pattern", default="*.parquet", help="File pattern for directories") + parser.add_argument("--recursive", "-r", action="store_true", help="Recursive search") + + parser.add_argument("--reward_workers", type=int, default=64, help="Worker processes") + parser.add_argument("--correct_reward_threshold", type=float, default=1.0, help="Pass threshold") + parser.add_argument("--recalculate_rewards", action="store_true", help="Recompute existing rewards") + parser.add_argument("--maxtasks_per_child", type=int, default=50, help="Tasks per worker") + + parser.add_argument("--compute_response_length", action="store_true", help="Compute response length using specified tokenizer") + parser.add_argument("--tokenizer_model", default="deepseek-ai/DeepSeek-R1-0528", help="Tokenizer model name for response length computation (default: deepseek-ai/DeepSeek-R1-0528)") + parser.add_argument("--response_column_name", default="r1_0528_responses", help="Column name for responses in the parquet file (default: r1_0528_responses)") + parser.add_argument("--debug", action="store_true", help="Process first file only") + + args = parser.parse_args() + + # Initialize tokenizer if response length computation is requested + if args.compute_response_length: + init_tokenizer(args.tokenizer_model) + + # Validate input + input_path = Path(args.input_path) + if not input_path.exists(): + print(f"Input path does not exist: {args.input_path}") + return + + is_single_file = input_path.is_file() + + # Setup output + if args.output_dir: + output_base = Path(args.output_dir) + if is_single_file: + stem = input_path.stem + suffix = input_path.suffix + output_filename = f"{stem}_{args.output_suffix}{suffix}" if args.output_suffix else f"{stem}_scored{suffix}" + output_path = output_base / output_filename + else: + output_path = output_base + else: + if is_single_file: + stem = input_path.stem + suffix = input_path.suffix + output_filename = f"{stem}_{args.output_suffix}{suffix}" if args.output_suffix else f"{stem}_scored{suffix}" + output_path = input_path.parent / output_filename + else: + output_path = input_path.parent / f"{input_path.name}_scored" + + # Find files + if is_single_file: + parquet_files = [input_path] + else: + if args.recursive: + parquet_files = list(input_path.rglob(args.pattern)) + else: + parquet_files = list(input_path.glob(args.pattern)) + + if not parquet_files: + print(f"No files matching '{args.pattern}' found") + return + + parquet_files.sort() + + if args.debug: + parquet_files = parquet_files[:1] + + # Setup workers + workers = min(args.reward_workers, max(1, cpu_count() - 1)) if args.reward_workers > 1 else 1 + reward_pool = Pool(processes=workers, maxtasksperchild=args.maxtasks_per_child) if workers > 1 else None + + print(f"Processing {len(parquet_files)} files with {workers} workers") + print(f"Input: {args.input_path}") + print(f"Output: {output_path}") + + # Create output directory + if is_single_file: + output_path.parent.mkdir(parents=True, exist_ok=True) + else: + output_path.mkdir(parents=True, exist_ok=True) + + # Process files + total_processed = 0 + total_elapsed = 0.0 + successful_files = 0 + + for file_path in parquet_files: + if is_single_file: + output_file_path = output_path + else: + relative_path = file_path.relative_to(input_path) + if args.output_suffix: + stem = relative_path.stem + suffix = relative_path.suffix + output_filename = f"{stem}_{args.output_suffix}{suffix}" + output_file_path = output_path / relative_path.parent / output_filename + else: + output_file_path = output_path / relative_path + output_file_path.parent.mkdir(parents=True, exist_ok=True) + + # Skip if output file already exists when processing a folder + if output_file_path.exists(): + print(f"Skipping {os.path.basename(file_path)} - output file already exists: {output_file_path}") + continue + + processed, elapsed = process_parquet_file(str(file_path), str(output_file_path), args, reward_pool) + total_processed += processed + total_elapsed += elapsed + successful_files += 1 + + # Cleanup + if reward_pool: + reward_pool.close() + reward_pool.join() + + # Summary + print(f"\nComplete: {successful_files}/{len(parquet_files)} files") + print(f"Total responses: {total_processed}") + print(f"Total time: {format_time(total_elapsed)}") + if total_processed > 0: + print(f"Time per response: {total_elapsed/total_processed:.3f}s") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/tools/serve_llm_as_verifier.sh b/scripts/tools/serve_llm_as_verifier.sh index 15f6689e1..041ea647a 100644 --- a/scripts/tools/serve_llm_as_verifier.sh +++ b/scripts/tools/serve_llm_as_verifier.sh @@ -1,6 +1,6 @@ #!/bin/bash #SBATCH --job-name=server_llm_as_verifier -#SBATCH --partition=main +#SBATCH --account=iq #SBATCH --nodes=1 #SBATCH --ntasks=1 #SBATCH --cpus-per-task=64 @@ -10,6 +10,8 @@ #SBATCH --error=slurm/serve_llm_as_verifier_%j.err +source activate /mnt/weka/home/haonan.li/miniconda3/envs/Reasoning360 + # (1) detect this node’s primary IP NODE_IP=$(hostname -I | awk '{print $1}') echo "Detected NODE_IP = $NODE_IP" diff --git a/scripts/tools/serve_llm_as_verifier_m2.sh b/scripts/tools/serve_llm_as_verifier_m2.sh new file mode 100644 index 000000000..ed4101b35 --- /dev/null +++ b/scripts/tools/serve_llm_as_verifier_m2.sh @@ -0,0 +1,23 @@ +#!/bin/bash +#SBATCH --job-name=llm_as_verifier +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=64 +#SBATCH --gres=gpu:8 +#SBATCH --time=720:00:00 +#SBATCH --output=slurm/serve_llm_as_verifier_%j.log +#SBATCH --error=slurm/serve_llm_as_verifier_%j.err + + +source activate /lustrefs/users/haonan.li/miniconda3/envs/Reasoning360 + +# (1) detect this node’s primary IP +NODE_IP=$(hostname -I | awk '{print $1}') +echo "Detected NODE_IP = $NODE_IP" + +# (2) export judge URL for downstream clients +export STEM_LLM_JUDGE_URL="http://${NODE_IP}:8000" +echo "STEM_LLM_JUDGE_URL=$STEM_LLM_JUDGE_URL" + +# (3) launch the vLLM server bound to that IP +vllm serve TIGER-Lab/general-verifier --host "$NODE_IP" --data-parallel-size 8 diff --git a/scripts/tools/stats.py b/scripts/tools/stats.py new file mode 100644 index 000000000..34315764b --- /dev/null +++ b/scripts/tools/stats.py @@ -0,0 +1,516 @@ +#!/usr/bin/env python3 +""" +Visualize pass rate and response length distributions from scored parquet files. +""" + +import argparse +import os +from pathlib import Path +from typing import List, Optional, Dict, Any + +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt +import seaborn as sns +from datasets import Dataset + + +def read_parquet_file(file_path: str) -> pd.DataFrame: + """Read a parquet file using HuggingFace datasets.""" + if not os.path.exists(file_path) or os.path.getsize(file_path) == 0: + return pd.DataFrame() + + dataset = Dataset.from_parquet(file_path) + return dataset.to_pandas() + + +def extract_pass_rates(df: pd.DataFrame) -> List[float]: + """Extract pass rates from DataFrame.""" + pass_rates = [] + for _, row in df.iterrows(): + if 'pass_rate' in row and row['pass_rate'] is not None: + pass_rates.append(row['pass_rate']) + return pass_rates + + +def extract_response_lengths(df: pd.DataFrame) -> List[int]: + """Extract all response lengths from DataFrame, filtering to 32k limit.""" + all_lengths = [] + filtered_count = 0 + total_count = 0 + + for _, row in df.iterrows(): + if 'response_lengths' in row and row['response_lengths'] is not None: + lengths = row['response_lengths'] + if isinstance(lengths, (list, np.ndarray)): + for length in lengths: + if length > 0: + total_count += 1 + if length > 32000: + filtered_count += 1 + else: + all_lengths.append(length) + + if filtered_count > 0: + print(f"Warning: Filtered out {filtered_count} responses longer than 32,000 tokens ({filtered_count/total_count*100:.1f}% of total)") + + return all_lengths + + +def extract_passed_response_lengths(df: pd.DataFrame, threshold: float = 1.0) -> List[int]: + """Extract response lengths for passed responses only, filtering to 32k limit.""" + passed_lengths = [] + filtered_count = 0 + total_passed_count = 0 + + for _, row in df.iterrows(): + if ('response_lengths' in row and 'scores' in row and + row['response_lengths'] is not None and row['scores'] is not None): + lengths = row['response_lengths'] + scores = row['scores'] + if isinstance(lengths, (list, np.ndarray)) and isinstance(scores, (list, np.ndarray)): + for length, score in zip(lengths, scores): + if score >= threshold and length > 0: + total_passed_count += 1 + if length > 32000: + filtered_count += 1 + else: + passed_lengths.append(length) + + if filtered_count > 0: + print(f"Warning: Filtered out {filtered_count} passed responses longer than 32,000 tokens ({filtered_count/total_passed_count*100:.1f}% of passed responses)") + + return passed_lengths + + +def plot_pass_rate_distribution(pass_rates: List[float], output_dir: Path, model_name: str): + """Plot pass rate distribution.""" + if not pass_rates: + print("No pass rate data to plot") + return + + plt.figure(figsize=(10, 6)) + + # Histogram only - cleaner and more readable + plt.hist(pass_rates, bins=30, alpha=0.7, edgecolor='black', color='skyblue') + plt.xlabel('Pass Rate', fontsize=12) + plt.ylabel('Frequency', fontsize=12) + plt.title(f'Pass Rate Distribution - {model_name}', fontsize=14, fontweight='bold') + plt.grid(True, alpha=0.3) + + # Add statistics text box + mean_pass_rate = np.mean(pass_rates) + median_pass_rate = np.median(pass_rates) + std_pass_rate = np.std(pass_rates) + + stats_text = f'Mean: {mean_pass_rate:.3f}\nMedian: {median_pass_rate:.3f}\nStd: {std_pass_rate:.3f}\nCount: {len(pass_rates)}' + plt.text(0.02, 0.98, stats_text, transform=plt.gca().transAxes, + bbox=dict(boxstyle="round,pad=0.3", facecolor='white', alpha=0.8), + verticalalignment='top', fontsize=10) + + plt.tight_layout() + + # Save plot + output_file = output_dir / f'{model_name}_pass_rate_distribution.png' + plt.savefig(output_file, dpi=300, bbox_inches='tight') + plt.close() + + print(f"Pass rate distribution saved to: {output_file}") + + # Print statistics + mean_pass_rate = np.mean(pass_rates) + median_pass_rate = np.median(pass_rates) + std_pass_rate = np.std(pass_rates) + + print(f"Pass Rate Statistics:") + print(f" Mean: {mean_pass_rate:.3f}") + print(f" Median: {median_pass_rate:.3f}") + print(f" Std: {std_pass_rate:.3f}") + print(f" Min: {np.min(pass_rates):.3f}") + print(f" Max: {np.max(pass_rates):.3f}") + + +def plot_length_distribution(all_lengths: List[int], passed_lengths: List[int], + output_dir: Path, model_name: str): + """Plot response length distributions with improved readability.""" + if not all_lengths: + print("No response length data to plot") + return + + plt.figure(figsize=(15, 8)) + + # Log scale histogram comparison + plt.subplot(2, 2, 1) + bins = np.logspace(np.log10(min(all_lengths)), np.log10(max(all_lengths)), 40) + plt.hist(all_lengths, bins=bins, alpha=0.7, label='All Responses', density=True, color='skyblue') + if passed_lengths: + plt.hist(passed_lengths, bins=bins, alpha=0.7, label='Passed Responses', density=True, color='lightcoral') + plt.xlabel('Response Length (tokens)', fontsize=12) + plt.ylabel('Density', fontsize=12) + plt.title(f'Response Length Distribution (Log Scale) - {model_name}', fontsize=13, fontweight='bold') + plt.xscale('log') + plt.legend(fontsize=10) + plt.grid(True, alpha=0.3) + + # Linear scale histogram for detailed view + plt.subplot(2, 2, 2) + plt.hist(all_lengths, bins=50, alpha=0.7, label='All Responses', density=True, color='skyblue') + if passed_lengths: + plt.hist(passed_lengths, bins=50, alpha=0.7, label='Passed Responses', density=True, color='lightcoral') + plt.xlabel('Response Length (tokens)', fontsize=12) + plt.ylabel('Density', fontsize=12) + plt.title(f'Response Length Distribution (Linear Scale) - {model_name}', fontsize=13, fontweight='bold') + plt.legend(fontsize=10) + plt.grid(True, alpha=0.3) + + # CDF comparison + plt.subplot(2, 2, 3) + sorted_all = np.sort(all_lengths) + cdf_all = np.arange(1, len(sorted_all) + 1) / len(sorted_all) + plt.plot(sorted_all, cdf_all, label='All Responses', linewidth=2, color='steelblue') + + if passed_lengths: + sorted_passed = np.sort(passed_lengths) + cdf_passed = np.arange(1, len(sorted_passed) + 1) / len(sorted_passed) + plt.plot(sorted_passed, cdf_passed, label='Passed Responses', linewidth=2, color='crimson') + + plt.xlabel('Response Length (tokens)', fontsize=12) + plt.ylabel('Cumulative Probability', fontsize=12) + plt.title(f'Response Length CDF - {model_name}', fontsize=13, fontweight='bold') + plt.xscale('log') + plt.legend(fontsize=10) + plt.grid(True, alpha=0.3) + + # Statistics comparison + plt.subplot(2, 2, 4) + stats_text_all = (f'All Responses:\n' + f'Mean: {np.mean(all_lengths):.1f}\n' + f'Median: {np.median(all_lengths):.1f}\n' + f'Count: {len(all_lengths)}') + + if passed_lengths: + stats_text_passed = (f'Passed Responses:\n' + f'Mean: {np.mean(passed_lengths):.1f}\n' + f'Median: {np.median(passed_lengths):.1f}\n' + f'Count: {len(passed_lengths)}') + stats_text = f'{stats_text_all}\n\n{stats_text_passed}' + else: + stats_text = stats_text_all + + plt.text(0.1, 0.5, stats_text, transform=plt.gca().transAxes, + bbox=dict(boxstyle="round,pad=0.5", facecolor='lightgray', alpha=0.8), + fontsize=11, verticalalignment='center') + plt.axis('off') + plt.title(f'Response Length Statistics - {model_name}', fontsize=13, fontweight='bold') + + plt.tight_layout() + + # Save plot + output_file = output_dir / f'{model_name}_response_length_distribution.png' + plt.savefig(output_file, dpi=300, bbox_inches='tight') + plt.close() + + print(f"Response length distribution saved to: {output_file}") + + # Print statistics + print(f"Response Length Statistics:") + print(f" All Responses - Mean: {np.mean(all_lengths):.1f}, Median: {np.median(all_lengths):.1f}, Count: {len(all_lengths)}") + if passed_lengths: + print(f" Passed Responses - Mean: {np.mean(passed_lengths):.1f}, Median: {np.median(passed_lengths):.1f}, Count: {len(passed_lengths)}") + + +def plot_pass_rate_vs_length(df: pd.DataFrame, output_dir: Path, model_name: str): + """Plot pass rate vs average response length correlation.""" + pass_rates = [] + avg_lengths = [] + filtered_entries = 0 + total_entries = 0 + + for _, row in df.iterrows(): + if ('pass_rate' in row and 'response_lengths' in row and + row['pass_rate'] is not None and row['response_lengths'] is not None): + lengths = row['response_lengths'] + if isinstance(lengths, (list, np.ndarray)) and len(lengths) > 0: + total_entries += 1 + valid_lengths = [l for l in lengths if 0 < l <= 32000] + filtered_lengths = [l for l in lengths if l > 32000] + + if filtered_lengths: + filtered_entries += 1 + + if valid_lengths: + pass_rates.append(row['pass_rate']) + avg_lengths.append(np.mean(valid_lengths)) + + if filtered_entries > 0: + print(f"Warning: {filtered_entries} entries had responses longer than 32,000 tokens ({filtered_entries/total_entries*100:.1f}% of entries)") + + if not pass_rates: + print("No data for pass rate vs length correlation") + return + + plt.figure(figsize=(12, 6)) + + # Scatter plot + plt.subplot(1, 2, 1) + plt.scatter(avg_lengths, pass_rates, alpha=0.6, color='steelblue', s=30) + plt.xlabel('Average Response Length (tokens)', fontsize=12) + plt.ylabel('Pass Rate', fontsize=12) + plt.title(f'Pass Rate vs Average Response Length - {model_name}', fontsize=14, fontweight='bold') + plt.grid(True, alpha=0.3) + + # Correlation coefficient + correlation = np.corrcoef(avg_lengths, pass_rates)[0, 1] + plt.text(0.05, 0.95, f'Correlation: {correlation:.3f}', + transform=plt.gca().transAxes, bbox=dict(boxstyle="round", facecolor='wheat', alpha=0.8), + fontsize=11) + + # Hexbin plot for density + plt.subplot(1, 2, 2) + plt.hexbin(avg_lengths, pass_rates, gridsize=25, cmap='Blues', mincnt=1) + plt.xlabel('Average Response Length (tokens)', fontsize=12) + plt.ylabel('Pass Rate', fontsize=12) + plt.title(f'Pass Rate vs Length Density - {model_name}', fontsize=14, fontweight='bold') + cbar = plt.colorbar(label='Count') + cbar.ax.tick_params(labelsize=10) + + # Add statistics to the density plot + stats_text = (f'Length Mean: {np.mean(avg_lengths):.1f}\n' + f'Pass Rate Mean: {np.mean(pass_rates):.3f}\n' + f'Sample Count: {len(pass_rates)}') + plt.text(0.02, 0.98, stats_text, transform=plt.gca().transAxes, + bbox=dict(boxstyle="round,pad=0.3", facecolor='white', alpha=0.8), + verticalalignment='top', fontsize=10) + + plt.tight_layout() + + # Save plot + output_file = output_dir / f'{model_name}_pass_rate_vs_length.png' + plt.savefig(output_file, dpi=300, bbox_inches='tight') + plt.close() + + print(f"Pass rate vs length correlation saved to: {output_file}") + print(f"Correlation between pass rate and average response length: {correlation:.3f}") + + +def collect_file_statistics(df: pd.DataFrame, file_name: str) -> Dict[str, Any]: + """Collect statistics for a single file.""" + stats = {"file_name": file_name} + + # Extract data + pass_rates = extract_pass_rates(df) + response_lengths = extract_response_lengths(df) + + # Pass rate statistics + if pass_rates: + stats.update({ + "pass_rate_mean": np.mean(pass_rates), + "pass_rate_median": np.median(pass_rates), + "pass_rate_std": np.std(pass_rates), + "pass_rate_min": np.min(pass_rates), + "pass_rate_max": np.max(pass_rates), + "pass_rate_count": len(pass_rates) + }) + else: + stats.update({ + "pass_rate_mean": 0, "pass_rate_median": 0, "pass_rate_std": 0, + "pass_rate_min": 0, "pass_rate_max": 0, "pass_rate_count": 0 + }) + + # Response length statistics + if response_lengths: + stats.update({ + "length_min": np.min(response_lengths), + "length_max": np.max(response_lengths), + "length_25th": np.percentile(response_lengths, 25), + "length_75th": np.percentile(response_lengths, 75), + "length_mean": np.mean(response_lengths), + "length_median": np.median(response_lengths), + "length_count": len(response_lengths) + }) + else: + stats.update({ + "length_min": 0, "length_max": 0, "length_25th": 0, + "length_75th": 0, "length_mean": 0, "length_median": 0, "length_count": 0 + }) + + return stats + + +def print_summary_table(all_file_stats: List[Dict[str, Any]]): + """Print a comprehensive summary table of all processed files.""" + if not all_file_stats: + print("No files processed for summary.") + return + + print("\n" + "="*120) + print("SUMMARY STATISTICS FOR ALL FILES") + print("="*120) + + # Pass Rate Summary Table + print("\nPASS RATE STATISTICS:") + print("-" * 100) + header = f"{'File Name':<35} {'Mean':<8} {'Median':<8} {'Std':<8} {'Min':<8} {'Max':<8} {'Count':<8}" + print(header) + print("-" * 100) + + for stats in all_file_stats: + row = (f"{stats['file_name'][:33]:<35} " + f"{stats['pass_rate_mean']:<8.3f} " + f"{stats['pass_rate_median']:<8.3f} " + f"{stats['pass_rate_std']:<8.3f} " + f"{stats['pass_rate_min']:<8.3f} " + f"{stats['pass_rate_max']:<8.3f} " + f"{stats['pass_rate_count']:<8}") + print(row) + + # Response Length Summary Table + print("\nRESPONSE LENGTH STATISTICS:") + print("-" * 110) + header = f"{'File Name':<35} {'Min':<8} {'25%':<8} {'Mean':<8} {'Median':<8} {'75%':<8} {'Max':<8} {'Count':<8}" + print(header) + print("-" * 110) + + for stats in all_file_stats: + row = (f"{stats['file_name'][:33]:<35} " + f"{stats['length_min']:<8.0f} " + f"{stats['length_25th']:<8.0f} " + f"{stats['length_mean']:<8.0f} " + f"{stats['length_median']:<8.0f} " + f"{stats['length_75th']:<8.0f} " + f"{stats['length_max']:<8.0f} " + f"{stats['length_count']:<8}") + print(row) + + # Overall Summary + print("\nOVERALL SUMMARY:") + print("-" * 60) + + # Aggregate pass rate statistics + all_pass_rates = [] + all_lengths = [] + total_files = len(all_file_stats) + + for stats in all_file_stats: + if stats['pass_rate_count'] > 0: + # Weight by count for proper aggregation + all_pass_rates.extend([stats['pass_rate_mean']] * stats['pass_rate_count']) + if stats['length_count'] > 0: + all_lengths.extend([stats['length_mean']] * stats['length_count']) + + if all_pass_rates: + overall_pass_rate_mean = np.mean(all_pass_rates) + overall_pass_rate_std = np.std(all_pass_rates) + print(f"Overall Pass Rate - Mean: {overall_pass_rate_mean:.3f}, Std: {overall_pass_rate_std:.3f}") + + if all_lengths: + overall_length_mean = np.mean(all_lengths) + overall_length_std = np.std(all_lengths) + print(f"Overall Length - Mean: {overall_length_mean:.0f}, Std: {overall_length_std:.0f} tokens") + + print(f"Total Files Processed: {total_files}") + print("="*120) + + +def main(): + parser = argparse.ArgumentParser(description="Visualize pass rate and response length distributions") + + parser.add_argument("input_path", help="Directory or parquet file path with scored data") + parser.add_argument("--output_dir", default="./figures", help="Output directory for figures") + parser.add_argument("--model_name", default="model", help="Model name for plot titles") + parser.add_argument("--pattern", default="*.parquet", help="File pattern for directories") + parser.add_argument("--recursive", "-r", action="store_true", help="Recursive search") + parser.add_argument("--correct_reward_threshold", type=float, default=1.0, help="Pass threshold") + + args = parser.parse_args() + + # Setup paths + input_path = Path(args.input_path) + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + if not input_path.exists(): + print(f"Input path does not exist: {args.input_path}") + return + + # Find files + if input_path.is_file(): + parquet_files = [input_path] + else: + if args.recursive: + parquet_files = list(input_path.rglob(args.pattern)) + else: + parquet_files = list(input_path.glob(args.pattern)) + + if not parquet_files: + print(f"No files matching '{args.pattern}' found") + return + + parquet_files.sort() + + print(f"Found {len(parquet_files)} parquet files") + + # Set style + plt.style.use('default') + sns.set_palette("husl") + + # Process each file individually + processed_files = 0 + all_file_stats = [] + + for file_path in parquet_files: + print(f"\nProcessing: {file_path.name}") + df = read_parquet_file(str(file_path)) + + if df.empty: + print(f" Skipping empty file: {file_path.name}") + continue + + # Extract data for this file + pass_rates = extract_pass_rates(df) + response_lengths = extract_response_lengths(df) + passed_lengths = extract_passed_response_lengths(df, args.correct_reward_threshold) + + print(f" Found {len(pass_rates)} pass rates, {len(response_lengths)} response lengths") + + if not pass_rates and not response_lengths: + print(f" No data found in {file_path.name}") + continue + + # Collect statistics for this file + file_stats = collect_file_statistics(df, file_path.name) + all_file_stats.append(file_stats) + + # Create file-specific model name + file_stem = file_path.stem + file_model_name = f"{args.model_name}_{file_stem}" if args.model_name != "model" else file_stem + + # Create file-specific output directory + file_output_dir = output_dir / file_stem + file_output_dir.mkdir(parents=True, exist_ok=True) + + print(f" Generating figures for {file_path.name}...") + + # Generate visualizations for this file + if pass_rates: + plot_pass_rate_distribution(pass_rates, file_output_dir, file_model_name) + + if response_lengths: + plot_length_distribution(response_lengths, passed_lengths, file_output_dir, file_model_name) + + if pass_rates and response_lengths: + plot_pass_rate_vs_length(df, file_output_dir, file_model_name) + + processed_files += 1 + print(f" Figures for {file_path.name} saved to: {file_output_dir}") + + print(f"\nProcessed {processed_files}/{len(parquet_files)} files successfully") + print(f"All figures saved under: {output_dir}") + + # Print summary table for all files + if all_file_stats: + print_summary_table(all_file_stats) + + +if __name__ == "__main__": + main() diff --git a/scripts/train/example_multinode_rl_llama3.1_70b_distill_fsdp.sh b/scripts/train/example_multinode_rl_llama3.1_70b_distill_fsdp.sh new file mode 100644 index 000000000..2c16eddf9 --- /dev/null +++ b/scripts/train/example_multinode_rl_llama3.1_70b_distill_fsdp.sh @@ -0,0 +1,269 @@ +#!/bin/bash +#SBATCH --job-name=example-multinode-rl-llama3.1-70b-distill-fsdp +#SBATCH --nodes=32 +#SBATCH --ntasks=32 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:8 +#SBATCH --cpus-per-task=128 +#SBATCH --mem=0 +#SBATCH --output=slurm/%x-%j.out +#SBATCH --error=slurm/%x-%j.err +#SBATCH --exclusive +#SBATCH --time=720:00:00 + + +# =================== Frequently Used Variables =================== +RESUME_CKPT_DIR_NAME="" # Fill in the checkpoint directory name to resume from, otherwise from scratch +export STEM_LLM_JUDGE_URL="" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain + +# =================== Cluster Environment =================== +export NCCL_DEBUG=info +export NCCL_ALGO=NVLSTree +export NCCL_IBEXT_DISABLE=1 +export NCCL_NVLS_ENABLE=1 +export NCCL_IB_HCA=mlx5 +export UCX_NET_DEVICES=mlx5_0:1,mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1 +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export CUDA_LAUNCH_BLOCKING=1 + +# Get the list of allocated nodes +nodes=( $(scontrol show hostnames "$SLURM_JOB_NODELIST") ) +echo "Nodes to check: ${nodes[@]}" + +# We'll track PIDs so we can wait on them and detect errors +declare -A pids +export head_node=${nodes[0]} +head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) +port=6379 +address_head=$head_node_ip:$port + +export worker_num=$SLURM_NNODES +export HYDRA_FULL_ERROR=1 +export VLLM_USE_V1=0 + +# =================== Data Mixture =================== +SHARED_DATA_PATH=./data +TRAIN_DATA_DIR=${SHARED_DATA_PATH}/train/ +TEST_DATA_DIR=${SHARED_DATA_PATH}/offline_eval/ + +# Math (train) +math_train_path=${TRAIN_DATA_DIR}/math__combined_54.4k.parquet +# Math (test) +math_test_path=${TEST_DATA_DIR}/math__math_500.parquet +aime_test_path=${TEST_DATA_DIR}/math__aime_repeated_8x_240.parquet +amc_test_path=${TEST_DATA_DIR}/math__amc_repeated_4x_332.parquet + +# Code (train) +leetcode_train_path=${TRAIN_DATA_DIR}/codegen__leetcode2k_1.3k.parquet +livecodebench_train_path=${TRAIN_DATA_DIR}/codegen__livecodebench_440.parquet +primeintellect_train_path=${TRAIN_DATA_DIR}/codegen__primeintellect_7.5k.parquet +taco_train_path=${TRAIN_DATA_DIR}/codegen__taco_8.8k.parquet +# Code (test) +humaneval_test_path=${TEST_DATA_DIR}/codegen__humaneval_164.parquet +mbpp_test_path=${TEST_DATA_DIR}/codegen__mbpp_500.parquet +livecodebench_test_path=${TEST_DATA_DIR}/codegen__livecodebench_279.parquet + +# Logic (train) +arcagi1_train_path=${TRAIN_DATA_DIR}/logic__arcagi1_111.parquet +arcagi2_train_path=${TRAIN_DATA_DIR}/logic__arcagi2_190.parquet +barc_train_path=${TRAIN_DATA_DIR}/logic__barc_1.6k.parquet +graph_train_path=${TRAIN_DATA_DIR}/logic__graph_logical_1.2k.parquet +ordering_train_path=${TRAIN_DATA_DIR}/logic__ordering_puzzle_1.9k.parquet +zebra_train_path=${TRAIN_DATA_DIR}/logic__zebra_puzzle_1.3k.parquet +# Logic (test) +zebralogic_test_path=${TEST_DATA_DIR}/logic__zebra_puzzle_dataset_300.parquet +ordering_puzzle_test_path=${TEST_DATA_DIR}/logic__ordering_puzzle_dataset_150.parquet + +# Simulation (train) +codeio_train_path=${TRAIN_DATA_DIR}/simulation__codeio_3.7k.parquet +# Simulation (test) +codeio_test_path=${TEST_DATA_DIR}/simulation__codeio_500.parquet +arcagi1_test_path=${TEST_DATA_DIR}/simulation__arcagi1_200.parquet + +# Table (train) +hitab_train_path=${TRAIN_DATA_DIR}/table__hitab_4.3k.parquet +multihier_train_path=${TRAIN_DATA_DIR}/table__multihier_1.5k.parquet +# Table (test) +multihier_test_path=${TEST_DATA_DIR}/table__multihier_300.parquet +hitab_test_path=${TEST_DATA_DIR}/table__hitab_300.parquet + +# Stem (train) +webinstruct_train_path=${TRAIN_DATA_DIR}/stem__web_3.6k.parquet +# Stem (test) +gpqa_diamond_test_path=${TEST_DATA_DIR}/stem__gpqa_diamond_198.parquet +supergpqa_test_path=${TEST_DATA_DIR}/stem__supergpqa_200.parquet + +train_files="['${math_train_path}']" # Use math as example, add to more tasks as needed +test_files="['${math_test_path}','${aime_test_path}']" # Use math as example, add to more tasks as needed + +# =================== Model =================== +BASE_MODEL=deepseek-ai/DeepSeek-R1-Distill-Llama-70B + +# =================== Logging =================== +WANDB_PROJECT=Reasoning360 +WANDB_EXPERIMENT_NAME=${SLURM_JOB_ID}-${SLURM_JOB_NAME}-${BASE_MODEL##*/} + +# If RESUME_CKPT_DIR is not empty, resume from the checkpoint +if [[ -n "$RESUME_CKPT_DIR_NAME" ]]; then + WANDB_EXPERIMENT_NAME="$RESUME_CKPT_DIR_NAME" +fi + + +# =================== Ray start =================== +# ray stop at all nodes +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 ${CONDA_BIN_PATH}ray stop + +sleep 10 +# Remove existing Ray cluster +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 rm -rf /tmp/ray/ray_current_cluster + +# Start Ray head node +srun --nodes=1 --ntasks=1 -w "$head_node" --export=ALL \ + ${CONDA_BIN_PATH}ray start --head --node-ip-address="$head_node_ip" --port=$port \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --include-dashboard=True --block & + +sleep 10 + +# Start Ray worker nodes +for ((i = 1; i < worker_num; i++)); do + node_i=${nodes[$i]} + echo "Starting WORKER $i at $node_i" + srun --nodes=1 --ntasks=1 -w "$node_i" --export=ALL \ + ${CONDA_BIN_PATH}ray start --address "$address_head" \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --block & +done +sleep 10 + + +# =================== RL Config =================== +# Note, we borrowed the config format from DAPO while here disabled all DAPO features to run the naive RL baseline. + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.2 + +max_prompt_length=$((1024 * 4)) +max_response_length=$((1024 * 32)) +enable_overlong_buffer=False +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +enable_filter_groups=False +filter_groups_metric=acc +max_num_gen_batches=10 +train_prompt_bsz=256 # on-policy model update batchsize: train_prompt_bsz * rollout.n +gen_prompt_bsz=$((train_prompt_bsz * 1)) +n_resp_per_prompt=16 +train_prompt_mini_bsz=32 # model grad update batchsize + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout + +# Training config +sp_size=4 +gen_tp=4 +gen_max_num_seqs=1024 +infer_micro_batch_size=null +train_micro_batch_size=null +use_dynamic_bsz=True +actor_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up model forward & backward but note memory overflow +infer_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up modelforward, but note memory overflow +offload=True + +# =================== Start RL training =================== +"${CONDA_BIN_PATH}python" -m recipe.dapo.main_dapo \ + --config-path=config \ + --config-name="dapo_fsdp_config.yaml" \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.prompt_key=prompt \ + data.truncation='right' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + data.gen_batch_size=${gen_prompt_bsz} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.actor.strategy="fsdp" \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.optim.warmup_style=constant \ + actor_rollout_ref.actor.optim.min_lr_ratio=0. \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.max_num_seqs=${gen_max_num_seqs} \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${top_p}\ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.model.path=$BASE_MODEL \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.rollout.multi_turn.enable=False \ + actor_rollout_ref.rollout.mode="sync" \ + +actor_rollout_ref.model.override_config.attention_dropout=0. \ + +actor_rollout_ref.model.override_config.embd_pdrop=0. \ + +actor_rollout_ref.model.override_config.resid_pdrop=0. \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + reward_model.reward_manager=async_multi_process \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.logger=['console','wandb'] \ + trainer.project_name=${WANDB_PROJECT} \ + trainer.experiment_name=${WANDB_EXPERIMENT_NAME} \ + trainer.val_before_train=True \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=$worker_num \ + trainer.save_freq=10 \ + trainer.test_freq=10 \ + trainer.total_epochs=5 \ + trainer.log_val_generations=50 \ + trainer.resume_mode=auto \ + trainer.max_actor_ckpt_to_keep=2 \ No newline at end of file diff --git a/scripts/train/example_multinode_rl_llama3.1_70b_distill_megatron.sh b/scripts/train/example_multinode_rl_llama3.1_70b_distill_megatron.sh new file mode 100644 index 000000000..584b58197 --- /dev/null +++ b/scripts/train/example_multinode_rl_llama3.1_70b_distill_megatron.sh @@ -0,0 +1,291 @@ +#!/bin/bash +#SBATCH --job-name=example-multinode-rl-llama3.1-70b-distill-megatron +#SBATCH --nodes=32 +#SBATCH --ntasks=32 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:8 +#SBATCH --cpus-per-task=128 +#SBATCH --mem=0 +#SBATCH --output=slurm/%x-%j.out +#SBATCH --error=slurm/%x-%j.err +#SBATCH --exclusive +#SBATCH --time=720:00:00 + + +# =================== Frequently Used Variables =================== +RESUME_CKPT_DIR_NAME="" +export STEM_LLM_JUDGE_URL="" + +# =================== Environment =================== +export NCCL_DEBUG=info +export NCCL_ALGO=NVLSTree +export NCCL_IBEXT_DISABLE=1 +export NCCL_NVLS_ENABLE=1 +export NCCL_IB_HCA=mlx5 +export UCX_NET_DEVICES=mlx5_0:1,mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1 +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export CUDA_LAUNCH_BLOCKING=1 + + +# Get the list of allocated nodes +nodes=( $(scontrol show hostnames "$SLURM_JOB_NODELIST") ) +echo "Nodes to check: ${nodes[@]}" + +# We'll track PIDs so we can wait on them and detect errors +declare -A pids + +export head_node=${nodes[0]} +head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) +port=6379 +address_head=$head_node_ip:$port + +export worker_num=$SLURM_NNODES +export HYDRA_FULL_ERROR=1 +export VLLM_USE_V1=0 +export RAY_record_ref_creation_sites=1 # NOTE(yonghao): DEBUG code +# export GLOO_SOCKET_IFNAME=ens10f0np0 + + +# =================== Data Mixture =================== +SHARED_DATA_PATH=./data +TRAIN_DATA_DIR=${SHARED_DATA_PATH}/train/ +TEST_DATA_DIR=${SHARED_DATA_PATH}/offline_eval/ + +# Math (train) +math_train_path=${TRAIN_DATA_DIR}/math__combined_54.4k.parquet +# Math (test) +math_test_path=${TEST_DATA_DIR}/math__math_500.parquet +aime_test_path=${TEST_DATA_DIR}/math__aime_repeated_8x_240.parquet +amc_test_path=${TEST_DATA_DIR}/math__amc_repeated_4x_332.parquet + +# Code (train) +leetcode_train_path=${TRAIN_DATA_DIR}/codegen__leetcode2k_1.3k.parquet +livecodebench_train_path=${TRAIN_DATA_DIR}/codegen__livecodebench_440.parquet +primeintellect_train_path=${TRAIN_DATA_DIR}/codegen__primeintellect_7.5k.parquet +taco_train_path=${TRAIN_DATA_DIR}/codegen__taco_8.8k.parquet +# Code (test) +humaneval_test_path=${TEST_DATA_DIR}/codegen__humaneval_164.parquet +mbpp_test_path=${TEST_DATA_DIR}/codegen__mbpp_500.parquet +livecodebench_test_path=${TEST_DATA_DIR}/codegen__livecodebench_279.parquet + +# Logic (train) +arcagi1_train_path=${TRAIN_DATA_DIR}/logic__arcagi1_111.parquet +arcagi2_train_path=${TRAIN_DATA_DIR}/logic__arcagi2_190.parquet +barc_train_path=${TRAIN_DATA_DIR}/logic__barc_1.6k.parquet +graph_train_path=${TRAIN_DATA_DIR}/logic__graph_logical_1.2k.parquet +ordering_train_path=${TRAIN_DATA_DIR}/logic__ordering_puzzle_1.9k.parquet +zebra_train_path=${TRAIN_DATA_DIR}/logic__zebra_puzzle_1.3k.parquet +# Logic (test) +zebralogic_test_path=${TEST_DATA_DIR}/logic__zebra_puzzle_dataset_300.parquet +ordering_puzzle_test_path=${TEST_DATA_DIR}/logic__ordering_puzzle_dataset_150.parquet + +# Simulation (train) +codeio_train_path=${TRAIN_DATA_DIR}/simulation__codeio_3.7k.parquet +# Simulation (test) +codeio_test_path=${TEST_DATA_DIR}/simulation__codeio_500.parquet +arcagi1_test_path=${TEST_DATA_DIR}/simulation__arcagi1_200.parquet + +# Table (train) +hitab_train_path=${TRAIN_DATA_DIR}/table__hitab_4.3k.parquet +multihier_train_path=${TRAIN_DATA_DIR}/table__multihier_1.5k.parquet +# Table (test) +multihier_test_path=${TEST_DATA_DIR}/table__multihier_300.parquet +hitab_test_path=${TEST_DATA_DIR}/table__hitab_300.parquet + +# Stem (train) +webinstruct_train_path=${TRAIN_DATA_DIR}/stem__web_3.6k.parquet +# Stem (test) +gpqa_diamond_test_path=${TEST_DATA_DIR}/stem__gpqa_diamond_198.parquet +supergpqa_test_path=${TEST_DATA_DIR}/stem__supergpqa_200.parquet + +train_files="['${math_train_path}']" # Use math as example, add to more tasks as needed +test_files="['${math_test_path}','${aime_test_path}']" # Use math as example, add to more tasks as needed + +# =================== Model =================== +BASE_MODEL="deepseek-ai/DeepSeek-R1-Distill-Llama-70B" + +# =================== Logging =================== +WANDB_PROJECT=Reasoning360 +WANDB_EXPERIMENT_NAME=${SLURM_JOB_ID}-${SLURM_JOB_NAME}-${BASE_MODEL##*/} + +# If RESUME_CKPT_DIR is not empty, resume from the checkpoint +if [[ -n "$RESUME_CKPT_DIR_NAME" ]]; then + WANDB_EXPERIMENT_NAME="$RESUME_CKPT_DIR_NAME" +fi + + +# =================== Ray start =================== +# ray stop at all nodes +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 ${CONDA_BIN_MEGATRON_PATH}ray stop + +sleep 10 +# Remove existing Ray cluster +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 rm -rf /tmp/ray/ray_current_cluster + +# Start Ray head node +srun --nodes=1 --ntasks=1 -w "$head_node" --export=ALL \ + ${CONDA_BIN_MEGATRON_PATH}ray start --head --node-ip-address="$head_node_ip" --port=$port \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --include-dashboard=True --block & + +sleep 10 + +# Start Ray worker nodes +for ((i = 1; i < worker_num; i++)); do + node_i=${nodes[$i]} + echo "Starting WORKER $i at $node_i" + srun --nodes=1 --ntasks=1 -w "$node_i" --export=ALL \ + ${CONDA_BIN_MEGATRON_PATH}ray start --address "$address_head" \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --block & +done +sleep 10 + + +# =================== RL Config =================== +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.2 + +max_prompt_length=$((1024 * 4)) +max_response_length=$((1024 * 32)) +enable_overlong_buffer=False +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +enable_filter_groups=False +filter_groups_metric=acc +max_num_gen_batches=16 +train_prompt_bsz=256 # grad accum bsz; real grad accum bsz: train_prompt_bsz * rollout.n +gen_prompt_bsz=$((train_prompt_bsz * 1)) # rollout bsz, i.e., the x-axis in RL plot +n_resp_per_prompt=16 +train_prompt_mini_bsz=8 + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout + +# Generation config +gen_tp=4 +gen_max_num_seqs=1024 + +# Megatron trainer config +train_tp=8 +train_pp=2 +sp_size=8 +offload=True + +# Batch size +use_dynamic_bsz=True +train_micro_batch_size=null +train_micro_batch_size_per_gpu_placeholder=1 # can't be null, as in ray_trainer.py ```minimal_bsz = megatron_dp * config.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu``` +infer_micro_batch_size_per_gpu_placeholder=8 # can't be null, as in megatron_worker.py ```assert self.config.ref.get("log_prob_micro_batch_size_per_gpu", None) is not None, "Please note that in the ref policy configuration, `log_prob_micro_batch_size_per_gpu` and `log_prob_micro_batch_size` should not be None at the same time."``` +# NOTE: this one is for per gpu, so it times sp_size (defined later) +# actor_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1 )) +actor_ppo_max_token_len=8192 +infer_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1 )) + + +# NOTE(yonghao): all other parts (weights, optimizer states) exists across stages (training, generation) +# while this one only lives during a training iteration. +grad_offload=True +#### + +# =================== Start RL training =================== +"${CONDA_BIN_MEGATRON_PATH}python" -m recipe.dapo.main_dapo \ + --config-path=config \ + --config-name="dapo_megatron_config.yaml" \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.prompt_key=prompt \ + data.truncation='right' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + data.gen_batch_size=${gen_prompt_bsz} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.actor.strategy="megatron" \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.lr_warmup_init=0.0 \ + actor_rollout_ref.actor.optim.lr=5e-7 \ + actor_rollout_ref.actor.optim.lr_decay_style=constant \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.optim.min_lr=0. \ + actor_rollout_ref.actor.optim.clip_grad=1.0 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_micro_batch_size_per_gpu_placeholder} \ + actor_rollout_ref.actor.megatron.param_offload=${offload} \ + actor_rollout_ref.actor.megatron.optimizer_offload=${offload} \ + actor_rollout_ref.actor.megatron.grad_offload=${grad_offload} \ + actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${train_pp} \ + actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${train_tp} \ + actor_rollout_ref.actor.megatron.context_parallel_size=${sp_size} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=${infer_micro_batch_size_per_gpu_placeholder} \ + actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${train_pp} \ + actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${train_tp} \ + actor_rollout_ref.ref.megatron.context_parallel_size=${sp_size} \ + actor_rollout_ref.ref.megatron.param_offload=${offload} \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.65 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${infer_micro_batch_size_per_gpu_placeholder} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.max_num_seqs=${gen_max_num_seqs} \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${top_p}\ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.model.path=$BASE_MODEL \ + +actor_rollout_ref.model.use_remove_padding=True \ + +actor_rollout_ref.model.override_config.attention_dropout=0. \ + +actor_rollout_ref.model.override_config.embd_pdrop=0. \ + +actor_rollout_ref.model.override_config.resid_pdrop=0. \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + reward_model.reward_manager=async_multi_process \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.logger=['console','wandb'] \ + trainer.project_name=${WANDB_PROJECT} \ + trainer.experiment_name=${WANDB_EXPERIMENT_NAME} \ + trainer.val_before_train=False \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=$worker_num \ + trainer.save_freq=-1 \ + trainer.test_freq=10 \ + trainer.total_epochs=5 \ + trainer.log_val_generations=50 \ + trainer.resume_mode=auto \ + trainer.max_actor_ckpt_to_keep=2 \ No newline at end of file diff --git a/scripts/train/example_multinode_rl_qwen2.5_32b_base_fsdp.sh b/scripts/train/example_multinode_rl_qwen2.5_32b_base_fsdp.sh new file mode 100644 index 000000000..043ed5179 --- /dev/null +++ b/scripts/train/example_multinode_rl_qwen2.5_32b_base_fsdp.sh @@ -0,0 +1,270 @@ +#!/bin/bash +#SBATCH --job-name=example-multinode-rl-qwen2.5-32b-base-fsdp +#SBATCH --nodes=8 +#SBATCH --ntasks=8 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:8 +#SBATCH --cpus-per-task=96 +#SBATCH --mem=0 +#SBATCH --output=slurm/%x-%j.out +#SBATCH --error=slurm/%x-%j.err +#SBATCH --exclusive +#SBATCH --time=720:00:00 + + +# =================== Frequently Used Variables =================== +RESUME_CKPT_DIR_NAME="" # Fill in the checkpoint directory name to resume from, otherwise from scratch +export STEM_LLM_JUDGE_URL="http://10.24.1.81:8000" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain + +# =================== Cluster Environment =================== +export NCCL_DEBUG=info +export NCCL_ALGO=NVLSTree +export NCCL_IBEXT_DISABLE=1 +export NCCL_NVLS_ENABLE=1 +export NCCL_IB_HCA=mlx5 +export UCX_NET_DEVICES=mlx5_0:1,mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1 +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export CUDA_LAUNCH_BLOCKING=1 + +# Get the list of allocated nodes +nodes=( $(scontrol show hostnames "$SLURM_JOB_NODELIST") ) +echo "Nodes to check: ${nodes[@]}" + +# We'll track PIDs so we can wait on them and detect errors +declare -A pids +export head_node=${nodes[0]} +head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) +port=6379 +address_head=$head_node_ip:$port + +export worker_num=$SLURM_NNODES +export HYDRA_FULL_ERROR=1 +export VLLM_USE_V1=0 + +# =================== Data Mixture =================== +# SHARED_DATA_PATH=./data +SHARED_DATA_PATH=/mnt/sharefs/users/chengqian.gao/guru +TRAIN_DATA_DIR=${SHARED_DATA_PATH}/train +TEST_DATA_DIR=${SHARED_DATA_PATH}/offline_eval + +# Math (train) +math_train_path=${TRAIN_DATA_DIR}/math__combined_5k.parquet +# Math (test) +math_test_path=${TEST_DATA_DIR}/math__math_500.parquet +aime_test_path=${TEST_DATA_DIR}/math__aime_repeated_8x_240.parquet +amc_test_path=${TEST_DATA_DIR}/math__amc_repeated_4x_332.parquet + +# Code (train) +leetcode_train_path=${TRAIN_DATA_DIR}/codegen__leetcode2k_1.3k.parquet +livecodebench_train_path=${TRAIN_DATA_DIR}/codegen__livecodebench_440.parquet +primeintellect_train_path=${TRAIN_DATA_DIR}/codegen__primeintellect_7.5k.parquet +taco_train_path=${TRAIN_DATA_DIR}/codegen__taco_8.8k.parquet +# Code (test) +humaneval_test_path=${TEST_DATA_DIR}/codegen__humaneval_164.parquet +mbpp_test_path=${TEST_DATA_DIR}/codegen__mbpp_500.parquet +livecodebench_test_path=${TEST_DATA_DIR}/codegen__livecodebench_279.parquet + +# Logic (train) +arcagi1_train_path=${TRAIN_DATA_DIR}/logic__arcagi1_111.parquet +arcagi2_train_path=${TRAIN_DATA_DIR}/logic__arcagi2_190.parquet +barc_train_path=${TRAIN_DATA_DIR}/logic__barc_1.6k.parquet +graph_train_path=${TRAIN_DATA_DIR}/logic__graph_logical_1.2k.parquet +ordering_train_path=${TRAIN_DATA_DIR}/logic__ordering_puzzle_1.9k.parquet +zebra_train_path=${TRAIN_DATA_DIR}/logic__zebra_puzzle_1.3k.parquet +# Logic (test) +zebralogic_test_path=${TEST_DATA_DIR}/logic__zebra_puzzle_dataset_200.parquet +ordering_puzzle_test_path=${TEST_DATA_DIR}/logic__ordering_puzzle_dataset_150.parquet + +# Simulation (train) +codeio_train_path=${TRAIN_DATA_DIR}/simulation__codeio_3.7k.parquet +# Simulation (test) +codeio_test_path=${TEST_DATA_DIR}/simulation__codeio_200.parquet +arcagi1_test_path=${TEST_DATA_DIR}/logic__arcagi1_400.parquet + +# Table (train) +hitab_train_path=${TRAIN_DATA_DIR}/table__hitab_4.3k.parquet +multihier_train_path=${TRAIN_DATA_DIR}/table__multihier_1.5k.parquet +# Table (test) +multihier_test_path=${TEST_DATA_DIR}/table__multihier_336.parquet +hitab_test_path=${TEST_DATA_DIR}/table__hitab_1k.parquet + +# Stem (train) +webinstruct_train_path=${TRAIN_DATA_DIR}/stem__web_3.6k.parquet +# Stem (test) +gpqa_diamond_test_path=${TEST_DATA_DIR}/stem__gpqa_diamond_198.parquet +supergpqa_test_path=${TEST_DATA_DIR}/stem__supergpqa_1k.parquet + +train_files="['${math_train_path}']" # Use math as example, add to more tasks as needed +test_files="['${math_test_path}','${aime_test_path}']" # Use math as example, add to more tasks as needed + +# =================== Model =================== +BASE_MODEL=Qwen/Qwen2.5-32B + +# =================== Logging =================== +WANDB_PROJECT=Difficulty-Aware-RL +WANDB_EXPERIMENT_NAME=${SLURM_JOB_ID}-${SLURM_JOB_NAME}-${BASE_MODEL##*/} + +# If RESUME_CKPT_DIR is not empty, resume from the checkpoint +if [[ -n "$RESUME_CKPT_DIR_NAME" ]]; then + WANDB_EXPERIMENT_NAME="$RESUME_CKPT_DIR_NAME" +fi + + +# =================== Ray start =================== +# ray stop at all nodes +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 ${CONDA_BIN_PATH}ray stop + +sleep 10 +# Remove existing Ray cluster +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 rm -rf /tmp/ray/ray_current_cluster + +# Start Ray head node +srun --nodes=1 --ntasks=1 -w "$head_node" --export=ALL \ + ${CONDA_BIN_PATH}ray start --head --node-ip-address="$head_node_ip" --port=$port \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --include-dashboard=True --block & + +sleep 10 + +# Start Ray worker nodes +for ((i = 1; i < worker_num; i++)); do + node_i=${nodes[$i]} + echo "Starting WORKER $i at $node_i" + srun --nodes=1 --ntasks=1 -w "$node_i" --export=ALL \ + ${CONDA_BIN_PATH}ray start --address "$address_head" \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --block & +done +sleep 10 + + +# =================== RL Config =================== +# Note, we borrowed the config format from DAPO while here disabled all DAPO features to run the naive RL baseline. + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.2 + +max_prompt_length=$((1024 * 4)) +max_response_length=$((1024 * 8)) +enable_overlong_buffer=False +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +enable_filter_groups=False +filter_groups_metric=acc +max_num_gen_batches=10 +train_prompt_bsz=256 # on-policy model update batchsize: train_prompt_bsz * rollout.n +gen_prompt_bsz=$((train_prompt_bsz * 1)) +n_resp_per_prompt=16 +train_prompt_mini_bsz=32 # model grad update batchsize + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout + +# Training config +sp_size=1 +gen_tp=2 +gen_max_num_seqs=1024 +infer_micro_batch_size=null +train_micro_batch_size=null +use_dynamic_bsz=True +actor_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 2)) # increase this to speed up model forward & backward but note memory overflow +infer_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 2)) # increase this to speed up modelforward, but note memory overflow +offload=True + +# =================== Start RL training =================== +"${CONDA_BIN_PATH}python" -m recipe.dapo.main_dapo \ + --config-path=config \ + --config-name="dapo_fsdp_config.yaml" \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.prompt_key=prompt \ + data.truncation='right' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + data.gen_batch_size=${gen_prompt_bsz} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.actor.strategy="fsdp" \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.optim.warmup_style=constant \ + actor_rollout_ref.actor.optim.min_lr_ratio=0. \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.max_num_seqs=${gen_max_num_seqs} \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${top_p}\ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.model.path=$BASE_MODEL \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.rollout.multi_turn.enable=False \ + actor_rollout_ref.rollout.mode="sync" \ + +actor_rollout_ref.model.override_config.attention_dropout=0. \ + +actor_rollout_ref.model.override_config.embd_pdrop=0. \ + +actor_rollout_ref.model.override_config.resid_pdrop=0. \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + reward_model.reward_manager=async_multi_process \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.logger=['console','wandb'] \ + trainer.project_name=${WANDB_PROJECT} \ + trainer.experiment_name=${WANDB_EXPERIMENT_NAME} \ + trainer.val_before_train=True \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=$worker_num \ + trainer.save_freq=10 \ + trainer.test_freq=10 \ + trainer.total_epochs=5 \ + trainer.log_val_generations=50 \ + trainer.resume_mode=auto \ + trainer.max_actor_ckpt_to_keep=2 \ No newline at end of file diff --git a/scripts/train/example_multinode_rl_qwen2.5_32b_base_megatron.sh b/scripts/train/example_multinode_rl_qwen2.5_32b_base_megatron.sh new file mode 100644 index 000000000..e00c1bb12 --- /dev/null +++ b/scripts/train/example_multinode_rl_qwen2.5_32b_base_megatron.sh @@ -0,0 +1,288 @@ +#!/bin/bash +#SBATCH --job-name=example-multinode-rl-qwen2.5-32b-base-megatron +#SBATCH --nodes=8 +#SBATCH --ntasks=8 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:8 +#SBATCH --cpus-per-task=96 +#SBATCH --mem=0 +#SBATCH --output=slurm/%x-%j.out +#SBATCH --error=slurm/%x-%j.err +#SBATCH --exclusive +#SBATCH --time=720:00:00 + + +# =================== Frequently Used Variables =================== +RESUME_CKPT_DIR_NAME="" # Fill in the checkpoint directory name to resume from, otherwise from scratch +export STEM_LLM_JUDGE_URL="" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain + +# =================== Cluster Environment =================== +export NCCL_DEBUG=info +export NCCL_ALGO=NVLSTree +export NCCL_IBEXT_DISABLE=1 +export NCCL_NVLS_ENABLE=1 +export NCCL_IB_HCA=mlx5 +export UCX_NET_DEVICES=mlx5_0:1,mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1 +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export CUDA_LAUNCH_BLOCKING=1 + +# Get the list of allocated nodes +nodes=( $(scontrol show hostnames "$SLURM_JOB_NODELIST") ) +echo "Nodes to check: ${nodes[@]}" + +# We'll track PIDs so we can wait on them and detect errors +declare -A pids +export head_node=${nodes[0]} +head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) +port=6379 +address_head=$head_node_ip:$port + +export worker_num=$SLURM_NNODES +export HYDRA_FULL_ERROR=1 +export VLLM_USE_V1=0 + +# =================== Data Mixture =================== +SHARED_DATA_PATH=./data +TRAIN_DATA_DIR=${SHARED_DATA_PATH}/train/ +TEST_DATA_DIR=${SHARED_DATA_PATH}/offline_eval/ + +# Math (train) +math_train_path=${TRAIN_DATA_DIR}/math__combined_54.4k.parquet +# Math (test) +math_test_path=${TEST_DATA_DIR}/math__math_500.parquet +aime_test_path=${TEST_DATA_DIR}/math__aime_repeated_8x_240.parquet +amc_test_path=${TEST_DATA_DIR}/math__amc_repeated_4x_332.parquet + +# Code (train) +leetcode_train_path=${TRAIN_DATA_DIR}/codegen__leetcode2k_1.3k.parquet +livecodebench_train_path=${TRAIN_DATA_DIR}/codegen__livecodebench_440.parquet +primeintellect_train_path=${TRAIN_DATA_DIR}/codegen__primeintellect_7.5k.parquet +taco_train_path=${TRAIN_DATA_DIR}/codegen__taco_8.8k.parquet +# Code (test) +humaneval_test_path=${TEST_DATA_DIR}/codegen__humaneval_164.parquet +mbpp_test_path=${TEST_DATA_DIR}/codegen__mbpp_500.parquet +livecodebench_test_path=${TEST_DATA_DIR}/codegen__livecodebench_279.parquet + +# Logic (train) +arcagi1_train_path=${TRAIN_DATA_DIR}/logic__arcagi1_111.parquet +arcagi2_train_path=${TRAIN_DATA_DIR}/logic__arcagi2_190.parquet +barc_train_path=${TRAIN_DATA_DIR}/logic__barc_1.6k.parquet +graph_train_path=${TRAIN_DATA_DIR}/logic__graph_logical_1.2k.parquet +ordering_train_path=${TRAIN_DATA_DIR}/logic__ordering_puzzle_1.9k.parquet +zebra_train_path=${TRAIN_DATA_DIR}/logic__zebra_puzzle_1.3k.parquet +# Logic (test) +zebralogic_test_path=${TEST_DATA_DIR}/logic__zebra_puzzle_dataset_300.parquet +ordering_puzzle_test_path=${TEST_DATA_DIR}/logic__ordering_puzzle_dataset_150.parquet + +# Simulation (train) +codeio_train_path=${TRAIN_DATA_DIR}/simulation__codeio_3.7k.parquet +# Simulation (test) +codeio_test_path=${TEST_DATA_DIR}/simulation__codeio_500.parquet +arcagi1_test_path=${TEST_DATA_DIR}/simulation__arcagi1_200.parquet + +# Table (train) +hitab_train_path=${TRAIN_DATA_DIR}/table__hitab_4.3k.parquet +multihier_train_path=${TRAIN_DATA_DIR}/table__multihier_1.5k.parquet +# Table (test) +multihier_test_path=${TEST_DATA_DIR}/table__multihier_300.parquet +hitab_test_path=${TEST_DATA_DIR}/table__hitab_300.parquet + +# Stem (train) +webinstruct_train_path=${TRAIN_DATA_DIR}/stem__web_3.6k.parquet +# Stem (test) +gpqa_diamond_test_path=${TEST_DATA_DIR}/stem__gpqa_diamond_198.parquet +supergpqa_test_path=${TEST_DATA_DIR}/stem__supergpqa_200.parquet + +train_files="['${math_train_path}']" # Use math as example, add to more tasks as needed +test_files="['${math_test_path}','${aime_test_path}']" # Use math as example, add to more tasks as needed + +# =================== Model =================== +BASE_MODEL=Qwen/Qwen2.5-32B + +# =================== Logging =================== +WANDB_PROJECT=Reasoning360 +WANDB_EXPERIMENT_NAME=${SLURM_JOB_ID}-${SLURM_JOB_NAME}-${BASE_MODEL##*/} + +# If RESUME_CKPT_DIR is not empty, resume from the checkpoint +if [[ -n "$RESUME_CKPT_DIR_NAME" ]]; then + WANDB_EXPERIMENT_NAME="$RESUME_CKPT_DIR_NAME" +fi + + +# =================== Ray start =================== +# ray stop at all nodes +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 ${CONDA_BIN_MEGATRON_PATH}ray stop + +sleep 10 +# Remove existing Ray cluster +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 rm -rf /tmp/ray/ray_current_cluster + +# Start Ray head node +srun --nodes=1 --ntasks=1 -w "$head_node" --export=ALL \ + ${CONDA_BIN_MEGATRON_PATH}ray start --head --node-ip-address="$head_node_ip" --port=$port \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --include-dashboard=True --block & + +sleep 10 + +# Start Ray worker nodes +for ((i = 1; i < worker_num; i++)); do + node_i=${nodes[$i]} + echo "Starting WORKER $i at $node_i" + srun --nodes=1 --ntasks=1 -w "$node_i" --export=ALL \ + ${CONDA_BIN_MEGATRON_PATH}ray start --address "$address_head" \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --block & +done +sleep 10 + + +# =================== RL Config =================== +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.2 + +max_prompt_length=$((1024 * 4)) +max_response_length=$((1024 * 8)) +enable_overlong_buffer=False +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +enable_filter_groups=False +filter_groups_metric=acc +max_num_gen_batches=10 +train_prompt_bsz=512 # grad accum bsz; real grad accum bsz: train_prompt_bsz * rollout.n +gen_prompt_bsz=$((train_prompt_bsz * 1)) # rollout bsz, i.e., the x-axis in RL plot +n_resp_per_prompt=16 +train_prompt_mini_bsz=64 + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout + +# Generation config +gen_tp=2 +gen_max_num_seqs=1024 + +# Megatron trainer config +train_tp=8 +train_pp=1 +sp_size=2 +offload=True + +# Batch size +use_dynamic_bsz=True +train_micro_batch_size=null +train_micro_batch_size_per_gpu_placeholder=1 # can't be null, as in ray_trainer.py ```minimal_bsz = megatron_dp * config.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu``` +infer_micro_batch_size_per_gpu_placeholder=1 # can't be null, as in megatron_worker.py ```assert self.config.ref.get("log_prob_micro_batch_size_per_gpu", None) is not None, "Please note that in the ref policy configuration, `log_prob_micro_batch_size_per_gpu` and `log_prob_micro_batch_size` should not be None at the same time."``` +# NOTE: this one is for per gpu, so it times sp_size (defined later) +# actor_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1 )) +actor_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 2 )) +infer_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 2 )) + + +# NOTE(yonghao): all other parts (weights, optimizer states) exists across stages (training, generation) +# while this one only lives during a training iteration. +grad_offload=True +#### + +# =================== Start RL training =================== +"${CONDA_BIN_MEGATRON_PATH}python" -m recipe.dapo.main_dapo \ + --config-path=config \ + --config-name="dapo_megatron_config.yaml" \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.prompt_key=prompt \ + data.truncation='right' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + data.gen_batch_size=${gen_prompt_bsz} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.actor.strategy="megatron" \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.lr_warmup_init=0.0 \ + actor_rollout_ref.actor.optim.lr=5e-7 \ + actor_rollout_ref.actor.optim.lr_decay_style=constant \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.optim.min_lr=0. \ + actor_rollout_ref.actor.optim.clip_grad=1.0 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_micro_batch_size_per_gpu_placeholder} \ + actor_rollout_ref.actor.megatron.param_offload=${offload} \ + actor_rollout_ref.actor.megatron.optimizer_offload=${offload} \ + actor_rollout_ref.actor.megatron.grad_offload=${grad_offload} \ + actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${train_pp} \ + actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${train_tp} \ + actor_rollout_ref.actor.megatron.context_parallel_size=${sp_size} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=${infer_micro_batch_size_per_gpu_placeholder} \ + actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${train_pp} \ + actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${train_tp} \ + actor_rollout_ref.ref.megatron.context_parallel_size=${sp_size} \ + actor_rollout_ref.ref.megatron.param_offload=${offload} \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${infer_micro_batch_size_per_gpu_placeholder} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.max_num_seqs=${gen_max_num_seqs} \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${top_p}\ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.model.path=$BASE_MODEL \ + +actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.rollout.multi_turn.enable=False \ + actor_rollout_ref.rollout.mode="sync" \ + +actor_rollout_ref.model.override_config.attention_dropout=0. \ + +actor_rollout_ref.model.override_config.embd_pdrop=0. \ + +actor_rollout_ref.model.override_config.resid_pdrop=0. \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + reward_model.reward_manager=async_multi_process \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.logger=['console','wandb'] \ + trainer.project_name=${WANDB_PROJECT} \ + trainer.experiment_name=${WANDB_EXPERIMENT_NAME} \ + trainer.log_val_generations=50 \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=$worker_num \ + trainer.save_freq=10 \ + trainer.test_freq=10 \ + trainer.total_epochs=5 \ + trainer.log_val_generations=50 \ + trainer.resume_mode=auto \ + trainer.max_actor_ckpt_to_keep=2 \ No newline at end of file diff --git a/scripts/train/example_singlenode_rl_qwen2.5_7b_base_fsdp.sh b/scripts/train/example_singlenode_rl_qwen2.5_7b_base_fsdp.sh new file mode 100644 index 000000000..f8f2d95b1 --- /dev/null +++ b/scripts/train/example_singlenode_rl_qwen2.5_7b_base_fsdp.sh @@ -0,0 +1,240 @@ +#!/bin/bash + +# =================== User-Configurable Settings =================== +# --- Execution Environment --- +NUM_GPUS=8 # Set the number of GPUs to use on this node + +# --- Resuming & Logging --- +RESUME_CKPT_DIR_NAME="" # Fill in the W&B experiment name to resume from, otherwise leave empty to start from scratch +WANDB_PROJECT="Reasoning360" # Your wandb project name + +# --- External Services --- +export STEM_LLM_JUDGE_URL="" # Optional: Fill in the llm-as-judge hosted URL for 'STEM' domain evaluation + +# =================== Environment Setup =================== +export NCCL_DEBUG=info +export CUDA_DEVICE_MAX_CONNECTIONS=1 +# export CUDA_LAUNCH_BLOCKING=1 # Uncomment for easier debugging of CUDA errors + +export HYDRA_FULL_ERROR=1 +export VLLM_USE_V1=0 + +# =================== Data Mixture =================== +# Assumes data is in a directory named 'data' in the same directory as the script +SHARED_DATA_PATH=./data +TRAIN_DATA_DIR=${SHARED_DATA_PATH}/train/ +TEST_DATA_DIR=${SHARED_DATA_PATH}/offline_eval/ + +# Math (train) +math_train_path=${TRAIN_DATA_DIR}/math__combined_54.4k.parquet +# Math (test) +math_test_path=${TEST_DATA_DIR}/math__math_500.parquet +aime_test_path=${TEST_DATA_DIR}/math__aime_repeated_8x_240.parquet +amc_test_path=${TEST_DATA_DIR}/math__amc_repeated_4x_332.parquet + +# Code (train) +leetcode_train_path=${TRAIN_DATA_DIR}/codegen__leetcode2k_1.3k.parquet +livecodebench_train_path=${TRAIN_DATA_DIR}/codegen__livecodebench_440.parquet +primeintellect_train_path=${TRAIN_DATA_DIR}/codegen__primeintellect_7.5k.parquet +taco_train_path=${TRAIN_DATA_DIR}/codegen__taco_8.8k.parquet +# Code (test) +humaneval_test_path=${TEST_DATA_DIR}/codegen__humaneval_164.parquet +mbpp_test_path=${TEST_DATA_DIR}/codegen__mbpp_500.parquet +livecodebench_test_path=${TEST_DATA_DIR}/codegen__livecodebench_279.parquet + +# Logic (train) +arcagi1_train_path=${TRAIN_DATA_DIR}/logic__arcagi1_111.parquet +arcagi2_train_path=${TRAIN_DATA_DIR}/logic__arcagi2_190.parquet +barc_train_path=${TRAIN_DATA_DIR}/logic__barc_1.6k.parquet +graph_train_path=${TRAIN_DATA_DIR}/logic__graph_logical_1.2k.parquet +ordering_train_path=${TRAIN_DATA_DIR}/logic__ordering_puzzle_1.9k.parquet +zebra_train_path=${TRAIN_DATA_DIR}/logic__zebra_puzzle_1.3k.parquet +# Logic (test) +zebralogic_test_path=${TEST_DATA_DIR}/logic__zebra_puzzle_dataset_300.parquet +ordering_puzzle_test_path=${TEST_DATA_DIR}/logic__ordering_puzzle_dataset_150.parquet + +# Simulation (train) +codeio_train_path=${TRAIN_DATA_DIR}/simulation__codeio_3.7k.parquet +# Simulation (test) +codeio_test_path=${TEST_DATA_DIR}/simulation__codeio_500.parquet +arcagi1_test_path=${TEST_DATA_DIR}/simulation__arcagi1_200.parquet + +# Table (train) +hitab_train_path=${TRAIN_DATA_DIR}/table__hitab_4.3k.parquet +multihier_train_path=${TRAIN_DATA_DIR}/table__multihier_1.5k.parquet +# Table (test) +multihier_test_path=${TEST_DATA_DIR}/table__multihier_300.parquet +hitab_test_path=${TEST_DATA_DIR}/table__hitab_300.parquet + +# Stem (train) +webinstruct_train_path=${TRAIN_DATA_DIR}/stem__web_3.6k.parquet +# Stem (test) +gpqa_diamond_test_path=${TEST_DATA_DIR}/stem__gpqa_diamond_198.parquet +supergpqa_test_path=${TEST_DATA_DIR}/stem__supergpqa_200.parquet + +# --- Select your training/testing data --- +# Here we use math as an example. You can combine multiple datasets. +# Example for multiple files: "['path/to/file1.parquet', 'path/to/file2.parquet']" +train_files="['${math_train_path}']" +test_files="['${math_test_path}','${aime_test_path}']" # Use math as example, add to more tasks as needed + +# =================== Model =================== +BASE_MODEL=Qwen/Qwen2.5-7B + +# =================== Logging =================== +# Generate a unique experiment name if not resuming +if [[ -n "$RESUME_CKPT_DIR_NAME" ]]; then + WANDB_EXPERIMENT_NAME="$RESUME_CKPT_DIR_NAME" +else + TIMESTAMP=$(date +%Y%m%d-%H%M%S) + WANDB_EXPERIMENT_NAME="single-node-${TIMESTAMP}-${BASE_MODEL##*/}" +fi + +# =================== Ray Start (Single Node) =================== +# Stop any previous Ray instances +${CONDA_BIN_PATH}ray stop -f + +# Start a new Ray cluster on the local machine +# The number of CPUs is often best left for Ray to determine automatically. +echo "Starting Ray on the local node with ${NUM_GPUS} GPUs..." +${CONDA_BIN_PATH}ray start --head --num-gpus ${NUM_GPUS} --include-dashboard=True --dashboard-port 8265 +sleep 5 + + +# =================== RL Config =================== +# Note, we borrowed the config format from DAPO while here disabled all DAPO features to run the naive RL baseline. + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.2 + +max_prompt_length=$((1024 * 4)) +max_response_length=$((1024 * 8)) +enable_overlong_buffer=False +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +enable_filter_groups=False +filter_groups_metric=acc +max_num_gen_batches=10 +train_prompt_bsz=512 # on-policy model update batchsize: train_prompt_bsz * rollout.n +gen_prompt_bsz=$((train_prompt_bsz * 1)) +n_resp_per_prompt=16 +train_prompt_mini_bsz=64 # model grad update batchsize + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout + +# Training config +# NOTE: sp_size and gen_tp are parallelism settings. +# sp_size: Sequence Parallelism size. +# gen_tp: Tensor Parallelism size for vLLM generation. +# For a 32B model on 8 GPUs, TP=2 is a reasonable starting point. Adjust if you have memory issues. +sp_size=1 +gen_tp=2 +gen_max_num_seqs=1024 +infer_micro_batch_size=null +train_micro_batch_size=null +use_dynamic_bsz=True +actor_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 2)) # increase this to speed up model forward & backward but note memory overflow +infer_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 2)) # increase this to speed up model forward, but note memory overflow +offload=True + +# =================== Start RL training =================== +# Ensure your python environment (e.g., conda) is activated before running this script. +echo "Starting training..." +python -m recipe.dapo.main_dapo \ + --config-path=config \ + --config-name="dapo_fsdp_config.yaml" \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.prompt_key=prompt \ + data.truncation='right' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + data.gen_batch_size=${gen_prompt_bsz} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.actor.strategy="fsdp" \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.optim.warmup_style=constant \ + actor_rollout_ref.actor.optim.min_lr_ratio=0. \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.max_num_seqs=${gen_max_num_seqs} \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${top_p}\ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.model.path=$BASE_MODEL \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.rollout.multi_turn.enable=False \ + actor_rollout_ref.rollout.mode="sync" \ + +actor_rollout_ref.model.override_config.attention_dropout=0. \ + +actor_rollout_ref.model.override_config.embd_pdrop=0. \ + +actor_rollout_ref.model.override_config.resid_pdrop=0. \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + reward_model.reward_manager=async_multi_process \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.logger=['console','wandb'] \ + trainer.project_name=${WANDB_PROJECT} \ + trainer.experiment_name=${WANDB_EXPERIMENT_NAME} \ + trainer.val_before_train=True \ + trainer.n_gpus_per_node=${NUM_GPUS} \ + trainer.nnodes=1 \ + trainer.save_freq=10 \ + trainer.test_freq=10 \ + trainer.total_epochs=10 \ + trainer.log_val_generations=50 \ + trainer.resume_mode=auto \ No newline at end of file diff --git a/scripts/train/example_multinode_rl_qwen32b_base.sh b/scripts/train/example_singlenode_rl_qwen7b_synlogic.sh similarity index 91% rename from scripts/train/example_multinode_rl_qwen32b_base.sh rename to scripts/train/example_singlenode_rl_qwen7b_synlogic.sh index dfa538fe5..fc4516a00 100644 --- a/scripts/train/example_multinode_rl_qwen32b_base.sh +++ b/scripts/train/example_singlenode_rl_qwen7b_synlogic.sh @@ -1,16 +1,14 @@ #!/bin/bash -#SBATCH --job-name=example-multinode-rl-qwen32b-base -#SBATCH --partition=main -#SBATCH --nodes=8 -#SBATCH --ntasks=8 +#SBATCH --job-name=Qwen2-7B-OC +#SBATCH --nodes=2 +#SBATCH --ntasks=2 #SBATCH --ntasks-per-node=1 #SBATCH --gres=gpu:8 -#SBATCH --cpus-per-task=96 -#SBATCH --mem=512G +#SBATCH --cpus-per-task=10 #SBATCH --output=slurm/%x-%j.out #SBATCH --error=slurm/%x-%j.err -#SBATCH --exclusive -#SBATCH --time=720:00:00 +#SBATCH --account=iq +#SBATCH --mem=512G # =================== Frequently Used Variables =================== @@ -42,8 +40,12 @@ export VLLM_USE_V1=0 # =================== Data Mixture =================== SHARED_DATA_PATH=./data -TRAIN_DATA_DIR=${SHARED_DATA_PATH}/train/ -TEST_DATA_DIR=${SHARED_DATA_PATH}/offline_eval/ +TRAIN_DATA_DIR=${SHARED_DATA_PATH}/train +TEST_DATA_DIR=${SHARED_DATA_PATH}/test + +# synlogic +synlogic_train_path=${TRAIN_DATA_DIR}/synlogic_object_counting_train.parquet +synlogic_test_path=${TEST_DATA_DIR}/synlogic_object_counting_test.parquet # Math (train) math_train_path=${TRAIN_DATA_DIR}/math__combined_54.4k.parquet @@ -92,12 +94,14 @@ webinstruct_train_path=${TRAIN_DATA_DIR}/stem__web_3.6k.parquet gpqa_diamond_test_path=${TEST_DATA_DIR}/stem__gpqa_diamond_198.parquet supergpqa_test_path=${TEST_DATA_DIR}/stem__supergpqa_200.parquet -train_files="['${math_train_path}']" # Use math as example, add to more tasks as needed -test_files="['${math_test_path}']" # Use math as example, add to more tasks as needed +train_files="['${synlogic_train_path}']" # Use math as example, add to more tasks as needed +test_files="['${synlogic_test_path}']" # Use math as example, add to more tasks as needed # =================== Model =================== -BASE_MODEL=Qwen/Qwen2.5-32B # Note: This is the original Qwen32B-Base model. In training, we add 'think' system prompt to it (see README). - +# BASE_MODEL=Qwen/Qwen2.5-32B # Note: This is the original Qwen32B-Base model. In training, we add 'think' system prompt to it (see README). +# BASE_MODEL=deepseek-ai/DeepSeek-R1-Distill-Qwen-7B +BASE_MODEL=Qwen/Qwen2.5-7B-instruct +# BASE_MODEL=Qwen/Qwen3-4B-Thinking-2507 # =================== Logging =================== WANDB_PROJECT=Reasoning360 WANDB_EXPERIMENT_NAME=${SLURM_JOB_ID}-${SLURM_JOB_NAME}-${BASE_MODEL##*/} @@ -158,10 +162,10 @@ loss_agg_mode="token-mean" enable_filter_groups=False filter_groups_metric=acc max_num_gen_batches=10 -train_prompt_bsz=512 # on-policy model update batchsize: train_prompt_bsz * rollout.n +train_prompt_bsz=128 # on-policy model update batchsize: train_prompt_bsz * rollout.n gen_prompt_bsz=$((train_prompt_bsz * 1)) n_resp_per_prompt=16 -train_prompt_mini_bsz=64 # model grad update batchsize +train_prompt_mini_bsz=16 # model grad update batchsize # Algorithm temperature=1.0 @@ -169,7 +173,7 @@ top_p=1.0 top_k=-1 # 0 for HF rollout, -1 for vLLM rollout # Mathematically equivalent -sp_size=8 +sp_size=4 gen_tp=4 infer_micro_batch_size=null train_micro_batch_size=null diff --git a/scripts/train/ifbench_test.sh b/scripts/train/ifbench_test.sh new file mode 100644 index 000000000..2c559f411 --- /dev/null +++ b/scripts/train/ifbench_test.sh @@ -0,0 +1,239 @@ +#!/bin/bash +#SBATCH --job-name=example-multinode-rl-qwen7b-instruct-IFbench-test +#SBATCH --partition=main +#SBATCH --nodes=2 +#SBATCH --ntasks=2 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:8 +#SBATCH --cpus-per-task=96 +#SBATCH --account=iq +#SBATCH --mem=512G +#SBATCH --output=slurm/%x-%j.out +#SBATCH --error=slurm/%x-%j.err +#SBATCH --exclusive +#SBATCH --time=720:00:00 + +# =================== Frequently Used Variables =================== +RESUME_CKPT_DIR_NAME="" # Fill in the checkpoint directory name to resume from, otherwise from scratch +export STEM_LLM_JUDGE_URL="http://10.24.1.81:8000" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain + +# =================== Cluster Environment =================== +export NCCL_DEBUG=info +export NCCL_ALGO=NVLSTree +export NCCL_IBEXT_DISABLE=1 +export NCCL_NVLS_ENABLE=1 +export NCCL_IB_HCA=mlx5 +export UCX_NET_DEVICES=mlx5_0:1,mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1 +export TRITON_HOME="/tmp/triton_cache" +# export SANDBOX_FUSION_SERVERS="fs-mbz-cpu-002" + +# Get the list of allocated nodes +nodes=( $(scontrol show hostnames "$SLURM_JOB_NODELIST") ) +echo "Nodes to check: ${nodes[@]}" + +# We'll track PIDs so we can wait on them and detect errors +declare -A pids +export head_node=${nodes[0]} +echo "Head node: $head_node" + +# Get head node IP address with better error handling +head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" --cpu-bind=none hostname --ip-address) +if [ -z "$head_node_ip" ]; then + echo "Error: Could not get IP address for head node $head_node" + exit 1 +fi +echo "Head node IP: $head_node_ip" + +port=6379 +address_head=$head_node_ip:$port + +export worker_num=$SLURM_NNODES +export HYDRA_FULL_ERROR=1 +export VLLM_USE_V1=0 + +# =================== Data Mixture =================== +# SHARED_DATA_PATH=./data +SHARED_DATA_PATH=/mnt/sharefs/users/jianshu.she/ + +# 使用ifbench数据 +IFBENCH_TRAIN_PATH=/mnt/sharefs/users/jianshu.she/ifbench_split/ifbench_train_fixed.parquet +IFBENCH_TEST_PATH=/mnt/sharefs/users/jianshu.she/ifbench_split/ifbench_test_fixed.parquet + +# All training files across all domains +train_files="['${IFBENCH_TRAIN_PATH}']" + +# All test files across all domains +test_files="['${IFBENCH_TEST_PATH}']" + +# =================== Model =================== +# BASE_MODEL=deepseek-ai/DeepSeek-R1-Distill-Qwen-7B # Note: This is the original Qwen32B-Base model. In training, we add 'think' system prompt to it (see README). +BASE_MODEL=Qwen/Qwen2.5-7B-Instruct # Note: This is the original Qwen32B-Base model. In training, we add 'think' system prompt to it (see README). +CONDA_BIN_PATH=/mnt/weka/home/jianshu.she/miniconda3/envs/Reasoning360/bin/ +# =================== Logging =================== +WANDB_PROJECT=IFbench-test +WANDB_EXPERIMENT_NAME=${SLURM_JOB_ID}-${SLURM_JOB_NAME}-${BASE_MODEL##*/} + +# If RESUME_CKPT_DIR is not empty, resume from the checkpoint +if [[ -n "$RESUME_CKPT_DIR_NAME" ]]; then + WANDB_EXPERIMENT_NAME="$RESUME_CKPT_DIR_NAME" +fi + + +# =================== Ray start =================== +# ray stop at all nodes +echo "Stopping Ray on all nodes..." +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 --cpu-bind=none ray stop || true + +sleep 10 +# Remove existing Ray cluster +echo "Removing existing Ray cluster..." +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 --cpu-bind=none rm -rf /tmp/ray/ray_current_cluster || true + +# Start Ray head node +echo "Starting Ray head node on $head_node with IP $head_node_ip..." +srun --nodes=1 --ntasks=1 -w "$head_node" --export=ALL --cpu-bind=none \ + ${CONDA_BIN_PATH}ray start --head --node-ip-address="$head_node_ip" --port=$port \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --include-dashboard=True --block & + +sleep 10 + +# Start Ray worker nodes +for ((i = 1; i < worker_num; i++)); do + node_i=${nodes[$i]} + echo "Starting WORKER $i at $node_i" + srun --nodes=1 --ntasks=1 -w "$node_i" --export=ALL --cpu-bind=none \ + ${CONDA_BIN_PATH}ray start --address "$address_head" \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --block & +done +sleep 10 + +# =================== RL Config =================== +# Note, we borrowed the config format from DAPO while here disabled all DAPO features to run the naive RL baseline. + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.2 + +max_prompt_length=$((1024 * 4)) +max_response_length=$((1024 * 8)) +enable_overlong_buffer=False +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +enable_filter_groups=False +filter_groups_metric=acc +max_num_gen_batches=10 +train_prompt_bsz=256 # on-policy model update batchsize: train_prompt_bsz * rollout.n +gen_prompt_bsz=$((train_prompt_bsz * 1)) +n_resp_per_prompt=16 +train_prompt_mini_bsz=32 # model grad update batchsize + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout + +# Mathematically equivalent +sp_size=1 +gen_tp=1 +infer_micro_batch_size=null +train_micro_batch_size=null +use_dynamic_bsz=True +actor_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 2)) # increase this to speed up model forward & backward but note memory overflow +infer_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 2)) # increase this to speed up modelforward, but note memory overflow +offload=True + +# =================== Start RL training =================== +cd /mnt/weka/home/jianshu.she/IFM/Reasoning360 +"${CONDA_BIN_PATH}python" -m verl.recipe.dapo.src.main_dapo \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.prompt_key=prompt \ + data.truncation='right' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + data.gen_batch_size=${gen_prompt_bsz} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.actor.strategy="fsdp" \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.optim.warmup_style=constant \ + actor_rollout_ref.actor.optim.min_lr_ratio=0. \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${top_p}\ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.model.path=$BASE_MODEL \ + actor_rollout_ref.model.use_remove_padding=True \ + +actor_rollout_ref.rollout.multi_turn.enable=False \ + +actor_rollout_ref.rollout.mode="sync" \ + +actor_rollout_ref.model.override_config.attention_dropout=0. \ + +actor_rollout_ref.model.override_config.embd_pdrop=0. \ + +actor_rollout_ref.model.override_config.resid_pdrop=0. \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + reward_model.reward_manager=async_dapo \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.logger=['console','wandb'] \ + trainer.project_name=${WANDB_PROJECT} \ + trainer.experiment_name=${WANDB_EXPERIMENT_NAME} \ + trainer.val_before_train=True \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes="${NNODES}" \ + trainer.nnodes=$worker_num \ + trainer.save_freq=50 \ + trainer.test_freq=50 \ + trainer.total_epochs=5 \ + +trainer.vary_length=False \ + +trainer.val_generations_to_log_to_wandb=30 \ + trainer.resume_mode=auto \ No newline at end of file diff --git a/scripts/train/reasoning_gym_test.sh b/scripts/train/reasoning_gym_test.sh new file mode 100644 index 000000000..08294384e --- /dev/null +++ b/scripts/train/reasoning_gym_test.sh @@ -0,0 +1,266 @@ +#!/bin/bash +#SBATCH --job-name=example-multinode-rl-qwen7b-instruct-reasoning-gym-test +#SBATCH --partition=main +#SBATCH --nodes=2 +#SBATCH --ntasks=2 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:8 +#SBATCH --cpus-per-task=96 +#SBATCH --account=iq +#SBATCH --mem=512G +#SBATCH --output=slurm/%x-%j.out +#SBATCH --error=slurm/%x-%j.err +#SBATCH --exclusive +#SBATCH --time=720:00:00 + +# =================== Frequently Used Variables =================== +RESUME_CKPT_DIR_NAME="" # Fill in the checkpoint directory name to resume from, otherwise from scratch +export STEM_LLM_JUDGE_URL="http://10.24.1.81:8000" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain +export PATH=/mnt/weka/home/jianshu.she/miniconda3/envs/Reasoning360/bin:$PATH +# =================== Cluster Environment =================== +export NCCL_DEBUG=info +export NCCL_ALGO=NVLSTree +export NCCL_IBEXT_DISABLE=1 +export NCCL_NVLS_ENABLE=1 +export NCCL_IB_HCA=mlx5 +export UCX_NET_DEVICES=mlx5_0:1,mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1 +export TRITON_HOME="/tmp/triton_cache" +# export SANDBOX_FUSION_SERVERS="fs-mbz-cpu-002" + +# Get the list of allocated nodes +nodes=( $(scontrol show hostnames "$SLURM_JOB_NODELIST") ) +echo "Nodes to check: ${nodes[@]}" + +# We'll track PIDs so we can wait on them and detect errors +declare -A pids +export head_node=${nodes[0]} +echo "Head node: $head_node" + +# Get head node IP address with better error handling +head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" --cpu-bind=none hostname --ip-address) +if [ -z "$head_node_ip" ]; then + echo "Error: Could not get IP address for head node $head_node" + exit 1 +fi +echo "Head node IP: $head_node_ip" + +port=6379 +address_head=$head_node_ip:$port + +export worker_num=$SLURM_NNODES +export HYDRA_FULL_ERROR=1 +export VLLM_USE_V1=0 + +# =================== Data Mixture =================== +# SHARED_DATA_PATH=./data +SHARED_DATA_PATH=/mnt/sharefs/users/jianshu.she/ + +# 使用ifbench数据 +REASONING_GYM_TRAIN_PATH=/mnt/sharefs/users/jianshu.she/train_reasoning_gym.parquet +REASONING_GYM_TEST_PATH=/mnt/sharefs/users/jianshu.she/test_reasoning_gym.parquet + +# All training files across all domains +train_files="['${REASONING_GYM_TRAIN_PATH}']" + +# All test files across all domains +test_files="['${REASONING_GYM_TEST_PATH}']" + +# =================== Model =================== +# BASE_MODEL=deepseek-ai/DeepSeek-R1-Distill-Qwen-7B # Note: This is the original Qwen32B-Base model. In training, we add 'think' system prompt to it (see README). +# BASE_MODEL=Qwen/Qwen2.5-7B-Instruct # Note: This is the original Qwen32B-Base model. In training, we add 'think' system prompt to it (see README). +# BASE_MODEL=deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B +BASE_MODEL=/mnt/sharefs/users/haonan.li/models/Qwen2.5-7B-Instruct + +CONDA_BIN_PATH=/mnt/weka/home/jianshu.she/miniconda3/envs/Reasoning360/bin/ +source /mnt/weka/home/jianshu.she/miniconda3/bin/activate Reasoning360 + +# =================== Logging =================== +WANDB_PROJECT=reasoning-gym-test +WANDB_EXPERIMENT_NAME=${SLURM_JOB_ID}-${SLURM_JOB_NAME}-${BASE_MODEL##*/} + +# If RESUME_CKPT_DIR is not empty, resume from the checkpoint +if [[ -n "$RESUME_CKPT_DIR_NAME" ]]; then + WANDB_EXPERIMENT_NAME="$RESUME_CKPT_DIR_NAME" +fi + + +# =================== Ray start =================== +# ray stop at all nodes +echo "Stopping Ray on all nodes..." +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 --cpu-bind=none ray stop || true + +sleep 10 +# Remove existing Ray cluster +echo "Removing existing Ray cluster..." +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 --cpu-bind=none rm -rf /tmp/ray/ray_current_cluster || true + +# Start Ray head node +echo "Starting Ray head node on $head_node with IP $head_node_ip..." +srun --nodes=1 --ntasks=1 -w "$head_node" --export=ALL --cpu-bind=none \ + bash -c " + source /mnt/weka/home/jianshu.she/miniconda3/bin/activate Reasoning360 + if ! pip show reasoning_gym > /dev/null 2>&1; then + echo 'reasoning_gym not found, installing...' + pip install reasoning_gym + fi + echo '==== Node: $(hostname) ====' + which python + python --version + pip show reasoning_gym || echo 'reasoning_gym NOT FOUND' + python -c 'import reasoning_gym; print(\"reasoning_gym imported OK\")' || echo 'reasoning_gym import FAILED' + ${CONDA_BIN_PATH}ray start --head --node-ip-address=\"$head_node_ip\" --port=$port --num-cpus \"${SLURM_CPUS_PER_TASK}\" --num-gpus 8 --include-dashboard=True --block + " & + +sleep 10 + +# Start Ray worker nodes +for ((i = 1; i < worker_num; i++)); do + node_i=${nodes[$i]} + echo "Starting WORKER $i at $node_i" + srun --nodes=1 --ntasks=1 -w "$node_i" --export=ALL --cpu-bind=none \ + bash -c " + source /mnt/weka/home/jianshu.she/miniconda3/bin/activate Reasoning360 + if ! pip show reasoning_gym > /dev/null 2>&1; then + echo 'reasoning_gym not found, installing...' + pip install reasoning_gym + fi + echo '==== Node: $(hostname) ====' + which python + python --version + pip show reasoning_gym || echo 'reasoning_gym NOT FOUND' + python -c 'import reasoning_gym; print(\"reasoning_gym imported OK\")' || echo 'reasoning_gym import FAILED' + ${CONDA_BIN_PATH}ray start --address \"$address_head\" --num-cpus \"${SLURM_CPUS_PER_TASK}\" --num-gpus 8 --block + " & +done +sleep 10 + +# =================== RL Config =================== +# Note, we borrowed the config format from DAPO while here disabled all DAPO features to run the naive RL baseline. + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.2 + +max_prompt_length=$((1024 * 4)) +max_response_length=$((1024 * 8)) +enable_overlong_buffer=False +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +enable_filter_groups=False +filter_groups_metric=acc +max_num_gen_batches=10 +train_prompt_bsz=256 # on-policy model update batchsize: train_prompt_bsz * rollout.n +gen_prompt_bsz=$((train_prompt_bsz * 1)) +n_resp_per_prompt=16 +train_prompt_mini_bsz=32 # model grad update batchsize + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout + +# Mathematically equivalent +sp_size=1 +gen_tp=1 +infer_micro_batch_size=null +train_micro_batch_size=null +use_dynamic_bsz=True +actor_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 2)) # increase this to speed up model forward & backward but note memory overflow +infer_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 2)) # increase this to speed up modelforward, but note memory overflow +offload=True + +# =================== Start RL training =================== +cd /mnt/weka/home/jianshu.she/IFM/Reasoning360 +"${CONDA_BIN_PATH}python" -m verl.recipe.dapo.src.main_dapo \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.prompt_key=prompt \ + data.truncation='right' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + data.gen_batch_size=${gen_prompt_bsz} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.actor.strategy="fsdp" \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.optim.warmup_style=constant \ + actor_rollout_ref.actor.optim.min_lr_ratio=0. \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${top_p}\ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.model.path=$BASE_MODEL \ + actor_rollout_ref.model.use_remove_padding=True \ + +actor_rollout_ref.rollout.multi_turn.enable=False \ + +actor_rollout_ref.rollout.mode="sync" \ + +actor_rollout_ref.model.override_config.attention_dropout=0. \ + +actor_rollout_ref.model.override_config.embd_pdrop=0. \ + +actor_rollout_ref.model.override_config.resid_pdrop=0. \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + reward_model.reward_manager=async_dapo \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.logger=['console','wandb'] \ + trainer.project_name=${WANDB_PROJECT} \ + trainer.experiment_name=${WANDB_EXPERIMENT_NAME} \ + trainer.val_before_train=True \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes="${NNODES}" \ + trainer.nnodes=$worker_num \ + trainer.save_freq=50 \ + trainer.test_freq=50 \ + trainer.total_epochs=5 \ + +trainer.vary_length=False \ + +trainer.val_generations_to_log_to_wandb=30 \ + trainer.resume_mode=auto \ No newline at end of file diff --git a/setup.py b/setup.py index 1cd5d7a4a..7f56b8667 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,7 @@ "datasets", "dill", "hydra-core", - "numpy", + "numpy<2.0.0", "pandas", "peft", "pyarrow>=19.0.0", @@ -37,8 +37,8 @@ "pylatexenc", "ray[default]>=2.41.0", "torchdata", - "tensordict<=0.6.2", # NOTE: Reasoning360 used a fixed version ==0.72 - "transformers", # NOTE Reasoning360 used a fixed version ==4.51.0 + "tensordict>=0.8.0,<=0.9.1,!=0.9.0", + "transformers", "wandb", "packaging>=20.0", # NOTE: added by Reasoning360 @@ -48,19 +48,20 @@ "polars" ] -TEST_REQUIRES = ["pytest", "pre-commit", "py-spy"] +TEST_REQUIRES = ["pytest", "pre-commit", "py-spy", "pytest-asyncio"] PRIME_REQUIRES = ["pyext"] -GEO_REQUIRES = ["mathruler"] +GEO_REQUIRES = ["mathruler", "torchvision", "qwen_vl_utils"] GPU_REQUIRES = ["liger-kernel", "flash-attn", "nvitop",] # NOTE: nvitop is added by Reasoning360 -# NOTE: Reasoning360 used "math-verify[antlr4_9_3]==0.6.0" MATH_REQUIRES = ["math-verify"] # Add math-verify as an optional dependency -VLLM_REQUIRES = ["tensordict<=0.6.2", "vllm<=0.8.5"] +VLLM_REQUIRES = ["tensordict>=0.8.0,<=0.9.1,!=0.9.0", "vllm>=0.7.3,<=0.8.5"] SGLANG_REQUIRES = [ - "tensordict<=0.6.2", + "tensordict>=0.8.0,<=0.9.1,!=0.9.0", "sglang[srt,openai]==0.4.6.post5", "torch-memory-saver>=0.0.5", "torch==2.6.0", ] +TRL_REQUIRES = ["trl<=0.9.6"] +MCORE_REQUIRES = ["mbridge"] extras_require = { "test": TEST_REQUIRES, @@ -70,6 +71,8 @@ "math": MATH_REQUIRES, "vllm": VLLM_REQUIRES, "sglang": SGLANG_REQUIRES, + "trl": TRL_REQUIRES, + "mcore": MCORE_REQUIRES, } diff --git a/tests/README.md b/tests/README.md new file mode 100644 index 000000000..479f06933 --- /dev/null +++ b/tests/README.md @@ -0,0 +1,30 @@ +# Tests layout + +Each folder under tests/ corresponds to a test category for a sub-namespace in verl. For instance: +- `tests/trainer` for testing functionality related to `verl/trainer` +- `tests/models` for testing functionality related to `verl/models` +- ... + +There are a few folders with `special_` prefix, created for special purposes: +- `special_distributed`: unit tests that must run with multiple GPUs +- `special_e2e`: end-to-end tests with training/generation scripts +- `special_npu`: tests for NPUs +- `special_sanity`: a suite of quick sanity tests +- `special_standalone`: a set of test that are designed to run in dedicated environments + +Accelerators for tests +- By default tests are run with GPU available, except for the ones under `special_npu`, and any test script whose name ends with `on_cpu.py`. +- For test scripts with `on_cpu.py` name suffix would be tested on CPU resources in linux environment. + +# Workflow layout + +All CI tests are configured by yaml files in `.github/workflows/`. Here's an overview of all test configs: +1. A list of always triggered CPU sanity tests: `check-pr-title.yml`, `secrets_scan.yml`, `check-pr-title,yml`, `pre-commit.yml`, `doc.yml` +2. Some heavy multi-GPU unit tests, such as `model.yml`, `vllm.yml`, `sgl.yml` +3. End-to-end tests: `e2e_*.yml` +4. Unit tests + - `cpu_unit_tests.yml`, run pytest on all scripts with file name pattern `tests/**/test_*_on_cpu.py` + - `gpu_unit_tests.yml`, run pytest on all scripts with file without the `on_cpu.py` suffix. + - Since cpu/gpu unit tests by default runs all tests under `tests`, please make sure tests are manually excluded in them when + - new workflow yaml is added to `.github/workflows` + - new tests are added to workflow mentioned in 2. \ No newline at end of file diff --git a/tests/data_process/test_data_preprocess.py b/tests/data_process/test_data_preprocess.py deleted file mode 100644 index 50596013e..000000000 --- a/tests/data_process/test_data_preprocess.py +++ /dev/null @@ -1,179 +0,0 @@ -import random -import transformers -import pytest -import logging -import datasets - -from verl.utils.data_process.filter import LengthFilter -# Configure logging -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') -logger = logging.getLogger(__name__) - -random.seed(42) - -MODEL_NAME_OR_PATH = "Qwen/Qwen2.5-7B" -tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH) - -cache_dir = datasets.config.HF_DATASETS_CACHE - -EXTRA_INST_MAP = { - "math": "Please output the final answer within \\boxed{}.", - "tablereason": "Please output the final answer within \\boxed{}.", - "simulation": "Please output the final answer within \\boxed{}.", -} - -DATASETS_CONFIG = [ - # math - {"domain": "math", "name": "bigmath_preview_filtered_mar21", "check_train": True, "check_test": False, "max_length": "unlimited"}, # -1 means no length filter applied - {"domain": "math", "name": "deepscaler_preview", "check_train": True, "check_test": False, "max_length": "unlimited"}, - # # code - {"domain": "codegen", "name": "leetcode2k", "check_train": False, "check_test": True, "max_length": 4096}, - {"domain": "codegen", "name": "primeintellect", "check_train": True, "check_test": False, "max_length": 4096}, - {"domain": "codegen", "name": "taco", "check_train": True, "check_test": False, "max_length": 4096}, - {"domain": "codegen", "name": "livecodebench", "check_train": True, "check_test": True, "max_length": 4096}, - {"domain": "codegen", "name": "humaneval", "check_train": False, "check_test": True, "max_length": 4096}, - {"domain": "codegen", "name": "mbpp", "check_train": True, "check_test": True, "max_length": 4096}, - # simulation - {"domain": "simulation", "name": "codeio", "check_train": True, "check_test": True, "max_length": 4096}, - # table - {"domain": "table", "name": "multihier", "check_train": True, "check_test": True, "max_length": 4096}, - {"domain": "table", "name": "hitab", "check_train": True, "check_test": True, "max_length": 4096}, - - # Add more datasets here as needed -] - -@pytest.fixture(scope="module") -def tokenizer_fixture(): - return transformers.AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH) - -def load_samples_to_check(data_domain: str, dataname: str, check_train: bool, check_test: bool, samples_to_check: int = 50): - """ - Load samples to check - """ - data_source = f"{data_domain}__{dataname}" - logger.info(f"Loading {samples_to_check} samples from {data_source}") - import importlib - - # if dataname == "bigmath_preview_filtered_mar21": - module_path = f"data_preprocess.{data_domain}.{dataname}" - module = importlib.import_module(module_path) - train_dataset, test_dataset = module.get_datasets(cache_dir) - - result_datasets = {} - if check_train: - train_dataset = train_dataset.select(random.sample(range(len(train_dataset)), samples_to_check)) - result_datasets["train"] = train_dataset.map(module.make_map_fn("train", data_source), with_indices=True) - if check_test: - test_dataset = test_dataset.select(random.sample(range(len(test_dataset)), samples_to_check)) - result_datasets["test"] = test_dataset.map(module.make_map_fn("test", data_source), with_indices=True) - return result_datasets - - -def check_length(data_entry, tokenizer, min_length, max_length, length_tolerance=100): - """ - Check the length of the prompt - """ - if max_length == "unlimited": - max_length = 128000 - - if "prompt" in data_entry and data_entry["prompt"]: - prompt_tokens = tokenizer.tokenize(tokenizer.apply_chat_template(data_entry["prompt"], tokenize=False)) - elif "raw_prompt" in data_entry and data_entry["raw_prompt"]: - prompt_tokens = tokenizer.tokenize(data_entry["raw_prompt"]) - else: - raise ValueError("No prompt found in data") - - token_length = len(prompt_tokens) - assert min_length <= token_length <= max_length - length_tolerance, \ - f"Token length {token_length} outside acceptable range [{min_length}, {max_length - length_tolerance}]" - -def check_data_source_format(data_entry, data_domain, data_name): - """ - Check the format of the data source - """ - assert "data_source" in data_entry, "Missing data_source in extra_info" - - data_source = data_entry["data_source"] - assert data_source is not None, "data_source is None" - assert "__" in data_source, f"Invalid data_source format: {data_source}" - - domain_in_data_source, name_in_data_source = data_source.split("__") - assert domain_in_data_source == data_domain, \ - f"Domain mismatch: {domain_in_data_source} != {data_domain}" - assert name_in_data_source == data_name, \ - f"Name mismatch: {name_in_data_source} != {data_name}" - - -def check_prompt_format(data_entry, extra_instruction): - """ - Check the format of the prompt - """ - if extra_instruction is not None: - raw_prompt = data_entry.get("raw_prompt") - assert raw_prompt is not None, "Missing raw_prompt in data entry" - assert extra_instruction in raw_prompt, \ - f"Extra instruction '{extra_instruction}' not found in prompt" - - -def check_special_tokens(data_entry, model_name_or_path): - """ - Check the special tokens in the prompt - """ - prompt = data_entry.get("raw_prompt") - assert prompt is not None, "Missing raw_prompt in data entry" - - logger.info(prompt) - if "Qwen" in model_name_or_path: - assert "<|im_start|>" in prompt, \ - "Missing <|im_start|> token in prompt for Qwen model" - assert "<|im_end|>" in prompt, \ - "Missing <|im_end|> token in prompt for Qwen model" - - assert "" in prompt, "Missing token in prompt" - -@pytest.mark.parametrize("dataset_config", DATASETS_CONFIG) -def test_dataset_format(dataset_config, tokenizer_fixture): - """ - Parameterized test for checking dataset format - """ - dataname = dataset_config["name"] - data_domain = dataset_config["domain"] - check_train = dataset_config["check_train"] - check_test = dataset_config["check_test"] - min_length = dataset_config.get("min_length", 20) - max_length = dataset_config.get("max_length", 2048) - - logger.info(f"Testing dataset: {dataname}") - logger.info(f"Max length: {max_length}") - datasets = load_samples_to_check( - data_domain=data_domain, - dataname=dataname, - check_train=check_train, - check_test=check_test, - samples_to_check=50 - ) - - for split, dataset in datasets.items(): - if dataset is None: - continue - # Skip samples that failed to pass the coding unittests, which return empty entries - dataset = dataset.filter(lambda x: x["raw_prompt"] is not None and x["prompt"] is not None) - - # Filter by length, which is done in individual script - if max_length != "unlimited": - length_filter = LengthFilter(tokenizer=tokenizer_fixture, max_length=max_length) - dataset = dataset.filter(lambda x: length_filter.check(x)) - - logger.info(f"Testing {split} dataset: {data_domain}__{dataname}") - for i, data_entry in enumerate(dataset): - logger.debug(f"Checking sample {i+1}/{len(dataset)}") - - check_length(data_entry, tokenizer_fixture, min_length, max_length) - check_data_source_format(data_entry, data_domain, dataname) - - extra_instruction = EXTRA_INST_MAP.get(data_domain, None) - check_prompt_format(data_entry, extra_instruction) - - check_special_tokens(data_entry, MODEL_NAME_OR_PATH) - - logger.info(f"Successfully tested {len(dataset)} samples from {dataname}") \ No newline at end of file diff --git a/tests/e2e/arithmetic_sequence/data/create_dataset.py b/tests/e2e/arithmetic_sequence/data/create_dataset.py deleted file mode 100644 index 1729fd6af..000000000 --- a/tests/e2e/arithmetic_sequence/data/create_dataset.py +++ /dev/null @@ -1,46 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os - -from torch.utils import data - -from tests.e2e.envs.digit_completion import DigitCompletion - -if __name__ == "__main__": - simple_task = DigitCompletion(max_number=9, max_diff=9, max_num_in_response=9) - all_prompts = simple_task.get_all_prompts() - - # 21 * 6 * 4 - train_data, test_data = data.random_split(all_prompts, lengths=[0.8, 0.2]) - train_data = list(train_data) - test_data = list(test_data) - - train_data = [[{"role": "user", "content": str(item)}] for item in train_data] - test_data = [[{"role": "user", "content": str(item)}] for item in test_data] - - print(f"Size of train: {len(train_data)}, size of test: {len(test_data)}") - - train_data = {"prompt": train_data} - test_data = {"prompt": test_data} - - model_folder = os.path.join(os.path.dirname(os.path.abspath(__file__))) - - import pandas as pd - - train_data_frame = pd.DataFrame(train_data) - test_data_frame = pd.DataFrame(test_data) - - train_data_frame.to_parquet(os.path.join(model_folder, "train.parquet")) - test_data_frame.to_parquet(os.path.join(model_folder, "test.parquet")) diff --git a/tests/e2e/arithmetic_sequence/data/test.parquet b/tests/e2e/arithmetic_sequence/data/test.parquet deleted file mode 100644 index d0729dc3d..000000000 Binary files a/tests/e2e/arithmetic_sequence/data/test.parquet and /dev/null differ diff --git a/tests/e2e/arithmetic_sequence/data/train.parquet b/tests/e2e/arithmetic_sequence/data/train.parquet deleted file mode 100644 index 0a03a61a8..000000000 Binary files a/tests/e2e/arithmetic_sequence/data/train.parquet and /dev/null differ diff --git a/tests/e2e/arithmetic_sequence/model/config.json b/tests/e2e/arithmetic_sequence/model/config.json deleted file mode 100644 index 87944c51f..000000000 --- a/tests/e2e/arithmetic_sequence/model/config.json +++ /dev/null @@ -1,29 +0,0 @@ -{ - "architectures": [ - "LlamaForCausalLM" - ], - "attention_bias": false, - "attention_dropout": 0.0, - "bos_token_id": null, - "eos_token_id": 1, - "hidden_act": "silu", - "hidden_size": 128, - "initializer_range": 0.02, - "intermediate_size": 344, - "max_position_embeddings": 2048, - "mlp_bias": false, - "model_type": "llama", - "num_attention_heads": 4, - "num_hidden_layers": 4, - "num_key_value_heads": 4, - "pad_token_id": 2, - "pretraining_tp": 1, - "rms_norm_eps": 1e-06, - "rope_scaling": null, - "rope_theta": 10000.0, - "tie_word_embeddings": false, - "torch_dtype": "bfloat16", - "transformers_version": "4.43.3", - "use_cache": true, - "vocab_size": 16 -} diff --git a/tests/e2e/arithmetic_sequence/model/create_model_tokenizer.py b/tests/e2e/arithmetic_sequence/model/create_model_tokenizer.py deleted file mode 100644 index bfab8538d..000000000 --- a/tests/e2e/arithmetic_sequence/model/create_model_tokenizer.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Create a random model and tokenizer for PPO training -""" - -import os - -import torch -from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaConfig - -from tests.e2e.envs.digit_completion import CharTokenizer - -tokenizer = CharTokenizer( - characters=["0", "1", "2", "3", "4", "5", "6", "7", "8", "9", ",", ":"], - model_max_length=2048, - chat_template="{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set role = message['role'] %}{{ message['content'] }}{% endfor %}{% if add_generation_prompt %}{{ sep_token }}{% endif %}", # noqa: E501 -) - -config = LlamaConfig( - vocab_size=(tokenizer.vocab_size + 16 - 1) // 16 * 16, - hidden_size=128, - intermediate_size=344, - num_hidden_layers=4, - num_attention_heads=4, - num_key_value_heads=4, - pad_token_id=tokenizer.pad_token_id, - bos_token_id=tokenizer.bos_token_id, - eos_token_id=tokenizer.eos_token_id, -) - -model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.bfloat16) - -model_folder = os.path.join(os.path.dirname(os.path.abspath(__file__))) -os.makedirs(model_folder, exist_ok=True) - -model.save_pretrained(model_folder) - -tokenizer_folder = model_folder -tokenizer.save_pretrained(tokenizer_folder) - -load_tokenizer = AutoTokenizer.from_pretrained(tokenizer_folder) - -chat = [{"role": "user", "content": "1,0:2,3"}] - -load_tokenizer.padding_side = "left" -print(load_tokenizer.apply_chat_template(chat, tokenize=True, add_generation_prompt=True, max_length=10, padding="max_length")) diff --git a/tests/e2e/arithmetic_sequence/model/generation_config.json b/tests/e2e/arithmetic_sequence/model/generation_config.json deleted file mode 100644 index 578d37505..000000000 --- a/tests/e2e/arithmetic_sequence/model/generation_config.json +++ /dev/null @@ -1,6 +0,0 @@ -{ - "_from_model_config": true, - "eos_token_id": 1, - "pad_token_id": 2, - "transformers_version": "4.43.3" -} diff --git a/tests/e2e/arithmetic_sequence/model/model.safetensors b/tests/e2e/arithmetic_sequence/model/model.safetensors deleted file mode 100644 index 509e6e97c..000000000 Binary files a/tests/e2e/arithmetic_sequence/model/model.safetensors and /dev/null differ diff --git a/tests/e2e/arithmetic_sequence/model/tokenizer_config.json b/tests/e2e/arithmetic_sequence/model/tokenizer_config.json deleted file mode 100644 index d01bf75f1..000000000 --- a/tests/e2e/arithmetic_sequence/model/tokenizer_config.json +++ /dev/null @@ -1,18 +0,0 @@ -{ - "char_ords": [ - 48, - 49, - 50, - 51, - 52, - 53, - 54, - 55, - 56, - 57, - 44, - 58 - ], - "model_max_length": 2048, - "chat_template": "{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set role = message['role'] %}{{ message['content'] }}{% endfor %}{% if add_generation_prompt %}{{ sep_token }}{% endif %}" -} \ No newline at end of file diff --git a/tests/e2e/arithmetic_sequence/rl/README.md b/tests/e2e/arithmetic_sequence/rl/README.md deleted file mode 100644 index 9c561935f..000000000 --- a/tests/e2e/arithmetic_sequence/rl/README.md +++ /dev/null @@ -1,30 +0,0 @@ -# Digit completion - -This is an example of solving a digit completion problem. The problem is defined as below: - -The prompt is a sequence of numbers with fixed difference. The agent's goal is to complete the next N numbers. -If the max number is reached, the next number should be modulo with max number. - -For example, -- prompt = [1, 2, 3] -- N = 5 -- max_number = 6 - -The response should be [4, 5, 6, 7%6, 8%6] = [4, 5, 6, 0, 1]. - -# Environment definition - -The core definition of the task is defined in tests/e2e/envs/digit_completion/task.py - -It is highly recommended to take a look at it for better understanding. - - - -# Run experiments - -An example of running the task is provided in `tests/e2e/run_ray_trainer.sh`. - -```bash -bash tests/e2e/run_ray_trainer.sh -``` - diff --git a/tests/e2e/arithmetic_sequence/rl/main_trainer.py b/tests/e2e/arithmetic_sequence/rl/main_trainer.py deleted file mode 100644 index 311903965..000000000 --- a/tests/e2e/arithmetic_sequence/rl/main_trainer.py +++ /dev/null @@ -1,161 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Using FSDPTrainer -""" - -import os - -import hydra -import ray -import torch -from transformers import AutoTokenizer - -from verl import DataProto -from verl.trainer.ppo.ray_trainer import RayPPOTrainer -from verl.utils.fs import copy_to_local - - -def make_reward_function(tokenizer, num_examine): - def arithmetic_sequence_reward_function(data: DataProto, return_dict: bool = False): - from tests.e2e.envs.digit_completion.task import compute_reward - - reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32) - - for i in range(data.batch.batch_size[0]): - data_item = data[i] # DataProtoItem - - prompt_ids = data_item.batch["prompts"] - - prompt_length = prompt_ids.shape[-1] - - # extract raw prompt - valid_prompt_length = data_item.batch["attention_mask"][:prompt_length].sum() - valid_prompt_ids = prompt_ids[-valid_prompt_length:] - - # extract response - response_ids = data_item.batch["responses"] - response_length = response_ids.shape[-1] - response_mask = data.batch["attention_mask"][i][-response_length:] - valid_response_length = data_item.batch["attention_mask"][prompt_length:].sum() - valid_response_ids = response_ids[:valid_response_length] - - # decode - prompt = tokenizer.decode(valid_prompt_ids) - response = tokenizer.decode(valid_response_ids) - # remove bos and eos - prompt = prompt.replace(tokenizer.sep_token, "") - response = response.replace(tokenizer.eos_token, "") - if i < num_examine: - print(prompt, response) - - reward_output = compute_reward(prompt, response) - dense_reward = reward_output[0].tolist() - ground_truth_response = reward_output[1]["ground_truth_response"] - last_reward = dense_reward[-1] if len(dense_reward) > 0 else 1 if len(ground_truth_response) == 0 else 0 - - # pad to response_length - for _ in range(reward_tensor.shape[-1] - len(dense_reward)): - dense_reward.append(last_reward) - - dense_reward = torch.as_tensor(dense_reward, dtype=torch.float32, device=reward_tensor.device) - reward_tensor[i] = dense_reward * response_mask - - if return_dict: - return {"reward_tensor": reward_tensor} - else: - return reward_tensor - - return arithmetic_sequence_reward_function - - -@hydra.main(config_path="../../../../verl/trainer/config", config_name="ppo_trainer", version_base=None) -def main(config): - ray.init( - runtime_env={ - "env_vars": { - "MEGATRON_USE_CUDA_TIMER": "0", - "MEGATRON_START_PROCESS_TIMER": "False", - "TOKENIZERS_PARALLELISM": "true", - "NCCL_DEBUG": "WARN", - } - }, - num_cpus=config.ray_init.num_cpus, - ) - - # print initial config - from pprint import pprint - - from omegaconf import OmegaConf - - pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values - - # print the config - # print initial config - print("Config after normalizing batch_size") - pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values - - # download the checkpoint from hdfs - local_path = copy_to_local(config.actor_rollout_ref.model.path) - local_path = os.path.expanduser(local_path) - # instantiate tokenizern - from transformers import LlamaConfig - - from tests.e2e.envs.digit_completion import CharTokenizer - - AutoTokenizer.register(LlamaConfig, CharTokenizer, exist_ok=True) - tokenizer = AutoTokenizer.from_pretrained(local_path) - print(f"Tokenizer vocab_size: {tokenizer.vocab_size}") - - # define worker classes - from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role - from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker - - role_worker_mapping = { - Role.ActorRollout: ray.remote(ActorRolloutRefWorker), - Role.Critic: ray.remote(CriticWorker), - } - - global_pool_id = "global_pool" - resource_pool_spec = { - global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, - } - mapping = { - Role.ActorRollout: global_pool_id, - Role.Critic: global_pool_id, - } - - # use reward model - if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss: - role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker) - mapping[Role.RefPolicy] = global_pool_id - - reward_fn = make_reward_function(tokenizer=tokenizer, num_examine=1) - - resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) - - trainer = RayPPOTrainer( - config=config, - tokenizer=tokenizer, - role_worker_mapping=role_worker_mapping, - resource_pool_manager=resource_pool_manager, - reward_fn=reward_fn, - val_reward_fn=reward_fn, - ) - trainer.init_workers() - trainer.fit() - - -if __name__ == "__main__": - main() diff --git a/tests/e2e/naive_chat_scheduler.py b/tests/e2e/naive_chat_scheduler.py deleted file mode 100644 index 87592c766..000000000 --- a/tests/e2e/naive_chat_scheduler.py +++ /dev/null @@ -1,135 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import asyncio -from typing import Any, Dict, List - -import torch -from openai.types.chat.chat_completion import ChatCompletion -from tensordict import TensorDict - -from verl.protocol import DataProto -from verl.workers.rollout.async_server import ChatCompletionScheduler - - -class NaiveChatCompletionScheduler(ChatCompletionScheduler): - """ - A very naive implementation of ChatCompletionScheduler for demo purpose, - only do single-turn chat completion. - """ - - async def generate_sequences(self, batch: DataProto, **sampling_params) -> DataProto: - kwargs = dict( - n=self.config.n, - max_completion_tokens=self.config.response_length, - temperature=self.config.temperature, - top_p=self.config.top_p, - ) - - do_sample = batch.meta_info.get("do_sample", True) - is_validate = batch.meta_info.get("validate", False) - if not do_sample or is_validate: - kwargs["n"] = 1 - kwargs["temperature"] = 0 - - kwargs.update(sampling_params) - print(f"[NaiveChatCompletionScheduler] generate_sequences sampling params: {kwargs}") - - async def callback(completions: ChatCompletion, info: Dict[str, Any], exception: Exception): - assert exception is None, f"exception: {exception}" - conversation, batch_conversations, batch_index = ( - info["conversation"], - info["batch_conversations"], - info["batch_index"], - ) - - conversations = [] - for choice in completions.choices: - chat = conversation.copy() - chat.append({"role": choice.message.role, "content": choice.message.content}) - conversations.append(chat) - batch_conversations[batch_index] = conversations - - # NOTE: we can call tools and resubmit chat completions here. - # call_tools(completions, info) - # await self.submit_chat_completions(callback2, ...) - - # TODO: we may need to control max concurrent requests here, or it will harm prefix cache hit rate. - tasks, batch_conversations = [], [None] * len(batch) - for batch_index, conversation in enumerate(batch.non_tensor_batch["raw_prompt"]): - # raw_prompt: [{"role": "user", "content": ""}, ["role": "assistant", "content"], ...] - tasks.append( - asyncio.create_task( - self.submit_chat_completions( - callback=callback, - callback_additional_info={ - "batch_conversations": batch_conversations, - "batch_index": batch_index, - "conversation": list(conversation), - }, - model=self.model_name, - messages=conversation.tolist(), - **kwargs, - ) - ) - ) - await asyncio.gather(*tasks) - print("[NaiveChatCompletionScheduler] generate_sequences done") - - return self._postprocess(batch, batch_conversations, kwargs["n"]) - - def _postprocess(self, batch: DataProto, batch_conversations: List[List[List[Dict[str, str]]]], n: int) -> DataProto: - # NOTE: consistent with batch version of generate_sequences in vllm_rollout_spmd.py - # prompts: left pad - # responses: right pad - # input_ids: prompt + response - # attention_mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0] - # position_ids: [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11] - - # prompts: [prompt] from input dataset - prompts = [self.tokenizer.apply_chat_template(prompt, add_generation_prompt=True, tokenize=False) for prompt in batch.non_tensor_batch["raw_prompt"]] - - # flatten batch_conversations if n > 1 - assert len(batch_conversations) == len(prompts) - batch_conversations = [conversation for conversations in batch_conversations for conversation in conversations] - assert len(batch_conversations) == len(prompts) * n - - # sequences: [prompt + response] - sequences = [self.tokenizer.apply_chat_template(conversation, add_generation_prompt=False, tokenize=False) for conversation in batch_conversations] - - # responses: [response] - # TODO: mask out tools calling tokens? - responses = [sequence[len(prompts[i // n]) :] for i, sequence in enumerate(sequences)] - - prompts = self.tokenizer(prompts, return_tensors="pt", padding="longest", padding_side="left") - responses = self.tokenizer(responses, return_tensors="pt", padding="longest", padding_side="right") - if n > 1: - prompts["input_ids"] = prompts["input_ids"].repeat_interleave(n, dim=0) - prompts["attention_mask"] = prompts["attention_mask"].repeat_interleave(n, dim=0) - - input_ids = torch.cat([prompts["input_ids"], responses["input_ids"]], dim=1) - attention_mask = torch.cat([prompts["attention_mask"], responses["attention_mask"]], dim=1) - position_ids = (attention_mask.cumsum(dim=1) - 1) * attention_mask - - batch = TensorDict( - { - "prompts": prompts["input_ids"], - "responses": responses["input_ids"], - "input_ids": input_ids, - "attention_mask": attention_mask, - "position_ids": position_ids, - }, - batch_size=len(input_ids), - ) - - return DataProto(batch=batch) diff --git a/tests/e2e/run_deepseek_megatron.sh b/tests/e2e/run_deepseek_megatron.sh deleted file mode 100644 index f86c96b30..000000000 --- a/tests/e2e/run_deepseek_megatron.sh +++ /dev/null @@ -1,40 +0,0 @@ -set -x - -# the config file used: verl/trainer/main_ppo/config/ppo_megatron_trainer.yaml - -huggingface-cli download deepseek-ai/deepseek-coder-1.3b-instruct - -python3 -m verl.trainer.main_ppo --config-path=config \ - --config-name='ppo_megatron_trainer.yaml'\ - data.train_files=$HOME/data/gsm8k/train.parquet \ - data.val_files=$HOME/data/gsm8k/test.parquet \ - data.train_batch_size=1024 \ - data.max_prompt_length=512 \ - data.max_response_length=512 \ - actor_rollout_ref.model.path=deepseek-ai/deepseek-coder-1.3b-instruct \ - actor_rollout_ref.actor.optim.lr=2e-6 \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \ - actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \ - critic.optim.lr=2e-5 \ - critic.model.path=deepseek-ai/deepseek-coder-1.3b-instruct \ - critic.model.enable_gradient_checkpointing=False \ - critic.ppo_micro_batch_size_per_gpu=4 \ - critic.megatron.tensor_model_parallel_size=2 \ - algorithm.kl_ctrl.kl_coef=0.001 \ - trainer.critic_warmup=0 \ - trainer.logger=['console'] \ - trainer.project_name='verl_megatron_gsm8k_examples' \ - trainer.experiment_name='deepseek_llm_1b3_function_rm' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=-1 \ - trainer.test_freq=1 \ - trainer.total_epochs=15 \ - trainer.total_training_steps=3 $@ diff --git a/tests/e2e/run_deepseek_megatron_parallelism.sh b/tests/e2e/run_deepseek_megatron_parallelism.sh deleted file mode 100644 index 5398092d6..000000000 --- a/tests/e2e/run_deepseek_megatron_parallelism.sh +++ /dev/null @@ -1,46 +0,0 @@ -set -x - -# the config file used: verl/trainer/main_ppo/config/ppo_megatron_trainer.yaml - -huggingface-cli download deepseek-ai/deepseek-coder-1.3b-instruct - -python3 -m verl.trainer.main_ppo --config-path=config \ - --config-name='ppo_megatron_trainer.yaml'\ - data.train_files=$HOME/data/gsm8k/train.parquet \ - data.val_files=$HOME/data/gsm8k/test.parquet \ - data.train_batch_size=1024 \ - data.max_prompt_length=512 \ - data.max_response_length=512 \ - actor_rollout_ref.model.path=deepseek-ai/deepseek-coder-1.3b-instruct \ - actor_rollout_ref.actor.optim.lr=2e-6 \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \ - actor_rollout_ref.actor.megatron.virtual_pipeline_model_parallel_size=2 \ - actor_rollout_ref.actor.megatron.tensor_model_parallel_size=4 \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \ - actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=2 \ - actor_rollout_ref.ref.megatron.virtual_pipeline_model_parallel_size=2 \ - actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \ - critic.optim.lr=2e-5 \ - critic.model.path=deepseek-ai/deepseek-coder-1.3b-instruct \ - critic.model.enable_gradient_checkpointing=False \ - critic.ppo_micro_batch_size_per_gpu=4 \ - critic.megatron.pipeline_model_parallel_size=2 \ - critic.megatron.virtual_pipeline_model_parallel_size=2 \ - critic.megatron.tensor_model_parallel_size=2 \ - algorithm.kl_ctrl.kl_coef=0.001 \ - trainer.critic_warmup=0 \ - trainer.logger=['console'] \ - trainer.project_name='verl_megatron_gsm8k_examples' \ - trainer.experiment_name='deepseek_llm_1b3_function_rm' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=-1 \ - trainer.test_freq=1 \ - trainer.total_epochs=15 \ - trainer.total_training_steps=3 $@ diff --git a/tests/e2e/run_ppo_trainer_megatron.sh b/tests/e2e/run_ppo_trainer_megatron.sh deleted file mode 100644 index 91a3d5ac8..000000000 --- a/tests/e2e/run_ppo_trainer_megatron.sh +++ /dev/null @@ -1,216 +0,0 @@ -#!/usr/bin/env bash -set -xeuo pipefail - -export CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping - -NUM_GPUS=${NUM_GPUS:-8} - -MODEL_ID=${MODEL_ID:-Qwen/Qwen2.5-0.5B} -MODEL_PATH=${MODEL_PATH:-${HOME}/models/${MODEL_ID}} -huggingface-cli download "${MODEL_ID}" --local-dir "${MODEL_PATH}" - -USE_DUMMY_MODEL=${USE_DUMMY_MODEL:-False} -DUMMY_MODEL_PATH=${DUMMY_MODEL_PATH:-${HOME}/dummy_models/${MODEL_ID}} -if [ "$USE_DUMMY_MODEL" = "True" ]; then - if [ -z "${DUMMY_MODEL_CONFIG_PATH}" ]; then - echo "[ERROR] DUMMY_MODEL_CONFIG_PATH not set" - exit 1 - fi - - python scripts/init_random_model.py \ - --hf_model_path "${MODEL_PATH}" \ - --new_config_path "${DUMMY_MODEL_CONFIG_PATH}" \ - --output_path "${DUMMY_MODEL_PATH}" - - MODEL_PATH="${DUMMY_MODEL_PATH}" -fi - -TRAIN_FILES=${TRAIN_FILES:-${HOME}/data/gsm8k/train.parquet} -VAL_FILES=${VAL_FILES:-${HOME}/data/gsm8k/test.parquet} - -ADV_ESTIMATOR=${ADV_ESTIMATOR:-gae} -# Validation -VAL_BEFORE_TRAIN=${VAL_BEFORE_TRAIN:-False} -TEST_FREQ=${TEST_FREQ:--1} -# Save & Resume -RESUME_MODE=${RESUME_MODE:-disable} -SAVE_FREQ=${SAVE_FREQ:--1} -TOTAL_TRAIN_STEPS=${TOTAL_TRAIN_STEPS:-1} - -USE_DYNAMIC_BSZ=${USE_DYNAMIC_BSZ:-True} -ppo_max_token_len_per_gpu=${PPO_MAX_TOKEN_LEN:-2400} -forward_max_token_len_per_gpu=${FWD_MAX_TOKEN_LEN:-4800} -train_traj_micro_bsz_per_gpu=${MICRO_BSZ:-2} # b -n_resp_per_prompt=4 # g - -train_traj_micro_bsz=$((train_traj_micro_bsz_per_gpu * NUM_GPUS)) # b * n -train_traj_mini_bsz=$((train_traj_micro_bsz * 2)) # 2 * b * n -train_prompt_mini_bsz=$((train_traj_mini_bsz * n_resp_per_prompt)) # 2 * b * n / g -train_prompt_bsz=$((train_prompt_mini_bsz * 2)) # 4 * b * n / g - -MAX_PROMPT_LENGTH=${MAX_PROMPT_LENGTH:-512} -MAX_RESPONSE_LENGTH=${MAX_RESPONSE_LENGTH:-512} - -COMMON_PP=${COMMON_PP:-2} -COMMON_VPP=${COMMON_VPP:-2} -COMMON_CP=${COMMON_CP:-2} -COMMON_TP=${COMMON_TP:-2} -COMMON_EP=${COMMON_EP:-1} -COMMON_ETP=${COMMON_ETP:-null} - -TRAIN_TP=${TRAIN_TP:-$COMMON_TP} -INFER_TP=${INFER_TP:-$COMMON_TP} - -ACTOR_PP=${ACTOR_PP:-$COMMON_PP} -ACTOR_VPP=${ACTOR_VPP:-$COMMON_VPP} -ACTOR_CP=${ACTOR_CP:-$COMMON_CP} -ACTOR_TP=${ACTOR_TP:-$TRAIN_TP} -ACTOR_EP=${ACTOR_EP:-$COMMON_EP} -ACTOR_ETP=${ACTOR_ETP:-$COMMON_ETP} -ROLLOUT_TP=${ROLLOUT_TP:-$INFER_TP} -REF_PP=${REF_PP:-$COMMON_PP} -REF_VPP=${REF_VPP:-$COMMON_VPP} -REF_CP=${REF_CP:-$COMMON_CP} -REF_TP=${REF_TP:-$TRAIN_TP} -REF_EP=${REF_EP:-$COMMON_EP} -REF_ETP=${REF_ETP:-$COMMON_ETP} -CRITIC_PP=${CRITIC_PP:-$COMMON_PP} -CRITIC_VPP=${CRITIC_VPP:-$COMMON_VPP} -CRITIC_CP=${CRITIC_CP:-$COMMON_CP} -CRITIC_TP=${CRITIC_TP:-$TRAIN_TP} -CRITIC_EP=${CRITIC_EP:-$COMMON_EP} -CRITIC_ETP=${CRITIC_ETP:-$COMMON_ETP} -RM_PP=${RM_PP:-$COMMON_PP} -RM_VPP=${RM_VPP:-$COMMON_VPP} -RM_CP=${RM_CP:-$COMMON_CP} -RM_TP=${RM_TP:-$TRAIN_TP} -RM_EP=${RM_EP:-$COMMON_EP} -RM_ETP=${RM_ETP:-$COMMON_ETP} - -ALL_OFFLOAD=${ALL_OFFLOAD:-False} -COMMON_PARAM_OFFLOAD=${COMMON_PARAM_OFFLOAD:-$ALL_OFFLOAD} -COMMON_GRAD_OFFLOAD=${COMMON_GRAD_OFFLOAD:-$ALL_OFFLOAD} -COMMON_OPTIMIZER_OFFLOAD=${COMMON_OPTIMIZER_OFFLOAD:-$ALL_OFFLOAD} - -ACTOR_PARAM_OFFLOAD=${ACTOR_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD} -ACTOR_GRAD_OFFLOAD=${ACTOR_GRAD_OFFLOAD:-$COMMON_GRAD_OFFLOAD} -ACTOR_OPTIMIZER_OFFLOAD=${ACTOR_OPTIMIZER_OFFLOAD:-$COMMON_OPTIMIZER_OFFLOAD} -REF_PARAM_OFFLOAD=${REF_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD} -CRITIC_PARAM_OFFLOAD=${CRITIC_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD} -CRITIC_GRAD_OFFLOAD=${CRITIC_GRAD_OFFLOAD:-$COMMON_GRAD_OFFLOAD} -CRITIC_OPTIMIZER_OFFLOAD=${CRITIC_OPTIMIZER_OFFLOAD:-$COMMON_OPTIMIZER_OFFLOAD} -RM_PARAM_OFFLOAD=${RM_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD} - -CHECKPOINT_CONTENTS=['model','hf_model','optimizer','extra'] -SKIP_SAVE_HF_MODEL=${SKIP_SAVE_HF_MODEL:-0} -if [ $SKIP_SAVE_HF_MODEL -eq 1 ]; then - CHECKPOINT_CONTENTS=['model','optimizer','extra'] -fi - -USE_DIST_CKPT=${USE_DIST_CKPT:-False} -DIST_CKPT_PATH=${DIST_CKPT_PATH:-${HOME}/dist_ckpt/${MODEL_ID}} -if [ "$USE_DIST_CKPT" = "True" ]; then - if [ "$USE_DUMMY_MODEL" = "True" ]; then - DIST_CKPT_PATH=${HOME}/dist_ckpt_dummy/${MODEL_ID} - fi - python scripts/converter_hf_to_mcore.py \ - --hf_model_path "${MODEL_PATH}" \ - --output_path "${DIST_CKPT_PATH}" -fi - -ENGINES=("vllm" "sglang_async") - -exp_name="$(basename "${MODEL_ID,,}")-megatron-gsm8k-minimal" - -for ENGINE in "${ENGINES[@]}"; do - python3 -m verl.trainer.main_ppo --config-path=config \ - --config-name='ppo_megatron_trainer.yaml'\ - algorithm.adv_estimator="${ADV_ESTIMATOR}" \ - data.train_files="${TRAIN_FILES}" \ - data.val_files="${VAL_FILES}" \ - data.train_batch_size=${train_prompt_bsz} \ - data.max_prompt_length=${MAX_PROMPT_LENGTH} \ - data.max_response_length=${MAX_RESPONSE_LENGTH} \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - actor_rollout_ref.model.path="${MODEL_PATH}" \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ - actor_rollout_ref.actor.use_dynamic_bsz=${USE_DYNAMIC_BSZ} \ - actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${ppo_max_token_len_per_gpu} \ - actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=$ACTOR_PP \ - actor_rollout_ref.actor.megatron.virtual_pipeline_model_parallel_size=$ACTOR_VPP \ - actor_rollout_ref.actor.megatron.context_parallel_size=$ACTOR_CP \ - actor_rollout_ref.actor.megatron.tensor_model_parallel_size=$ACTOR_TP \ - actor_rollout_ref.actor.megatron.expert_model_parallel_size=$ACTOR_EP \ - actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=$ACTOR_ETP \ - actor_rollout_ref.actor.megatron.param_offload=${ACTOR_PARAM_OFFLOAD} \ - actor_rollout_ref.actor.megatron.optimizer_offload=${ACTOR_OPTIMIZER_OFFLOAD} \ - actor_rollout_ref.actor.megatron.grad_offload=${ACTOR_GRAD_OFFLOAD} \ - actor_rollout_ref.actor.megatron.use_dist_checkpointing=${USE_DIST_CKPT} \ - actor_rollout_ref.actor.megatron.dist_checkpointing_path=${DIST_CKPT_PATH} \ - actor_rollout_ref.actor.use_kl_loss=True \ - actor_rollout_ref.actor.kl_loss_coef=0.001 \ - actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.actor.checkpoint.contents=$CHECKPOINT_CONTENTS \ - actor_rollout_ref.rollout.name="${ENGINE}" \ - actor_rollout_ref.rollout.tensor_model_parallel_size=$ROLLOUT_TP \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ - actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ - actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=$REF_PP \ - actor_rollout_ref.ref.megatron.virtual_pipeline_model_parallel_size=$REF_VPP \ - actor_rollout_ref.ref.megatron.context_parallel_size=$REF_CP \ - actor_rollout_ref.ref.megatron.tensor_model_parallel_size=$REF_TP \ - actor_rollout_ref.ref.megatron.expert_model_parallel_size=$REF_EP \ - actor_rollout_ref.ref.megatron.expert_tensor_parallel_size=$REF_ETP \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ - actor_rollout_ref.ref.megatron.param_offload=${REF_PARAM_OFFLOAD} \ - actor_rollout_ref.ref.megatron.use_dist_checkpointing=${USE_DIST_CKPT} \ - actor_rollout_ref.ref.megatron.dist_checkpointing_path=${DIST_CKPT_PATH} \ - critic.optim.lr=2e-5 \ - critic.model.path="${MODEL_PATH}" \ - critic.model.enable_gradient_checkpointing=False \ - critic.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ - critic.ppo_max_token_len_per_gpu=${forward_max_token_len_per_gpu} \ - critic.megatron.pipeline_model_parallel_size=$CRITIC_PP \ - critic.megatron.virtual_pipeline_model_parallel_size=$CRITIC_VPP \ - critic.megatron.context_parallel_size=$CRITIC_CP \ - critic.megatron.tensor_model_parallel_size=$CRITIC_TP \ - critic.megatron.expert_model_parallel_size=$CRITIC_EP \ - critic.megatron.expert_tensor_parallel_size=$CRITIC_ETP \ - critic.megatron.param_offload=${CRITIC_PARAM_OFFLOAD} \ - critic.megatron.optimizer_offload=${CRITIC_OPTIMIZER_OFFLOAD} \ - critic.megatron.grad_offload=${CRITIC_GRAD_OFFLOAD} \ - critic.megatron.use_dist_checkpointing=${USE_DIST_CKPT} \ - critic.megatron.dist_checkpointing_path=${DIST_CKPT_PATH} \ - critic.checkpoint.contents=$CHECKPOINT_CONTENTS \ - reward_model.enable=True \ - reward_model.model.path="${MODEL_PATH}" \ - reward_model.micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ - reward_model.megatron.pipeline_model_parallel_size=$RM_PP \ - reward_model.megatron.virtual_pipeline_model_parallel_size=$RM_VPP \ - reward_model.megatron.context_parallel_size=$RM_CP \ - reward_model.megatron.tensor_model_parallel_size=$RM_TP \ - reward_model.megatron.expert_model_parallel_size=$RM_TP \ - reward_model.megatron.expert_tensor_parallel_size=$RM_TP \ - reward_model.megatron.param_offload=${RM_PARAM_OFFLOAD} \ - reward_model.megatron.use_dist_checkpointing=${USE_DIST_CKPT} \ - reward_model.megatron.dist_checkpointing_path=${DIST_CKPT_PATH} \ - algorithm.use_kl_in_reward=False \ - algorithm.kl_penalty=kl \ - algorithm.kl_ctrl.kl_coef=0.001 \ - trainer.critic_warmup=0 \ - trainer.logger=['console'] \ - trainer.project_name='verl-test' \ - trainer.experiment_name="${exp_name}" \ - trainer.nnodes=1 \ - trainer.n_gpus_per_node=${NUM_GPUS} \ - trainer.val_before_train="${VAL_BEFORE_TRAIN}" \ - trainer.test_freq="${TEST_FREQ}" \ - trainer.save_freq="${SAVE_FREQ}" \ - trainer.resume_mode="${RESUME_MODE}" \ - trainer.total_epochs=2 \ - trainer.total_training_steps="${TOTAL_TRAIN_STEPS}" $@ -done diff --git a/tests/e2e/run_qwen_gsm8k_custom_function_rm.sh b/tests/e2e/run_qwen_gsm8k_custom_function_rm.sh deleted file mode 100644 index 72eb9dbb1..000000000 --- a/tests/e2e/run_qwen_gsm8k_custom_function_rm.sh +++ /dev/null @@ -1,59 +0,0 @@ -#!/bin/bash -set -e -x -FILE="$(pwd)/my_reward_function.py" -rm -rf $FILE -cat < "$FILE" -def my_reward_function(data_source, solution_str, ground_truth, extra_info=None): - print(f"Congratulations!!! You have called my_reward_function successfully!!!") - return 0.1 -EOF - - -OUTPUT_FILE="$(pwd)/output_custom_reward.txt" -FUNCTION_NAME="my_reward_function" -rm -rf $OUTPUT_FILE - -export VLLM_ATTENTION_BACKEND=XFORMERS - -python3 -m verl.trainer.main_ppo \ - data.train_files=$HOME/data/gsm8k/train.parquet \ - data.val_files=$HOME/data/gsm8k/test.parquet \ - data.train_batch_size=1024 \ - data.max_prompt_length=512 \ - data.max_response_length=512 \ - actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - critic.optim.lr=1e-5 \ - critic.model.use_remove_padding=True \ - critic.model.path=Qwen/Qwen2.5-0.5B \ - critic.model.enable_gradient_checkpointing=False \ - critic.ppo_micro_batch_size_per_gpu=4 \ - critic.model.fsdp_config.param_offload=False \ - critic.model.fsdp_config.optimizer_offload=False \ - algorithm.kl_ctrl.kl_coef=0.001 \ - custom_reward_function.path=$FILE\ - custom_reward_function.name=$FUNCTION_NAME\ - trainer.critic_warmup=0 \ - trainer.logger=['console'] \ - trainer.project_name='verl_example_gsm8k' \ - trainer.experiment_name='qwen_e2e_ci_custom_function_rm' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=-1 \ - trainer.default_local_dir=$HOME/ckpt/ \ - trainer.total_training_steps=2 | tee $OUTPUT_FILE; - -python3 tests/e2e/check_custom_rwd_fn.py --output_file=$OUTPUT_FILE -rm -rf $FILE -rm -rf $OUTPUT_FILE \ No newline at end of file diff --git a/tests/e2e/run_qwen_gsm8k_function_rm_grpo.sh b/tests/e2e/run_qwen_gsm8k_function_rm_grpo.sh deleted file mode 100644 index f9eb4aca8..000000000 --- a/tests/e2e/run_qwen_gsm8k_function_rm_grpo.sh +++ /dev/null @@ -1,33 +0,0 @@ -set -x - -export VLLM_ATTENTION_BACKEND=XFORMERS - -python3 -m verl.trainer.main_ppo \ - data.train_files=$HOME/data/gsm8k/train.parquet \ - data.val_files=$HOME/data/gsm8k/test.parquet \ - data.train_batch_size=1024 \ - data.max_prompt_length=512 \ - data.max_response_length=512 \ - actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - algorithm.kl_ctrl.kl_coef=0.001 \ - algorithm.adv_estimator=grpo \ - trainer.critic_warmup=0 \ - trainer.logger=['console'] \ - trainer.project_name='verl_example_gsm8k' \ - trainer.experiment_name='qwen_e2e_ci_function_rm' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=-1 \ - trainer.total_training_steps=1 $@ diff --git a/tests/e2e/run_qwen_gsm8k_function_rm_no_rmpad.sh b/tests/e2e/run_qwen_gsm8k_function_rm_no_rmpad.sh deleted file mode 100644 index cbe4fcec1..000000000 --- a/tests/e2e/run_qwen_gsm8k_function_rm_no_rmpad.sh +++ /dev/null @@ -1,40 +0,0 @@ -set -x - -export VLLM_ATTENTION_BACKEND=XFORMERS - -python3 -m verl.trainer.main_ppo \ - data.train_files=$HOME/data/gsm8k/train.parquet \ - data.val_files=$HOME/data/gsm8k/test.parquet \ - data.train_batch_size=1024 \ - data.max_prompt_length=512 \ - data.max_response_length=512 \ - actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.model.use_remove_padding=False \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - critic.optim.lr=1e-5 \ - critic.model.use_remove_padding=False \ - critic.model.path=Qwen/Qwen2.5-0.5B \ - critic.model.enable_gradient_checkpointing=False \ - critic.ppo_micro_batch_size_per_gpu=4 \ - critic.model.fsdp_config.param_offload=False \ - critic.model.fsdp_config.optimizer_offload=False \ - algorithm.kl_ctrl.kl_coef=0.001 \ - trainer.critic_warmup=0 \ - trainer.logger=['console'] \ - +trainer.val_before_train=False \ - trainer.project_name='verl_example_gsm8k' \ - trainer.experiment_name='qwen_e2e_ci_function_rm' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=-1 \ - trainer.total_training_steps=1 $@ diff --git a/tests/e2e/run_qwen_gsm8k_function_rm_remax.sh b/tests/e2e/run_qwen_gsm8k_function_rm_remax.sh deleted file mode 100644 index 1eb5f7752..000000000 --- a/tests/e2e/run_qwen_gsm8k_function_rm_remax.sh +++ /dev/null @@ -1,33 +0,0 @@ -set -x - -export VLLM_ATTENTION_BACKEND=XFORMERS - -python3 -m verl.trainer.main_ppo \ - data.train_files=$HOME/data/gsm8k/train.parquet \ - data.val_files=$HOME/data/gsm8k/test.parquet \ - data.train_batch_size=1024 \ - data.max_prompt_length=512 \ - data.max_response_length=512 \ - actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - algorithm.kl_ctrl.kl_coef=0.001 \ - algorithm.adv_estimator=remax \ - trainer.critic_warmup=0 \ - trainer.logger=['console'] \ - trainer.project_name='verl_example_gsm8k' \ - trainer.experiment_name='qwen_e2e_ci_function_rm' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=-1 \ - trainer.total_training_steps=1 $@ diff --git a/tests/e2e/run_qwen_gsm8k_model_rm.sh b/tests/e2e/run_qwen_gsm8k_model_rm.sh deleted file mode 100644 index 3e908d7a4..000000000 --- a/tests/e2e/run_qwen_gsm8k_model_rm.sh +++ /dev/null @@ -1,48 +0,0 @@ -set -x - -export VLLM_ATTENTION_BACKEND=XFORMERS - -python3 -m verl.trainer.main_ppo \ - data.train_files=$HOME/data/gsm8k/train.parquet \ - data.val_files=$HOME/data/gsm8k/test.parquet \ - data.train_batch_size=1024 \ - data.max_prompt_length=512 \ - data.max_response_length=512 \ - data.return_raw_chat=True \ - actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.1 \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - critic.optim.lr=1e-5 \ - critic.model.use_remove_padding=True \ - critic.optim.lr_warmup_steps_ratio=0.05 \ - critic.model.path=Qwen/Qwen2.5-0.5B \ - critic.model.enable_gradient_checkpointing=False \ - critic.ppo_micro_batch_size_per_gpu=4 \ - critic.model.fsdp_config.param_offload=False \ - critic.model.fsdp_config.optimizer_offload=False \ - reward_model.enable=True \ - reward_model.model.path=Qwen/Qwen2.5-0.5B\ - reward_model.model.use_remove_padding=True \ - reward_model.model.fsdp_config.param_offload=True \ - reward_model.micro_batch_size_per_gpu=16 \ - algorithm.kl_ctrl.kl_coef=0.001 \ - trainer.critic_warmup=0 \ - trainer.logger=['console'] \ - +trainer.val_before_train=False \ - trainer.project_name='verl_example' \ - trainer.experiment_name='Qwen2.5-0.5B-ci_hybrid_rm' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=-1 \ - trainer.total_training_steps=1 $@ diff --git a/tests/e2e/run_qwen_gsm8k_model_rm_liger_kernel.sh b/tests/e2e/run_qwen_gsm8k_model_rm_liger_kernel.sh deleted file mode 100644 index ca6a5d108..000000000 --- a/tests/e2e/run_qwen_gsm8k_model_rm_liger_kernel.sh +++ /dev/null @@ -1,49 +0,0 @@ -set -x - -export VLLM_ATTENTION_BACKEND=XFORMERS - -python3 -m verl.trainer.main_ppo \ - data.train_files=$HOME/data/gsm8k/train.parquet \ - data.val_files=$HOME/data/gsm8k/test.parquet \ - data.train_batch_size=1024 \ - data.max_prompt_length=512 \ - data.max_response_length=512 \ - data.return_raw_chat=True \ - actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.model.use_remove_padding=True \ - +actor_rollout_ref.model.use_liger=True \ - actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.1 \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size=32 \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size=128 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ - actor_rollout_ref.ref.log_prob_micro_batch_size=128 \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - critic.optim.lr=1e-5 \ - critic.model.use_remove_padding=True \ - critic.optim.lr_warmup_steps_ratio=0.05 \ - critic.model.path=Qwen/Qwen2.5-0.5B \ - critic.model.enable_gradient_checkpointing=False \ - critic.ppo_micro_batch_size=32 \ - critic.model.fsdp_config.param_offload=False \ - critic.model.fsdp_config.optimizer_offload=False \ - reward_model.enable=True \ - reward_model.model.path=Qwen/Qwen2.5-0.5B\ - reward_model.model.use_remove_padding=True \ - reward_model.model.fsdp_config.param_offload=True \ - reward_model.micro_batch_size=16 \ - algorithm.kl_ctrl.kl_coef=0.001 \ - trainer.critic_warmup=0 \ - trainer.logger=['console'] \ - +trainer.val_before_train=False \ - trainer.project_name='verl_example' \ - trainer.experiment_name='Qwen2.5-0.5B-ci_hybrid_rm' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=-1 \ - trainer.total_training_steps=1 $@ diff --git a/tests/e2e/run_qwen_gsm8k_model_rm_no_rmpad.sh b/tests/e2e/run_qwen_gsm8k_model_rm_no_rmpad.sh deleted file mode 100644 index 97f270c52..000000000 --- a/tests/e2e/run_qwen_gsm8k_model_rm_no_rmpad.sh +++ /dev/null @@ -1,48 +0,0 @@ -set -x - -export VLLM_ATTENTION_BACKEND=XFORMERS - -python3 -m verl.trainer.main_ppo \ - data.train_files=$HOME/data/gsm8k/train.parquet \ - data.val_files=$HOME/data/gsm8k/test.parquet \ - data.train_batch_size=1024 \ - data.max_prompt_length=512 \ - data.max_response_length=512 \ - data.return_raw_chat=True \ - actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.model.use_remove_padding=False \ - actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.1 \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - critic.optim.lr=1e-5 \ - critic.model.use_remove_padding=False \ - critic.optim.lr_warmup_steps_ratio=0.05 \ - critic.model.path=Qwen/Qwen2.5-0.5B \ - critic.model.enable_gradient_checkpointing=False \ - critic.ppo_micro_batch_size_per_gpu=4 \ - critic.model.fsdp_config.param_offload=False \ - critic.model.fsdp_config.optimizer_offload=False \ - reward_model.enable=True \ - reward_model.model.path=Qwen/Qwen2.5-0.5B\ - reward_model.model.use_remove_padding=False \ - reward_model.model.fsdp_config.param_offload=True \ - reward_model.micro_batch_size_per_gpu=16 \ - algorithm.kl_ctrl.kl_coef=0.001 \ - trainer.critic_warmup=0 \ - +trainer.val_before_train=False \ - trainer.logger=['console'] \ - trainer.project_name='verl_example' \ - trainer.experiment_name='Qwen2.5-0.5B-ci_hybrid_rm' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=-1 \ - trainer.total_training_steps=1 $@ diff --git a/tests/e2e/run_qwen_gsm8k_model_rm_seq_balance.sh b/tests/e2e/run_qwen_gsm8k_model_rm_seq_balance.sh deleted file mode 100644 index efe279fb5..000000000 --- a/tests/e2e/run_qwen_gsm8k_model_rm_seq_balance.sh +++ /dev/null @@ -1,51 +0,0 @@ -set -x - -export VLLM_ATTENTION_BACKEND=XFORMERS - -python3 -m verl.trainer.main_ppo \ - data.train_files=$HOME/data/gsm8k/train.parquet \ - data.val_files=$HOME/data/gsm8k/test.parquet \ - data.train_batch_size=1024 \ - data.max_prompt_length=512 \ - data.max_response_length=512 \ - data.return_raw_chat=True \ - actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.1 \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.use_dynamic_bsz=True \ - actor_rollout_ref.actor.ppo_max_token_len_per_gpu=12000 \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ - actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=12000 \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=12000 \ - critic.optim.lr=1e-5 \ - critic.model.use_remove_padding=True \ - critic.optim.lr_warmup_steps_ratio=0.05 \ - critic.model.path=Qwen/Qwen2.5-0.5B \ - critic.model.enable_gradient_checkpointing=False \ - critic.use_dynamic_bsz=True \ - critic.ppo_max_token_len_per_gpu=98304 \ - critic.model.fsdp_config.param_offload=False \ - critic.model.fsdp_config.optimizer_offload=False \ - reward_model.enable=True \ - reward_model.model.path=Qwen/Qwen2.5-0.5B\ - reward_model.model.use_remove_padding=True \ - reward_model.model.fsdp_config.param_offload=True \ - reward_model.use_dynamic_bsz=True \ - reward_model.forward_max_token_len_per_gpu=98304 \ - algorithm.kl_ctrl.kl_coef=0.001 \ - trainer.critic_warmup=0 \ - trainer.logger=['console'] \ - +trainer.val_before_train=False \ - trainer.project_name='verl_example' \ - trainer.experiment_name='Qwen2.5-0.5B-ci_hybrid_rm_seq_balance' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=-1 \ - trainer.total_training_steps=1 $@ diff --git a/tests/e2e/run_qwen_gsm8k_model_rm_ulysses.sh b/tests/e2e/run_qwen_gsm8k_model_rm_ulysses.sh deleted file mode 100644 index b4a18ac02..000000000 --- a/tests/e2e/run_qwen_gsm8k_model_rm_ulysses.sh +++ /dev/null @@ -1,53 +0,0 @@ -set -x - -export VLLM_ATTENTION_BACKEND=XFORMERS # vllm + qwen2 with flash_attn has some issues - -python3 -m verl.trainer.main_ppo \ - data.train_files=$HOME/data/gsm8k/train.parquet \ - data.val_files=$HOME/data/gsm8k/test.parquet \ - data.train_batch_size=1024 \ - data.max_prompt_length=512 \ - data.max_response_length=512 \ - data.return_raw_chat=True \ - actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.1 \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.actor.ulysses_sequence_parallel_size=2 \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.actor.fsdp_config.fsdp_size=4 \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - critic.optim.lr=1e-5 \ - critic.ulysses_sequence_parallel_size=2 \ - critic.model.use_remove_padding=True \ - critic.optim.lr_warmup_steps_ratio=0.05 \ - critic.model.path=Qwen/Qwen2.5-0.5B \ - critic.model.enable_gradient_checkpointing=False \ - critic.ppo_micro_batch_size_per_gpu=4 \ - critic.model.fsdp_config.param_offload=False \ - critic.model.fsdp_config.optimizer_offload=False \ - critic.model.fsdp_config.fsdp_size=4 \ - reward_model.enable=True \ - reward_model.ulysses_sequence_parallel_size=2 \ - reward_model.model.path=Qwen/Qwen2.5-0.5B\ - reward_model.model.use_remove_padding=True \ - reward_model.model.fsdp_config.param_offload=True \ - reward_model.micro_batch_size_per_gpu=16 \ - algorithm.kl_ctrl.kl_coef=0.001 \ - trainer.critic_warmup=0 \ - +trainer.val_before_train=False \ - trainer.logger=['console'] \ - trainer.project_name='verl_example' \ - trainer.experiment_name='Qwen2.5-0.5B-ci_hybrid_rm_sp2' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=-1 \ - trainer.total_training_steps=1 $@ diff --git a/tests/e2e/run_qwen_gsm8k_prime.sh b/tests/e2e/run_qwen_gsm8k_prime.sh deleted file mode 100644 index 31251d8dc..000000000 --- a/tests/e2e/run_qwen_gsm8k_prime.sh +++ /dev/null @@ -1,48 +0,0 @@ -set -x - -export VLLM_ATTENTION_BACKEND=XFORMERS - -python3 -m recipe.prime.main_prime \ - data.train_files=$HOME/data/gsm8k/train.parquet \ - data.val_files=$HOME/data/gsm8k/test.parquet \ - data.train_batch_size=32 \ - data.max_prompt_length=512 \ - data.max_response_length=512 \ - data.filter_accuracy=True \ - data.accuracy_lower_bound=0.2 \ - data.accuracy_upper_bound=0.8 \ - data.oversample_factor=4 \ - data.return_raw_chat=True \ - actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B \ - actor_rollout_ref.actor.optim.lr=5e-7 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=32 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \ - actor_rollout_ref.model.enable_gradient_checkpointing=False \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.n=4 \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - algorithm.adv_estimator=rloo \ - reward_model.model.path=Qwen/Qwen2.5-0.5B \ - reward_model.micro_batch_size_per_gpu=1 \ - reward_model.model.update=before \ - reward_model.model.beta_train=0.05 \ - reward_model.model.optim.lr=1e-6 \ - reward_model.model.optim.grad_clip=10.0 \ - reward_model.model.input_tokenizer=null \ - reward_model.mini_batch_size=32 \ - reward_model.reward_manager=prime \ - trainer.val_before_train=False \ - trainer.logger=['console'] \ - trainer.project_name='verl_example' \ - trainer.experiment_name='Qwen2.5-0.5B-PRIME' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=-1 \ - trainer.total_training_steps=1 $@ diff --git a/tests/e2e/run_qwen_megatron.sh b/tests/e2e/run_qwen_megatron.sh deleted file mode 100644 index 1da46ed14..000000000 --- a/tests/e2e/run_qwen_megatron.sh +++ /dev/null @@ -1,42 +0,0 @@ -set -x - -# the config file used: verl/trainer/main_ppo/config/ppo_megatron_trainer.yaml - -huggingface-cli download Qwen/Qwen2.5-0.5B - -export VLLM_ATTENTION_BACKEND=XFORMERS - -python3 -m verl.trainer.main_ppo --config-path=config \ - --config-name='ppo_megatron_trainer.yaml'\ - data.train_files=$HOME/data/gsm8k/train.parquet \ - data.val_files=$HOME/data/gsm8k/test.parquet \ - data.train_batch_size=1024 \ - data.max_prompt_length=512 \ - data.max_response_length=512 \ - actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B \ - actor_rollout_ref.actor.optim.lr=2e-6 \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \ - actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \ - critic.optim.lr=2e-5 \ - critic.model.path=Qwen/Qwen2.5-0.5B \ - critic.model.enable_gradient_checkpointing=False \ - critic.ppo_micro_batch_size_per_gpu=4 \ - critic.megatron.tensor_model_parallel_size=2 \ - algorithm.kl_ctrl.kl_coef=0.001 \ - trainer.critic_warmup=0 \ - trainer.logger=['console'] \ - trainer.project_name='verl_megatron_gsm8k_examples' \ - trainer.experiment_name='qwen2_5_0b5_function_rm' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=-1 \ - trainer.test_freq=1 \ - trainer.total_epochs=15 \ - trainer.total_training_steps=3 $@ diff --git a/tests/e2e/run_qwen_megatron_parallelism.sh b/tests/e2e/run_qwen_megatron_parallelism.sh deleted file mode 100644 index 2c4996acc..000000000 --- a/tests/e2e/run_qwen_megatron_parallelism.sh +++ /dev/null @@ -1,48 +0,0 @@ -set -x - -# the config file used: verl/trainer/main_ppo/config/ppo_megatron_trainer.yaml - -huggingface-cli download Qwen/Qwen2.5-0.5B - -export VLLM_ATTENTION_BACKEND=XFORMERS - -python3 -m verl.trainer.main_ppo --config-path=config \ - --config-name='ppo_megatron_trainer.yaml'\ - data.train_files=$HOME/data/gsm8k/train.parquet \ - data.val_files=$HOME/data/gsm8k/test.parquet \ - data.train_batch_size=1024 \ - data.max_prompt_length=512 \ - data.max_response_length=512 \ - actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B \ - actor_rollout_ref.actor.optim.lr=2e-6 \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \ - actor_rollout_ref.actor.megatron.virtual_pipeline_model_parallel_size=2 \ - actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \ - actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=2 \ - actor_rollout_ref.ref.megatron.virtual_pipeline_model_parallel_size=2 \ - actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \ - critic.optim.lr=2e-5 \ - critic.model.path=Qwen/Qwen2.5-0.5B \ - critic.model.enable_gradient_checkpointing=False \ - critic.ppo_micro_batch_size_per_gpu=4 \ - critic.megatron.pipeline_model_parallel_size=2 \ - critic.megatron.virtual_pipeline_model_parallel_size=2 \ - critic.megatron.tensor_model_parallel_size=2 \ - algorithm.kl_ctrl.kl_coef=0.001 \ - trainer.critic_warmup=0 \ - trainer.logger=['console'] \ - trainer.project_name='verl_megatron_gsm8k_examples' \ - trainer.experiment_name='qwen2_5_0b5_function_rm' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=-1 \ - trainer.test_freq=1 \ - trainer.total_epochs=15 \ - trainer.total_training_steps=3 $@ diff --git a/tests/e2e/run_ray_trainer.sh b/tests/e2e/run_ray_trainer.sh deleted file mode 100644 index f9cb19aeb..000000000 --- a/tests/e2e/run_ray_trainer.sh +++ /dev/null @@ -1,40 +0,0 @@ -#!/usr/bin/env bash - -set -e -x - -OUTPUT_FILE="/tmp/output_ray_trainer.txt" - -export PATH=$PATH:~/.local/bin - -rm -rf $OUTPUT_FILE -python3 tests/e2e/arithmetic_sequence/rl/main_trainer.py \ - algorithm.adv_estimator=gae \ - data.train_files=tests/e2e/arithmetic_sequence/data/train.parquet \ - data.val_files=tests/e2e/arithmetic_sequence/data/test.parquet \ - data.train_batch_size=800 \ - data.max_prompt_length=16 \ - data.max_response_length=32 \ - data.return_raw_input_ids=True \ - actor_rollout_ref.model.path=tests/e2e/arithmetic_sequence/model \ - actor_rollout_ref.model.external_lib=tests.e2e.envs.digit_completion \ - actor_rollout_ref.model.use_fused_kernels=True \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=128 \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.actor.optim.lr=1e-4 \ - actor_rollout_ref.actor.use_kl_loss=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=200 \ - actor_rollout_ref.rollout.name=hf \ - actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ - critic.ppo_micro_batch_size_per_gpu=128 \ - critic.model.path=tests/e2e/arithmetic_sequence/model \ - critic.optim.lr=1e-3 \ - algorithm.use_kl_in_reward=False \ - trainer.total_epochs=200 \ - trainer.experiment_name=arithmetic_sequences \ - trainer.logger=['console'] \ - trainer.n_gpus_per_node=1 \ - trainer.test_freq=1 \ - trainer.save_freq=110 | tee $OUTPUT_FILE; - -python3 tests/e2e/check_results.py --output_file=$OUTPUT_FILE -rm -rf $OUTPUT_FILE diff --git a/tests/e2e/run_ray_trainer_fire_sampling.sh b/tests/e2e/run_ray_trainer_fire_sampling.sh deleted file mode 100644 index 2b3f698c0..000000000 --- a/tests/e2e/run_ray_trainer_fire_sampling.sh +++ /dev/null @@ -1,42 +0,0 @@ -#!/usr/bin/env bash - -set -e -x - -OUTPUT_FILE="/tmp/output_ray_trainer.txt" - -export PATH=$PATH:~/.local/bin - -rm -rf $OUTPUT_FILE -python3 tests/e2e/arithmetic_sequence/rl/main_trainer.py \ - algorithm.adv_estimator=gae \ - data.train_files=tests/e2e/arithmetic_sequence/data/train.parquet \ - data.val_files=tests/e2e/arithmetic_sequence/data/test.parquet \ - data.train_batch_size=800 \ - data.val_batch_size=200 \ - data.max_prompt_length=16 \ - data.max_response_length=32 \ - data.return_raw_input_ids=True \ - actor_rollout_ref.model.path=tests/e2e/arithmetic_sequence/model \ - actor_rollout_ref.model.external_lib=tests.e2e.envs.digit_completion \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=128 \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.actor.optim.lr=1e-4 \ - actor_rollout_ref.actor.use_kl_loss=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=200 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=200 \ - actor_rollout_ref.rollout.name=hf \ - actor_rollout_ref.rollout.use_fire_sampling=True \ - actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ - critic.ppo_micro_batch_size_per_gpu=128 \ - critic.model.path=tests/e2e/arithmetic_sequence/model \ - critic.optim.lr=1e-3 \ - algorithm.use_kl_in_reward=False \ - trainer.total_epochs=200 \ - trainer.experiment_name=arithmetic_sequences \ - trainer.logger=['console'] \ - trainer.n_gpus_per_node=1 \ - trainer.test_freq=1 \ - trainer.save_freq=110 | tee $OUTPUT_FILE; - -python3 tests/e2e/check_results.py --output_file=$OUTPUT_FILE --target 0.19 -rm -rf $OUTPUT_FILE diff --git a/tests/e2e/run_ray_trainer_rmpad.sh b/tests/e2e/run_ray_trainer_rmpad.sh deleted file mode 100644 index edab167e6..000000000 --- a/tests/e2e/run_ray_trainer_rmpad.sh +++ /dev/null @@ -1,20 +0,0 @@ -#!/usr/bin/env bash - -set -e -x - -huggingface-cli download Qwen/Qwen2.5-0.5B --local-dir $HOME/models/Qwen/Qwen2.5-0.5B - -python3 tests/e2e/arithmetic_sequence/rl/main_trainer.py \ - algorithm.adv_estimator=gae \ - data.train_files=tests/e2e/arithmetic_sequence/data/train.parquet \ - data.val_files=tests/e2e/arithmetic_sequence/data/test.parquet \ - actor_rollout_ref.model.use_fused_kernels=True \ - actor_rollout_ref.actor.use_kl_loss=False \ - actor_rollout_ref.model.path=tests/e2e/arithmetic_sequence/model \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ - actor_rollout_ref.model.tokenizer_path=tests/e2e/arithmetic_sequence/model \ - critic.model.path=Qwen/Qwen2.5-0.5B \ - critic.model.use_remove_padding=True \ - algorithm.use_kl_in_reward=False \ - trainer.total_epochs=1 diff --git a/tests/experimental/agent_loop/agent_utils.py b/tests/experimental/agent_loop/agent_utils.py new file mode 100644 index 000000000..3c708c42c --- /dev/null +++ b/tests/experimental/agent_loop/agent_utils.py @@ -0,0 +1,69 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import ray +from omegaconf import DictConfig + +from verl.experimental.agent_loop import AgentLoopManager +from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup +from verl.single_controller.ray.base import create_colocated_worker_cls +from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role +from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker + + +def init_agent_loop_manager(config: DictConfig) -> AgentLoopManager | RayWorkerGroup: + # =========================== 1. Create hybrid ActorRollout workers =========================== + actor_rollout_cls = ( + AsyncActorRolloutRefWorker if config.actor_rollout_ref.rollout.mode == "async" else ActorRolloutRefWorker + ) + role_worker_mapping = { + Role.ActorRollout: ray.remote(actor_rollout_cls), + } + global_pool_id = "global_pool" + resource_pool_spec = { + global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, + } + mapping = { + Role.ActorRollout: global_pool_id, + } + resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) + resource_pool_manager.create_resource_pool() + resource_pool_to_cls = {pool: {} for pool in resource_pool_manager.resource_pool_dict.values()} + + # create actor and rollout + resource_pool = resource_pool_manager.get_resource_pool(Role.ActorRollout) + actor_rollout_cls = RayClassWithInitArgs( + cls=role_worker_mapping[Role.ActorRollout], config=config.actor_rollout_ref, role="actor_rollout" + ) + resource_pool_to_cls[resource_pool]["actor_rollout"] = actor_rollout_cls + + all_wg = {} + for resource_pool, class_dict in resource_pool_to_cls.items(): + worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) + wg_dict = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls) + spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys()) + all_wg.update(spawn_wg) + actor_rollout_wg = all_wg["actor_rollout"] + actor_rollout_wg.init_model() + + if config.actor_rollout_ref.rollout.mode == "sync": + return actor_rollout_wg + + # =========================== 2. Create AgentLoopManager =========================== + agent_loop_manager = AgentLoopManager( + config=config, + worker_group=actor_rollout_wg, + ) + + return agent_loop_manager diff --git a/tests/experimental/agent_loop/test_basic_agent_loop.py b/tests/experimental/agent_loop/test_basic_agent_loop.py new file mode 100644 index 000000000..14deb01f0 --- /dev/null +++ b/tests/experimental/agent_loop/test_basic_agent_loop.py @@ -0,0 +1,286 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import os +from typing import Any + +import numpy as np +import pytest +import ray +from omegaconf import DictConfig +from transformers.utils import get_json_schema + +from tests.experimental.agent_loop.agent_utils import init_agent_loop_manager +from verl.experimental.agent_loop.agent_loop import get_trajectory_info +from verl.protocol import DataProto +from verl.tools.base_tool import BaseTool, OpenAIFunctionToolSchema +from verl.utils import hf_tokenizer + + +@pytest.fixture +def init_config() -> DictConfig: + from hydra import compose, initialize_config_dir + + with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")): + config = compose(config_name="ppo_trainer") + model_path = "Qwen/Qwen2.5-1.5B-Instruct" + config.actor_rollout_ref.model.path = model_path + config.actor_rollout_ref.rollout.name = os.getenv("ROLLOUT_NAME", "vllm") + config.actor_rollout_ref.rollout.mode = "async" + config.actor_rollout_ref.rollout.prompt_length = 4096 + config.actor_rollout_ref.rollout.response_length = 4096 + config.actor_rollout_ref.rollout.n = 4 + config.actor_rollout_ref.rollout.agent.num_workers = 2 + + # test sleep/wake_up with fsdp offload + config.actor_rollout_ref.actor.fsdp_config.param_offload = True + config.actor_rollout_ref.actor.fsdp_config.optimizer_offload = True + + return config + + +def test_single_turn(init_config): + ray.init( + runtime_env={ + "env_vars": { + "TOKENIZERS_PARALLELISM": "true", + "NCCL_DEBUG": "WARN", + "VLLM_LOGGING_LEVEL": "INFO", + "VLLM_USE_V1": "1", + } + } + ) + + agent_loop_manager = init_agent_loop_manager(init_config) + + raw_prompts = [ + [ + { + "role": "user", + "content": "Let's play a role playing game. Your name is Alice, your favorite color is blue.", + } + ], + [{"role": "user", "content": "Let's play a role playing game. Your name is Bob, your favorite color is red."}], + ] + batch = DataProto( + non_tensor_batch={ + "raw_prompt": np.array(raw_prompts), + "agent_name": np.array(["single_turn_agent"] * len(raw_prompts)), + }, + ) + n = init_config.actor_rollout_ref.rollout.n + batch = batch.repeat(n) + result = agent_loop_manager.generate_sequences(prompts=batch) + assert len(result) == len(raw_prompts) * n + + # check result + seq_len = result.batch["prompts"].size(1) + result.batch["responses"].size(1) + assert result.batch["input_ids"].size(1) == seq_len + assert result.batch["attention_mask"].size(1) == seq_len + assert result.batch["position_ids"].size(1) == seq_len + + # check turns + num_turns = result.non_tensor_batch["__num_turns__"] + assert np.all(num_turns == 2) + + print("Test passed!") + ray.shutdown() + + +class WeatherTool(BaseTool): + def get_current_temperature(self, location: str, unit: str = "celsius"): + """Get current temperature at a location. + + Args: + location: The location to get the temperature for, in the format "City, State, Country". + unit: The unit to return the temperature in. Defaults to "celsius". (choices: ["celsius", "fahrenheit"]) + + Returns: + the temperature, the location, and the unit in a dict + """ + print(f"[DEBUG] get_current_temperature: {location}, {unit}") + return { + "temperature": 26.1, + "location": location, + "unit": unit, + } + + def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema: + schema = get_json_schema(self.get_current_temperature) + return OpenAIFunctionToolSchema(**schema) + + async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]: + try: + result = self.get_current_temperature(**parameters) + return json.dumps(result), 0, {} + except Exception as e: + return str(e), 0, {} + + +class WeatherToolWithData(BaseTool): + def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema: + schema = get_json_schema(self.get_temperature_date) + return OpenAIFunctionToolSchema(**schema) + + def get_temperature_date(self, location: str, date: str, unit: str = "celsius"): + """Get temperature at a location and date. + + Args: + location: The location to get the temperature for, in the format "City, State, Country". + date: The date to get the temperature for, in the format "Year-Month-Day". + unit: The unit to return the temperature in. Defaults to "celsius". (choices: ["celsius", "fahrenheit"]) + + Returns: + the temperature, the location, the date and the unit in a dict + """ + print(f"[DEBUG] get_temperature_date: {location}, {date}, {unit}") + return { + "temperature": 25.9, + "location": location, + "date": date, + "unit": unit, + } + + async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]: + try: + result = self.get_temperature_date(**parameters) + return json.dumps(result), 0, {} + except Exception as e: + return str(e), 0, {} + + +def test_tool_agent(init_config): + ray.init( + runtime_env={ + "env_vars": { + "TOKENIZERS_PARALLELISM": "true", + "NCCL_DEBUG": "WARN", + "VLLM_LOGGING_LEVEL": "INFO", + "VLLM_USE_V1": "1", + } + } + ) + + # =========================== 1. Init rollout manager =========================== + tool_config = { + "tools": [ + { + "class_name": "tests.experimental.agent_loop.test_basic_agent_loop.WeatherTool", + "config": {"type": "native"}, + }, + { + "class_name": "tests.experimental.agent_loop.test_basic_agent_loop.WeatherToolWithData", + "config": {"type": "native"}, + }, + ] + } + tool_config_path = "/tmp/tool_config.json" + with open(tool_config_path, "w") as f: + json.dump(tool_config, f) + + n = 2 + init_config.actor_rollout_ref.rollout.n = n + init_config.actor_rollout_ref.rollout.multi_turn.tool_config_path = tool_config_path + init_config.actor_rollout_ref.rollout.multi_turn.max_parallel_calls = 2 + agent_loop_manager = init_agent_loop_manager(init_config) + + # =========================== 2. Generate sequences =========================== + raw_prompts = [ + [ + {"role": "user", "content": "How are you?"}, + ], + [ + {"role": "user", "content": "What's the temperature in Los Angeles now?"}, + ], + [ + {"role": "user", "content": "What's the temperature in New York now?"}, + ], + [ + { + "role": "system", + "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant.\n\n" + "Current Date: 2024-09-30", + }, + {"role": "user", "content": "What's the temperature in San Francisco now? How about tomorrow?"}, + ], + ] + batch = DataProto( + non_tensor_batch={ + "raw_prompt": np.array([np.array(prompt) for prompt in raw_prompts], dtype=object), + "agent_name": np.array(["tool_agent"] * len(raw_prompts)), + }, + ) + batch = batch.repeat(n) + result = agent_loop_manager.generate_sequences(prompts=batch) + assert len(result) == len(raw_prompts) * n + + # Check turns + num_turns = result.non_tensor_batch["__num_turns__"] + print(f"num_turns: {num_turns}") + for i in range(len(num_turns)): + if i // n == 0: + # [user, assistant] + assert num_turns[i] == 2 + else: + # [user, assistant, tool, assistant] + assert num_turns[i] == 4 + + # Check response_mask + tokenizer = hf_tokenizer(init_config.actor_rollout_ref.model.path) + responses = result.batch["responses"] + response_mask = result.batch["response_mask"] + attention_mask = result.batch["attention_mask"] + assert responses.size() == response_mask.size(), f"{responses.size()} != {response_mask.size()}" + response_length = response_mask.size(1) + + for i in range(len(responses)): + # response with tool response + valid_tokens = responses[i][attention_mask[i][-response_length:].bool()] + response_with_obs = tokenizer.decode(valid_tokens) + + # response without tool response + valid_tokens = responses[i][response_mask[i].bool()] + response_without_obs = tokenizer.decode(valid_tokens) + + assert "" not in response_without_obs, ( + f"found in response: {response_without_obs}" + ) + assert "" not in response_without_obs, ( + f"found in response: {response_without_obs}" + ) + print("=========================") + print(response_with_obs) + print("---") + print(response_without_obs) + + print("Test passed!") + ray.shutdown() + + +@pytest.mark.asyncio +async def test_get_trajectory_info(): + """Tests the get_trajectory_info method.""" + # Initialize the class to set up class-level attributes + step = 10 + index = [1, 1, 3, 3] + expected_info = [ + {"step": step, "sample_index": 1, "rollout_n": 0, "validate": False}, + {"step": step, "sample_index": 1, "rollout_n": 1, "validate": False}, + {"step": step, "sample_index": 3, "rollout_n": 0, "validate": False}, + {"step": step, "sample_index": 3, "rollout_n": 1, "validate": False}, + ] + + trajectory_info = await get_trajectory_info(step, index, validate=False) + + assert trajectory_info == expected_info diff --git a/tests/generation/run_gen_qwen05.sh b/tests/generation/run_gen_qwen05.sh deleted file mode 100755 index deea53d0c..000000000 --- a/tests/generation/run_gen_qwen05.sh +++ /dev/null @@ -1,31 +0,0 @@ -# Tested with 1 & 4 GPUs -set -x - -if [ "$#" -lt 2 ]; then - echo "Usage: run_gen_qwen05.sh [other_configs...]" - exit 1 -fi - -nproc_per_node=$1 -save_path=$2 -infer_tp=${3:-2} # Default tensor parallel size to 2 - -# Shift the arguments so $@ refers to the rest -shift 2 - -python3 -m verl.trainer.main_generation \ - trainer.nnodes=1 \ - trainer.n_gpus_per_node=$nproc_per_node \ - data.path=$HOME/data/gsm8k/test.parquet \ - data.prompt_key=prompt \ - data.n_samples=1 \ - data.output_path=$save_path \ - model.path=Qwen/Qwen2.5-0.5B-Instruct \ - +model.trust_remote_code=True \ - rollout.temperature=1.0 \ - rollout.top_k=50 \ - rollout.top_p=0.7 \ - rollout.prompt_length=2048 \ - rollout.response_length=1024 \ - rollout.tensor_model_parallel_size=$infer_tp \ - rollout.gpu_memory_utilization=0.8 diff --git a/tests/gpu_utility/test_ops.py b/tests/gpu_utility/test_ops.py deleted file mode 100644 index 4bfb22298..000000000 --- a/tests/gpu_utility/test_ops.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -def test_flash_attn_cross_entropy(): - import torch - from flash_attn.ops.triton.cross_entropy import cross_entropy_loss - from torch import nn - - from verl.utils.debug import log_gpu_memory_usage - from verl.utils.torch_functional import logprobs_from_logits_naive - - log_gpu_memory_usage("At start") - - hidden_states = torch.randn(size=(2048, 5120), device="cuda", requires_grad=True, dtype=torch.bfloat16) - - linear = nn.Linear(in_features=5120, out_features=155136, bias=False, device="cuda", dtype=torch.bfloat16) - - logits = linear(hidden_states) - - # logits = logits.float() - labels = torch.randint(low=0, high=155136, size=(2048,), device="cuda") - - log_gpu_memory_usage("before computation") - # output = checkpoint.checkpoint(logprobs_from_logits, logits, labels, use_reentrant=True) - output = -cross_entropy_loss(logits, labels)[0] - # output = logprobs_from_logits(logits, labels) - log_gpu_memory_usage("After forward") - output.sum().backward() - log_gpu_memory_usage("After backward") - - groundtruth = logprobs_from_logits_naive(logits.float(), labels) - - torch.testing.assert_close(output, groundtruth) diff --git a/tests/gpu_utility/test_torch_functional.py b/tests/gpu_utility/test_torch_functional.py deleted file mode 100644 index ebe488ed0..000000000 --- a/tests/gpu_utility/test_torch_functional.py +++ /dev/null @@ -1,96 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pytest -import torch -from flash_attn.bert_padding import unpad_input - -from verl.utils.model import create_random_mask - - -def test_log_probs_from_logits_response_rmpad(): - from verl.utils.torch_functional import log_probs_from_logits_response, log_probs_from_logits_response_rmpad - - vocab_size = 32000 - batch_size = 2 - prompt_length = 256 - response_length = 256 - - input_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, prompt_length + response_length), device="cuda") - attention_mask = create_random_mask(input_ids=input_ids, max_ratio_of_left_padding=0.2, max_ratio_of_valid_token=0.8, min_ratio_of_valid_token=0.6) - - response_mask = attention_mask[:, -response_length:] - - assert torch.all(response_mask[:, 0] == 1) - - logits = torch.randn(batch_size, prompt_length + response_length, vocab_size, device="cuda") - logits_rmpad = unpad_input(logits, attention_mask)[0] - - expected_output = log_probs_from_logits_response(input_ids=input_ids, logits=logits, response_length=response_length) - actual_output = log_probs_from_logits_response_rmpad(input_ids=input_ids, attention_mask=attention_mask, logits_rmpad=logits_rmpad, response_length=response_length) - - # This should bitwise align as only this operation only contains gather operators - assert torch.all(torch.eq(actual_output * response_mask, expected_output * response_mask)) - - -@pytest.mark.parametrize("dtype", [torch.float64, torch.float32, torch.float16, torch.bfloat16]) -def test_logprobs_from_logits_v2(dtype): - from verl.utils.torch_functional import logprobs_from_logits_naive, logprobs_from_logits_v2 - - vocab_size = 32000 - batch_size = 2 - seq_len = 512 - - labels = torch.randint(low=0, high=vocab_size, size=(batch_size, seq_len), device="cuda") - logits = torch.randn(batch_size, seq_len, vocab_size, device="cuda", dtype=dtype) - - expected_output = logprobs_from_logits_naive(labels=labels, logits=logits) - actual_output = logprobs_from_logits_v2(labels=labels, logits=logits) - - if dtype in [torch.float16, torch.bfloat16]: # float16 falls back to an exactly equivalent method - assert torch.equal(actual_output, expected_output) - else: # small numerical difference when using gather / logsumexp approach - torch.testing.assert_close(actual_output, expected_output, rtol=1e-5, atol=1e-5) - - -def test_lr_scheduler(): - from torch import nn - - model = nn.Linear(10, 10) - optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) - - from verl.utils.torch_functional import get_constant_schedule_with_warmup - - constant_lr = get_constant_schedule_with_warmup(optimizer=optimizer, num_warmup_steps=2) - - lr_lst = [] - - for _ in range(5): - lr_lst.append(constant_lr.get_last_lr()[0]) - constant_lr.step() - - torch.testing.assert_close(lr_lst, [0.0, 0.0005, 0.001, 0.001, 0.001]) - - from verl.utils.torch_functional import get_cosine_schedule_with_warmup - - optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) - cosine_lr = get_cosine_schedule_with_warmup(optimizer=optimizer, num_warmup_steps=2, num_training_steps=5, min_lr_ratio=0.1) - - lr_lst = [] - - for _ in range(5): - lr_lst.append(cosine_lr.get_last_lr()[0]) - cosine_lr.step() - - torch.testing.assert_close(lr_lst, [0.0, 0.0005, 0.001, 0.0007750000000000002, 0.0003250000000000002]) diff --git a/tests/input.txt b/tests/input.txt deleted file mode 100644 index fe663e803..000000000 --- a/tests/input.txt +++ /dev/null @@ -1,15 +0,0 @@ - -3 -4 -tilak + -tilak + -tilak - -tilak + -3 -ratna + -shashi - -ratna - -3 -bhavani - -bhavani + -bhavani - diff --git a/tests/interactions/__init__.py b/tests/interactions/__init__.py new file mode 100644 index 000000000..b6db0fcef --- /dev/null +++ b/tests/interactions/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/interactions/test_gsm8k_interaction.py b/tests/interactions/test_gsm8k_interaction.py new file mode 100644 index 000000000..bc16877c2 --- /dev/null +++ b/tests/interactions/test_gsm8k_interaction.py @@ -0,0 +1,421 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import patch + +import pytest + +from verl.interactions.gsm8k_interaction import Gsm8kInteraction + + +class TestGsm8kInteraction: + """Test cases for Gsm8kInteraction class.""" + + def setup_method(self): + """Set up test environment before each test method.""" + self.config = {"name": "gsm8k"} + self.interaction = Gsm8kInteraction(self.config) + + def test_init(self): + """Test Gsm8kInteraction initialization.""" + assert self.interaction._instance_dict == {} + assert self.interaction.config == self.config + assert self.interaction.name == "gsm8k" + + @pytest.mark.asyncio + async def test_start_interaction_with_instance_id(self): + """Test start_interaction with provided instance_id.""" + instance_id = "test_instance" + ground_truth = "42" + + result_id = await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth) + + assert result_id == instance_id + assert instance_id in self.interaction._instance_dict + assert self.interaction._instance_dict[instance_id]["response"] == "" + assert self.interaction._instance_dict[instance_id]["ground_truth"] == ground_truth + assert self.interaction._instance_dict[instance_id]["reward"] == 0.0 + + @pytest.mark.asyncio + async def test_start_interaction_without_instance_id(self): + """Test start_interaction without provided instance_id (auto-generated).""" + ground_truth = "42" + + result_id = await self.interaction.start_interaction(ground_truth=ground_truth) + + assert result_id is not None + assert len(result_id) == 36 # UUID4 length + assert result_id in self.interaction._instance_dict + assert self.interaction._instance_dict[result_id]["ground_truth"] == ground_truth + + @pytest.mark.asyncio + async def test_start_interaction_without_ground_truth(self): + """Test start_interaction without ground_truth parameter.""" + instance_id = "test_instance" + + result_id = await self.interaction.start_interaction(instance_id=instance_id) + + assert result_id == instance_id + assert self.interaction._instance_dict[instance_id]["ground_truth"] is None + + @pytest.mark.asyncio + async def test_generate_response_correct_answer_with_prefix(self): + """Test generate_response with correct answer already having #### prefix.""" + instance_id = "test_instance" + ground_truth = "42" + + # Setup instance + await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth) + + messages = [{"role": "user", "content": "#### 42"}] + + with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=1.0): + should_terminate, response, reward, metadata = await self.interaction.generate_response( + instance_id, messages + ) + + assert should_terminate is True + assert response == "Your response is correct!" + assert reward == 1.0 + assert metadata == {} + assert self.interaction._instance_dict[instance_id]["response"] == "#### 42" + + @pytest.mark.asyncio + async def test_generate_response_correct_answer_without_prefix(self): + """Test generate_response with correct answer missing #### prefix.""" + instance_id = "test_instance" + ground_truth = "42" + + # Setup instance + await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth) + + messages = [{"role": "user", "content": "42"}] + + with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=1.0): + should_terminate, response, reward, metadata = await self.interaction.generate_response( + instance_id, messages + ) + + assert should_terminate is True + assert response == "Your response is correct!" + assert reward == 1.0 + assert self.interaction._instance_dict[instance_id]["response"] == "#### 42" + + @pytest.mark.asyncio + async def test_generate_response_incorrect_answer(self): + """Test generate_response with incorrect answer.""" + instance_id = "test_instance" + ground_truth = "42" + + # Setup instance + await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth) + + messages = [{"role": "user", "content": "24"}] + + with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=0.0): + should_terminate, response, reward, metadata = await self.interaction.generate_response( + instance_id, messages + ) + + assert should_terminate is False + assert response == "Your response is incorrect! You need to reflect on your answer and try again." + assert reward == 0.0 + assert self.interaction._instance_dict[instance_id]["response"] == "#### 24" + + @pytest.mark.asyncio + async def test_generate_response_multiple_messages(self): + """Test generate_response with multiple messages (should use last user message).""" + instance_id = "test_instance" + ground_truth = "42" + + # Setup instance + await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth) + + messages = [ + {"role": "user", "content": "What is 2+2?"}, + {"role": "assistant", "content": "Let me think about this..."}, + {"role": "user", "content": "#### 42"}, + ] + + with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=1.0): + should_terminate, response, reward, metadata = await self.interaction.generate_response( + instance_id, messages + ) + + assert should_terminate is True + assert response == "Your response is correct!" + assert self.interaction._instance_dict[instance_id]["response"] == "#### 42" + + @pytest.mark.asyncio + async def test_generate_response_no_user_message(self): + """Test generate_response with no user messages.""" + instance_id = "test_instance" + ground_truth = "42" + + # Setup instance + await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth) + + messages = [{"role": "assistant", "content": "Hello!"}] + + with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=0.0): + should_terminate, response, reward, metadata = await self.interaction.generate_response( + instance_id, messages + ) + + assert should_terminate is False + assert self.interaction._instance_dict[instance_id]["response"] == "#### " + + @pytest.mark.asyncio + async def test_calculate_score_direct_call(self): + """Test calculate_score method directly.""" + instance_id = "test_instance" + ground_truth = "42" + + # Setup instance + await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth) + + # Set a response + self.interaction._instance_dict[instance_id]["response"] = "#### 42" + + with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=1.0) as mock_compute: + score = await self.interaction.calculate_score(instance_id) + + assert score == 1.0 + mock_compute.assert_called_once_with("#### 42", "42", method="flexible", format_score=0.0, score=1.0) + + @pytest.mark.asyncio + async def test_calculate_score_with_kwargs(self): + """Test calculate_score method with additional kwargs.""" + instance_id = "test_instance" + ground_truth = "42" + + # Setup instance + await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth) + + # Set a response + self.interaction._instance_dict[instance_id]["response"] = "#### 24" + + with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=0.0) as mock_compute: + score = await self.interaction.calculate_score(instance_id, extra_param="test") + + assert score == 0.0 + mock_compute.assert_called_once_with("#### 24", "42", method="flexible", format_score=0.0, score=1.0) + + @pytest.mark.asyncio + async def test_finalize_interaction(self): + """Test finalize_interaction method.""" + instance_id = "test_instance" + ground_truth = "42" + + # Setup instance + await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth) + + assert instance_id in self.interaction._instance_dict + + await self.interaction.finalize_interaction(instance_id) + + assert instance_id not in self.interaction._instance_dict + + @pytest.mark.asyncio + async def test_finalize_interaction_with_kwargs(self): + """Test finalize_interaction method with additional kwargs.""" + instance_id = "test_instance" + ground_truth = "42" + + # Setup instance + await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth) + + assert instance_id in self.interaction._instance_dict + + await self.interaction.finalize_interaction(instance_id, extra_param="test") + + assert instance_id not in self.interaction._instance_dict + + @pytest.mark.asyncio + async def test_finalize_nonexistent_interaction(self): + """Test finalize_interaction with non-existent instance_id.""" + instance_id = "nonexistent_instance" + + # This should raise KeyError + with pytest.raises(KeyError): + await self.interaction.finalize_interaction(instance_id) + + @pytest.mark.asyncio + async def test_full_interaction_workflow_correct(self): + """Test complete interaction workflow with correct answer.""" + ground_truth = "42" + + # Start interaction + instance_id = await self.interaction.start_interaction(ground_truth=ground_truth) + + # Generate response with correct answer + messages = [{"role": "user", "content": "42"}] + + with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=1.0): + should_terminate, response, reward, metadata = await self.interaction.generate_response( + instance_id, messages + ) + + assert should_terminate is True + assert reward == 1.0 + + # Finalize interaction + await self.interaction.finalize_interaction(instance_id) + assert instance_id not in self.interaction._instance_dict + + @pytest.mark.asyncio + async def test_full_interaction_workflow_incorrect(self): + """Test complete interaction workflow with incorrect answer.""" + ground_truth = "42" + + # Start interaction + instance_id = await self.interaction.start_interaction(ground_truth=ground_truth) + + # Generate response with incorrect answer + messages = [{"role": "user", "content": "24"}] + + with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=0.0): + should_terminate, response, reward, metadata = await self.interaction.generate_response( + instance_id, messages + ) + + assert should_terminate is False + assert reward == 0.0 + + # Continue with another attempt + messages.append({"role": "assistant", "content": response}) + messages.append({"role": "user", "content": "42"}) + + with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=1.0): + should_terminate, response, reward, metadata = await self.interaction.generate_response( + instance_id, messages + ) + + assert should_terminate is True + assert reward == 1.0 + + # Finalize interaction + await self.interaction.finalize_interaction(instance_id) + assert instance_id not in self.interaction._instance_dict + + @pytest.mark.asyncio + async def test_multiple_concurrent_interactions(self): + """Test multiple concurrent interaction instances.""" + ground_truth_1 = "42" + ground_truth_2 = "24" + + # Start multiple interactions + instance_id_1 = await self.interaction.start_interaction(ground_truth=ground_truth_1) + instance_id_2 = await self.interaction.start_interaction(ground_truth=ground_truth_2) + + assert len(self.interaction._instance_dict) == 2 + assert instance_id_1 in self.interaction._instance_dict + assert instance_id_2 in self.interaction._instance_dict + + # Test responses for both instances + messages_1 = [{"role": "user", "content": "42"}] + messages_2 = [{"role": "user", "content": "24"}] + + with patch("verl.utils.reward_score.gsm8k.compute_score", side_effect=[1.0, 1.0]): + should_terminate_1, _, reward_1, _ = await self.interaction.generate_response(instance_id_1, messages_1) + should_terminate_2, _, reward_2, _ = await self.interaction.generate_response(instance_id_2, messages_2) + + assert should_terminate_1 is True + assert should_terminate_2 is True + assert reward_1 == 1.0 + assert reward_2 == 1.0 + + # Finalize both interactions + await self.interaction.finalize_interaction(instance_id_1) + await self.interaction.finalize_interaction(instance_id_2) + + assert len(self.interaction._instance_dict) == 0 + + @pytest.mark.asyncio + async def test_edge_case_empty_messages(self): + """Test edge case with empty messages list.""" + instance_id = "test_instance" + ground_truth = "42" + + # Setup instance + await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth) + + messages = [] + + with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=0.0): + should_terminate, response, reward, metadata = await self.interaction.generate_response( + instance_id, messages + ) + + assert should_terminate is False + assert reward == 0.0 + assert self.interaction._instance_dict[instance_id]["response"] == "#### " + + @pytest.mark.asyncio + async def test_edge_case_message_without_content(self): + """Test edge case with message without content field.""" + instance_id = "test_instance" + ground_truth = "42" + + # Setup instance + await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth) + + messages = [ + {"role": "user"} # Missing content field + ] + + with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=0.0): + should_terminate, response, reward, metadata = await self.interaction.generate_response( + instance_id, messages + ) + + assert should_terminate is False + assert reward == 0.0 + assert self.interaction._instance_dict[instance_id]["response"] == "#### None" + + def test_inheritance_from_base_interaction(self): + """Test that Gsm8kInteraction properly inherits from BaseInteraction.""" + from verl.interactions.base import BaseInteraction + + assert isinstance(self.interaction, BaseInteraction) + + # Test that all required methods are implemented + assert hasattr(self.interaction, "start_interaction") + assert hasattr(self.interaction, "generate_response") + assert hasattr(self.interaction, "calculate_score") + assert hasattr(self.interaction, "finalize_interaction") + + # Test that methods are callable + assert callable(self.interaction.start_interaction) + assert callable(self.interaction.generate_response) + assert callable(self.interaction.calculate_score) + assert callable(self.interaction.finalize_interaction) + + def test_name_attribute_initialization(self): + """Test name attribute initialization with different configs.""" + # Test with explicit name in config + config_with_name = {"name": "custom_gsm8k"} + interaction_with_name = Gsm8kInteraction(config_with_name) + assert interaction_with_name.name == "custom_gsm8k" + + # Test with default name when not provided in config + config_without_name = {} + interaction_without_name = Gsm8kInteraction(config_without_name) + assert interaction_without_name.name == "interaction_agent" # Default from BaseInteraction + + # Test that name is accessible as attribute + assert hasattr(self.interaction, "name") + assert self.interaction.name == "gsm8k" diff --git a/tests/interactions/test_interaction_registry.py b/tests/interactions/test_interaction_registry.py new file mode 100644 index 000000000..7fe193b52 --- /dev/null +++ b/tests/interactions/test_interaction_registry.py @@ -0,0 +1,206 @@ +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile + +import pytest +from omegaconf import OmegaConf + +from verl.interactions.base import BaseInteraction +from verl.interactions.gsm8k_interaction import Gsm8kInteraction +from verl.interactions.utils.interaction_registry import ( + get_interaction_class, + initialize_interactions_from_config, +) + + +class TestInteractionRegistry: + def test_get_interaction_class(self): + """Test getting interaction class by name.""" + # Test getting base interaction class + base_cls = get_interaction_class("verl.interactions.base.BaseInteraction") + assert base_cls == BaseInteraction + + # Test getting gsm8k interaction class + gsm8k_cls = get_interaction_class("verl.interactions.gsm8k_interaction.Gsm8kInteraction") + assert gsm8k_cls == Gsm8kInteraction + + def test_initialize_single_interaction_from_config(self): + """Test initializing single interaction from config.""" + # Create temporary config file + config_content = { + "interaction": [ + { + "name": "test_gsm8k", + "class_name": "verl.interactions.gsm8k_interaction.Gsm8kInteraction", + "config": {}, + } + ] + } + + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + OmegaConf.save(config_content, f.name) + temp_config_path = f.name + + try: + interaction_map = initialize_interactions_from_config(temp_config_path) + + # Check that interaction was created + assert len(interaction_map) == 1 + assert "test_gsm8k" in interaction_map + assert isinstance(interaction_map["test_gsm8k"], Gsm8kInteraction) + assert interaction_map["test_gsm8k"].name == "test_gsm8k" + finally: + os.unlink(temp_config_path) + + def test_initialize_multiple_interactions_from_config(self): + """Test initializing multiple interactions from config.""" + config_content = { + "interaction": [ + { + "name": "gsm8k_solver", + "class_name": "verl.interactions.gsm8k_interaction.Gsm8kInteraction", + "config": {}, + }, + { + "name": "base_agent", + "class_name": "verl.interactions.base.BaseInteraction", + "config": {"custom_param": "test_value"}, + }, + ] + } + + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + OmegaConf.save(config_content, f.name) + temp_config_path = f.name + + try: + interaction_map = initialize_interactions_from_config(temp_config_path) + + # Check that both interactions were created + assert len(interaction_map) == 2 + assert "gsm8k_solver" in interaction_map + assert "base_agent" in interaction_map + + # Check types + assert isinstance(interaction_map["gsm8k_solver"], Gsm8kInteraction) + assert isinstance(interaction_map["base_agent"], BaseInteraction) + + # Check names were injected + assert interaction_map["gsm8k_solver"].name == "gsm8k_solver" + assert interaction_map["base_agent"].name == "base_agent" + + # Check custom config was passed + assert interaction_map["base_agent"].config.get("custom_param") == "test_value" + finally: + os.unlink(temp_config_path) + + def test_initialize_interaction_without_explicit_name(self): + """Test that interaction name is derived from class name when not specified.""" + config_content = { + "interaction": [{"class_name": "verl.interactions.gsm8k_interaction.Gsm8kInteraction", "config": {}}] + } + + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + OmegaConf.save(config_content, f.name) + temp_config_path = f.name + + try: + interaction_map = initialize_interactions_from_config(temp_config_path) + + # Check that interaction name was derived from class name + assert len(interaction_map) == 1 + assert "gsm8k" in interaction_map # Should be "gsm8k" after removing "interaction" suffix + assert isinstance(interaction_map["gsm8k"], Gsm8kInteraction) + assert interaction_map["gsm8k"].name == "gsm8k" + finally: + os.unlink(temp_config_path) + + def test_initialize_empty_config(self): + """Test initializing from empty config.""" + config_content = {"interaction": []} + + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + OmegaConf.save(config_content, f.name) + temp_config_path = f.name + + try: + interaction_map = initialize_interactions_from_config(temp_config_path) + assert len(interaction_map) == 0 + finally: + os.unlink(temp_config_path) + + def test_invalid_class_name(self): + """Test handling of invalid class name.""" + config_content = { + "interaction": [{"name": "invalid", "class_name": "invalid.module.InvalidClass", "config": {}}] + } + + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + OmegaConf.save(config_content, f.name) + temp_config_path = f.name + + try: + with pytest.raises(ModuleNotFoundError): + initialize_interactions_from_config(temp_config_path) + finally: + os.unlink(temp_config_path) + + def test_duplicate_interaction_names(self): + """Test handling of duplicate interaction names.""" + config_content = { + "interaction": [ + {"name": "duplicate", "class_name": "verl.interactions.base.BaseInteraction", "config": {}}, + { + "name": "duplicate", + "class_name": "verl.interactions.gsm8k_interaction.Gsm8kInteraction", + "config": {}, + }, + ] + } + + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + OmegaConf.save(config_content, f.name) + temp_config_path = f.name + + try: + with pytest.raises(ValueError, match="Duplicate interaction name 'duplicate' found"): + initialize_interactions_from_config(temp_config_path) + finally: + os.unlink(temp_config_path) + + def test_auto_name_generation_edge_cases(self): + """Test automatic name generation for various class name patterns.""" + config_content = { + "interaction": [ + {"class_name": "verl.interactions.base.BaseInteraction", "config": {}}, + {"class_name": "verl.interactions.gsm8k_interaction.Gsm8kInteraction", "config": {}}, + ] + } + + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + OmegaConf.save(config_content, f.name) + temp_config_path = f.name + + try: + interaction_map = initialize_interactions_from_config(temp_config_path) + + # Check that names were generated correctly + assert len(interaction_map) == 2 + assert "base" in interaction_map # BaseInteraction -> base + assert "gsm8k" in interaction_map # Gsm8kInteraction -> gsm8k + finally: + os.unlink(temp_config_path) diff --git a/tests/models/test_transformer.py b/tests/models/test_transformer.py index 467165d05..111230a8a 100644 --- a/tests/models/test_transformer.py +++ b/tests/models/test_transformer.py @@ -44,7 +44,9 @@ def test_hf_casual_models(): for config in test_configs: # config = AutoConfig.from_pretrained(test_case) with torch.device("cuda"): - model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2") + model = AutoModelForCausalLM.from_config( + config=config, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" + ) model = model.to(device="cuda") input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seqlen), device="cuda") attention_mask = create_random_mask( @@ -53,18 +55,28 @@ def test_hf_casual_models(): max_ratio_of_valid_token=0.8, min_ratio_of_valid_token=0.5, ) - position_ids = compute_position_id_with_mask(attention_mask) # TODO(sgm): we can construct the position_ids_rmpad here + position_ids = compute_position_id_with_mask( + attention_mask + ) # TODO(sgm): we can construct the position_ids_rmpad here - input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), attention_mask) # input_ids_rmpad (total_nnz, ...) + input_ids_rmpad, indices, *_ = unpad_input( + input_ids.unsqueeze(-1), attention_mask + ) # input_ids_rmpad (total_nnz, ...) input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) # unpad the position_ids to align the rotary - position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices).transpose(0, 1) + position_ids_rmpad = index_first_axis( + rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices + ).transpose(0, 1) # input with input_ids_rmpad and postition_ids to enable flash attention varlen - logits_rmpad = model(input_ids_rmpad, position_ids=position_ids_rmpad, use_cache=False).logits # (1, total_nnz, vocab_size) + logits_rmpad = model( + input_ids_rmpad, position_ids=position_ids_rmpad, use_cache=False + ).logits # (1, total_nnz, vocab_size) - origin_logits = model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False).logits + origin_logits = model( + input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False + ).logits origin_logits_rmpad, origin_logits_indices, *_ = unpad_input(origin_logits, attention_mask) logits_rmpad = logits_rmpad.squeeze(0) @@ -104,7 +116,9 @@ def test_hf_value_models(): config.classifier_dropout = 0 config.hidden_dropout = 0 with torch.device("cuda"): - model = AutoModelForTokenClassification.from_config(config=config, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2") + model = AutoModelForTokenClassification.from_config( + config=config, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" + ) model = model.to(device="cuda") input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seqlen), device="cuda") attention_mask = create_random_mask( @@ -113,18 +127,28 @@ def test_hf_value_models(): max_ratio_of_valid_token=0.8, min_ratio_of_valid_token=0.5, ) - position_ids = compute_position_id_with_mask(attention_mask) # TODO(sgm): we can construct the position_ids_rmpad here + position_ids = compute_position_id_with_mask( + attention_mask + ) # TODO(sgm): we can construct the position_ids_rmpad here - input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), attention_mask) # input_ids_rmpad (total_nnz, ...) + input_ids_rmpad, indices, *_ = unpad_input( + input_ids.unsqueeze(-1), attention_mask + ) # input_ids_rmpad (total_nnz, ...) input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) # unpad the position_ids to align the rotary - position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices).transpose(0, 1) + position_ids_rmpad = index_first_axis( + rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices + ).transpose(0, 1) - origin_logits = model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False).logits + origin_logits = model( + input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False + ).logits # input with input_ids_rmpad and postition_ids to enable flash attention varlen - rmpad_logits = model(input_ids_rmpad, position_ids=position_ids_rmpad, use_cache=False).logits # (1, total_nnz, 1) + rmpad_logits = model( + input_ids_rmpad, position_ids=position_ids_rmpad, use_cache=False + ).logits # (1, total_nnz, 1) rmpad_logits = rmpad_logits.squeeze(0) pad_logits = pad_input(rmpad_logits, indices, batch_size, seqlen=seqlen) diff --git a/tests/models/test_transformers_ulysses.py b/tests/models/test_transformers_ulysses.py index 3ebe3fa52..111b35ec9 100644 --- a/tests/models/test_transformers_ulysses.py +++ b/tests/models/test_transformers_ulysses.py @@ -27,7 +27,7 @@ from verl.utils.distributed import initialize_global_process_group from verl.utils.model import compute_position_id_with_mask, create_random_mask from verl.utils.ulysses import ( - gather_outpus_and_unpad, + gather_outputs_and_unpad, get_ulysses_sequence_parallel_world_size, set_ulysses_sequence_parallel_group, ulysses_pad_and_slice_inputs, @@ -47,7 +47,9 @@ class SequenceParallelConfig: def test_configs(): return [ - SequenceParallelConfig(LlamaConfig(num_hidden_layers=2, num_attention_heads=32, num_key_value_heads=32), sp_size=8, is_valid=True), + SequenceParallelConfig( + LlamaConfig(num_hidden_layers=2, num_attention_heads=32, num_key_value_heads=32), sp_size=8, is_valid=True + ), SequenceParallelConfig( Qwen2Config(num_hidden_layers=2, num_attention_heads=28, num_key_value_heads=4, hidden_size=3584), sp_size=4, @@ -58,8 +60,12 @@ def test_configs(): sp_size=8, is_valid=False, ), - SequenceParallelConfig(Qwen2Config(num_hidden_layers=2, num_attention_heads=32, num_key_value_heads=4), sp_size=4, is_valid=True), - SequenceParallelConfig(Qwen2Config(num_hidden_layers=2, num_attention_heads=32, num_key_value_heads=4), sp_size=8, is_valid=True), + SequenceParallelConfig( + Qwen2Config(num_hidden_layers=2, num_attention_heads=32, num_key_value_heads=4), sp_size=4, is_valid=True + ), + SequenceParallelConfig( + Qwen2Config(num_hidden_layers=2, num_attention_heads=32, num_key_value_heads=4), sp_size=8, is_valid=True + ), ] @@ -86,7 +92,9 @@ def test_hf_casual_fwd_bwd(test_config): def _hf_casual_fwd(config, sp_size, dp_size): assert torch.cuda.device_count() >= 2, "need at least 2 gpus for test" - ulysses_device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(dp_size, sp_size), mesh_dim_names=("dp", "sp")) + ulysses_device_mesh = init_device_mesh( + device_type="cuda", mesh_shape=(dp_size, sp_size), mesh_dim_names=("dp", "sp") + ) sharding_manager = FSDPUlyssesShardingManager(ulysses_device_mesh) batch_size = 1 @@ -95,15 +103,21 @@ def _hf_casual_fwd(config, sp_size, dp_size): # patch before load with torch.device("cuda"): - model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2") + model = AutoModelForCausalLM.from_config( + config=config, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" + ) apply_monkey_patch(model, sp_size) model = model.to(device="cuda") sync_model_parameters_global(model) # different rank will generate different input_ids following fsdp input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seqlen), device="cuda") - attention_mask = create_random_mask(input_ids=input_ids, max_ratio_of_left_padding=0, max_ratio_of_valid_token=0.9, min_ratio_of_valid_token=0.8) - position_ids = compute_position_id_with_mask(attention_mask) # TODO(sgm): we can construct the position_ids_rmpad here + attention_mask = create_random_mask( + input_ids=input_ids, max_ratio_of_left_padding=0, max_ratio_of_valid_token=0.9, min_ratio_of_valid_token=0.8 + ) + position_ids = compute_position_id_with_mask( + attention_mask + ) # TODO(sgm): we can construct the position_ids_rmpad here model_inputs = { "input_ids": input_ids.cuda(), @@ -119,25 +133,35 @@ def _hf_casual_fwd(config, sp_size, dp_size): input_ids = model_inputs.batch["input_ids"] attention_mask = model_inputs.batch["attention_mask"] position_ids = model_inputs.batch["position_ids"] - input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), attention_mask) # input_ids_rmpad (total_nnz, ...) + input_ids_rmpad, indices, *_ = unpad_input( + input_ids.unsqueeze(-1), attention_mask + ) # input_ids_rmpad (total_nnz, ...) input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) # unpad the position_ids to align the rotary - position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices).transpose(0, 1) + position_ids_rmpad = index_first_axis( + rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices + ).transpose(0, 1) # slice input tensor for ulysses # input_ids are padded and sliced # postition_ids are only padded but not sliced - input_ids_rmpad_sliced, position_ids_rmpad_padded, pad_size = ulysses_pad_and_slice_inputs(input_ids_rmpad, position_ids_rmpad, sp_size=get_ulysses_sequence_parallel_world_size()) + input_ids_rmpad_sliced, position_ids_rmpad_padded, pad_size = ulysses_pad_and_slice_inputs( + input_ids_rmpad, position_ids_rmpad, sp_size=get_ulysses_sequence_parallel_world_size() + ) # input with input_ids_rmpad and postition_ids to enable flash attention varlen - logits_split_in_seq = model(input_ids_rmpad_sliced, position_ids=position_ids_rmpad_padded, use_cache=False).logits # (1, total_nnz/n, vocab_size) + logits_split_in_seq = model( + input_ids_rmpad_sliced, position_ids=position_ids_rmpad_padded, use_cache=False + ).logits # (1, total_nnz/n, vocab_size) # all_gather output - logits_full = gather_outpus_and_unpad(logits_split_in_seq, gather_dim=1, unpad_dim=1, padding_size=pad_size) + logits_full = gather_outputs_and_unpad(logits_split_in_seq, gather_dim=1, unpad_dim=1, padding_size=pad_size) # 2. perform normal forward set_ulysses_sequence_parallel_group(None) - logits_rmpad_local = model(input_ids_rmpad, position_ids=position_ids_rmpad, use_cache=False).logits # (1, total_nnz, vocab_size) + logits_rmpad_local = model( + input_ids_rmpad, position_ids=position_ids_rmpad, use_cache=False + ).logits # (1, total_nnz, vocab_size) mean_local = logits_rmpad_local.mean() mean_full = logits_full.mean() @@ -147,7 +171,9 @@ def _hf_casual_fwd(config, sp_size, dp_size): def _hf_casual_fwd_bwd(config, sp_size, dp_size): assert torch.cuda.device_count() >= 2, "need at least 2 gpus for test" - ulysses_device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(dp_size, sp_size), mesh_dim_names=("dp", "sp")) + ulysses_device_mesh = init_device_mesh( + device_type="cuda", mesh_shape=(dp_size, sp_size), mesh_dim_names=("dp", "sp") + ) sharding_manager = FSDPUlyssesShardingManager(ulysses_device_mesh) batch_size = 1 @@ -156,15 +182,21 @@ def _hf_casual_fwd_bwd(config, sp_size, dp_size): # patch before load with torch.device("cuda"): - model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2") + model = AutoModelForCausalLM.from_config( + config=config, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" + ) apply_monkey_patch(model, sp_size) model = model.to(device="cuda") sync_model_parameters_global(model) # different rank will generate different input_ids following fsdp input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seqlen), device="cuda") - attention_mask = create_random_mask(input_ids=input_ids, max_ratio_of_left_padding=0, max_ratio_of_valid_token=0.9, min_ratio_of_valid_token=0.8) - position_ids = compute_position_id_with_mask(attention_mask) # TODO(sgm): we can construct the position_ids_rmpad here + attention_mask = create_random_mask( + input_ids=input_ids, max_ratio_of_left_padding=0, max_ratio_of_valid_token=0.9, min_ratio_of_valid_token=0.8 + ) + position_ids = compute_position_id_with_mask( + attention_mask + ) # TODO(sgm): we can construct the position_ids_rmpad here model_inputs = { "input_ids": input_ids.cuda(), @@ -180,28 +212,38 @@ def _hf_casual_fwd_bwd(config, sp_size, dp_size): input_ids = model_inputs.batch["input_ids"] attention_mask = model_inputs.batch["attention_mask"] position_ids = model_inputs.batch["position_ids"] - input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), attention_mask) # input_ids_rmpad (total_nnz, ...) + input_ids_rmpad, indices, *_ = unpad_input( + input_ids.unsqueeze(-1), attention_mask + ) # input_ids_rmpad (total_nnz, ...) input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) # unpad the position_ids to align the rotary - position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices).transpose(0, 1) + position_ids_rmpad = index_first_axis( + rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices + ).transpose(0, 1) # slice input tensor for ulysses # input_ids are padded and sliced # postition_ids are only padded but not sliced - input_ids_rmpad_sliced, position_ids_rmpad_padded, pad_size = ulysses_pad_and_slice_inputs(input_ids_rmpad, position_ids_rmpad, sp_size=get_ulysses_sequence_parallel_world_size()) + input_ids_rmpad_sliced, position_ids_rmpad_padded, pad_size = ulysses_pad_and_slice_inputs( + input_ids_rmpad, position_ids_rmpad, sp_size=get_ulysses_sequence_parallel_world_size() + ) # input with input_ids_rmpad and postition_ids to enable flash attention varlen - logits_split_in_seq = model(input_ids_rmpad_sliced, position_ids=position_ids_rmpad_padded, use_cache=False).logits # (1, total_nnz/n, vocab_size) + logits_split_in_seq = model( + input_ids_rmpad_sliced, position_ids=position_ids_rmpad_padded, use_cache=False + ).logits # (1, total_nnz/n, vocab_size) # all_gather output - logits_full = gather_outpus_and_unpad(logits_split_in_seq, gather_dim=1, unpad_dim=1, padding_size=pad_size) + logits_full = gather_outputs_and_unpad(logits_split_in_seq, gather_dim=1, unpad_dim=1, padding_size=pad_size) # 2. perform normal forward set_ulysses_sequence_parallel_group(None) input_ids_full = copy.deepcopy(input_ids_rmpad) position_ids_full = copy.deepcopy(position_ids_rmpad) model_no_sp = copy.deepcopy(model) - logits_rmpad_local = model_no_sp(input_ids_full, position_ids=position_ids_full, use_cache=False).logits # (1, total_nnz, vocab_size) + logits_rmpad_local = model_no_sp( + input_ids_full, position_ids=position_ids_full, use_cache=False + ).logits # (1, total_nnz, vocab_size) mean_local = logits_rmpad_local.mean() mean_full = logits_full.mean() diff --git a/tests/ray_cpu/test_check_worker_alive.py b/tests/ray_cpu/test_check_worker_alive.py deleted file mode 100644 index 1596fd3c9..000000000 --- a/tests/ray_cpu/test_check_worker_alive.py +++ /dev/null @@ -1,47 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import subprocess -import time - - -def test(): - wait_time = 10 - - my_env = os.environ.copy() - my_env["WAIT_TIME"] = str(wait_time) - - p = subprocess.Popen(["python3", "-u", "./check_worker_alive/main.py"], env=my_env, stdout=subprocess.PIPE) - - count = 0 - while b"foo started" not in p.stdout.read(): - time.sleep(1) - count += 1 - if count > 40: - raise RuntimeError("timeout for start foo in check_worker_alive/main.py") - - print( - time.time(), - f"wait 1.5 wait time {wait_time * 1.5} to let signal returned to process but still not exceed process wait time", - ) - time.sleep(wait_time * 1.5) - print(time.time(), "start checking") - assert p.poll() is not None, f"process {p} still alive, expecting signal raised abort" - assert p.returncode != 0, f"process {p} exit with code 0, expecting not-zero exit code" - print("test passed") - - -if __name__ == "__main__": - test() diff --git a/tests/reward/test_codegen_reward.py b/tests/reward/test_codegen_reward.py deleted file mode 100644 index 89b99caca..000000000 --- a/tests/reward/test_codegen_reward.py +++ /dev/null @@ -1,33 +0,0 @@ -from verl.utils.reward_score.coder1 import compute_score - - -def test_tablereason_reward(): - output = "\\boxed{aboriginal population}" - ground_truth = "aboriginal population" - score = compute_score(output, ground_truth)["score"] - assert score == 1.0 - - output = "To determine the most popular religion among all French-language workers in New Brunswick's agri-food sector, we need to compare the percentages of French-language workers belonging to each religious group across the different regions.\n\nHere are the percentages for French-language workers in each region:\n- Agricultural region 1: \n - Catholic: 94.2%\n - Pentecostal: 1.5%\n - Anglican: 0.0%\n - Presbyterian: 0.0%\n - United Church: 0.0%\n - Other Christian: 0.7%\n - Buddhist, Hindu, Jewish, Muslim, and Sikh: 0.0%\n - Other religions: 0.0%\n - No religious affiliation: 3.3%\n\n- Agricultural region 3:\n - Catholic: 87.2%\n - Pentecostal: 0.5%\n - Anglican: 0.0%\n - Presbyterian: 0.0%\n - United Church: 0.0%\n - Other Christian: 1.7%\n - Buddhist, Hindu, Jewish, Muslim, and Sikh: 1.1%\n - Other religions: 0.0%\n - No religious affiliation: 9.4%\n\n- Agricultural region 4:\n - Catholic: 96.4%\n - Pentecostal: 0.0%\n - Anglican: 0.0%\n - Presbyterian: 0.0%\n - United Church: 0.0%\n - Other Christian: 0.6%\n - Buddhist, Hindu, Jewish, Muslim, and Sikh: 0.0%\n - Other religions: 0.0%\n - No religious affiliation: 2.3%\n\nFrom the data, we can see that the most common religion among French-language workers across all regions is Catholic, with percentages of 94.2%, 87.2%, and 96.4% in agricultural regions 1, 3, and 4, respectively.\n\nTherefore, the most popular religion among all French-language workers in New Brunswick's agri-food sector is \\boxed{catholic}." - ground_truth = "catholic" - score = compute_score(output, ground_truth)["score"] - assert score == 1.0 - - output = " Aboriginal population" - ground_truth = "aboriginal population" - score = compute_score(output, ground_truth)["score"] - assert score == 1.0 - - output = "To solve this problem, we first need to identify the intramural R&D expenditures for the specified states from the table, which are New York, California, Florida, South Carolina, and Connecticut. Then, we will calculate the percentage each state contributed to the total intramural R&D expenditures of $582 million.\n\nFrom the table, we can extract the following intramural R&D expenditures:\n- New York: 241,921\n- California: 55,223\n- Florida: 45,016\n- South Carolina: 22,352\n- Connecticut: 21,886\n\nNext, we calculate the total intramural R&D expenditures for these five states:\n\\[ 241,921 + 55,223 + 45,016 + 22,352 + 21,886 = 386,398 \\]\n\nNow, we need to find the percentage of the total $582 million that these five states account for:\n\\[ \\text{Percentage} = \\left( \\frac{386,398}{582,000,000} \\right) \\times 100 \\approx 0.066435 \\times 100 \\approx 6.64\\% \\]\n\nThus, the five states account for approximately 6.64% of the $582 million of intramural R&D performed by state agencies.\n\nTherefore, the answer is:\n\\boxed{6.64%}" - ground_truth = "6.64%" - score = compute_score(output, ground_truth)["score"] - assert score == 1.0 - - output = "\\boxed{threatening or harassing phone calls | forcible confinement, kidnapping or abduction|criminal harassment}" - ground_truth = "forcible confinement, kidnapping or abduction|criminal harassment|threatening or harassing phone calls" - score = compute_score(output, ground_truth)["score"] - assert score == 1.0 - - output = "\nOkay, let's try to figure out how many appearances McIndoe made for Doncaster Rovers. First, I need to look at the table provided and find the relevant information.\n\nLooking at the table, there's a row for Doncaster Rovers. The seasons listed are 2003-04, 2004-05, 2005-06, and then a total. Under the \"league\" column, for each season, there are numbers under \"apps\" which I assume stands for appearances. \n\nFor the 2003-04 season, Doncaster Rovers have 45 apps in the third division. Then in 2004-05, they have 44 apps in League One. In 2005-06, 33 apps in League One again. The total for Doncaster Rovers is listed as 122 apps in the league. But the question is about appearances for Doncaster Rovers. However, the table also includes other competitions like national cup, league cup, and other. But the user is asking specifically about appearances for Doncaster Rovers. \n\nWait, but the question is about McIndoe's appearances. Wait, the table's columns are club, season, league, then national cup, league cup, other, and total. But the data here might be for a player's stats. But the problem is, the table's structure is a bit confusing. Let me check again.\n\nLooking at the headers: club, season, league, then three columns (division, apps, goals), then national cup, then league cup, then other, then total. Wait, the first row after the headers is \"luton town\" with 1998-99 season, second division, 22 apps, 0 goals, then national cup apps 2, goals 0, etc. So for each club and season, the data is broken down into league, national cup, league cup, other, and total. \n\nSo for Doncaster Rovers, the entries are:\n\n2003-04: third division, 45 apps, 10 goals; national cup apps 1, goals 0; league cup apps 2, goals 0; other apps 2, goals 0; total apps 50, goals 10.\n\n2004-05: league one, 44 apps, 10 goals; national cup 2 apps, 1 goal; league cup 3 apps, 1 goal; other 2 apps, 0 goals; total 51 apps, 12 goals.\n\n2005-06: league one, 33 apps, 8 goals; national cup 3 apps, 2 goals; league cup 5 apps, 3 goals; other 0 apps, 0 goals; total 41 apps, 13 goals.\n\nThen the total for Doncaster Rovers is 122 league apps, 6 national cup apps, 10 league cup apps, 4 other apps, and total 142 apps, 35 goals.\n\nBut the question is asking for the number of appearances McIndoe made for Doncaster Rovers. However, the table doesn't mention McIndoe's name anywhere. Wait, maybe the table is about a player's career? But the clubs listed are different, like Luton Town, Hereford United, etc. But the user is asking about McIndoe. Maybe the table is for McIndoe's career? But the problem is that the table doesn't have a player's name. Maybe the user is assuming that the table is about McIndoe? Or perhaps there's a mistake here.\n\nWait, the original question is \"how many appearances did mcindoe make for the doncaster rovers?\" But the table provided doesn't have a player's name. Maybe the table is for a player named McIndoe? But the table's first column is \"club\", so maybe each row is a club's stats, but that doesn't make sense. Alternatively, maybe the table is for a player's stats across different clubs. However, the way the table is structured, each row is a club-season, and the columns are for different competitions. For example, for Luton Town 1998-99, under league (second division), they have 22 apps, 0 goals. Then national cup apps 2, goals 0, etc. So this seems like a player's stats across different clubs. But the player's name isn't mentioned. Maybe the user is assuming that the table is for McIndoe, but the question is missing that info. \n\nAlternatively, maybe the user made a mistake in the question, but given the table, we have to answer based on the data provided. However, since the table doesn't mention McIndoe, but the question is about McIndoe's appearances for Doncaster Rovers, perhaps the answer is that there's no data available. But that seems unlikely. Maybe the table is for McIndoe's career, and the user is asking for that. \n\nAssuming that the table is for McIndoe's career, then for Doncaster Rovers, the total appearances would be the \"total\" column for Doncaster Rovers. Looking at the Doncaster Rovers row, the total is 142 apps. But the question is specifically about Doncaster Rovers. However, the total for Doncaster Rovers is 142 apps across all competitions (league, national cup, league cup, other). But if the question is asking for total appearances for Doncaster Rovers, then it's 142. However, if the question is asking for league appearances only, it would be 122. But the question says \"appearances\" without specifying, so it's ambiguous. However, in football stats, \"appearances\" usually refers to all competitions unless specified otherwise. \n\nBut let me check the table again. For Doncaster Rovers, the total is 142 apps. But looking at the individual seasons:\n\n2003-04: total apps 50 (league 45, national cup 1, league cup 2, other 2)\n2004-05: total 51 (44 + 2 + 3 + 2)\n2005-06: total 41 (33 + 3 +5 +0)\nAdding those: 50 + 51 +41 = 142. Yes, that matches the total. \n\nSo if the question is asking for total appearances for Doncaster Rovers, the answer is 142. However, the problem is that the table doesn't mention McIndoe. But since the user is asking about McIndoe, perhaps the table is for McIndoe's career. Maybe the user assumes that the table is for McIndoe. But that's an assumption. However, given that the question is about McIndoe and the table is provided, perhaps that's the case. \n\nAlternatively, maybe the user made a mistake in the question, but given the information, the answer would be 142. However, let me check again. The user's question is \"how many appearances did mcindoe make for the doncaster rovers?\" and the table has a row for Doncaster Rovers with total apps 142. So if we assume that the table is about McIndoe's career, then the answer is 142. But since the table's first column is \"club\", maybe it's the club's stats, not a player's. But that would make the question impossible. \n\nAlternatively, maybe the user intended the table to be for a player, and there's a formatting issue. For example, maybe the \"club\" column is the player's name, but that doesn't make sense. \n\nAlternatively, perhaps the table is for a player's stats across different clubs, and the \"club\" column is the club he played for. So for example, the first row is Luton Town, and the stats are for that player's appearances there. If that's the case, then for Doncaster Rovers, the total appearances would be 142. So the answer is 142. \n\nGiven that, I think the answer is 142. But I need to make sure. Let me check the \"career total\" row. The career total is 568 apps. For Doncaster Rovers, the total is 142. So if the question is about Doncaster Rovers, the answer is 142. \n\nBut the problem is that the table doesn't have the player's name. However, since the question is about McIndoe, and the table is provided, perhaps it's implied that the table is for McIndoe. Therefore, the answer is 142.\n\n\nThe question asks for the number of appearances McIndoe made for Doncaster Rovers. Based on the table provided, the total appearances for Doncaster Rovers across all competitions (league, national cup, league cup, and other) are listed as **142**. This figure is derived from the \"total\" column under the Doncaster Rovers row, which aggregates all appearances across the 2003-04, 2004-05, and 2005-06 seasons.\n\n$$\n\\boxed{142}\n$$", - ground_truth = "142.0" - score = compute_score(output, ground_truth)["score"] - assert score == 1.0 diff --git a/tests/reward/test_tablereason_reward.py b/tests/reward/test_tablereason_reward.py deleted file mode 100644 index 6c10a720e..000000000 --- a/tests/reward/test_tablereason_reward.py +++ /dev/null @@ -1,46 +0,0 @@ -from verl.utils.reward_score.tablereason import compute_score - - -def test_tablereason_reward(): - output = "\\boxed{aboriginal population}" - ground_truth = "aboriginal population" - score = compute_score(output, ground_truth)["score"] - assert score == 1.0 - - output = "To determine the most popular religion among all French-language workers in New Brunswick's agri-food sector, we need to compare the percentages of French-language workers belonging to each religious group across the different regions.\n\nHere are the percentages for French-language workers in each region:\n- Agricultural region 1: \n - Catholic: 94.2%\n - Pentecostal: 1.5%\n - Anglican: 0.0%\n - Presbyterian: 0.0%\n - United Church: 0.0%\n - Other Christian: 0.7%\n - Buddhist, Hindu, Jewish, Muslim, and Sikh: 0.0%\n - Other religions: 0.0%\n - No religious affiliation: 3.3%\n\n- Agricultural region 3:\n - Catholic: 87.2%\n - Pentecostal: 0.5%\n - Anglican: 0.0%\n - Presbyterian: 0.0%\n - United Church: 0.0%\n - Other Christian: 1.7%\n - Buddhist, Hindu, Jewish, Muslim, and Sikh: 1.1%\n - Other religions: 0.0%\n - No religious affiliation: 9.4%\n\n- Agricultural region 4:\n - Catholic: 96.4%\n - Pentecostal: 0.0%\n - Anglican: 0.0%\n - Presbyterian: 0.0%\n - United Church: 0.0%\n - Other Christian: 0.6%\n - Buddhist, Hindu, Jewish, Muslim, and Sikh: 0.0%\n - Other religions: 0.0%\n - No religious affiliation: 2.3%\n\nFrom the data, we can see that the most common religion among French-language workers across all regions is Catholic, with percentages of 94.2%, 87.2%, and 96.4% in agricultural regions 1, 3, and 4, respectively.\n\nTherefore, the most popular religion among all French-language workers in New Brunswick's agri-food sector is \\boxed{catholic}." - ground_truth = "catholic" - score = compute_score(output, ground_truth)["score"] - assert score == 1.0 - - output = " Aboriginal population" - ground_truth = "aboriginal population" - score = compute_score(output, ground_truth)["score"] - assert score == 1.0 - - output = "To solve this problem, we first need to identify the intramural R&D expenditures for the specified states from the table, which are New York, California, Florida, South Carolina, and Connecticut. Then, we will calculate the percentage each state contributed to the total intramural R&D expenditures of $582 million.\n\nFrom the table, we can extract the following intramural R&D expenditures:\n- New York: 241,921\n- California: 55,223\n- Florida: 45,016\n- South Carolina: 22,352\n- Connecticut: 21,886\n\nNext, we calculate the total intramural R&D expenditures for these five states:\n\\[ 241,921 + 55,223 + 45,016 + 22,352 + 21,886 = 386,398 \\]\n\nNow, we need to find the percentage of the total $582 million that these five states account for:\n\\[ \\text{Percentage} = \\left( \\frac{386,398}{582,000,000} \\right) \\times 100 \\approx 0.066435 \\times 100 \\approx 6.64\\% \\]\n\nThus, the five states account for approximately 6.64% of the $582 million of intramural R&D performed by state agencies.\n\nTherefore, the answer is:\n\\boxed{6.64%}" - ground_truth = "6.64%" - score = compute_score(output, ground_truth)["score"] - assert score == 1.0 - - output = "\\boxed{threatening or harassing phone calls | forcible confinement, kidnapping or abduction|criminal harassment}" - ground_truth = "forcible confinement, kidnapping or abduction|criminal harassment|threatening or harassing phone calls" - score = compute_score(output, ground_truth)["score"] - assert score == 1.0 - - output = "\nOkay, let's try to figure out how many appearances McIndoe made for Doncaster Rovers. First, I need to look at the table provided and find the relevant information.\n\nLooking at the table, there's a row for Doncaster Rovers. The seasons listed are 2003-04, 2004-05, 2005-06, and then a total. Under the \"league\" column, for each season, there are numbers under \"apps\" which I assume stands for appearances. \n\nFor the 2003-04 season, Doncaster Rovers have 45 apps in the third division. Then in 2004-05, they have 44 apps in League One. In 2005-06, 33 apps in League One again. The total for Doncaster Rovers is listed as 122 apps in the league. But the question is about appearances for Doncaster Rovers. However, the table also includes other competitions like national cup, league cup, and other. But the user is asking specifically about appearances for Doncaster Rovers. \n\nWait, but the question is about McIndoe's appearances. Wait, the table's columns are club, season, league, then national cup, league cup, other, and total. But the data here might be for a player's stats. But the problem is, the table's structure is a bit confusing. Let me check again.\n\nLooking at the headers: club, season, league, then three columns (division, apps, goals), then national cup, then league cup, then other, then total. Wait, the first row after the headers is \"luton town\" with 1998-99 season, second division, 22 apps, 0 goals, then national cup apps 2, goals 0, etc. So for each club and season, the data is broken down into league, national cup, league cup, other, and total. \n\nSo for Doncaster Rovers, the entries are:\n\n2003-04: third division, 45 apps, 10 goals; national cup apps 1, goals 0; league cup apps 2, goals 0; other apps 2, goals 0; total apps 50, goals 10.\n\n2004-05: league one, 44 apps, 10 goals; national cup 2 apps, 1 goal; league cup 3 apps, 1 goal; other 2 apps, 0 goals; total 51 apps, 12 goals.\n\n2005-06: league one, 33 apps, 8 goals; national cup 3 apps, 2 goals; league cup 5 apps, 3 goals; other 0 apps, 0 goals; total 41 apps, 13 goals.\n\nThen the total for Doncaster Rovers is 122 league apps, 6 national cup apps, 10 league cup apps, 4 other apps, and total 142 apps, 35 goals.\n\nBut the question is asking for the number of appearances McIndoe made for Doncaster Rovers. However, the table doesn't mention McIndoe's name anywhere. Wait, maybe the table is about a player's career? But the clubs listed are different, like Luton Town, Hereford United, etc. But the user is asking about McIndoe. Maybe the table is for McIndoe's career? But the problem is that the table doesn't have a player's name. Maybe the user is assuming that the table is about McIndoe? Or perhaps there's a mistake here.\n\nWait, the original question is \"how many appearances did mcindoe make for the doncaster rovers?\" But the table provided doesn't have a player's name. Maybe the table is for a player named McIndoe? But the table's first column is \"club\", so maybe each row is a club's stats, but that doesn't make sense. Alternatively, maybe the table is for a player's stats across different clubs. However, the way the table is structured, each row is a club-season, and the columns are for different competitions. For example, for Luton Town 1998-99, under league (second division), they have 22 apps, 0 goals. Then national cup apps 2, goals 0, etc. So this seems like a player's stats across different clubs. But the player's name isn't mentioned. Maybe the user is assuming that the table is for McIndoe, but the question is missing that info. \n\nAlternatively, maybe the user made a mistake in the question, but given the table, we have to answer based on the data provided. However, since the table doesn't mention McIndoe, but the question is about McIndoe's appearances for Doncaster Rovers, perhaps the answer is that there's no data available. But that seems unlikely. Maybe the table is for McIndoe's career, and the user is asking for that. \n\nAssuming that the table is for McIndoe's career, then for Doncaster Rovers, the total appearances would be the \"total\" column for Doncaster Rovers. Looking at the Doncaster Rovers row, the total is 142 apps. But the question is specifically about Doncaster Rovers. However, the total for Doncaster Rovers is 142 apps across all competitions (league, national cup, league cup, other). But if the question is asking for total appearances for Doncaster Rovers, then it's 142. However, if the question is asking for league appearances only, it would be 122. But the question says \"appearances\" without specifying, so it's ambiguous. However, in football stats, \"appearances\" usually refers to all competitions unless specified otherwise. \n\nBut let me check the table again. For Doncaster Rovers, the total is 142 apps. But looking at the individual seasons:\n\n2003-04: total apps 50 (league 45, national cup 1, league cup 2, other 2)\n2004-05: total 51 (44 + 2 + 3 + 2)\n2005-06: total 41 (33 + 3 +5 +0)\nAdding those: 50 + 51 +41 = 142. Yes, that matches the total. \n\nSo if the question is asking for total appearances for Doncaster Rovers, the answer is 142. However, the problem is that the table doesn't mention McIndoe. But since the user is asking about McIndoe, perhaps the table is for McIndoe's career. Maybe the user assumes that the table is for McIndoe. But that's an assumption. However, given that the question is about McIndoe and the table is provided, perhaps that's the case. \n\nAlternatively, maybe the user made a mistake in the question, but given the information, the answer would be 142. However, let me check again. The user's question is \"how many appearances did mcindoe make for the doncaster rovers?\" and the table has a row for Doncaster Rovers with total apps 142. So if we assume that the table is about McIndoe's career, then the answer is 142. But since the table's first column is \"club\", maybe it's the club's stats, not a player's. But that would make the question impossible. \n\nAlternatively, maybe the user intended the table to be for a player, and there's a formatting issue. For example, maybe the \"club\" column is the player's name, but that doesn't make sense. \n\nAlternatively, perhaps the table is for a player's stats across different clubs, and the \"club\" column is the club he played for. So for example, the first row is Luton Town, and the stats are for that player's appearances there. If that's the case, then for Doncaster Rovers, the total appearances would be 142. So the answer is 142. \n\nGiven that, I think the answer is 142. But I need to make sure. Let me check the \"career total\" row. The career total is 568 apps. For Doncaster Rovers, the total is 142. So if the question is about Doncaster Rovers, the answer is 142. \n\nBut the problem is that the table doesn't have the player's name. However, since the question is about McIndoe, and the table is provided, perhaps it's implied that the table is for McIndoe. Therefore, the answer is 142.\n\n\nThe question asks for the number of appearances McIndoe made for Doncaster Rovers. Based on the table provided, the total appearances for Doncaster Rovers across all competitions (league, national cup, league cup, and other) are listed as **142**. This figure is derived from the \"total\" column under the Doncaster Rovers row, which aggregates all appearances across the 2003-04, 2004-05, and 2005-06 seasons.\n\n$$\n\\boxed{142}\n$$", - ground_truth = "142.0" - score = compute_score(output, ground_truth)["score"] - assert score == 1.0 - - output = "\nOkay, let's try to figure out the answer to this question. The question is asking: \"Among which age group people engaging in utilitarian walking more than 5 hours a week was more prevalent?\" So, I need to look at the table provided and find which age group has the highest percentage of people who walk utilitarian for more than 5 hours a week.\n\nFirst, I need to understand the table structure. The table has several rows for different characteristics like sex, education, household income, age group, and daily physical activity. Each of these has columns for \"less than 1 hour,\" \"1 to 5 hours,\" and \"more than 5 hours,\" with percentages and 95% confidence intervals. There are also some numbers at the end of each row, probably chi-square statistics and p-values, but maybe that's not directly relevant here.\n\nThe question is about the age group, so I should focus on the \"age group\" section. Let me locate that part of the table. The rows under \"age group\" are \"20 to 39,\" \"40 to 59,\" and \"60 to 79.\" For each of these, there are percentages for the three categories: less than 1 hour, 1 to 5 hours, and more than 5 hours.\n\nThe key here is to look at the \"more than 5 hours\" column for each age group. The question is asking which age group had a higher prevalence (i.e., higher percentage) of people engaging in utilitarian walking more than 5 hours a week. \n\nLooking at the \"more than 5 hours\" column for each age group:\n\n- For \"20 to 39\": The percentage is 19.8% (with confidence interval 16.4 to 23.7).\n- For \"40 to 59\": The percentage is 16.9% (14.5 to 19.5).\n- For \"60 to 79\": The percentage is 15.9% (13.8 to 18.4).\n\nComparing these percentages: 19.8% (20-39) is higher than 16.9% (40-59) and 15.9% (60-79). Therefore, the age group with the highest prevalence of engaging in utilitarian walking more than 5 hours a week is \"20 to 39.\"\n\nWait, but I should double-check to make sure I'm looking at the correct columns. Let me confirm:\n\nThe columns are labeled as \"less than 1 hour,\" then \"1 to 5 hours,\" then \"more than 5 hours.\" The headers for the columns are a bit confusing because there are multiple columns with \"from\" and \"to\" for the confidence intervals, but the percentages for each time category are in the first column of each set. For example, under \"less than 1 hour,\" the first number is the percentage, then the next two are the confidence interval. Then for \"1 to 5 hours,\" same structure, and then \"more than 5 hours.\"\n\nSo yes, the numbers under \"more than 5 hours\" for each age group are 19.8, 16.9, and 15.9. So 20-39 has the highest percentage. \n\nI don't see any other age groups listed here. The table only has three age groups. So the answer should be \"20 to 39.\"\n\nBut wait, the question says \"more than 5 hours a week.\" The table's \"more than 5 hours\" column is for that. So that's correct. \n\nI don't see any other information that would contradict this. The p-values and chi-square statistics might indicate if the differences are statistically significant, but the question isn't asking about significance, just which group has a higher prevalence. So even if the differences are significant, the answer is based on the percentages. \n\nSo the final answer is the age group 20 to 39.\n\n\nThe question asks which age group had a higher prevalence of engaging in utilitarian walking for more than 5 hours a week. By examining the \"more than 5 hours\" column for each age group in the table:\n\n- **20 to 39 years**: 19.8% \n- **40 to 59 years**: 16.9% \n- **60 to 79 years**: 15.9% \n\nThe highest percentage is found in the **20 to 39 age group**.\n\n$$\n\\boxed{20\\ \\text{to}\\ 39}\n$$", - ground_truth = "20 to 39" - score = compute_score(output, ground_truth)["score"] - assert score == 1.0 - - output = "\nOkay, let's see. The user is asking what the total liabilities will be in 2015 if they continue to grow at the current rate. Hmm. So I need to figure out the growth rate of total liabilities from the given data and then project that into 2015.\n\nFirst, I should check the tables provided. Let me look through the balance sheet data. There's a section with \"Total liabilities\" for different years. Let me find that. \n\nLooking at the table under \"BALANCE SHEET DATA:\", there's a row for Total liabilities. The values are:\n\n- 2014: $113,589 million\n- 2013: $102,958 million\n- 2012: $102,924 million\n- 2011: $106,261 million\n- 2010: $106,995 million\n\nWait, but the question is about 2015. So the user wants to know the projected value for 2015 based on the current growth rate. But what's the current growth rate? The data available goes up to 2014. So I need to calculate the growth rate from previous years and apply it to 2014 to get 2015.\n\nBut first, let me check the growth rates between the years. Let's calculate the year-over-year growth for Total liabilities.\n\nFrom 2010 to 2011: 106,995 to 106,261. Wait, that's a decrease. Let me compute the percentage change. \n\nThe formula for percentage change is ((New Value - Old Value)/Old Value)*100.\n\nFor 2010 to 2011: (106,261 - 106,995)/106,995 * 100 = (-734)/106,995 \u2248 -0.686%. So a decrease.\n\nFrom 2011 to 2012: 102,924 - 106,261 = -3,337. (-3,337 / 106,261) * 100 \u2248 -3.14%.\n\nFrom 2012 to 2013: 102,958 - 102,924 = 34. (34 / 102,924) * 100 \u2248 0.033%.\n\nFrom 2013 to 2014: 113,589 - 102,958 = 10,631. (10,631 / 102,958) * 100 \u2248 10.32%.\n\nSo the growth rates are volatile. From 2010 to 2014, there's a big increase in 2013-2014. But the question is about the \"current rate\". But what's the current rate? If the user is asking as of 2014, then the latest growth rate would be from 2013 to 2014, which is 10.32%. However, maybe they want to calculate the average growth rate over the years, or maybe the most recent trend.\n\nBut the problem is that the data only goes up to 2014. So if we are to assume that the growth rate from 2013 to 2014 continues into 2015, then we can apply that 10.32% growth rate to 2014's value.\n\nAlternatively, maybe the user expects using the average growth rate over the years. But given that the question says \"continues to grow at its current rate\", it's more likely referring to the most recent rate, which is the 2013-2014 growth rate of ~10.32%.\n\nSo let's calculate that. \n\nTotal liabilities in 2014: $113,589 million.\n\nIf it grows by 10.32% in 2015, then:\n\n2015 liabilities = 113,589 * (1 + 0.1032) = 113,589 * 1.1032 \u2248 ?\n\nLet me compute that. \n\nFirst, 113,589 * 1.1 = 124,947.9\n\nThen, 113,589 * 0.0032 = 363.4848\n\nAdding them together: 124,947.9 + 363.4848 \u2248 125,311.38 million.\n\nBut wait, the exact growth rate from 2013 to 2014 was 10,631 / 102,958 = let me calculate that more accurately.\n\n10,631 divided by 102,958 equals approximately 0.10323, which is 10.323%. So 10.323%.\n\nSo 113,589 * 1.10323 = ?\n\nAlternatively, maybe I should calculate it as:\n\n113,589 + (113,589 * 0.10323) = ?\n\nBut perhaps a better way is to compute 113,589 * 1.10323.\n\nAlternatively, since the exact value of the change from 2013 to 2014 was $10,631 million. So the growth rate is 10,631 / 102,958 \u2248 0.10323.\n\nSo if we apply the same growth rate to 2014's total liabilities, then the 2015 value would be 113,589 * (1 + 0.10323) \u2248 113,589 * 1.10323.\n\nLet me compute this:\n\nFirst, multiply 113,589 by 1.10323.\n\nBreaking it down:\n\n113,589 * 1 = 113,589\n\n113,589 * 0.10 = 11,358.9\n\n113,589 * 0.00323 \u2248 113,589 * 0.003 = 340.767, and 113,589 * 0.00023 \u2248 26.125. So total \u2248 340.767 + 26.125 \u2248 366.892.\n\nSo adding up: 113,589 + 11,358.9 + 366.892 \u2248 113,589 + 11,725.792 \u2248 125,314.79 million.\n\nSo approximately $125,315 million. But maybe the answer expects using the exact percentage change.\n\nAlternatively, maybe the problem expects using the average annual growth rate over the years. However, given that the question says \"current rate\", which likely refers to the most recent growth rate, which is 2013-2014.\n\nBut another thought: perhaps the user is considering the growth from 2014 to 2015 as the same as the previous year's growth. However, the problem is that the data only gives up to 2014. So if we take the latest growth rate (from 2013 to 2014), then apply it to 2014 to get 2015.\n\nAlternatively, maybe the question is expecting a different approach. Let me check the data again.\n\nLooking at the Total liabilities for 2014 is 113,589. The previous year, 2013, was 102,958. The change is +10,631. So the growth rate is 10,631 / 102,958 = approx 10.32%.\n\nSo if this rate continues, then 2015 would be 113,589 * 1.1032 \u2248 125,315 million.\n\nBut I need to check if there's another way. For example, maybe the question is referring to the growth rate between 2014 and 2015 as per the previous years' trends. However, without more data points, it's hard to calculate an average. But since only the 2013-2014 growth is available, that's the only rate we can use.\n\nAlternatively, maybe the question is expecting to use the percent change from the \"As of December 31\" table. Wait, looking at the tables, there's another table with \"Total liabilities\" for 2014, 2013, etc. But that's the same as what I already considered.\n\nAnother possibility: maybe the question is referring to the \"Change\" column in the first table. Wait, the first table has a \"Change\" column for the commitment amounts. But that's for commitment amounts, not liabilities. So probably not relevant.\n\nSo, based on the available data, the answer would be Total liabilities in 2015 = 113,589 * (1 + 10.32%) \u2248 125,315 million.\n\nBut I should check if there's a different interpretation. For example, maybe the question is referring to the \"Percent\" column in the first table. But that's for the commitment amounts, not liabilities. So no.\n\nAlternatively, maybe the question is looking for a different approach. For instance, if the user considers the growth from 2010 to 2014 and extrapolates. Let me check that.\n\nFrom 2010: 106,995\n\n2014: 113,589\n\nThe total growth over 4 years (2010-2014). But that's 4 years. However, the question is about 2015, so one more year. But without knowing the exact growth rate, it's hard. But maybe calculate the CAGR (Compound Annual Growth Rate).\n\nCAGR = (Ending Value / Beginning Value)^(1/n) - 1\n\nWhere n is the number of years. From 2010 to 2014 is 4 years.\n\nSo CAGR = (113,589 / 106,995)^(1/4) - 1.\n\nCalculate 113,589 / 106,995 \u2248 1.0616.\n\nThen, 1.0616^(1/4) \u2248 e^(ln(1.0616)/4) \u2248 ln(1.0616) \u2248 0.0598. Divided by 4 is ~0.01495. So e^0.01495 \u2248 1.0151, so approximately 1.51% CAGR.\n\nIf we apply that to 2014's value, 113,589 * 1.0151 \u2248 115,310 million. But this is different from the previous 125,315. However, the question says \"current rate\", which might imply the most recent rate rather than the average over the past 4 years.\n\nBut since the user hasn't specified, but the question is \"continues to grow at its current rate\", which would refer to the latest growth rate, which is 10.32% as calculated from 2013-2014. So I think that's the right approach.\n\nTherefore, the answer would be approximately $125,315 million. But I need to check if there's any other data that I missed.\n\nWait, looking back at the tables, there's another table with \"Other Balance Sheet Data\" that includes Capital ratios, but nothing about liabilities growth. So no, the only data is what I already considered.\n\nSo, the answer is Total liabilities in 2015 = 113,589 * 1.1032 \u2248 125,315 million. But to be precise, let me calculate it exactly.\n\nExact calculation:\n\n10,631 / 102,958 = 0.10323 (approx 10.323%)\n\nSo 113,589 * 1.10323 = ?\n\nLet me compute 113,589 * 1.10323:\n\nFirst, 113,589 * 1 = 113,589\n\n113,589 * 0.1 = 11,358.9\n\n113,589 * 0.00323 = ?\n\nLet me compute 113,589 * 0.003 = 340.767\n\n113,589 * 0.00023 = 26.12547\n\nTotal 340.767 + 26.12547 = 366.89247\n\nSo adding up:\n\n113,589 + 11,358.9 + 366.89247 = 113,589 + 11,725.79247 = 125,314.79247 million.\n\nSo approximately $125,314.79 million, which rounds to $125,315 million.\n\nBut the problem says \"in millions\", so we can present it as 125,315 million. However, maybe the answer expects a different method, like using the percentage change from the previous year's total liabilities. But that's exactly what I did.\n\nAlternatively, maybe there's a different way. For example, looking at the \"Change\" column in the first table. But that's for commitment amounts, not liabilities. So no.\n\nTherefore, I think the answer is approximately $125,315 million.\n\n\nTo determine what the **Total liabilities** will be in **2015** if it continues to grow at its **current rate**, we must analyze the **growth rate of Total liabilities** between the most recent years provided in the data.\n\n---\n\n### Step 1: Identify the Relevant Data\n\nFrom the **Balance Sheet Data** table, we find the **Total liabilities** for the years 2013 and 2014:\n\n| Year | Total Liabilities (in millions) |\n|------|----------------------------------|\n| 2013 | $102,958 |\n| 2014 | $113,589 |\n\n---\n\n### Step 2: Calculate the Growth Rate from 2013 to 2014\n\nThe growth rate is calculated as:\n\n$$\n\\text{Growth Rate} = \\frac{\\text{2014 Liabilities} - \\text{2013 Liabilities}}{\\text{2013 Liabilities}} \\times 100\n$$\n\n$$\n\\text{Growth Rate} = \\frac{113,589 - 102,958}{102,958} \\times 100 = \\frac{10,631}{102,958} \\times 100 \\approx 10.323\\%\n$$\n\n---\n\n### Step 3: Apply the Growth Rate to Project 2015 Liabilities\n\nTo estimate the **Total liabilities in 2015**, we apply the 10.323% growth rate to the 2014 value:\n\n$$\n\\text{2015 Liabilities} = 113,589 \\times (1 + 0.10323)\n$$\n\n$$\n\\text{2015 Liabilities} = 113,589 \\times 1.10323 \\approx 125,314.79\n$$\n\n---\n\n### Final Answer\n\n$$\n\\boxed{125315}\n$$" - ground_truth = "125317.71131" - score = compute_score(output, ground_truth)["score"] - assert score == 1.0 - - output = "\nOkay, let me try to figure out the answer to this question. The user is asking, \"what is the incarceration rate of total correctional services?\" \n\nFirst, I need to look at the table provided and understand what each column represents. The table has several columns, including \"jurisdiction,\" \"custody,\" \"community supervision,\" \"total correctional services,\" and others. The user is specifically asking about the \"incarceration rate of total correctional services.\" \n\nLooking at the columns, there's a \"total correctional services\" section with \"number\" and \"rate\" columns. The question is about the rate, so I need to find the \"rate\" under the \"total correctional services\" column. \n\nBut wait, the question is about the incarceration rate. However, the table has separate columns for \"custody\" and \"community supervision,\" and then \"total correctional services.\" The \"custody\" column likely refers to the incarceration rate, as custody would mean being in prison or jail. The \"total correctional services\" would include both custody and community supervision. But the user is asking specifically for the incarceration rate, which might be under the \"custody\" section. However, the question says \"total correctional services,\" which might be confusing. \n\nWait, let me check the exact wording again. The question is \"incarceration rate of total correctional services.\" But \"total correctional services\" might include both incarceration (custody) and community supervision. However, the term \"incarceration rate\" typically refers to the rate of people in custody (prison). But maybe in this table's context, the \"total correctional services\" rate is the combined rate. But the user might be asking for the incarceration rate within the total correctional services. \n\nAlternatively, maybe the user is confused and actually wants the incarceration rate from the custody column. Let me check the table structure again. \n\nLooking at the table headers:\n\n- The first row has \"jurisdiction,\" \"custody,\" then \"percent change incarceration rate from 2012/2013 to 2013/2014,\" then \"community supervision,\" \"percent change in community supervision rate...\", then \"total correctional services,\" \"percent change in total rate...\", and finally \"percent change in total rate from 2009/2010 to 2013/2014.\"\n\nUnder \"custody,\" there are columns \"number\" and \"rate.\" Similarly for \"community supervision\" and \"total correctional services.\" So the \"rate\" under \"custody\" is the incarceration rate (since custody is where people are incarcerated). The \"total correctional services\" rate would be the combined rate of both custody and community supervision. But the question is about \"incarceration rate of total correctional services,\" which is a bit confusing. \n\nAlternatively, maybe the user is asking for the incarceration rate (from custody) but within the total correctional services. But that might not make sense. Alternatively, perhaps the user is asking for the \"total correctional services\" rate, but that's not incarceration. \n\nWait, maybe the question is phrased incorrectly. If the user is asking for the incarceration rate, that would be under the \"custody\" column's \"rate.\" However, the question specifically says \"total correctional services,\" which might be a mistake. But I have to go with the question as given. \n\nLooking at the table, for each jurisdiction, there's a \"total correctional services\" rate. For example, Newfoundland and Labrador has a total correctional services rate of 467. But what does that rate represent? It's possible that it's the combined rate of custody and community supervision. However, the term \"incarceration rate\" specifically refers to the custody rate. \n\nBut the user's question is about \"incarceration rate of total correctional services.\" Maybe they mean the incarceration rate as part of total correctional services. However, without more context, it's unclear. \n\nAlternatively, perhaps the user is referring to the \"total correctional services\" rate as the incarceration rate. But that's not accurate. \n\nAlternatively, maybe the question is asking for the incarceration rate (custody rate) for the total correctional services. But that's not clear. \n\nLooking at the data, for example, in the \"total correctional services\" row for provinces and territories - total, the rate is 492. But that's the combined rate. However, the user might be asking for the incarceration rate (custody) for the total correctional services. But the table doesn't have a column that combines them. \n\nAlternatively, maybe the user is asking for the \"total correctional services\" rate, which might be the sum of custody and community supervision. But the question says \"incarceration rate,\" which is different. \n\nThis is a bit confusing. However, given that the user specifically mentions \"incarceration rate of total correctional services,\" and the table has a \"total correctional services\" column with a \"rate\" column, maybe they are referring to that rate. However, that rate would include both custody and community supervision. But the term \"incarceration rate\" is specific to custody. \n\nAlternatively, perhaps there's a misunderstanding in the question. If the user wants the incarceration rate (custody rate) for the total correctional services, but the table doesn't have that. However, the \"total correctional services\" rate might be a composite. \n\nAlternatively, maybe the user is asking for the \"custody\" rate, and the mention of \"total correctional services\" is a mistake. \n\nBut given the information, the safest approach is to check the \"total correctional services\" rate. For example, the provinces and territories - total has a rate of 492. However, the question doesn't specify a particular jurisdiction. The question is general. \n\nWait, the question is \"what is the incarceration rate of total correctional services?\" Without specifying a jurisdiction. Looking at the table, there's a row for \"provinces and territories - total\" and \"federal.\" \n\nBut the user might be asking for the rate in general. However, the table has different rates for different jurisdictions. But since the question doesn't specify a jurisdiction, perhaps the answer is the rate for the total of all provinces and territories. \n\nBut the user might be looking for the \"custody\" rate. For example, in the \"custody\" column, the \"rate\" for provinces and territories - total is 87. But the question is about \"total correctional services.\" \n\nAlternatively, maybe the user is confused and wants the \"total correctional services\" rate. However, the term \"incarceration rate\" is different. \n\nThis is a bit ambiguous. But given the information, I'll proceed by checking the \"total correctional services\" rate. For example, in the \"provinces and territories - total\" row, the \"total correctional services\" rate is 492. However, the \"custody\" rate for the same row is 87. \n\nBut the user is asking about incarceration rate, which is custody. However, the question says \"total correctional services.\" \n\nAlternatively, maybe the user is referring to the \"total correctional services\" as the sum of custody and community supervision, and the incarceration rate is part of that. But without more data, it's hard to say. \n\nAlternatively, perhaps the question is asking for the incarceration rate (custody) in the context of total correctional services. For example, the incarceration rate as a part of total correctional services. But the table doesn't provide that. \n\nGiven the ambiguity, but considering that the user might have made a mistake in the question, and given that \"total correctional services\" includes custody, perhaps the answer is the \"custody\" rate. However, the question specifically says \"total correctional services.\" \n\nAlternatively, maybe the user is referring to the \"total correctional services\" rate as the incarceration rate. But that's not accurate. \n\nIn the absence of more information, I'll have to make an assumption. If the question is about the incarceration rate, it's likely referring to the \"custody\" rate. However, the user's wording is confusing. \n\nLooking back at the table, the \"custody\" column has \"rate\" and \"number,\" and the \"total correctional services\" also has \"rate\" and \"number.\" The question is about \"incarceration rate of total correctional services.\" If \"total correctional services\" includes custody (which is incarceration) and community supervision, then the incarceration rate would be part of that. But the question is asking for the rate of total correctional services. \n\nAlternatively, maybe the user is asking for the total correctional services rate, which is the combined rate. For example, the \"total correctional services\" rate for provinces and territories - total is 492. \n\nBut the question is phrased as \"incarceration rate of total correctional services,\" which is a bit conflicting. However, if we take it literally, maybe the answer is the \"total correctional services\" rate. But I'm not sure. \n\nAlternatively, maybe the user meant the incarceration rate (custody) for the total correctional services, which would be the custody rate. \n\nSince the user might have intended to ask for the custody rate, but mentioned \"total correctional services,\" I'll check both. \n\nFor example, in the \"provinces and territories - total\" row, the custody rate is 87, and the total correctional services rate is 492. \n\nBut without more context, it's hard to know. However, given the way the table is structured, the \"incarceration rate\" would be under the \"custody\" column. So maybe the user made a mistake in the question and meant \"custody\" instead of \"total correctional services.\" \n\nAlternatively, maybe the answer is the \"total correctional services\" rate. \n\nBut since the user is asking for \"incarceration rate,\" which is typically custody, I think the answer is the custody rate. However, the question says \"total correctional services.\" \n\nThis is really ambiguous. But I need to make a decision. \n\nPerhaps the answer is the \"total correctional services\" rate, but that's not incarceration. Alternatively, maybe the user is asking for the total correctional services rate, and the term \"incarceration rate\" is a mistake. \n\nGiven that the user might have intended to ask for the \"total correctional services\" rate, I'll provide that. But I need to check the table. \n\nLooking at the \"total correctional services\" rate for provinces and territories - total, it's 492. For federal, it's 81. But the question doesn't specify. \n\nAlternatively, if the user is asking for the incarceration rate (custody) for total correctional services, perhaps they mean the custody rate. \n\nSince the user is asking for \"incarceration rate,\" which is custody, and the question says \"of total correctional services,\" which might mean the rate within the total correctional services. But the table doesn't have that. \n\nAlternatively, maybe the answer is the custody rate for the total correctional services. But the table doesn't have that. \n\nI think the safest answer is that the incarceration rate is under the \"custody\" column. For example, in the \"provinces and territories - total\" row, the custody rate is 87. But the question is not specific to a jurisdiction. \n\nBut the question is general. If the user is asking for the incarceration rate of total correctional services, and \"total correctional services\" is the sum of custody and community supervision, but the question is about incarceration rate, which is custody. \n\nGiven that, and since the user might have made a mistake, I think the answer is the custody rate. However, without more context, it's hard to tell. \n\nBut the user might have intended to ask for the total correctional services rate. In that case, the answer would be 492 for provinces and territories - total. \n\nBut the question is not specific to a jurisdiction. So the answer could be multiple values. \n\nAlternatively, the user might be asking for the rate in the \"total correctional services\" column. \n\nSince the user hasn't specified a jurisdiction, but the table has a \"provinces and territories - total\" row, I'll go with that. \n\nSo the answer would be 492. \n\nBut I'm not entirely sure. However, given the information, this is the best I can do.\n\n\nThe question asks for the \"incarceration rate of total correctional services.\" Based on the table structure:\n\n- **\"Custody\"** likely represents the incarceration rate (people in prison/jail), with a \"rate\" column.\n- **\"Total correctional services\"** includes both custody and community supervision.\n\nHowever, the term \"incarceration rate\" specifically refers to custody. Since the question mentions \"total correctional services,\" there is ambiguity. If the intent was to ask for the **custody rate** (incarceration rate), the answer for \"provinces and territories - total\" is **87**. If the question refers to the **total correctional services rate**, the value is **492**. \n\nGiven the phrasing, the most precise answer based on the table's structure is the **custody rate** for \"provinces and territories - total\" as the incarceration rate. However, if the question intended the \"total correctional services\" rate, the answer is **492**.\n\n**Final Answer:** \n\\boxed{87} | \\boxed{492}", - ground_truth = "492|87" diff --git a/tests/single_controller/__init__.py b/tests/single_controller/__init__.py new file mode 100644 index 000000000..1cd1e8433 --- /dev/null +++ b/tests/single_controller/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/single_controller/base/test_decorator.py b/tests/single_controller/base/test_decorator.py index 256244916..5447d65ce 100644 --- a/tests/single_controller/base/test_decorator.py +++ b/tests/single_controller/base/test_decorator.py @@ -15,7 +15,14 @@ import pytest import verl.single_controller.base.decorator as decorator_module -from verl.single_controller.base.decorator import DISPATCH_MODE_FN_REGISTRY, Dispatch, _check_dispatch_mode, get_predefined_dispatch_fn, register_dispatch_mode, update_dispatch_mode +from verl.single_controller.base.decorator import ( + DISPATCH_MODE_FN_REGISTRY, + Dispatch, + _check_dispatch_mode, + get_predefined_dispatch_fn, + register_dispatch_mode, + update_dispatch_mode, +) @pytest.fixture @@ -42,7 +49,10 @@ def dummy_collect(worker_group, output): _check_dispatch_mode(Dispatch.TEST_MODE) # Verify registry update - assert get_predefined_dispatch_fn(Dispatch.TEST_MODE) == {"dispatch_fn": dummy_dispatch, "collect_fn": dummy_collect} + assert get_predefined_dispatch_fn(Dispatch.TEST_MODE) == { + "dispatch_fn": dummy_dispatch, + "collect_fn": dummy_collect, + } # Clean up Dispatch.remove("TEST_MODE") diff --git a/tests/ray_cpu/check_worker_alive/main.py b/tests/single_controller/check_worker_alive/main.py similarity index 100% rename from tests/ray_cpu/check_worker_alive/main.py rename to tests/single_controller/check_worker_alive/main.py diff --git a/tests/ray_gpu/detached_worker/README.md b/tests/single_controller/detached_worker/README.md similarity index 100% rename from tests/ray_gpu/detached_worker/README.md rename to tests/single_controller/detached_worker/README.md diff --git a/tests/ray_gpu/detached_worker/client.py b/tests/single_controller/detached_worker/client.py similarity index 93% rename from tests/ray_gpu/detached_worker/client.py rename to tests/single_controller/detached_worker/client.py index 262bc66ad..52f2c7242 100644 --- a/tests/ray_gpu/detached_worker/client.py +++ b/tests/single_controller/detached_worker/client.py @@ -34,7 +34,9 @@ def compute_position_id_with_mask(mask): # get the worker group using names worker_names = ["trainerTrainer_0:0", "trainerTrainer_0:1"] cls_with_init_args = RayClassWithInitArgs(cls=Trainer) - worker_group = NVMegatronRayWorkerGroup.from_detached(worker_names=worker_names, ray_cls_with_init=cls_with_init_args) + worker_group = NVMegatronRayWorkerGroup.from_detached( + worker_names=worker_names, ray_cls_with_init=cls_with_init_args + ) batch_size = 16 sequence_length = 1024 diff --git a/tests/ray_gpu/detached_worker/run.sh b/tests/single_controller/detached_worker/run.sh similarity index 100% rename from tests/ray_gpu/detached_worker/run.sh rename to tests/single_controller/detached_worker/run.sh diff --git a/tests/ray_gpu/detached_worker/server.py b/tests/single_controller/detached_worker/server.py similarity index 95% rename from tests/ray_gpu/detached_worker/server.py rename to tests/single_controller/detached_worker/server.py index 084122870..57e555a3a 100644 --- a/tests/ray_gpu/detached_worker/server.py +++ b/tests/single_controller/detached_worker/server.py @@ -114,12 +114,16 @@ def train_model(self, data: DataProto) -> DataProto: position_ids = data.batch["position_ids"] self.optimizer.zero_grad() - self.model.zero_grad_buffer(zero_buffer=(not self.optimizer_config.use_distributed_optimizer)) # use use_contiguous_buffers_in_local_ddp and no overlap_dp_param_comm + self.model.zero_grad_buffer( + zero_buffer=(not self.optimizer_config.use_distributed_optimizer) + ) # use use_contiguous_buffers_in_local_ddp and no overlap_dp_param_comm # update for 1 iteration output = self.model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids).logits output.mean().backward() - update_successful, grad_norm, num_zeros_in_grad = self.optimizer.step(self.megatron_config, self.megatron_config.timers) + update_successful, grad_norm, num_zeros_in_grad = self.optimizer.step( + self.megatron_config, self.megatron_config.timers + ) return DataProto(batch=TensorDict({"loss": output.detach()}, batch_size=output.shape[0])) diff --git a/tests/ray_cpu/test_auto_padding.py b/tests/single_controller/test_auto_padding_on_cpu.py similarity index 84% rename from tests/ray_cpu/test_auto_padding.py rename to tests/single_controller/test_auto_padding_on_cpu.py index a961cd9d5..f2c441243 100644 --- a/tests/ray_cpu/test_auto_padding.py +++ b/tests/single_controller/test_auto_padding_on_cpu.py @@ -52,11 +52,16 @@ def test_auto_padding(): padding_size = (chunk_size - (test_size % chunk_size)) if (test_size % chunk_size > 0) else 0 local_data.padding(padding_size) # print(f"after padding, local_data = {local_data}") - assert len(local_data) == len(local_data) + len(local_data) % chunk_size, f"expecting padded length to be {len(local_data) + len(local_data) % chunk_size}, but got {len(local_data)}" + assert len(local_data) == len(local_data) + len(local_data) % chunk_size, ( + f"expecting padded length to be {len(local_data) + len(local_data) % chunk_size}, but got {len(local_data)}" + ) chunked = local_data.chunk(chunk_size) assert len(chunked) == chunk_size, f"during test_size = {test_size}, expecting {chunk_size}, got {chunked}" for dp in chunked: - assert len(dp) == test_size // chunk_size + bool(test_size % chunk_size), f"test size = {test_size}, expecting dp to be length of {test_size // chunk_size + bool(test_size % chunk_size)}, but got {len(dp)}: {dp} {chunked}" + assert len(dp) == test_size // chunk_size + bool(test_size % chunk_size), ( + f"test size = {test_size}, expecting dp to be length of " + f"{test_size // chunk_size + bool(test_size % chunk_size)}, but got {len(dp)}: {dp} {chunked}" + ) # test with RayWorkerGroup method decorated as dispatch_mode=Dispatch.DP_COMPUTE_PROTO data = DataProto.from_dict({"a": torch.zeros(10)}, {"na": np.array([str(i) for i in range(10)], dtype=object)}) @@ -80,12 +85,16 @@ def test_auto_padding(): # test data proto specific config DataProtoConfig.auto_padding = False - data = DataProto.from_dict({"a": torch.zeros(10)}, {"na": np.array([str(i) for i in range(10)], dtype=object)}, auto_padding=True) + data = DataProto.from_dict( + {"a": torch.zeros(10)}, {"na": np.array([str(i) for i in range(10)], dtype=object)}, auto_padding=True + ) output = actor_wg.add(data) print(output.batch["a"]) assert len(output) == 10 - data = DataProto.from_single_dict({"a": torch.zeros(1), "na": np.array([str(i) for i in range(1)], dtype=object)}, auto_padding=True) + data = DataProto.from_single_dict( + {"a": torch.zeros(1), "na": np.array([str(i) for i in range(1)], dtype=object)}, auto_padding=True + ) output = actor_wg.add(data) print(output.batch["a"]) diff --git a/tests/ray_gpu/test_colocated_workers.py b/tests/single_controller/test_colocated_workers.py similarity index 100% rename from tests/ray_gpu/test_colocated_workers.py rename to tests/single_controller/test_colocated_workers.py diff --git a/tests/ray_gpu/test_colocated_workers_fused.py b/tests/single_controller/test_colocated_workers_fused.py similarity index 100% rename from tests/ray_gpu/test_colocated_workers_fused.py rename to tests/single_controller/test_colocated_workers_fused.py diff --git a/tests/ray_gpu/test_data_transfer.py b/tests/single_controller/test_data_transfer.py similarity index 97% rename from tests/ray_gpu/test_data_transfer.py rename to tests/single_controller/test_data_transfer.py index fdd854e32..13777b0bd 100644 --- a/tests/ray_gpu/test_data_transfer.py +++ b/tests/single_controller/test_data_transfer.py @@ -96,7 +96,7 @@ def test_data_transfer(): # takes around 40 seconds output_lst = ray.get(output_ref) - for input_data, output_data in zip(data_list, output_lst): + for input_data, output_data in zip(data_list, output_lst, strict=True): for key in input_data.batch.keys(): assert torch.all(torch.eq(input_data.batch[key] + 1, output_data.batch[key])), ( input_data.batch[key], diff --git a/tests/ray_cpu/test_decorator.py b/tests/single_controller/test_decorator_on_cpu.py similarity index 93% rename from tests/ray_cpu/test_decorator.py rename to tests/single_controller/test_decorator_on_cpu.py index 7cd0c77bb..4dfec6331 100644 --- a/tests/ray_cpu/test_decorator.py +++ b/tests/single_controller/test_decorator_on_cpu.py @@ -70,7 +70,9 @@ def test_decorator_dp_compute(ray_init_shutdown): num_workers = 2 resource_pool = RayResourcePool([num_workers], use_gpu=False, max_colocate_count=1) # Use CPU for simplicity cls_with_args = RayClassWithInitArgs(cls=DecoratorTestWorker, initial_value=10) - worker_group = RayWorkerGroup(resource_pool, cls_with_args, name_prefix=f"decorator_test_sync_dp_{int(time.time())}") + worker_group = RayWorkerGroup( + resource_pool, cls_with_args, name_prefix=f"decorator_test_sync_dp_{int(time.time())}" + ) # Prepare input data (size 4, for 2 workers) input_tensor = torch.arange(4, dtype=torch.float32) @@ -104,7 +106,9 @@ def test_decorator_async_function(ray_init_shutdown): num_workers = 2 resource_pool = RayResourcePool([num_workers], use_gpu=False, max_colocate_count=1) cls_with_args = RayClassWithInitArgs(cls=DecoratorTestWorker, initial_value=5) - worker_group = RayWorkerGroup(resource_pool, cls_with_args, name_prefix=f"decorator_test_async_dp_{int(time.time())}") + worker_group = RayWorkerGroup( + resource_pool, cls_with_args, name_prefix=f"decorator_test_async_dp_{int(time.time())}" + ) # Prepare input data (size 4, for 2 workers) input_tensor = torch.arange(4, dtype=torch.float32) @@ -132,4 +136,6 @@ def test_decorator_async_function(ray_init_shutdown): expected_output_part2 = (torch.tensor([2, 3], dtype=torch.float32) * 2) + 5 + 1 expected_output = torch.cat([expected_output_part1, expected_output_part2]) - torch.testing.assert_close(result_data.batch["output_async"], expected_output, msg="Async DP compute output data mismatch") + torch.testing.assert_close( + result_data.batch["output_async"], expected_output, msg="Async DP compute output data mismatch" + ) diff --git a/tests/ray_gpu/test_driverfunc_to_worker.py b/tests/single_controller/test_driverfunc_to_worker.py similarity index 92% rename from tests/ray_gpu/test_driverfunc_to_worker.py rename to tests/single_controller/test_driverfunc_to_worker.py index b20cc5f7c..a38d790d6 100644 --- a/tests/ray_gpu/test_driverfunc_to_worker.py +++ b/tests/single_controller/test_driverfunc_to_worker.py @@ -43,7 +43,11 @@ def get_aux_metrics(self, test_proto): decode_count = [] for i in range(sequence_ids.size(0)): decode_count.append(len(sequence_ids[i].tolist())) - ret_proto = DataProto(batch=TensorDict({"sequence_ids": sequence_ids, "decode_count": torch.tensor(decode_count)}, batch_size=sequence_ids.size(0))) + ret_proto = DataProto( + batch=TensorDict( + {"sequence_ids": sequence_ids, "decode_count": torch.tensor(decode_count)}, batch_size=sequence_ids.size(0) + ) + ) return ret_proto diff --git a/tests/ray_cpu/test_fused_workers.py b/tests/single_controller/test_fused_workers_on_cpu.py similarity index 92% rename from tests/ray_cpu/test_fused_workers.py rename to tests/single_controller/test_fused_workers_on_cpu.py index 89f575484..527ddc102 100644 --- a/tests/ray_cpu/test_fused_workers.py +++ b/tests/single_controller/test_fused_workers_on_cpu.py @@ -16,7 +16,12 @@ from verl.single_controller.base import Worker from verl.single_controller.base.decorator import Dispatch, register -from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup, create_colocated_worker_raw_cls +from verl.single_controller.ray.base import ( + RayClassWithInitArgs, + RayResourcePool, + RayWorkerGroup, + create_colocated_worker_raw_cls, +) @ray.remote @@ -75,7 +80,7 @@ def test_fused_workers(): print(y) z = fused_wg.foo(0.1) print(z) - for i, j in zip(y, z): + for i, j in zip(y, z, strict=True): assert i == j ray.shutdown() diff --git a/tests/ray_gpu/test_high_level_scheduling_api.py b/tests/single_controller/test_high_level_scheduling_api.py similarity index 100% rename from tests/ray_gpu/test_high_level_scheduling_api.py rename to tests/single_controller/test_high_level_scheduling_api.py diff --git a/tests/single_controller/test_ray_collectives.py b/tests/single_controller/test_ray_collectives.py new file mode 100644 index 000000000..3722a8f80 --- /dev/null +++ b/tests/single_controller/test_ray_collectives.py @@ -0,0 +1,113 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Test for using ray collective group. +Suppose we Actor and Rollout. Actor contains 4 workers and Rollout contains 2 workers. We established a Worker to +Rollout relationship by using collective groups +Actor: rank 0, 1 - Rollout rank 0 +Rollout rank 2, 3 - Rollout rank 1 +Then, we initiate 4 p2p comms from actor to rollout +""" + +import ray +import ray.util.collective as collective +import torch + +from verl.single_controller.base import Worker +from verl.single_controller.base.decorator import Dispatch, register +from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup + + +@ray.remote +class Actor(Worker): + @register(Dispatch.ONE_TO_ALL) + def init(self): + remote_rank = self.rank // 2 + self.group_name = f"A{self.rank}_R{remote_rank}" + collective.init_collective_group(world_size=2, rank=0, backend="nccl", group_name=self.group_name) + + @register(Dispatch.ONE_TO_ALL, blocking=False) + def send_tensors(self): + tensor = torch.ones(size=(4,), dtype=torch.float32, device="cuda") * self.rank + collective.send(tensor=tensor, dst_rank=1, group_name=self.group_name) + + +@ray.remote +class Rollout(Worker): + @register(Dispatch.ONE_TO_ALL) + def init(self): + self.remote_first_rank = self.rank * 2 + self.remote_second_rank = self.remote_first_rank + 1 + self.first_group_name = f"A{self.remote_first_rank}_R{self.rank}" + self.second_group_name = f"A{self.remote_second_rank}_R{self.rank}" + + collective.init_collective_group(world_size=2, rank=1, backend="nccl", group_name=self.first_group_name) + collective.init_collective_group(world_size=2, rank=1, backend="nccl", group_name=self.second_group_name) + + @register(Dispatch.ONE_TO_ALL, blocking=False) + def receive_tensors(self): + self.tensor1 = torch.randn(size=(4,), dtype=torch.float32, device="cuda") + self.tensor2 = torch.randn(size=(4,), dtype=torch.float32, device="cuda") + + collective.recv(self.tensor1, src_rank=0, group_name=self.first_group_name) + collective.recv(self.tensor2, src_rank=0, group_name=self.second_group_name) + + @register(Dispatch.ONE_TO_ALL) + def get_tensors(self): + return {f"src_{self.remote_first_rank}": self.tensor1, f"src_{self.remote_second_rank}": self.tensor2} + + +def test_ray_collective_group(): + ray.init() + + actor_resource_pool = RayResourcePool([4]) + rollout_resource_pool = RayResourcePool([2]) + + actor_cls = RayClassWithInitArgs(cls=Actor) + rollout_cls = RayClassWithInitArgs(cls=Rollout) + + actor_wg = RayWorkerGroup( + resource_pool=actor_resource_pool, ray_cls_with_init=actor_cls, name_prefix="collective_group_actor" + ) + rollout_wg = RayWorkerGroup( + resource_pool=rollout_resource_pool, ray_cls_with_init=rollout_cls, name_prefix="collective_group_rollout" + ) + + actor_wg.init() + rollout_wg.init() + + out1 = actor_wg.send_tensors() + out2 = rollout_wg.receive_tensors() + + # block to wait + ray.get(out1) + ray.get(out2) + + output = rollout_wg.get_tensors() + + rollout_0_output = output[0] + rollout_1_output = output[1] + + output = rollout_0_output | rollout_1_output + + print(output) + + for i in range(4): + assert torch.sum(output[f"src_{i}"]).item() == 4 * i + + ray.shutdown() + + +if __name__ == "__main__": + test_ray_collective_group() diff --git a/tests/ray_cpu/test_ray_local_envs.py b/tests/single_controller/test_ray_local_envs_on_cpu.py similarity index 91% rename from tests/ray_cpu/test_ray_local_envs.py rename to tests/single_controller/test_ray_local_envs_on_cpu.py index a59193bd6..ee6c0cbed 100644 --- a/tests/ray_cpu/test_ray_local_envs.py +++ b/tests/single_controller/test_ray_local_envs_on_cpu.py @@ -40,7 +40,9 @@ def test_basics(): resource_pool = RayResourcePool([4], use_gpu=False) class_with_args = RayClassWithInitArgs(cls=TestActor) - worker_group = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=class_with_args, name_prefix="worker_group_basic") + worker_group = RayWorkerGroup( + resource_pool=resource_pool, ray_cls_with_init=class_with_args, name_prefix="worker_group_basic" + ) output = worker_group.execute_all_sync("getenv", key="RAY_LOCAL_WORLD_SIZE") assert output == ["4", "4", "4", "4"] diff --git a/tests/ray_cpu/test_ray_utils.py b/tests/single_controller/test_ray_utils_on_cpu.py similarity index 100% rename from tests/ray_cpu/test_ray_utils.py rename to tests/single_controller/test_ray_utils_on_cpu.py diff --git a/tests/ray_gpu/test_rvdz.py b/tests/single_controller/test_rvdz.py similarity index 100% rename from tests/ray_gpu/test_rvdz.py rename to tests/single_controller/test_rvdz.py diff --git a/tests/ray_gpu/test_worker_group_basics.py b/tests/single_controller/test_worker_group_basics.py similarity index 92% rename from tests/ray_gpu/test_worker_group_basics.py rename to tests/single_controller/test_worker_group_basics.py index df684cb0a..5c4823dfb 100644 --- a/tests/ray_gpu/test_worker_group_basics.py +++ b/tests/single_controller/test_worker_group_basics.py @@ -68,7 +68,9 @@ def foo_custom(self, x, y): @ray.remote(num_gpus=0.1) def remote_call_wg(worker_names): class_with_args = RayClassWithInitArgs(cls=TestActor, x=2) - worker_group = RayWorkerGroup.from_detached(worker_names=worker_names, ray_cls_with_init=class_with_args, name_prefix=None) + worker_group = RayWorkerGroup.from_detached( + worker_names=worker_names, ray_cls_with_init=class_with_args, name_prefix=None + ) print(worker_group.worker_names) output_ref = worker_group.foo_custom(x=[1, 2], y=[5, 6]) @@ -94,7 +96,9 @@ def test_basics(): resource_pool = RayResourcePool([4], use_gpu=True) class_with_args = RayClassWithInitArgs(cls=TestActor, x=2) - worker_group = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=class_with_args, name_prefix="worker_group_basic") + worker_group = RayWorkerGroup( + resource_pool=resource_pool, ray_cls_with_init=class_with_args, name_prefix="worker_group_basic" + ) print(worker_group.worker_names) diff --git a/tests/ray_gpu/test_worker_group_torch.py b/tests/single_controller/test_worker_group_torch.py similarity index 91% rename from tests/ray_gpu/test_worker_group_torch.py rename to tests/single_controller/test_worker_group_torch.py index 3cf27b4dd..a601c43da 100644 --- a/tests/ray_gpu/test_worker_group_torch.py +++ b/tests/single_controller/test_worker_group_torch.py @@ -38,7 +38,9 @@ def init(self): def all_gather(self): world_size = self._world_size - output = torch.zeros(size=(self.tensor.shape[0] * world_size,), dtype=self.tensor.dtype, device=self.tensor.device) + output = torch.zeros( + size=(self.tensor.shape[0] * world_size,), dtype=self.tensor.dtype, device=self.tensor.device + ) torch.distributed.all_gather_into_tensor(output, self.tensor, async_op=False) return output @@ -55,7 +57,9 @@ def __init__(self, size) -> None: def all_gather(self): world_size = self._world_size - output = torch.zeros(size=(self.tensor.shape[0] * world_size,), dtype=self.tensor.dtype, device=self.tensor.device) + output = torch.zeros( + size=(self.tensor.shape[0] * world_size,), dtype=self.tensor.dtype, device=self.tensor.device + ) torch.distributed.all_gather_into_tensor(output, self.tensor, async_op=False) return output diff --git a/tests/special_distributed/README.md b/tests/special_distributed/README.md new file mode 100644 index 000000000..f2f865e8b --- /dev/null +++ b/tests/special_distributed/README.md @@ -0,0 +1 @@ +This folder is reserved for unit tests (instead of end-to-end tests) that require multiple GPUs. diff --git a/tests/distributed/run_all.sh b/tests/special_distributed/run_all.sh similarity index 88% rename from tests/distributed/run_all.sh rename to tests/special_distributed/run_all.sh index f8654d7b5..c34edf222 100644 --- a/tests/distributed/run_all.sh +++ b/tests/special_distributed/run_all.sh @@ -15,4 +15,4 @@ #!/usr/bin/env bash set -e -x -torchrun --nproc-per-node=4 --standalone tests/distributed/test_tensor_dict.py \ No newline at end of file +torchrun --nproc-per-node=4 --standalone tests/special_distributed/test_tensor_dict.py \ No newline at end of file diff --git a/tests/utils/gpu_tests/checkpoint/test_fsdp_ckpt.py b/tests/special_distributed/test_fsdp_ckpt.py similarity index 88% rename from tests/utils/gpu_tests/checkpoint/test_fsdp_ckpt.py rename to tests/special_distributed/test_fsdp_ckpt.py index f1ab930b7..49dceb7c1 100644 --- a/tests/utils/gpu_tests/checkpoint/test_fsdp_ckpt.py +++ b/tests/special_distributed/test_fsdp_ckpt.py @@ -36,12 +36,16 @@ def test_fsdp_ckpt(strategy="fsdp"): config = Qwen2Config(num_hidden_layers=1) with torch.device("cuda"): - model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2") + model = AutoModelForCausalLM.from_config( + config=config, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" + ) model = model.to(device="cuda") # Wrap model with FSDP if strategy == "fsdp": - mixed_precision = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32) + mixed_precision = MixedPrecision( + param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32 + ) model = FSDP( model, @@ -52,7 +56,9 @@ def test_fsdp_ckpt(strategy="fsdp"): device_mesh=device_mesh, ) else: - mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.float32, cast_forward_inputs=True) + mp_policy = MixedPrecisionPolicy( + param_dtype=torch.bfloat16, reduce_dtype=torch.float32, cast_forward_inputs=True + ) fsdp_kwargs = { "mesh": device_mesh, "mp_policy": mp_policy, @@ -64,7 +70,9 @@ def test_fsdp_ckpt(strategy="fsdp"): # Create checkpoint manager tokenizer = AutoTokenizer.from_pretrained(model_name) - checkpoint_manager = FSDPCheckpointManager(model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, tokenizer=tokenizer) + checkpoint_manager = FSDPCheckpointManager( + model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, tokenizer=tokenizer + ) # Generate sample input batch_size = 2 diff --git a/tests/distributed/test_tensor_dict.py b/tests/special_distributed/test_tensor_dict.py similarity index 80% rename from tests/distributed/test_tensor_dict.py rename to tests/special_distributed/test_tensor_dict.py index cb4a6f5e5..0a7f8039d 100644 --- a/tests/distributed/test_tensor_dict.py +++ b/tests/special_distributed/test_tensor_dict.py @@ -58,11 +58,13 @@ def test_all_gather_data_proto(): def test_vocab_parallel_entropy(): from megatron.core import parallel_state as mpu - from verl.utils.debug import log_gpu_memory_usage from verl.utils.megatron.tensor_parallel import vocab_parallel_entropy + from verl.utils.profiler import log_gpu_memory_usage from verl.utils.torch_functional import entropy_from_logits - mpu.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=1, virtual_pipeline_model_parallel_size=None) + mpu.initialize_model_parallel( + tensor_model_parallel_size=2, pipeline_model_parallel_size=1, virtual_pipeline_model_parallel_size=None + ) batch_size = 2 seqlen = 128 @@ -72,14 +74,20 @@ def test_vocab_parallel_entropy(): target = torch.randint(low=0, high=vocab_size, size=(batch_size * seqlen,), device="cuda", dtype=torch.int64) # broadcast across tp - torch.distributed.broadcast(logits, mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group()) - torch.distributed.broadcast(target, mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group()) + torch.distributed.broadcast( + logits, mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group() + ) + torch.distributed.broadcast( + target, mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group() + ) tp_rank = mpu.get_tensor_model_parallel_rank() vocab_size_per_tp = vocab_size // mpu.get_tensor_model_parallel_world_size() # get the local logits of each tp - vocab_parallel_logits = logits.clone().detach()[:, tp_rank * vocab_size_per_tp : (tp_rank + 1) * vocab_size_per_tp].requires_grad_() + vocab_parallel_logits = ( + logits.clone().detach()[:, tp_rank * vocab_size_per_tp : (tp_rank + 1) * vocab_size_per_tp].requires_grad_() + ) logits.grad = None vocab_parallel_logits.grad = None @@ -93,9 +101,13 @@ def test_vocab_parallel_entropy(): target_entropy = entropy_from_logits(logits) torch.testing.assert_close(output_entropy, target_entropy) target_entropy.backward(grad_output) - torch.testing.assert_close(logits.grad[:, tp_rank * vocab_size_per_tp : (tp_rank + 1) * vocab_size_per_tp], vocab_parallel_logits.grad) + torch.testing.assert_close( + logits.grad[:, tp_rank * vocab_size_per_tp : (tp_rank + 1) * vocab_size_per_tp], vocab_parallel_logits.grad + ) # make sure logits is not altered - torch.testing.assert_close(logits[:, tp_rank * vocab_size_per_tp : (tp_rank + 1) * vocab_size_per_tp], vocab_parallel_logits) + torch.testing.assert_close( + logits[:, tp_rank * vocab_size_per_tp : (tp_rank + 1) * vocab_size_per_tp], vocab_parallel_logits + ) if mpu.get_tensor_model_parallel_rank() == 0: print("test_vocab_parallel_entropy passes") diff --git a/tests/special_e2e/README.md b/tests/special_e2e/README.md new file mode 100644 index 000000000..3c295e844 --- /dev/null +++ b/tests/special_e2e/README.md @@ -0,0 +1 @@ +This folder is reserved for end-to-end tests that typically require multiple GPUs. diff --git a/verl/third_party/vllm/vllm_v_0_5_4/__init__.py b/tests/special_e2e/__init__.py similarity index 100% rename from verl/third_party/vllm/vllm_v_0_5_4/__init__.py rename to tests/special_e2e/__init__.py diff --git a/tests/e2e/check_custom_rwd_fn.py b/tests/special_e2e/check_custom_rwd_fn.py similarity index 100% rename from tests/e2e/check_custom_rwd_fn.py rename to tests/special_e2e/check_custom_rwd_fn.py diff --git a/tests/e2e/check_results.py b/tests/special_e2e/check_results.py similarity index 100% rename from tests/e2e/check_results.py rename to tests/special_e2e/check_results.py diff --git a/tests/e2e/envs/__init__.py b/tests/special_e2e/envs/__init__.py similarity index 100% rename from tests/e2e/envs/__init__.py rename to tests/special_e2e/envs/__init__.py diff --git a/tests/e2e/envs/digit_completion/__init__.py b/tests/special_e2e/envs/digit_completion/__init__.py similarity index 100% rename from tests/e2e/envs/digit_completion/__init__.py rename to tests/special_e2e/envs/digit_completion/__init__.py diff --git a/tests/e2e/envs/digit_completion/task.py b/tests/special_e2e/envs/digit_completion/task.py similarity index 95% rename from tests/e2e/envs/digit_completion/task.py rename to tests/special_e2e/envs/digit_completion/task.py index f35caa473..c3643a86b 100644 --- a/tests/e2e/envs/digit_completion/task.py +++ b/tests/special_e2e/envs/digit_completion/task.py @@ -54,7 +54,11 @@ def __init__(self, max_number: int, max_diff: int, max_num_in_response: int, see self.np_rng = np.random.default_rng(seed=seed) def __str__(self): - return f"Prompt length: {self.prompt_length}. Response length: {self.response_length}, Max number: {self.max_number}. Max diff: {self.max_diff}, Max number in response: {self.max_num_in_response}" + return ( + f"Prompt length: {self.prompt_length}. Response length: {self.response_length}, " + f"Max number: {self.max_number}. Max diff: {self.max_diff}, " + f"Max number in response: {self.max_num_in_response}" + ) def get_state(self): return {"rng": self.np_rng} diff --git a/tests/e2e/envs/digit_completion/tokenizer.py b/tests/special_e2e/envs/digit_completion/tokenizer.py similarity index 89% rename from tests/e2e/envs/digit_completion/tokenizer.py rename to tests/special_e2e/envs/digit_completion/tokenizer.py index f50a0ffb6..6ff471938 100644 --- a/tests/e2e/envs/digit_completion/tokenizer.py +++ b/tests/special_e2e/envs/digit_completion/tokenizer.py @@ -21,7 +21,7 @@ import json import os from pathlib import Path -from typing import Dict, List, Optional, Sequence, Union +from typing import Optional, Sequence from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer @@ -86,7 +86,7 @@ def vocab_size(self) -> int: def get_vocab(self): return self._vocab_str_to_int - def _tokenize(self, text: str) -> List[str]: + def _tokenize(self, text: str) -> list[str]: return list(text) def _convert_token_to_id(self, token: str) -> int: @@ -98,7 +98,9 @@ def _convert_id_to_token(self, index: int) -> str: def convert_tokens_to_string(self, tokens): return "".join(tokens) - def build_inputs_with_special_tokens(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None) -> List[int]: + def build_inputs_with_special_tokens( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: sep = [self.sep_token_id] cls = [self.cls_token_id] result = cls + token_ids_0 + sep @@ -108,10 +110,10 @@ def build_inputs_with_special_tokens(self, token_ids_0: List[int], token_ids_1: def get_special_tokens_mask( self, - token_ids_0: List[int], - token_ids_1: Optional[List[int]] = None, + token_ids_0: list[int], + token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False, - ) -> List[int]: + ) -> list[int]: if already_has_special_tokens: return super().get_special_tokens_mask( token_ids_0=token_ids_0, @@ -124,7 +126,7 @@ def get_special_tokens_mask( result += ([0] * len(token_ids_1)) + [1] return result - def get_config(self) -> Dict: + def get_config(self) -> dict: return { "char_ords": [ord(ch) for ch in self.characters], "model_max_length": self.model_max_length, @@ -132,21 +134,21 @@ def get_config(self) -> Dict: } @classmethod - def from_config(cls, config: Dict): + def from_config(cls, config: dict): cfg = {} cfg["characters"] = [chr(i) for i in config["char_ords"]] cfg["model_max_length"] = config["model_max_length"] cfg["chat_template"] = config["chat_template"] return cls(**cfg) - def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs): + def save_pretrained(self, save_directory: str | os.PathLike, **kwargs): cfg_file = Path(save_directory) / "tokenizer_config.json" cfg = self.get_config() with open(cfg_file, "w") as f: json.dump(cfg, f, indent=4) @classmethod - def from_pretrained(cls, save_directory: Union[str, os.PathLike], **kwargs): + def from_pretrained(cls, save_directory: str | os.PathLike, **kwargs): cfg_file = Path(save_directory) / "tokenizer_config.json" with open(cfg_file) as f: cfg = json.load(f) diff --git a/tests/e2e/generation/run_gen_qwen05.sh b/tests/special_e2e/generation/run_gen_qwen05.sh similarity index 100% rename from tests/e2e/generation/run_gen_qwen05.sh rename to tests/special_e2e/generation/run_gen_qwen05.sh diff --git a/tests/e2e/ppo_trainer/expert_parallel/qwen2moe_minimal.json b/tests/special_e2e/ppo_trainer/expert_parallel/qwen2moe_minimal.json similarity index 100% rename from tests/e2e/ppo_trainer/expert_parallel/qwen2moe_minimal.json rename to tests/special_e2e/ppo_trainer/expert_parallel/qwen2moe_minimal.json diff --git a/tests/e2e/ppo_trainer/run_function_reward.sh b/tests/special_e2e/ppo_trainer/run_function_reward.sh similarity index 87% rename from tests/e2e/ppo_trainer/run_function_reward.sh rename to tests/special_e2e/ppo_trainer/run_function_reward.sh index dde760a9f..62bf410ef 100644 --- a/tests/e2e/ppo_trainer/run_function_reward.sh +++ b/tests/special_e2e/ppo_trainer/run_function_reward.sh @@ -13,12 +13,20 @@ MAX_PROMPT_LEN=${MAX_PROMPT_LEN:-512} MAX_RESPONSE_LEN=${MAX_RESPONSE_LEN:-512} ENGINE=${ENGINE:-vllm} +ROLLOUT_MODE=${ROLLOUT_MODE:-sync} + +RETURN_RAW_CHAT="False" +if [ "$ROLLOUT_MODE" = "async" ]; then + RETURN_RAW_CHAT="True" +fi + GPU_MEMORY_UTILIZATION=${GPU_MEMORY_UTILIZATION:-0.8} ACTOR_FSDP_PARAM_OFFLOAD=${ACTOR_FSDP_PARAM_OFFLOAD:-False} ACTOR_FSDP_OPTIMIZER_OFFLOAD=${ACTOR_FSDP_OPTIMIZER_OFFLOAD:-False} REF_FSDP_PARAM_OFFLOAD=${REF_FSDP_PARAM_OFFLOAD:-True} RM_PAD=${RM_PAD:-True} FUSED_KERNELS=${FUSED_KERNELS:-False} +FUSED_KERNEL_BACKEND=${FUSED_KERNEL_BACKEND:-torch} # or 'triton' for triton backend ADV_ESTIMATOR=${ADV_ESTIMATOR:-gae} USE_KL=${USE_KL:-False} CUSTOM_REWARD_FN=${CUSTOM_REWARD_FN:-False} @@ -27,6 +35,8 @@ STRATEGY=${STRATEGY:-fsdp} # LoRA config LORA_RANK=${LORA_RANK:-0} LORA_ALPHA=${LORA_ALPHA:-${LORA_RANK}} +LORA_TARGET=${LORA_TARGET:-"all-linear"} +LORA_EXCLUDE=${LORA_EXCLUDE:-"DONT_EXCLUDE"} USE_SHM=${USE_SHM:-False} LOAD_FORMAT=${LOAD_FORMAT:-dummy_dtensor} LAYERED_SUMMON=${LAYERED_SUMMON:-False} @@ -82,13 +92,17 @@ python3 -m verl.trainer.main_ppo \ data.train_batch_size="${train_prompt_bsz}" \ data.max_prompt_length="${MAX_PROMPT_LEN}" \ data.max_response_length="${MAX_RESPONSE_LEN}" \ + data.return_raw_chat=${RETURN_RAW_CHAT} \ actor_rollout_ref.model.path="${MODEL_PATH}" \ actor_rollout_ref.model.use_shm=${USE_SHM} \ actor_rollout_ref.model.lora_rank=${LORA_RANK} \ actor_rollout_ref.model.lora_alpha=${LORA_ALPHA} \ + actor_rollout_ref.model.target_modules=${LORA_TARGET} \ + actor_rollout_ref.model.exclude_modules=${LORA_EXCLUDE} \ actor_rollout_ref.actor.optim.lr=1e-6 \ actor_rollout_ref.model.use_remove_padding="${RM_PAD}" \ actor_rollout_ref.model.use_fused_kernels=${FUSED_KERNELS} \ + actor_rollout_ref.model.fused_kernel_options.impl_backend=${FUSED_KERNEL_BACKEND} \ actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ actor_rollout_ref.actor.strategy=${STRATEGY} \ @@ -96,11 +110,12 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.actor.fsdp_config.optimizer_offload=${ACTOR_FSDP_OPTIMIZER_OFFLOAD} \ actor_rollout_ref.actor.fsdp_config.fsdp_size=${FSDP_SIZE} \ actor_rollout_ref.actor.ulysses_sequence_parallel_size="${SP_SIZE}" \ - actor_rollout_ref.actor.checkpoint.contents=${CHECKPOINT_CONTENTS} \ + actor_rollout_ref.actor.checkpoint.save_contents=${CHECKPOINT_CONTENTS} \ actor_rollout_ref.actor.use_kl_loss="${USE_KL}" \ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ actor_rollout_ref.rollout.name="${ENGINE}" \ + actor_rollout_ref.rollout.mode="${ROLLOUT_MODE}" \ actor_rollout_ref.rollout.load_format=${LOAD_FORMAT} \ actor_rollout_ref.rollout.layered_summon=${LAYERED_SUMMON} \ actor_rollout_ref.rollout.gpu_memory_utilization="${GPU_MEMORY_UTILIZATION}" \ @@ -120,7 +135,7 @@ python3 -m verl.trainer.main_ppo \ algorithm.kl_penalty=kl \ algorithm.kl_ctrl.kl_coef=0.001 \ trainer.critic_warmup=0 \ - trainer.logger=['console'] \ + trainer.logger=console \ trainer.project_name='verl-test' \ trainer.experiment_name="${exp_name}" \ trainer.nnodes=1 \ @@ -130,11 +145,12 @@ python3 -m verl.trainer.main_ppo \ trainer.save_freq="${SAVE_FREQ}" \ trainer.resume_mode="${RESUME_MODE}" \ trainer.total_epochs=2 \ + trainer.device=cuda \ trainer.total_training_steps="${TOTAL_TRAIN_STEPS}" $@ \ | tee "${output_file}" if [ "${CUSTOM_REWARD_FN}" = "True" ]; then - python3 tests/e2e/check_custom_rwd_fn.py --output_file="${output_file}" + python3 tests/special_e2e/check_custom_rwd_fn.py --output_file="${output_file}" check_exit_code=$? rm -rf "${reward_fn_file_path}" rm -rf "${output_file}" diff --git a/tests/e2e/ppo_trainer/run_model_reward.sh b/tests/special_e2e/ppo_trainer/run_model_reward.sh similarity index 93% rename from tests/e2e/ppo_trainer/run_model_reward.sh rename to tests/special_e2e/ppo_trainer/run_model_reward.sh index 4c11e7a27..e7711f96d 100644 --- a/tests/e2e/ppo_trainer/run_model_reward.sh +++ b/tests/special_e2e/ppo_trainer/run_model_reward.sh @@ -11,6 +11,8 @@ TRAIN_FILES=${TRAIN_FILES:-$HOME/data/gsm8k/train.parquet} VAL_FILES=${VAL_FILES:-$HOME/data/gsm8k/test.parquet} RM_PAD=${RM_PAD:-True} +FUSED_KERNELS=${FUSED_KERNELS:-False} +FUSED_KERNEL_BACKEND=${FUSED_KERNEL_BACKEND:-torch} # or 'triton' for triton backend SP_SIZE=${SP_SIZE:-1} SEQ_BALANCE=${SEQ_BALANCE:-False} LIGER=${LIGER:-False} @@ -47,6 +49,8 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.model.use_liger="${LIGER}" \ actor_rollout_ref.actor.optim.lr=1e-6 \ actor_rollout_ref.model.use_remove_padding="${RM_PAD}" \ + actor_rollout_ref.model.use_fused_kernels=${FUSED_KERNELS} \ + actor_rollout_ref.model.fused_kernel_options.impl_backend=${FUSED_KERNEL_BACKEND} \ actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.1 \ actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ actor_rollout_ref.actor.use_dynamic_bsz="${SEQ_BALANCE}" \ @@ -84,7 +88,7 @@ python3 -m verl.trainer.main_ppo \ reward_model.micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ - trainer.logger=['console'] \ + trainer.logger=console \ trainer.project_name='verl-test' \ trainer.experiment_name="${exp_name}" \ trainer.nnodes=1 \ diff --git a/tests/special_e2e/ppo_trainer/run_single_gpu.sh b/tests/special_e2e/ppo_trainer/run_single_gpu.sh new file mode 100644 index 000000000..7e8615a24 --- /dev/null +++ b/tests/special_e2e/ppo_trainer/run_single_gpu.sh @@ -0,0 +1,24 @@ +PYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo \ + data.train_files=$HOME/data/gsm8k/train.parquet \ + data.val_files=$HOME/data/gsm8k/test.parquet \ + data.train_batch_size=256 \ + data.max_prompt_length=512 \ + data.max_response_length=256 \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.ppo_mini_batch_size=64 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ + critic.optim.lr=1e-5 \ + critic.model.path=Qwen/Qwen2.5-0.5B-Instruct \ + critic.ppo_micro_batch_size_per_gpu=4 \ + algorithm.kl_ctrl.kl_coef=0.001 \ + trainer.logger=console \ + trainer.val_before_train=False \ + trainer.n_gpus_per_node=1 \ + trainer.nnodes=1 \ + actor_rollout_ref.rollout.name=hf \ + trainer.total_training_steps=2 \ No newline at end of file diff --git a/tests/e2e/run_dapo.sh b/tests/special_e2e/run_dapo.sh similarity index 99% rename from tests/e2e/run_dapo.sh rename to tests/special_e2e/run_dapo.sh index bdbc40b12..56ff0ae05 100644 --- a/tests/e2e/run_dapo.sh +++ b/tests/special_e2e/run_dapo.sh @@ -78,7 +78,7 @@ python3 -m recipe.dapo.main_dapo \ actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ actor_rollout_ref.ref.fsdp_config.param_offload=True \ - trainer.logger=['console'] \ + trainer.logger=console \ trainer.project_name='verl-test' \ trainer.experiment_name="${exp_name}" \ trainer.n_gpus_per_node=${NUM_GPUS} \ diff --git a/tests/special_e2e/run_genrm_remote.sh b/tests/special_e2e/run_genrm_remote.sh new file mode 100644 index 000000000..4819248be --- /dev/null +++ b/tests/special_e2e/run_genrm_remote.sh @@ -0,0 +1,81 @@ +#!/usr/bin/env bash + +export no_proxy="localhost,127.0.0.1" + +set -x + +# Launch a vllm server +CUDA_VISIBLE_DEVICES=0 vllm serve verl-team/GenRM-CI-Test-1.5B \ + --served_model_name genrm-demo --host localhost --port 30000 > /dev/null & +SERVER_PID=$! + +# kill server when script exits +cleanup() { + echo "Cleaning up..." + kill $SERVER_PID 2>/dev/null || true + wait $SERVER_PID 2>/dev/null || true + echo "Cleanup done" +} +trap cleanup EXIT + +# wait for server to start +wait_for_server() { + local max_attempts=60 + local attempt=0 + local sleep_time=10 + + while [ $attempt -lt $max_attempts ]; do + if curl -s "http://localhost:30000/health" >/dev/null; then + echo "Server is up and running!" + return 0 + fi + echo "Waiting for server to start... (attempt $((attempt+1))/$max_attempts)" + sleep $sleep_time + ((attempt++)) + done + + echo "Error: Failed to start server after $max_attempts attempts" >&2 + return 1 +} + +if ! wait_for_server; then + exit 1 +fi + +CUDA_VISIBLE_DEVICES=4,5,6,7 python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=${HOME}/data/gsm8k/train.parquet \ + data.val_files=${HOME}/data/gsm8k/test.parquet \ + data.train_batch_size=256 \ + data.max_prompt_length=1024 \ + data.max_response_length=2048 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=128 \ + actor_rollout_ref.actor.use_dynamic_bsz=True \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \ + actor_rollout_ref.rollout.n=4 \ + algorithm.use_kl_in_reward=False \ + reward_model.reward_manager=batch \ + custom_reward_function.path=recipe/genrm_remote/reward_function.py \ + custom_reward_function.name=compute_score_batch \ + trainer.critic_warmup=0 \ + trainer.logger=console \ + trainer.project_name='verl-test' \ + trainer.experiment_name='qwen2.5-0.5b-gen-rm' \ + trainer.n_gpus_per_node=4 \ + trainer.val_before_train=False \ + trainer.nnodes=1 \ + trainer.save_freq=-1 \ + trainer.total_epochs=10 \ + trainer.resume_mode='disable' \ + trainer.total_training_steps=1 diff --git a/tests/special_e2e/run_geo3k_fsdp_sgl_multiturn_w_tool.sh b/tests/special_e2e/run_geo3k_fsdp_sgl_multiturn_w_tool.sh new file mode 100644 index 000000000..caa9e664c --- /dev/null +++ b/tests/special_e2e/run_geo3k_fsdp_sgl_multiturn_w_tool.sh @@ -0,0 +1,58 @@ +# run on 8xH100 +# make sure your current working directory is the root of the project + +set -x + +huggingface-cli download Qwen/Qwen2.5-VL-3B-Instruct --local-dir $HOME/models/Qwen/Qwen2.5-VL-3B-Instruct + +ulimit -n 65535 + +PROJECT_DIR="$(pwd)" +CONFIG_PATH="$PROJECT_DIR/examples/sglang_multiturn/config" +FSDP_STRATEGY=${FSDP_STRATEGY:-fsdp} + +python3 -m verl.trainer.main_ppo \ + --config-path="$CONFIG_PATH" \ + --config-name='geo3k_multiturn_grpo' \ + algorithm.adv_estimator=grpo \ + data.train_batch_size=64 \ + data.max_prompt_length=2048 \ + data.max_response_length=2048 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.return_raw_chat=True \ + actor_rollout_ref.model.path=$HOME/models/Qwen/Qwen2.5-VL-3B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=64 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.strategy=$FSDP_STRATEGY \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=sglang \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ + actor_rollout_ref.rollout.n=8 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=8 \ + actor_rollout_ref.ref.strategy=$FSDP_STRATEGY \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger=console \ + trainer.project_name='geo3k_async_rl' \ + trainer.experiment_name=qwen2.5-vl-3b_function_rm-geo3k-sgl-multi-w-tool-$FSDP_STRATEGY-rebased-0619-verify-n8 \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=-1 \ + trainer.test_freq=-1 \ + data.train_files=$HOME/data/geo3k_verl_sgl_multi_turn_preprocessed/train.parquet \ + data.val_files=$HOME/data/geo3k_verl_sgl_multi_turn_preprocessed/test.parquet \ + actor_rollout_ref.rollout.multi_turn.tool_config_path="$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/geo3k_tool_config.yaml" \ + trainer.val_before_train=False \ + trainer.total_training_steps=1 $@ \ No newline at end of file diff --git a/tests/special_e2e/run_grpo_lora_with_merge.sh b/tests/special_e2e/run_grpo_lora_with_merge.sh new file mode 100644 index 000000000..192d935ba --- /dev/null +++ b/tests/special_e2e/run_grpo_lora_with_merge.sh @@ -0,0 +1,93 @@ +#!/usr/bin/env bash +# +# An e2e test script for testing the GRPO LoRA training process +# and processing the generated checkpoint using the merge_model.py script. + +set -xeuo pipefail + +MODEL_ID=${MODEL_ID:-Qwen/Qwen2.5-0.5B} +MODEL_PATH=${MODEL_PATH:-${HOME}/models/${MODEL_ID}} +if [ ! -d "$MODEL_PATH" ]; then + echo "Downloading model to ${MODEL_PATH}..." + huggingface-cli download "${MODEL_ID}" --local-dir "${MODEL_PATH}" +else + echo "Model directory ${MODEL_PATH} already exists, skip downloading." +fi + + +BATCH_SIZE=16 +EXP_NAME="qwen2.5_0.5b_grpo_lora" +# step 1. train model with grpo-lora for 1 step +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=$HOME/data/gsm8k/train.parquet \ + data.val_files=$HOME/data/gsm8k/test.parquet \ + data.train_batch_size=${BATCH_SIZE} \ + data.max_prompt_length=512 \ + data.max_response_length=1024 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.shuffle=False \ + actor_rollout_ref.model.path=${MODEL_PATH} \ + actor_rollout_ref.model.use_shm=True \ + actor_rollout_ref.model.lora_rank=64 \ + actor_rollout_ref.model.lora_alpha=32 \ + actor_rollout_ref.actor.optim.lr=3e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=${BATCH_SIZE} \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=40 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.rollout.load_format=safetensors \ + actor_rollout_ref.rollout.layered_summon=True \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=40 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='verl_grpo_example_gsm8k' \ + trainer.experiment_name=${EXP_NAME} \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.total_training_steps=1 \ + trainer.save_freq=1 \ + trainer.test_freq=5 \ + trainer.total_epochs=1 $@ + +# step 2. merge model +python3 -m verl.model_merger merge \ + --backend fsdp \ + --local_dir checkpoints/verl_grpo_example_gsm8k/${EXP_NAME}/global_step_1/actor/ \ + --target_dir checkpoints/verl_grpo_example_gsm8k/${EXP_NAME}/global_step_1/actor/hf + +# step 3. assert +# make sure adapter_model.safetensors exists and its size is larger than 1MB +file_path="checkpoints/verl_grpo_example_gsm8k/${EXP_NAME}/global_step_1/actor/hf/lora_adapter/adapter_model.safetensors" + +if [ ! -f "$file_path" ]; then + echo "Error: File $file_path does not exist!" + exit 1 +fi + +file_size=$(stat -c %s "$file_path") + +min_size_mb=1 +min_size=$((min_size_mb * 1024 * 1024)) # 1MB = 1048576 bytes + +if [ "$file_size" -lt "$min_size" ]; then + echo "Error: File $file_path is too small! Current size: $((file_size/1024))KB, Required: ${min_size_mb}MB" + exit 1 +fi + +echo "Check passed: File exists and size is $(($file_size/1024/1024))MB" +exit 0 diff --git a/tests/e2e/run_gsm8k_fsdp_sgl_multiturn_sf_tool.sh b/tests/special_e2e/run_gsm8k_fsdp_sgl_multiturn_sf_tool.sh similarity index 96% rename from tests/e2e/run_gsm8k_fsdp_sgl_multiturn_sf_tool.sh rename to tests/special_e2e/run_gsm8k_fsdp_sgl_multiturn_sf_tool.sh index eceef7826..729b42554 100644 --- a/tests/e2e/run_gsm8k_fsdp_sgl_multiturn_sf_tool.sh +++ b/tests/special_e2e/run_gsm8k_fsdp_sgl_multiturn_sf_tool.sh @@ -43,14 +43,14 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.actor.fsdp_config.param_offload=True \ actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=sglang_async \ + actor_rollout_ref.rollout.name=sglang \ actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \ actor_rollout_ref.rollout.n=8 \ actor_rollout_ref.rollout.multi_turn.tool_config_path="$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/sandbox_fusion_tool_config.yaml" \ actor_rollout_ref.ref.fsdp_config.param_offload=True \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ - trainer.logger=['console','wandb'] \ + trainer.logger='["console","wandb"]' \ trainer.project_name='retool_async_rl' \ trainer.experiment_name='qwen3-4b_function_rm-retool-async-sgl-no-sft-n8-v2505271300' \ trainer.val_before_train=False \ diff --git a/tests/e2e/run_gsm8k_fsdp_sgl_multiturn_w_tool.sh b/tests/special_e2e/run_gsm8k_fsdp_sgl_multiturn_w_tool.sh similarity index 98% rename from tests/e2e/run_gsm8k_fsdp_sgl_multiturn_w_tool.sh rename to tests/special_e2e/run_gsm8k_fsdp_sgl_multiturn_w_tool.sh index 333f6d2bd..76983ddad 100644 --- a/tests/e2e/run_gsm8k_fsdp_sgl_multiturn_w_tool.sh +++ b/tests/special_e2e/run_gsm8k_fsdp_sgl_multiturn_w_tool.sh @@ -44,7 +44,7 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.ref.fsdp_config.param_offload=True \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ - trainer.logger=['console'] \ + trainer.logger=console \ trainer.project_name='gsm8k_async_rl' \ trainer.experiment_name=qwen2.5-3b_function_rm-gsm8k-sgl-multi-w-tool-$FSDP_STRATEGY-rebased-0427-verify-n16 \ trainer.n_gpus_per_node=8 \ diff --git a/tests/special_e2e/run_ppo_trainer_megatron.sh b/tests/special_e2e/run_ppo_trainer_megatron.sh new file mode 100644 index 000000000..72232d4db --- /dev/null +++ b/tests/special_e2e/run_ppo_trainer_megatron.sh @@ -0,0 +1,237 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +export CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping +export VERL_LOGGING_LEVEL=INFO +export VERL_PPO_LOGGING_LEVEL=INFO + +NUM_GPUS=${NUM_GPUS:-8} + +MODEL_ID=${MODEL_ID:-Qwen/Qwen2.5-0.5B} +MODEL_PATH=${MODEL_PATH:-${HOME}/models/${MODEL_ID}} +huggingface-cli download "${MODEL_ID}" --local-dir "${MODEL_PATH}" + +USE_DUMMY_MODEL=${USE_DUMMY_MODEL:-False} +DUMMY_MODEL_PATH=${DUMMY_MODEL_PATH:-${HOME}/dummy_models/${MODEL_ID}} +if [ "$USE_DUMMY_MODEL" = "True" ]; then + if [ -z "${DUMMY_MODEL_CONFIG_PATH}" ]; then + echo "[ERROR] DUMMY_MODEL_CONFIG_PATH not set" + exit 1 + fi + + python scripts/init_random_model.py \ + --hf_model_path "${MODEL_PATH}" \ + --new_config_path "${DUMMY_MODEL_CONFIG_PATH}" \ + --output_path "${DUMMY_MODEL_PATH}" + + MODEL_PATH="${DUMMY_MODEL_PATH}" +fi + +TRAIN_FILES=${TRAIN_FILES:-${HOME}/data/gsm8k/train.parquet} +VAL_FILES=${VAL_FILES:-${HOME}/data/gsm8k/test.parquet} + +ADV_ESTIMATOR=${ADV_ESTIMATOR:-gae} +# Validation +VAL_BEFORE_TRAIN=${VAL_BEFORE_TRAIN:-False} +TEST_FREQ=${TEST_FREQ:--1} +# Save & Resume +RESUME_MODE=${RESUME_MODE:-disable} +SAVE_FREQ=${SAVE_FREQ:--1} +TOTAL_TRAIN_STEPS=${TOTAL_TRAIN_STEPS:-1} + +USE_DYNAMIC_BSZ=${USE_DYNAMIC_BSZ:-True} +ppo_max_token_len_per_gpu=${PPO_MAX_TOKEN_LEN:-2400} +forward_max_token_len_per_gpu=${FWD_MAX_TOKEN_LEN:-4800} +train_traj_micro_bsz_per_gpu=${MICRO_BSZ:-2} # b +n_resp_per_prompt=4 # g + +train_traj_micro_bsz=$((train_traj_micro_bsz_per_gpu * NUM_GPUS)) # b * n +train_traj_mini_bsz=$((train_traj_micro_bsz * 2)) # 2 * b * n +train_prompt_mini_bsz=$((train_traj_mini_bsz * n_resp_per_prompt)) # 2 * b * n / g +train_prompt_bsz=$((train_prompt_mini_bsz * 2)) # 4 * b * n / g + +MAX_PROMPT_LENGTH=${MAX_PROMPT_LENGTH:-512} +MAX_RESPONSE_LENGTH=${MAX_RESPONSE_LENGTH:-512} + +COMMON_PP=${COMMON_PP:-2} +COMMON_VPP=${COMMON_VPP:-2} +COMMON_CP=${COMMON_CP:-2} +COMMON_TP=${COMMON_TP:-2} +COMMON_EP=${COMMON_EP:-1} +COMMON_ETP=${COMMON_ETP:-null} + +TRAIN_TP=${TRAIN_TP:-$COMMON_TP} +INFER_TP=${INFER_TP:-$COMMON_TP} + +ACTOR_PP=${ACTOR_PP:-$COMMON_PP} +ACTOR_VPP=${ACTOR_VPP:-$COMMON_VPP} +ACTOR_CP=${ACTOR_CP:-$COMMON_CP} +ACTOR_TP=${ACTOR_TP:-$TRAIN_TP} +ACTOR_EP=${ACTOR_EP:-$COMMON_EP} +ACTOR_ETP=${ACTOR_ETP:-$COMMON_ETP} +ROLLOUT_TP=${ROLLOUT_TP:-$INFER_TP} +REF_PP=${REF_PP:-$COMMON_PP} +REF_VPP=${REF_VPP:-$COMMON_VPP} +REF_CP=${REF_CP:-$COMMON_CP} +REF_TP=${REF_TP:-$TRAIN_TP} +REF_EP=${REF_EP:-$COMMON_EP} +REF_ETP=${REF_ETP:-$COMMON_ETP} +CRITIC_PP=${CRITIC_PP:-$COMMON_PP} +CRITIC_VPP=${CRITIC_VPP:-$COMMON_VPP} +CRITIC_CP=${CRITIC_CP:-$COMMON_CP} +CRITIC_TP=${CRITIC_TP:-$TRAIN_TP} +CRITIC_EP=${CRITIC_EP:-$COMMON_EP} +CRITIC_ETP=${CRITIC_ETP:-$COMMON_ETP} +RM_PP=${RM_PP:-$COMMON_PP} +RM_VPP=${RM_VPP:-$COMMON_VPP} +RM_CP=${RM_CP:-$COMMON_CP} +RM_TP=${RM_TP:-$TRAIN_TP} +RM_EP=${RM_EP:-$COMMON_EP} +RM_ETP=${RM_ETP:-$COMMON_ETP} + +ALL_OFFLOAD=${ALL_OFFLOAD:-False} +COMMON_PARAM_OFFLOAD=${COMMON_PARAM_OFFLOAD:-$ALL_OFFLOAD} +COMMON_GRAD_OFFLOAD=${COMMON_GRAD_OFFLOAD:-$ALL_OFFLOAD} +COMMON_OPTIMIZER_OFFLOAD=${COMMON_OPTIMIZER_OFFLOAD:-$ALL_OFFLOAD} + +ACTOR_PARAM_OFFLOAD=${ACTOR_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD} +ACTOR_GRAD_OFFLOAD=${ACTOR_GRAD_OFFLOAD:-$COMMON_GRAD_OFFLOAD} +ACTOR_OPTIMIZER_OFFLOAD=${ACTOR_OPTIMIZER_OFFLOAD:-$COMMON_OPTIMIZER_OFFLOAD} +REF_PARAM_OFFLOAD=${REF_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD} +CRITIC_PARAM_OFFLOAD=${CRITIC_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD} +CRITIC_GRAD_OFFLOAD=${CRITIC_GRAD_OFFLOAD:-$COMMON_GRAD_OFFLOAD} +CRITIC_OPTIMIZER_OFFLOAD=${CRITIC_OPTIMIZER_OFFLOAD:-$COMMON_OPTIMIZER_OFFLOAD} +RM_PARAM_OFFLOAD=${RM_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD} +USE_MBRIDGE=${USE_MBRIDGE:-False} +USE_FUSED_KERNELS=${USE_FUSED_KERNELS:-False} + +LR_WARMUP_STEPS=${LR_WARMUP_STEPS:-null} + +CHECKPOINT_CONTENTS=['model','hf_model','optimizer','extra'] +SKIP_SAVE_HF_MODEL=${SKIP_SAVE_HF_MODEL:-0} +if [ $SKIP_SAVE_HF_MODEL -eq 1 ]; then + CHECKPOINT_CONTENTS=['model','optimizer','extra'] +fi + +USE_DIST_CKPT=${USE_DIST_CKPT:-False} +DIST_CKPT_PATH=${DIST_CKPT_PATH:-${HOME}/dist_ckpt/${MODEL_ID}} +if [ "$USE_DIST_CKPT" = "True" ]; then + if [ "$USE_DUMMY_MODEL" = "True" ]; then + DIST_CKPT_PATH=${HOME}/dist_ckpt_dummy/${MODEL_ID} + fi + python scripts/converter_hf_to_mcore.py \ + --hf_model_path "${MODEL_PATH}" \ + --output_path "${DIST_CKPT_PATH}" +fi + +ENGINE=${ENGINE:-"vllm"} + +exp_name="$(basename "${MODEL_ID,,}")-megatron-gsm8k-minimal" + +if [ "$ENGINE" = "vllm" ]; then + MODE=${MODE:-"sync"} + ROLLOUT_MODE_ARG="actor_rollout_ref.rollout.mode=${MODE}" + if [ "$MODE" = "async" ]; then + ROLLOUT_MODE_ARG="${ROLLOUT_MODE_ARG} data.return_raw_chat=True" + fi +else + ROLLOUT_MODE_ARG="" +fi + +python3 -m verl.trainer.main_ppo --config-path=config \ + --config-name='ppo_megatron_trainer.yaml'\ + algorithm.adv_estimator="${ADV_ESTIMATOR}" \ + data.train_files="${TRAIN_FILES}" \ + data.val_files="${VAL_FILES}" \ + data.train_batch_size=${train_prompt_bsz} \ + data.max_prompt_length=${MAX_PROMPT_LENGTH} \ + data.max_response_length=${MAX_RESPONSE_LENGTH} \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.model.use_fused_kernels=${USE_FUSED_KERNELS} \ + actor_rollout_ref.actor.optim.lr_warmup_steps=$LR_WARMUP_STEPS \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ + actor_rollout_ref.actor.use_dynamic_bsz=${USE_DYNAMIC_BSZ} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${ppo_max_token_len_per_gpu} \ + actor_rollout_ref.actor.megatron.use_mbridge=${USE_MBRIDGE} \ + actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=$ACTOR_PP \ + actor_rollout_ref.actor.megatron.virtual_pipeline_model_parallel_size=$ACTOR_VPP \ + actor_rollout_ref.actor.megatron.context_parallel_size=$ACTOR_CP \ + actor_rollout_ref.actor.megatron.tensor_model_parallel_size=$ACTOR_TP \ + actor_rollout_ref.actor.megatron.expert_model_parallel_size=$ACTOR_EP \ + actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=$ACTOR_ETP \ + actor_rollout_ref.actor.megatron.param_offload=${ACTOR_PARAM_OFFLOAD} \ + actor_rollout_ref.actor.megatron.optimizer_offload=${ACTOR_OPTIMIZER_OFFLOAD} \ + actor_rollout_ref.actor.megatron.grad_offload=${ACTOR_GRAD_OFFLOAD} \ + actor_rollout_ref.actor.megatron.use_dist_checkpointing=${USE_DIST_CKPT} \ + actor_rollout_ref.actor.megatron.dist_checkpointing_path=${DIST_CKPT_PATH} \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.checkpoint.save_contents=$CHECKPOINT_CONTENTS \ + actor_rollout_ref.rollout.name="${ENGINE}" ${ROLLOUT_MODE_ARG}\ + actor_rollout_ref.rollout.tensor_model_parallel_size=$ROLLOUT_TP \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + actor_rollout_ref.rollout.update_weights_bucket_megabytes=128 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ + actor_rollout_ref.ref.megatron.use_mbridge=${USE_MBRIDGE} \ + actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=$REF_PP \ + actor_rollout_ref.ref.megatron.virtual_pipeline_model_parallel_size=$REF_VPP \ + actor_rollout_ref.ref.megatron.context_parallel_size=$REF_CP \ + actor_rollout_ref.ref.megatron.tensor_model_parallel_size=$REF_TP \ + actor_rollout_ref.ref.megatron.expert_model_parallel_size=$REF_EP \ + actor_rollout_ref.ref.megatron.expert_tensor_parallel_size=$REF_ETP \ + actor_rollout_ref.ref.megatron.param_offload=${REF_PARAM_OFFLOAD} \ + actor_rollout_ref.ref.megatron.use_dist_checkpointing=${USE_DIST_CKPT} \ + actor_rollout_ref.ref.megatron.dist_checkpointing_path=${DIST_CKPT_PATH} \ + critic.optim.lr=2e-5 \ + critic.optim.lr_warmup_steps=$LR_WARMUP_STEPS \ + critic.model.path="${MODEL_PATH}" \ + critic.model.enable_gradient_checkpointing=False \ + critic.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ + critic.ppo_max_token_len_per_gpu=${forward_max_token_len_per_gpu} \ + critic.megatron.use_mbridge=${USE_MBRIDGE} \ + critic.megatron.pipeline_model_parallel_size=$CRITIC_PP \ + critic.megatron.virtual_pipeline_model_parallel_size=$CRITIC_VPP \ + critic.megatron.context_parallel_size=$CRITIC_CP \ + critic.megatron.tensor_model_parallel_size=$CRITIC_TP \ + critic.megatron.expert_model_parallel_size=$CRITIC_EP \ + critic.megatron.expert_tensor_parallel_size=$CRITIC_ETP \ + critic.megatron.param_offload=${CRITIC_PARAM_OFFLOAD} \ + critic.megatron.optimizer_offload=${CRITIC_OPTIMIZER_OFFLOAD} \ + critic.megatron.grad_offload=${CRITIC_GRAD_OFFLOAD} \ + critic.megatron.use_dist_checkpointing=${USE_DIST_CKPT} \ + critic.megatron.dist_checkpointing_path=${DIST_CKPT_PATH} \ + critic.checkpoint.save_contents=$CHECKPOINT_CONTENTS \ + reward_model.enable=True \ + reward_model.model.path="${MODEL_PATH}" \ + reward_model.micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ + reward_model.megatron.use_mbridge=${USE_MBRIDGE} \ + reward_model.megatron.pipeline_model_parallel_size=$RM_PP \ + reward_model.megatron.virtual_pipeline_model_parallel_size=$RM_VPP \ + reward_model.megatron.context_parallel_size=$RM_CP \ + reward_model.megatron.tensor_model_parallel_size=$RM_TP \ + reward_model.megatron.expert_model_parallel_size=$RM_EP \ + reward_model.megatron.expert_tensor_parallel_size=$RM_ETP \ + reward_model.megatron.param_offload=${RM_PARAM_OFFLOAD} \ + reward_model.megatron.use_dist_checkpointing=${USE_DIST_CKPT} \ + reward_model.megatron.dist_checkpointing_path=${DIST_CKPT_PATH} \ + algorithm.use_kl_in_reward=False \ + algorithm.kl_penalty=kl \ + algorithm.kl_ctrl.kl_coef=0.001 \ + trainer.critic_warmup=0 \ + trainer.logger=console \ + trainer.project_name='verl-test' \ + trainer.experiment_name="${exp_name}" \ + trainer.nnodes=1 \ + trainer.n_gpus_per_node=${NUM_GPUS} \ + trainer.val_before_train="${VAL_BEFORE_TRAIN}" \ + trainer.test_freq="${TEST_FREQ}" \ + trainer.save_freq="${SAVE_FREQ}" \ + trainer.resume_mode="${RESUME_MODE}" \ + trainer.total_epochs=2 \ + trainer.total_training_steps="${TOTAL_TRAIN_STEPS}" $@ diff --git a/tests/e2e/run_prime.sh b/tests/special_e2e/run_prime.sh similarity index 98% rename from tests/e2e/run_prime.sh rename to tests/special_e2e/run_prime.sh index 0d0a8b50a..ac7ecb79c 100644 --- a/tests/e2e/run_prime.sh +++ b/tests/special_e2e/run_prime.sh @@ -62,7 +62,7 @@ python3 -m recipe.prime.main_prime \ reward_model.mini_batch_size=${train_prompt_bsz} \ reward_model.reward_manager=prime \ trainer.val_before_train=False \ - trainer.logger=['console'] \ + trainer.logger=console \ trainer.project_name='verl-test' \ trainer.experiment_name="${exp_name}" \ trainer.n_gpus_per_node=${NUM_GPUS} \ diff --git a/tests/e2e/run_r1_distill_qwen_aime24_eval.sh b/tests/special_e2e/run_r1_distill_qwen_aime24_eval.sh similarity index 96% rename from tests/e2e/run_r1_distill_qwen_aime24_eval.sh rename to tests/special_e2e/run_r1_distill_qwen_aime24_eval.sh index a3aa4ac54..5dec6fe6c 100644 --- a/tests/e2e/run_r1_distill_qwen_aime24_eval.sh +++ b/tests/special_e2e/run_r1_distill_qwen_aime24_eval.sh @@ -21,7 +21,7 @@ python3 -m verl.trainer.main_generation \ rollout.gpu_memory_utilization=0.95 \ rollout.max_num_batched_tokens=65536 \ rollout.enforce_eager=False \ - rollout.free_cache_engine=False + rollout.free_cache_engine=True python3 -m recipe.r1.main_eval \ data.path=$HOME/data/r1/test-output-k1.parquet \ diff --git a/tests/e2e/run_spin.sh b/tests/special_e2e/run_spin.sh similarity index 94% rename from tests/e2e/run_spin.sh rename to tests/special_e2e/run_spin.sh index 1b1c628cd..1b5a2af0d 100644 --- a/tests/e2e/run_spin.sh +++ b/tests/special_e2e/run_spin.sh @@ -19,9 +19,8 @@ CUDA_VISIBLE_DEVICES=${VISIBLE_DEVICES} python3 -m recipe.spin.main_spin \ actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ actor_rollout_ref.ref.log_prob_micro_batch_size=64 \ algorithm.kl_ctrl.kl_coef=0.001 \ - trainer.logger=['console'] \ + trainer.logger=console \ trainer.val_before_train=False \ - trainer.default_hdfs_dir=null \ trainer.n_gpus_per_node=4 \ trainer.nnodes=1 \ trainer.save_freq=-1 \ diff --git a/tests/e2e/run_sppo.sh b/tests/special_e2e/run_sppo.sh similarity index 98% rename from tests/e2e/run_sppo.sh rename to tests/special_e2e/run_sppo.sh index 4781b3ded..a33131972 100644 --- a/tests/e2e/run_sppo.sh +++ b/tests/special_e2e/run_sppo.sh @@ -38,7 +38,7 @@ python3 -m recipe.sppo.main_sppo \ actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ - trainer.logger=['console'] \ + trainer.logger=console \ trainer.val_before_train=False \ trainer.n_gpus_per_node=8 \ trainer.nnodes=1 \ diff --git a/tests/e2e/run_test.sh b/tests/special_e2e/run_test.sh similarity index 100% rename from tests/e2e/run_test.sh rename to tests/special_e2e/run_test.sh diff --git a/tests/e2e/sft/run_sft.sh b/tests/special_e2e/sft/run_sft.sh similarity index 95% rename from tests/e2e/sft/run_sft.sh rename to tests/special_e2e/sft/run_sft.sh index 850f8b4a7..4cd9a4790 100644 --- a/tests/e2e/sft/run_sft.sh +++ b/tests/special_e2e/sft/run_sft.sh @@ -50,7 +50,6 @@ torchrun --standalone --nnodes=1 --nproc_per_node=${NUM_GPUS} ${ENTRYPOINT} \ trainer.project_name="${project_name}" \ trainer.experiment_name="${exp_name}" \ trainer.total_training_steps=1 \ - trainer.logger=['console'] \ - trainer.default_hdfs_dir=null $@ + trainer.logger=console $@ rm -rf "${ckpts_home:?}/*" \ No newline at end of file diff --git a/tests/e2e/sft/test_sp_loss_match.py b/tests/special_e2e/sft/test_sp_loss_match.py similarity index 93% rename from tests/e2e/sft/test_sp_loss_match.py rename to tests/special_e2e/sft/test_sp_loss_match.py index 71b1abff8..4dc0cbdae 100644 --- a/tests/e2e/sft/test_sp_loss_match.py +++ b/tests/special_e2e/sft/test_sp_loss_match.py @@ -101,7 +101,9 @@ def create_trainer(config): device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,), mesh_dim_names=("fsdp",)) dp_size = world_size // config.ulysses_sequence_parallel_size - ulysses_device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(dp_size, config.ulysses_sequence_parallel_size), mesh_dim_names=("dp", "sp")) + ulysses_device_mesh = init_device_mesh( + device_type="cuda", mesh_shape=(dp_size, config.ulysses_sequence_parallel_size), mesh_dim_names=("dp", "sp") + ) # build tokenizer and datasets first from verl.trainer.fsdp_sft_trainer import create_sft_dataset @@ -113,7 +115,14 @@ def create_trainer(config): train_dataset = create_sft_dataset(config.data.train_files, config.data, tokenizer) val_dataset = create_sft_dataset(config.data.val_files, config.data, tokenizer) - return FSDPSFTTrainer(config=config, device_mesh=device_mesh, ulysses_device_mesh=ulysses_device_mesh, tokenizer=tokenizer, train_dataset=train_dataset, val_dataset=val_dataset) + return FSDPSFTTrainer( + config=config, + device_mesh=device_mesh, + ulysses_device_mesh=ulysses_device_mesh, + tokenizer=tokenizer, + train_dataset=train_dataset, + val_dataset=val_dataset, + ) def main(config): diff --git a/tests/e2e/run_qwen_gsm8k_dapo.sh b/tests/special_npu/run_qwen2_5_05b_dapo.sh similarity index 51% rename from tests/e2e/run_qwen_gsm8k_dapo.sh rename to tests/special_npu/run_qwen2_5_05b_dapo.sh index 9c542d032..3f3756bdf 100644 --- a/tests/e2e/run_qwen_gsm8k_dapo.sh +++ b/tests/special_npu/run_qwen2_5_05b_dapo.sh @@ -1,7 +1,11 @@ #!/usr/bin/env bash -set -x +set -xeuo pipefail -export VLLM_ATTENTION_BACKEND=XFORMERS +NUM_GPUS=${NUM_GPUS:-8} + +MODEL_ID=${MODEL_ID:-Qwen/Qwen2.5-0.5B-Instruct} +MODEL_PATH=${MODEL_PATH:-${HOME}/models/${MODEL_ID}} +huggingface-cli download "${MODEL_ID}" --local-dir "${MODEL_PATH}" adv_estimator=grpo @@ -13,8 +17,8 @@ kl_loss_coef=0.0 clip_ratio_low=0.2 clip_ratio_high=0.28 -max_prompt_length=512 -max_response_length=512 +max_prompt_length=1024 +max_response_length=2048 enable_overlong_buffer=True overlong_buffer_len=128 overlong_penalty_factor=1.0 @@ -24,14 +28,22 @@ loss_agg_mode="token-mean" enable_filter_groups=True filter_groups_metric=seq_reward max_num_gen_batches=10 -train_prompt_bsz=32 -train_prompt_mini_bsz=$((train_prompt_bsz / 2)) -gen_prompt_bsz=$((train_prompt_bsz * 3)) -n_resp_per_prompt=4 -python3 -m recipe.dapo.src.main_dapo \ - data.train_files="$HOME/data/gsm8k/train.parquet" \ - data.val_files="$HOME/data/gsm8k/test.parquet" \ +train_traj_micro_bsz_per_gpu=2 # b +n_resp_per_prompt=4 # g + +train_traj_micro_bsz=$((train_traj_micro_bsz_per_gpu * NUM_GPUS)) # b * n +train_traj_mini_bsz=$((train_traj_micro_bsz * 2)) # 2 * b * n +train_prompt_mini_bsz=$((train_traj_mini_bsz * n_resp_per_prompt)) # 2 * b * n / g +train_prompt_bsz=$((train_prompt_mini_bsz * 2)) # 4 * b * n / g + +gen_prompt_bsz=$((train_prompt_bsz * 4)) + +exp_name="$(basename "${MODEL_ID,,}")-dapo-minimal" + +python3 -m recipe.dapo.main_dapo \ + data.train_files="${HOME}/data/gsm8k/train.parquet" \ + data.val_files="${HOME}/data/gsm8k/test.parquet" \ reward_model.reward_manager=dapo \ algorithm.adv_estimator=${adv_estimator} \ algorithm.use_kl_in_reward=${use_kl_in_reward} \ @@ -51,24 +63,35 @@ python3 -m recipe.dapo.src.main_dapo \ algorithm.filter_groups.enable=${enable_filter_groups} \ algorithm.filter_groups.metric=${filter_groups_metric} \ algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ - actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ actor_rollout_ref.actor.optim.lr=1e-6 \ actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.model.use_fused_kernels=True \ actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ actor_rollout_ref.actor.fsdp_config.param_offload=False \ actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ actor_rollout_ref.ref.fsdp_config.param_offload=True \ - trainer.logger=['console'] \ - trainer.project_name='verl_example_gsm8k' \ - trainer.experiment_name='qwen2.5_0.5b_e2e_ci_dapo' \ - trainer.n_gpus_per_node=8 \ + actor_rollout_ref.actor.fsdp_config.forward_prefetch=True \ + actor_rollout_ref.ref.fsdp_config.forward_prefetch=True \ + actor_rollout_ref.actor.entropy_checkpointing=True \ + actor_rollout_ref.ref.entropy_checkpointing=True \ + actor_rollout_ref.actor.entropy_from_logits_with_chunking=True \ + actor_rollout_ref.ref.entropy_from_logits_with_chunking=True \ + trainer.logger=console \ + trainer.project_name='verl-test' \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node=${NUM_GPUS} \ trainer.nnodes=1 \ trainer.save_freq=-1 \ - trainer.total_training_steps=1 $@ \ No newline at end of file + trainer.total_epochs=2 \ + trainer.resume_mode=disable \ + trainer.val_before_train=False \ + trainer.total_training_steps=1 \ + trainer.device=npu $@ diff --git a/tests/npu/run_qwen2_5_05b_grpo.sh b/tests/special_npu/run_qwen2_5_05b_grpo.sh similarity index 94% rename from tests/npu/run_qwen2_5_05b_grpo.sh rename to tests/special_npu/run_qwen2_5_05b_grpo.sh index d54102b75..466386b15 100644 --- a/tests/npu/run_qwen2_5_05b_grpo.sh +++ b/tests/special_npu/run_qwen2_5_05b_grpo.sh @@ -1,6 +1,5 @@ set -x -export VLLM_ATTENTION_BACKEND=XFORMERS python3 -m verl.trainer.main_ppo \ algorithm.adv_estimator=grpo \ @@ -32,7 +31,7 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.ref.fsdp_config.param_offload=True \ algorithm.kl_ctrl.kl_coef=0.001 \ trainer.critic_warmup=0 \ - trainer.logger=['console'] \ + trainer.logger=console \ trainer.project_name='verl_grpo_example_gsm8k' \ trainer.experiment_name='qwen2_7b_function_rm' \ trainer.n_gpus_per_node=8 \ @@ -41,4 +40,4 @@ python3 -m verl.trainer.main_ppo \ trainer.test_freq=5 \ trainer.total_epochs=1 \ trainer.total_training_steps=2 \ - trainer.device=npu $@ \ No newline at end of file + trainer.device=npu $@ diff --git a/tests/special_npu/run_qwen2_5_05b_sft_peft_sp2.sh b/tests/special_npu/run_qwen2_5_05b_sft_peft_sp2.sh new file mode 100644 index 000000000..1bb8fc4cd --- /dev/null +++ b/tests/special_npu/run_qwen2_5_05b_sft_peft_sp2.sh @@ -0,0 +1,29 @@ +set -x + +mkdir -p ./save_ckpts + +torchrun --standalone --nnodes=1 --nproc_per_node=8 \ + -m verl.trainer.fsdp_sft_trainer \ + data.train_files=$HOME/data/gsm8k/train.parquet \ + data.val_files=$HOME/data/gsm8k/test.parquet \ + data.prompt_key=extra_info \ + data.response_key=extra_info \ + optim.lr=1e-4 \ + data.prompt_dict_keys=['question'] \ + +data.response_dict_keys=['answer'] \ + data.micro_batch_size_per_gpu=32 \ + model.partial_pretrain=Qwen/Qwen2.5-0.5B-Instruct \ + trainer.default_local_dir=./save_ckpts \ + trainer.project_name=gsm8k-sft \ + trainer.experiment_name=gsm8k-sft-qwen-2.5-0.5b-instruct \ + trainer.logger=console \ + trainer.total_epochs=1 \ + trainer.total_training_steps=1 $@ \ + model.lora_rank=32 \ + model.lora_alpha=16 \ + model.target_modules=all-linear \ + model.strategy=fsdp \ + ulysses_sequence_parallel_size=2 \ + use_remove_padding=true + +rm -rf ./outputs ./save_ckpts diff --git a/tests/special_npu/run_qwen2_5_vl_3b_npu.sh b/tests/special_npu/run_qwen2_5_vl_3b_npu.sh new file mode 100644 index 000000000..dc3799e99 --- /dev/null +++ b/tests/special_npu/run_qwen2_5_vl_3b_npu.sh @@ -0,0 +1,52 @@ +set -x +ENGINE=${1:-vllm} + +# Some models are optimized by vllm ascend. While in some case, e.g. rlhf training, +# the optimized model may not be suitable. In this case, set this value to 0 to disable the optimized model. +export USE_OPTIMIZED_MODEL=0 + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=$HOME/data/geo3k/train.parquet \ + data.val_files=$HOME/data/geo3k/test.parquet \ + data.train_batch_size=512 \ + data.max_prompt_length=1024 \ + data.max_response_length=2048 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.image_key=images \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-3B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=16 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.01 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.use_torch_compile=False \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=$ENGINE \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.enable_chunked_prefill=False \ + actor_rollout_ref.rollout.enforce_eager=True \ + actor_rollout_ref.rollout.free_cache_engine=True \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger=console \ + trainer.project_name='verl_grpo_example_geo3k' \ + trainer.experiment_name='qwen2_5_vl_3b_function_rm' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=-1 \ + trainer.test_freq=-1 \ + trainer.total_epochs=1 \ + trainer.total_training_steps=1 \ + trainer.device=npu $@ \ No newline at end of file diff --git a/tests/special_sanity/check_api_docs.py b/tests/special_sanity/check_api_docs.py new file mode 100644 index 000000000..fa31ec8c5 --- /dev/null +++ b/tests/special_sanity/check_api_docs.py @@ -0,0 +1,135 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Fail CI if any function or class that is publicly exported via +``__all__`` lacks a docstring. + +Usage +----- + # Check specific modules or packages + python check_docstrings.py mypkg.core mypkg.utils + + # Check an entire source tree (all top-level packages under cwd) + python check_docstrings.py +""" + +from __future__ import annotations + +import argparse +import importlib +import inspect +import pkgutil +import sys +from pathlib import Path +from types import ModuleType +from typing import Iterable + +_ALLOW_LIST = [ + "verl.third_party.vllm.LLM", + "verl.third_party.vllm.parallel_state", + "verl.utils.profiler.WorkerProfiler", + "verl.utils.profiler.WorkerProfilerExtension", + "verl.utils.profiler.log_gpu_memory_usage", + "verl.utils.profiler.log_print", + "verl.utils.profiler.mark_annotate", + "verl.utils.profiler.mark_end_range", + "verl.utils.profiler.mark_start_range", + "verl.models.mcore.qwen2_5_vl.get_vision_model_config", + "verl.models.mcore.qwen2_5_vl.get_vision_projection_config", +] + + +def iter_submodules(root: ModuleType) -> Iterable[ModuleType]: + """Yield *root* and every sub-module inside it.""" + yield root + if getattr(root, "__path__", None): # only packages have __path__ + for mod_info in pkgutil.walk_packages(root.__path__, prefix=f"{root.__name__}."): + try: + yield importlib.import_module(mod_info.name) + except Exception as exc: # noqa: BLE001 + print(f"[warn] Skipping {mod_info.name!r}: {exc}", file=sys.stderr) + + +def names_missing_doc(mod: ModuleType) -> list[str]: + """Return fully-qualified names that need docstrings.""" + missing: list[str] = [] + public = getattr(mod, "__all__", []) + for name in public: + obj = getattr(mod, name, None) + if f"{mod.__name__}.{name}" in _ALLOW_LIST: + continue + if obj is None: + # Exported but not found in the module: flag it anyway. + missing.append(f"{mod.__name__}.{name} (not found)") + continue + + if inspect.isfunction(obj) or inspect.isclass(obj): + doc = inspect.getdoc(obj) + if not doc or not doc.strip(): + missing.append(f"{mod.__name__}.{name}") + return missing + + +def check_module(qualname: str) -> list[str]: + """Import *qualname* and check it (and sub-modules).""" + try: + module = importlib.import_module(qualname) + except ModuleNotFoundError as exc: + print(f"[error] Cannot import '{qualname}': {exc}", file=sys.stderr) + return [qualname] + + missing: list[str] = [] + for submod in iter_submodules(module): + missing.extend(names_missing_doc(submod)) + return missing + + +def autodiscover_packages() -> list[str]: + """Detect top-level packages under CWD when no argument is given.""" + pkgs: list[str] = [] + for p in Path.cwd().iterdir(): + if p.is_dir() and (p / "__init__.py").exists(): + pkgs.append(p.name) + return pkgs + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "modules", + nargs="*", + help="Fully-qualified module or package names (defaults to every top-level package found in CWD).", + ) + args = parser.parse_args() + + targets = args.modules or autodiscover_packages() + if not targets: + raise ValueError("[error] No modules specified and none detected automatically.") + + all_missing: list[str] = [] + for modname in targets: + all_missing.extend(check_module(modname)) + + if all_missing: + print("\nMissing docstrings:") + for name in sorted(all_missing): + print(f" - {name}") + raise ValueError("Missing docstrings detected. Please enhance them with docs accordingly.") + + print("✅ All exported functions/classes have docstrings.") + + +if __name__ == "__main__": + main() diff --git a/tests/special_sanity/check_device_api_usage.py b/tests/special_sanity/check_device_api_usage.py new file mode 100644 index 000000000..c8988db55 --- /dev/null +++ b/tests/special_sanity/check_device_api_usage.py @@ -0,0 +1,93 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This CI test is used for checking whether device api usage is irregular, suggest using api in `verl/utils/device.py`. +Search targets include .py files in verl/recipe and verl/verl. +Some files that must contain ".cuda", "cuda" or "nccl" keyword is pre-defined in whitelist below. +""" + +import os +from argparse import ArgumentParser +from pathlib import Path + +# directory or file path must contain keyword ".cuda" or "cuda" +CUDA_KEYWORD_CHECK_WHITELIST = [ + "verl/utils/device.py", + "recipe/prime/prime_ray_trainer.py", # appear in default device_name + "recipe/spin/spin_trainer.py", # appear in default device_name + "recipe/sppo/sppo_ray_trainer.py", # appear in default device_name + "verl/utils/profiler/nvtx_profile.py", # appear in NsightSystemsProfiler + "verl/utils/kernel/linear_cross_entropy.py", # appear in nvidia nvtx + "verl/utils/rendezvous/ray_backend.py", # appear in cupy importance + "verl/single_controller/ray/base.py", # appear in default device_name + "verl/trainer/ppo/ray_trainer.py", # appear in default device_name + "verl/utils/reward_score/sandbox_fusion/utils.py", # appear in sandbox language type + "verl/workers/reward_model/megatron/reward_model.py", # appear in default device_name +] + +# directory or file path must contain keyword "nccl" +NCCL_KEYWORD_CHECK_WHITELIST = [ + "verl/utils/device.py", + "verl/third_party/sglang/parallel_state.py", # appear in default backend +] + +SEARCH_WHITELIST = CUDA_KEYWORD_CHECK_WHITELIST + NCCL_KEYWORD_CHECK_WHITELIST + +SEARCH_KEYWORDS = [".cuda", '"cuda"', '"nccl"'] + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("--directory", "-d", required=True, type=str) + args = parser.parse_args() + directory_in_str = args.directory + + pathlist = Path(directory_in_str).glob("**/*.py") + for path in pathlist: + path_in_str = str(path.absolute()) + + # judge whether current path is in pre-defined search whitelist or not. + path_in_whitelist = False + + for sw in SEARCH_WHITELIST: + # for easy debugging in non-linux system + sw = sw.replace("/", os.sep) + if sw in path_in_str: + print(f"[SKIP] File {path_in_str} is in device api usage check whitelist, checking is skipped.") + path_in_whitelist = True + break + + if path_in_whitelist: + continue + + with open(path_in_str, encoding="utf-8") as f: + file_content = f.read() + + find_invalid_device_management = False + + for sk in SEARCH_KEYWORDS: + if sk in file_content: + find_invalid_device_management = True + break + + print( + f"[CHECK] File {path_in_str} is detected for device api usage check, check result: " + f"{'success' if not find_invalid_device_management else f'failed, because detect {sk}'}." + ) + + assert not find_invalid_device_management, ( + f'file {path_in_str} contains .cuda/"cuda"/"nccl" usage, please use api in ' + f"verl/utils/device.py directly." + ) diff --git a/tests/special_sanity/check_docs_time_info.py b/tests/special_sanity/check_docs_time_info.py new file mode 100644 index 000000000..a54d1d50a --- /dev/null +++ b/tests/special_sanity/check_docs_time_info.py @@ -0,0 +1,84 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Check that every .md and .rst file under docs/ contains the substring "Last updated", +with an allow-list for exceptions. +""" + +import sys +from pathlib import Path + +# === CONFIGURATION === + +# Relative paths (to docs/) or glob patterns to skip checking +ALLOW_LIST = { + "docs/README.md", # you can list individual files + "docs/legacy/*.rst", # or glob patterns + "docs/index.rst", + "docs/start/install.rst", + "docs/start/quickstart.rst", + "docs/README_vllm0.7.md", +} + +# The folder to scan +DOCS_DIR = Path("docs") + +# === SCRIPT === + + +def is_allowed(path: Path) -> bool: + """ + Return True if `path` matches any entry in ALLOW_LIST. + """ + rel = str(path) + for pattern in ALLOW_LIST: + if Path(rel).match(pattern): + return True + return False + + +def main(): + if not DOCS_DIR.exists(): + print(f"Error: Documentation directory '{DOCS_DIR}' does not exist.", file=sys.stderr) + sys.exit(1) + + missing = [] + + # Gather all .md and .rst files under docs/ + for ext in ("*.md", "*.rst"): + for path in DOCS_DIR.rglob(ext): + if is_allowed(path): + continue + + text = path.read_text(encoding="utf-8", errors="ignore") + if "Last updated" not in text: + missing.append(path) + + # Report + if missing: + print("\nThe following files are missing the 'Last updated' string:\n") + for p in missing: + print(f" - {p}") + print(f"\nTotal missing: {len(missing)}\n", file=sys.stderr) + raise AssertionError( + "Some documentation files lack a 'Last updated' line. Please include info such as " + "'Last updated: mm/dd/yyyy' to indicate the last update time of the document." + ) + else: + print("✅ All checked files contain 'Last updated'.") + + +if __name__ == "__main__": + main() diff --git a/tests/special_sanity/check_docstrings.py b/tests/special_sanity/check_docstrings.py new file mode 100644 index 000000000..7c5d8ed71 --- /dev/null +++ b/tests/special_sanity/check_docstrings.py @@ -0,0 +1,156 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Python script to check docstrings for functions and classes in specified files. +Checks that every public function and class has proper docstring documentation. +""" + +import ast +import os +import sys + + +class DocstringChecker(ast.NodeVisitor): + """AST visitor to check for missing docstrings in functions and classes.""" + + def __init__(self, filename: str): + self.filename = filename + self.missing_docstrings: list[tuple[str, str, int]] = [] + self.current_class = None + self.function_nesting_level = 0 + + def visit_FunctionDef(self, node: ast.FunctionDef): + """Visit function definitions and check for docstrings.""" + if not node.name.startswith("_") and self.function_nesting_level == 0: + if not self._has_docstring(node): + func_name = f"{self.current_class}.{node.name}" if self.current_class else node.name + self.missing_docstrings.append((func_name, self.filename, node.lineno)) + + self.function_nesting_level += 1 + self.generic_visit(node) + self.function_nesting_level -= 1 + + def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef): + """Visit async function definitions and check for docstrings.""" + if not node.name.startswith("_") and self.function_nesting_level == 0: + if not self._has_docstring(node): + func_name = f"{self.current_class}.{node.name}" if self.current_class else node.name + self.missing_docstrings.append((func_name, self.filename, node.lineno)) + + self.function_nesting_level += 1 + self.generic_visit(node) + self.function_nesting_level -= 1 + + def visit_ClassDef(self, node: ast.ClassDef): + """Visit class definitions and check for docstrings.""" + if not node.name.startswith("_"): + if not self._has_docstring(node): + self.missing_docstrings.append((node.name, self.filename, node.lineno)) + + old_class = self.current_class + self.current_class = node.name + self.generic_visit(node) + self.current_class = old_class + + def _has_docstring(self, node) -> bool: + """Check if a node has a docstring.""" + return ast.get_docstring(node) is not None + + +def check_file_docstrings(filepath: str) -> list[tuple[str, str, int]]: + """Check docstrings in a single file.""" + try: + with open(filepath, encoding="utf-8") as f: + content = f.read() + + tree = ast.parse(content, filename=filepath) + checker = DocstringChecker(filepath) + checker.visit(tree) + return checker.missing_docstrings + + except Exception as e: + print(f"Error processing {filepath}: {e}") + return [] + + +def main(): + """Main function to check docstrings in specified files.""" + + files_to_check = [ + "verl/trainer/ppo/ray_trainer.py", + "verl/trainer/main_ppo.py", + "verl/trainer/ppo/reward.py", + "verl/utils/reward_score/__init__.py", + "verl/trainer/ppo/core_algos.py", + "verl/experimental/agent_loop/agent_loop.py", + "verl/workers/sharding_manager/fsdp_vllm.py", + "verl/workers/sharding_manager/fsdp_ulysses.py", + ] + + script_dir = os.path.dirname(os.path.abspath(__file__)) + repo_path = os.path.dirname(os.path.dirname(script_dir)) + + if not os.path.exists(repo_path): + print(f"Repository path {repo_path} does not exist!") + sys.exit(1) + + os.chdir(repo_path) + + all_missing_docstrings = [] + + print("Checking docstrings in specified files...") + print("=" * 60) + + for file_path in files_to_check: + if not os.path.exists(file_path): + print(f"Warning: File {file_path} does not exist!") + continue + + print(f"Checking {file_path}...") + missing = check_file_docstrings(file_path) + all_missing_docstrings.extend(missing) + + if missing: + print(f" Found {len(missing)} missing docstrings") + else: + print(" All functions and classes have docstrings ✓") + + print("=" * 60) + + if all_missing_docstrings: + print(f"\nSUMMARY: Found {len(all_missing_docstrings)} functions/classes missing docstrings:") + print("-" * 60) + + by_file = {} + for name, filepath, lineno in all_missing_docstrings: + if filepath not in by_file: + by_file[filepath] = [] + by_file[filepath].append((name, lineno)) + + for filepath in sorted(by_file.keys()): + print(f"\n{filepath}:") + for name, lineno in sorted(by_file[filepath], key=lambda x: x[1]): + print(f" - {name} (line {lineno})") + + print(f"\nTotal missing docstrings: {len(all_missing_docstrings)}") + + raise Exception(f"Found {len(all_missing_docstrings)} functions/classes without proper docstrings!") + + else: + print("\n✅ All functions and classes have proper docstrings!") + + +if __name__ == "__main__": + main() diff --git a/tests/sanity/check_license.py b/tests/special_sanity/check_license.py similarity index 95% rename from tests/sanity/check_license.py rename to tests/special_sanity/check_license.py index 9cf2da3ce..a02afeb3d 100644 --- a/tests/sanity/check_license.py +++ b/tests/special_sanity/check_license.py @@ -21,6 +21,7 @@ license_head_individual = "Copyright 2025 Individual Contributor:" license_head_sglang = "Copyright 2023-2024 SGLang Team" license_head_modelbest = "Copyright 2025 ModelBest Inc. and/or its affiliates" +license_head_amazon = "Copyright 2025 Amazon.com Inc and/or its affiliates" license_headers = [ license_head_bytedance, license_head_bytedance_25, @@ -28,6 +29,7 @@ license_head_individual, license_head_sglang, license_head_modelbest, + license_head_amazon, ] diff --git a/tests/special_sanity/check_pr_description.py b/tests/special_sanity/check_pr_description.py new file mode 100644 index 000000000..4ed4563db --- /dev/null +++ b/tests/special_sanity/check_pr_description.py @@ -0,0 +1,97 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 +import json +import os + +# Number of lines to check +NUM_LINES = 5 + + +# Custom exception types for clear error handling +class TemplateFileError(Exception): + pass + + +class PRBodyLoadError(Exception): + pass + + +class PRDescriptionError(Exception): + pass + + +# Path to the PR template file +template_file = os.path.join(os.getenv("GITHUB_WORKSPACE", "."), ".github", "PULL_REQUEST_TEMPLATE.md") + + +def load_template(path): + """ + Load only the first NUM_LINES of the PR template file as a list of lines, + without stripping any characters. + """ + lines = [] + try: + with open(path, encoding="utf-8") as f: + for _ in range(NUM_LINES): + line = f.readline() + if not line: + break + lines.append(line.strip()) + return lines + except Exception as e: + raise TemplateFileError(f"Failed to read PR template (first {NUM_LINES} lines) at {path}: {e}") from e + + +def load_pr_body(event_path): + try: + with open(event_path, encoding="utf-8") as f: + payload = json.load(f) + return payload.get("pull_request", {}).get("body", "") or "" + except Exception as e: + raise PRBodyLoadError(f"Failed to read PR body from {event_path}: {e}") from e + + +def check_pr_description(body, template_lines): + """ + Compare the first NUM_LINES lines of the PR body to the template lines. + If they match exactly, the placeholder was not modified. + """ + pr_lines = body.splitlines(keepends=True) + pr_first = [x.strip() for x in pr_lines[:NUM_LINES]] + if pr_first == template_lines: + raise PRDescriptionError( + "It looks like you haven't updated the '### What does this PR do?' section. Please replace " + "the placeholder text with a concise description of what your PR does." + ) + else: + print(pr_first) + print(template_lines) + + +def main(): + event_path = os.getenv("GITHUB_EVENT_PATH") + if not event_path: + raise OSError("GITHUB_EVENT_PATH is not set.") + + template_lines = load_template(template_file) + pr_body = load_pr_body(event_path) + check_pr_description(pr_body, template_lines) + + print("✅ '### What does this PR do?' section has been filled out.") + + +if __name__ == "__main__": + main() diff --git a/tests/special_sanity/check_pr_title.py b/tests/special_sanity/check_pr_title.py new file mode 100644 index 000000000..f4cbd5238 --- /dev/null +++ b/tests/special_sanity/check_pr_title.py @@ -0,0 +1,67 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re + +# Get PR title from environment +pr_title = os.environ.get("PR_TITLE", "").strip() + +# Define rules +allowed_modules = ["fsdp", "megatron", "sglang", "vllm", "rollout", "trainer"] +allowed_modules += ["tests", "training_utils", "recipe", "hardware", "deployment"] +allowed_modules += ["ray", "worker", "single_controller", "misc", "docker", "ci"] +allowed_modules += ["perf", "model", "algo", "env", "tool", "ckpt", "doc", "data", "cfg"] +allowed_types = ["feat", "fix", "refactor", "chore", "test"] + +# Check for [BREAKING] prefix and extract the rest of the title +breaking_match = re.match(r"^\[BREAKING\]\s*(.+)$", pr_title, re.IGNORECASE) +if breaking_match: + core_pr_title = breaking_match.group(1).strip() + is_breaking = True +else: + core_pr_title = pr_title + is_breaking = False + +# Build dynamic regex pattern for modules (now working on core_pr_title) +re_modules_pattern = re.compile(r"^\[([a-z_,\s]+)\]", re.IGNORECASE) +re_modules = re_modules_pattern.match(core_pr_title) +if not re_modules: + print(f"❌ Invalid PR title: '{pr_title}'") + print("Expected format: [BREAKING][module] type: description") + print(f"Allowed modules: {', '.join(allowed_modules)}") + raise Exception("Invalid PR title") +else: + modules = re.findall(r"[a-z_]+", re_modules.group(1).lower()) + if not all(module in allowed_modules for module in modules): + invalid_modules = [module for module in modules if module not in allowed_modules] + print(f"❌ Invalid modules: {', '.join(invalid_modules)}") + print(f"Allowed modules: {', '.join(allowed_modules)}") + raise Exception("Invalid PR title") + +types_pattern = "|".join(re.escape(t) for t in allowed_types) +re_types_pattern = re.compile(rf"^\[[a-z_,\s]+\]\s+({types_pattern}):\s+.+$", re.IGNORECASE) +match = re_types_pattern.match(core_pr_title) + +if not match: + print(f"❌ Invalid PR title: '{pr_title}'") + print("Expected format: [BREAKING][module] type: description") + print(f"Allowed types: {', '.join(allowed_types)}") + raise Exception("Invalid PR title") + +change_type = match.group(1).lower() + +# Build the success message +breaking_info = " (BREAKING CHANGE)" if is_breaking else "" +print(f"✅ PR title is valid: {pr_title}, modules: {modules}, type: {change_type}{breaking_info}") diff --git a/tests/special_sanity/test_config_docs.py b/tests/special_sanity/test_config_docs.py new file mode 100644 index 000000000..2f260f10b --- /dev/null +++ b/tests/special_sanity/test_config_docs.py @@ -0,0 +1,86 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from pathlib import Path + + +def validate_yaml_format(yaml_lines): + errors = [] + i = 0 + + while i < len(yaml_lines): + line = yaml_lines[i] + stripped = line.strip() + + # Skip empty lines + if stripped == "": + i += 1 + continue + + # Match YAML keys like "field:" or "field: value" + key_match = re.match(r"^(\s*)([a-zA-Z0-9_]+):", line) + if key_match: + # Check if there's a comment above + if i == 0 or not yaml_lines[i - 1].strip().startswith("#"): + errors.append(f"Missing comment above line {i + 1}: {line.strip()}") + + # Check for inline comment + if "#" in line and not stripped.startswith("#"): + comment_index = line.index("#") + colon_index = line.index(":") + if comment_index > colon_index: + errors.append(f"Inline comment found on line {i + 1}: {line.strip()}") + + # Check for blank line after this key line (unless next is a deeper indent) + if i + 1 < len(yaml_lines): + next_line = yaml_lines[i + 1] + next_stripped = next_line.strip() + + # If next is not empty and not a deeper nested line, enforce blank line + if next_stripped != "": + errors.append(f"Missing blank line after line {i + 1}: {line.strip()}") + + i += 1 + + return errors + + +def test_trainer_config_doc(): + yamls_to_inspect = [ + "verl/trainer/config/ppo_trainer.yaml", + "verl/trainer/config/actor/actor.yaml", + "verl/trainer/config/actor/dp_actor.yaml", + "verl/trainer/config/ref/ref.yaml", + "verl/trainer/config/ref/dp_ref.yaml", + "verl/trainer/config/rollout/rollout.yaml", + ] + success = True + for yaml_to_inspect in yamls_to_inspect: + yaml_path = Path(yaml_to_inspect) # path to your YAML file + with open(yaml_path) as f: + lines = f.readlines() + + validation_errors = validate_yaml_format(lines) + if validation_errors: + success = False + print("YAML documentation format check failed:") + print(f"Please read the top block of {yaml_to_inspect} to see format rules:\n") + for err in validation_errors: + print(" -", err) + + if not success: + raise Exception("Please fix documentation format.") + else: + print("YAML format check passed ✅") diff --git a/tests/sanity/test_import.py b/tests/special_sanity/test_import.py similarity index 100% rename from tests/sanity/test_import.py rename to tests/special_sanity/test_import.py diff --git a/tests/special_sanity/type_coverage_check.py b/tests/special_sanity/type_coverage_check.py new file mode 100644 index 000000000..dc6dc7caf --- /dev/null +++ b/tests/special_sanity/type_coverage_check.py @@ -0,0 +1,180 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Custom type annotation check tool. +To inspect the type annotation for functions in the entire codebase, please run: +find verl -type f -name "*.py" | xargs -n 1 python3 tests/special_sanity/type_coverage_check.py --all-lines +--debug --target-file +""" + +import argparse +import ast +import linecache +import subprocess +from pathlib import Path + + +def get_changed_files() -> list[Path]: + result = subprocess.run( + ["git", "diff", "--name-only", "--diff-filter=AM", "origin/main...HEAD"], stdout=subprocess.PIPE, text=True + ) + return [Path(f) for f in result.stdout.splitlines() if f.endswith(".py")] + + +def get_changed_lines(file_path: Path) -> set[int]: + result = subprocess.run( + ["git", "diff", "-U0", "origin/main...HEAD", "--", str(file_path)], + stdout=subprocess.PIPE, + text=True, + ) + lines: set[int] = set() + for line in result.stdout.splitlines(): + if line.startswith("@@"): + for part in line.split(): + try: + if part.startswith("+") and "," in part: + start, count = map(int, part[1:].split(",")) + lines.update(range(start, start + count)) + elif part.startswith("+") and "," not in part: + lines.add(int(part[1:])) + except Exception: + # (vermouth1992) There are many edge cases here because + can be in the changed program + pass + return lines + + +CHECK_SUCCESS = 0 +CHECK_WARNING = 1 +CHECK_FAILURE = -1 + + +def should_check_type(arg_name: str) -> bool: + if arg_name in ("self", "cls"): + return False + if arg_name.startswith("*"): + return False + return True + + +def has_type_annotations(node: ast.AST, debug: bool = False) -> int: + if isinstance(node, ast.FunctionDef): + is_private = node.name.startswith("_") + has_ann = ( + all(arg.annotation is not None for arg in node.args.args if should_check_type(arg.arg)) + and node.returns is not None + ) + if has_ann or is_private: + return CHECK_SUCCESS + else: + if debug: + print(node, [(arg.annotation, arg.arg) for arg in node.args.args if should_check_type(arg.arg)]) + return CHECK_FAILURE + return CHECK_SUCCESS + + +def check_file( + file_path: Path, changed_lines: set[int], debug: bool = False +) -> tuple[int, int, list[tuple[Path, int, str]], list[tuple[Path, int, str]]]: + with open(file_path) as f: + source: str = f.read() + tree = ast.parse(source, filename=str(file_path)) + annotated = 0 + total = 0 + warning_lines: list[tuple[Path, int, str]] = [] + failure_lines: list[tuple[Path, int, str]] = [] + + for node in ast.walk(tree): + if hasattr(node, "lineno") and node.lineno in changed_lines: + if isinstance(node, ast.FunctionDef | ast.Assign | ast.AnnAssign): + total += 1 + result = has_type_annotations(node, debug) + if result == CHECK_SUCCESS or result == CHECK_WARNING: + annotated += 1 + if result == CHECK_WARNING: + warning_lines.append( + (file_path, node.lineno, linecache.getline(str(file_path), node.lineno).strip()) + ) + else: + source_line = linecache.getline(str(file_path), node.lineno).strip() + failure_lines.append((file_path, node.lineno, source_line)) + + return annotated, total, warning_lines, failure_lines + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument( + "--threshold", type=float, default=0.3, help="Minimum ratio of annotated lines required (0.0 - 1.0)" + ) + parser.add_argument("--target-file", type=str, default=None, help="Path to the Python source file to analyse") + parser.add_argument( + "--all-lines", + action="store_true", + help="Check all lines in the file instead of only changed lines based on git", + ) + parser.add_argument("--debug", action="store_true", help="Add debugging logs") + args = parser.parse_args() + + total_changed = 0 + total_annotated = 0 + all_warnings: list[tuple[Path, int, str]] = [] + all_failures: list[tuple[Path, int, str]] = [] + + target_files = [args.target_file] if args.target_file is not None else get_changed_files() + for fpath in target_files: + if "tests/" in str(fpath): + continue + if args.all_lines: + changed_lines = [i + 1 for i in range(len(open(fpath).readlines()))] + else: + changed_lines = get_changed_lines(fpath) + annotated, total, warning_lines, failure_lines = check_file(fpath, changed_lines, args.debug) + total_annotated += annotated + total_changed += total + all_warnings.extend(warning_lines) + all_failures.extend(failure_lines) + + ratio = (total_annotated / total_changed) if total_changed else 1.0 + + print( + f"🔍 Type coverage on {'all' if args.all_lines else 'changed'} lines: " + f"{total_annotated}/{total_changed} = {ratio:.2%}. Files inspected: {target_files}" + ) + + if all_warnings: + print("\n⚠️ Suggest Improve: Lines missing type annotations for inputs and outputs:\n") + for fname, lineno, line in all_warnings: + print(f"{fname}:{lineno}: {line}") + + if all_failures: + print("⚠️ [ERROR] Lines missing type annotations for inputs and outputs:\n") + for fname, lineno, line in all_failures: + print(f"{fname}:{lineno}: {line}") + + if ratio < args.threshold: + print( + f"Please add type annotations for inputs and outputs to meet threshold {args.threshold}. " + f"Cases exempt from checking:" + ) + print("1. Private methods.") + print("2. Args with name in ('self', 'cls'), or *args / **kwargs") + print("3. Files under tests/") + raise Exception(f"\n❌ Type coverage below threshold ({args.threshold:.0%}).") + else: + if all_warnings or all_failures: + print("") + print("✅ Type annotation coverage acceptable.\n") + + +if __name__ == "__main__": + main() diff --git a/tests/special_sanity/validate_imported_docs.py b/tests/special_sanity/validate_imported_docs.py new file mode 100644 index 000000000..b36a407be --- /dev/null +++ b/tests/special_sanity/validate_imported_docs.py @@ -0,0 +1,130 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +verify_imported_docs.py + +Assert that every function or class *explicitly imported* (via +`from import `) in a given Python file has a docstring. +""" + +from __future__ import annotations + +import argparse +import ast +import importlib +import inspect +import pathlib +import sys + + +def _parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser(description="Verify that imported functions/classes have docstrings.") + p.add_argument( + "--target-file", + default="verl/trainer/ppo/ray_trainer.py", + help="Path to the Python source file to analyse (e.g. verl/trainer/ppo/ray_trainer.py)", + ) + p.add_argument( + "--allow-list", + default=["omegaconf.open_dict"], + help="a list of third_party dependencies that do not have proper docs :(", + ) + p.add_argument( + "--project-root", + default=".", + help="Directory to prepend to PYTHONPATH so local packages resolve (default: .)", + ) + p.add_argument( + "--quiet", + action="store_true", + help="Suppress success message (still prints errors).", + ) + return p.parse_args() + + +def _import_attr(module_name: str, attr_name: str): + """Import `module_name` then return `getattr(module, attr_name)`.""" + module = importlib.import_module(module_name) + return getattr(module, attr_name) + + +def _check_file(py_file: pathlib.Path, project_root: pathlib.Path, allow_list: list[str]) -> list[str]: + """Return a list of error strings (empty == success).""" + # Ensure local packages resolve + sys.path.insert(0, str(project_root.resolve())) + + tree = ast.parse(py_file.read_text(), filename=str(py_file)) + problems: list[str] = [] + + for node in ast.walk(tree): + if not isinstance(node, ast.ImportFrom): + continue + + # Relative imports (level > 0) get the leading dots stripped + module_name = "." * node.level + (node.module or "") + for alias in node.names: + if alias.name == "*": + problems.append( + f"{py_file}:{node.lineno} - wildcard import `from {module_name} import *` cannot be verified." + ) + continue + + imported_name = alias.name + + try: + obj = _import_attr(module_name, imported_name) + except Exception: # pragma: no cover – wide net for import quirks + pass + # For some reason the module cannot be imported, skip for now + # problems.append( + # f"{py_file}:{node.lineno} - could not resolve " + # f"`{imported_name}` from `{module_name}` ({exc})" + # ) + continue + + if f"{module_name}.{imported_name}" in allow_list: + continue + if inspect.isfunction(obj) or inspect.isclass(obj): + doc = inspect.getdoc(obj) + if not (doc and doc.strip()): + kind = "class" if inspect.isclass(obj) else "function" + problems.append( + f"{py_file}:{node.lineno} - {kind} `{module_name}.{imported_name}` is missing a docstring." + ) + + return problems + + +def main() -> None: + args = _parse_args() + target_path = pathlib.Path(args.target_file).resolve() + project_root = pathlib.Path(args.project_root).resolve() + + if not target_path.is_file(): + raise Exception(f"❌ Target file not found: {target_path}") + + errors = _check_file(target_path, project_root, args.allow_list) + + if errors: + print("Docstring verification failed:\n") + print("\n".join(f" • {e}" for e in errors)) + raise Exception("❌ Docstring verification failed.") + + if not args.quiet: + print(f"✅ All explicitly imported functions/classes in {target_path} have docstrings.") + + +if __name__ == "__main__": + main() diff --git a/tests/special_sanity/validate_structure.py b/tests/special_sanity/validate_structure.py new file mode 100644 index 000000000..a5390b15a --- /dev/null +++ b/tests/special_sanity/validate_structure.py @@ -0,0 +1,118 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 +""" +Validate that test file subfolders mirror the top-level package layout. + +Usage examples +-------------- + +# Typical run (defaults: impl_root=my_project, tests_root=tests) +python check_tests_structure.py + +# Custom layout and extra allowed folders +python check_tests_structure.py \ + --impl-root verl \ + --tests-root tests \ + --allow-dirs special_e2e special_sanity special_standalone special_distributed +""" + +from __future__ import annotations + +import argparse +import sys +from pathlib import Path + + +def discover_allowed_modules(impl_root: Path, extra: list[str]) -> set[str]: + """Return the set of first-level directories that tests may live under.""" + allowed = {p.name for p in impl_root.iterdir() if p.is_dir()} + allowed.update(extra) + return allowed + + +def find_violations(tests_root: Path, allowed: set[str], allowed_files: list[str]) -> list[str]: + """Return a list of error strings for test files in the wrong place.""" + errors: list[str] = [] + for test_file in tests_root.rglob("test*.py"): + if str(test_file) in allowed_files: + continue + rel_parts = test_file.relative_to(tests_root).parts + if len(rel_parts) < 2: + errors.append(f"{test_file}: must be inside one of {sorted(allowed)} (not at tests root)") + continue + + first_folder = rel_parts[0] + if first_folder not in allowed: + errors.append( + f"{test_file}: subfolder '{first_folder}' under tests/ is not an allowed module. " + f"The valid ones are: {sorted(allowed)}" + ) + return errors + + +def main() -> None: + parser = argparse.ArgumentParser(description="Check that test files follow tests//… layout.") + parser.add_argument( + "--impl-root", + type=Path, + default="verl", + help="Implementation root (default: my_project)", + ) + parser.add_argument( + "--tests-root", + type=Path, + default="tests", + help="Root of test tree (default: tests)", + ) + parser.add_argument( + "--allow-dirs", + nargs="*", + default=["special_e2e", "special_sanity", "special_standalone", "special_distributed"], + help="Extra top-level test folders that are exempt from the rule", + ) + parser.add_argument( + "--allow-files", + nargs="*", + default=["tests/test_protocol_on_cpu.py", "tests/test_base_config_on_cpu.py"], + help="Extra top-level test folders that are exempt from the rule", + ) + args = parser.parse_args() + + if not args.impl_root.is_dir(): + raise Exception(f"Implementation root '{args.impl_root}' does not exist.") + if not args.tests_root.is_dir(): + raise Exception(f"Tests root '{args.tests_root}' does not exist.") + + allowed = discover_allowed_modules(args.impl_root, args.allow_dirs) + violations = find_violations(args.tests_root, allowed, args.allow_files) + + if violations: + print("❌ Test layout violations found:\n", file=sys.stderr) + for err in violations: + print(" -", err, file=sys.stderr) + + print( + f"\nGuideline:\n Place each test file under tests//…\n where is " + f"one of the top-level packages inside '{args.impl_root}', or is explicitly listed via --allow-dirs.\n", + file=sys.stderr, + ) + raise Exception("❌ Test layout violations found.") + + print("✅ Tests folder structure looks good.") + + +if __name__ == "__main__": + main() diff --git a/tests/special_standalone/README.md b/tests/special_standalone/README.md new file mode 100644 index 000000000..0e3596e1a --- /dev/null +++ b/tests/special_standalone/README.md @@ -0,0 +1 @@ +The standalone test folder is reserved for tests that require dedicated environment (e.g. memory stress tests) diff --git a/tests/gpu_utility/test_memory_buffers.py b/tests/special_standalone/test_memory_buffers.py similarity index 97% rename from tests/gpu_utility/test_memory_buffers.py rename to tests/special_standalone/test_memory_buffers.py index 6f34314be..778515347 100644 --- a/tests/gpu_utility/test_memory_buffers.py +++ b/tests/special_standalone/test_memory_buffers.py @@ -57,7 +57,7 @@ def test_memory_buffers(): change_ratio = (a - a_before) / a_before assert change_ratio < 0.01, f"make sure the allocated change is less than 1%, Got {change_ratio}" - for (name1, param1), (name2, param2) in zip(model.named_parameters(), model_copy.named_parameters()): + for (name1, param1), (name2, param2) in zip(model.named_parameters(), model_copy.named_parameters(), strict=True): assert name1 == name2 assert torch.eq(param1.data, param2.data).all(), f"{param1.data}, {param2.data}, {name1}" diff --git a/tests/test_base_config_on_cpu.py b/tests/test_base_config_on_cpu.py new file mode 100644 index 000000000..9a50235c8 --- /dev/null +++ b/tests/test_base_config_on_cpu.py @@ -0,0 +1,42 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from verl.base_config import BaseConfig + + +@pytest.fixture +def base_config_mock(): + """Fixture to create a mock BaseConfig instance with test attributes.""" + mock_config = BaseConfig() + mock_config.test_attr = "test_value" + return mock_config + + +def test_getitem_success(base_config_mock): + """Test __getitem__ with existing attribute (happy path).""" + assert base_config_mock["test_attr"] == "test_value" + + +def test_getitem_nonexistent_attribute(base_config_mock): + """Test __getitem__ with non-existent attribute (exception path 1).""" + with pytest.raises(AttributeError): + _ = base_config_mock["nonexistent_attr"] + + +def test_getitem_invalid_key_type(base_config_mock): + """Test __getitem__ with invalid key type (exception path 2).""" + with pytest.raises(TypeError): + _ = base_config_mock[123] # type: ignore diff --git a/tests/test_protocol.py b/tests/test_protocol_on_cpu.py similarity index 94% rename from tests/test_protocol.py rename to tests/test_protocol_on_cpu.py index 4cbd51030..2052635c1 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol_on_cpu.py @@ -29,7 +29,9 @@ def test_union_tensor_dict(): data1 = TensorDict({"obs": obs, "act": torch.randn(100, 3)}, batch_size=[100]) data2 = TensorDict({"obs": obs, "next_obs": torch.randn(100, 10), "rew": torch.randn(100)}, batch_size=[100]) - data_with_copied_obs = TensorDict({"obs": obs.clone(), "next_obs": torch.randn(100, 10), "rew": torch.randn(100)}, batch_size=[100]) + data_with_copied_obs = TensorDict( + {"obs": obs.clone(), "next_obs": torch.randn(100, 10), "rew": torch.randn(100)}, batch_size=[100] + ) data = union_tensor_dict(data1, data2) with pytest.raises(AssertionError): @@ -77,7 +79,7 @@ def test_tensor_dict_make_iterator(): for data in data_iter_2: data_list_2.append(data) - for data1, data2 in zip(data_list_1, data_list_2): + for data1, data2 in zip(data_list_1, data_list_2, strict=True): assert isinstance(data1, DataProto) assert isinstance(data2, DataProto) result = torch.all(torch.eq(data1.batch["obs"], data2.batch["obs"])) @@ -416,7 +418,9 @@ def test_dataproto_unfold_column_chunks(): obs2 = torch.tensor([[1, 2], [5, 6], [9, 10]]) labels = ["a", "b", "c"] - data = DataProto.from_dict(tensors={"obs1": obs1, "obs2": obs2}, non_tensors={"labels": labels}, meta_info={"name": "abc"}) + data = DataProto.from_dict( + tensors={"obs1": obs1, "obs2": obs2}, non_tensors={"labels": labels}, meta_info={"name": "abc"} + ) ret = data.unfold_column_chunks(2, split_keys=["obs1"]) expect_obs1 = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]]) @@ -431,7 +435,9 @@ def test_dataproto_unfold_column_chunks(): obs2 = torch.tensor([[1, 2], [5, 6], [9, 10]]) labels = [["a1", "a2"], ["b1", "b2"], ["c1", "c2"]] - data = DataProto.from_dict(tensors={"obs1": obs1, "obs2": obs2}, non_tensors={"labels": labels}, meta_info={"name": "abc"}) + data = DataProto.from_dict( + tensors={"obs1": obs1, "obs2": obs2}, non_tensors={"labels": labels}, meta_info={"name": "abc"} + ) ret = data.unfold_column_chunks(2, split_keys=["obs1", "labels"]) expect_obs1 = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]]) @@ -442,21 +448,37 @@ def test_dataproto_unfold_column_chunks(): assert (ret.non_tensor_batch["labels"] == expect_labels).all() assert ret.meta_info == {"name": "abc"} - obs1 = torch.tensor([[[1, 1], [2, 2], [3, 3], [4, 4]], [[5, 5], [6, 6], [7, 7], [8, 8]], [[9, 9], [10, 10], [11, 11], [12, 12]]]) + obs1 = torch.tensor( + [[[1, 1], [2, 2], [3, 3], [4, 4]], [[5, 5], [6, 6], [7, 7], [8, 8]], [[9, 9], [10, 10], [11, 11], [12, 12]]] + ) obs2 = torch.tensor([[[1, 1], [2, 2]], [[5, 5], [6, 6]], [[9, 9], [10, 10]]]) labels = ["a", "b", "c"] - data = DataProto.from_dict(tensors={"obs1": obs1, "obs2": obs2}, non_tensors={"labels": labels}, meta_info={"name": "abc"}) + data = DataProto.from_dict( + tensors={"obs1": obs1, "obs2": obs2}, non_tensors={"labels": labels}, meta_info={"name": "abc"} + ) ret = data.unfold_column_chunks(2, split_keys=["obs1"]) - expect_obs1 = torch.tensor([[[1, 1], [2, 2]], [[3, 3], [4, 4]], [[5, 5], [6, 6]], [[7, 7], [8, 8]], [[9, 9], [10, 10]], [[11, 11], [12, 12]]]) - expect_obs2 = torch.tensor([[[1, 1], [2, 2]], [[1, 1], [2, 2]], [[5, 5], [6, 6]], [[5, 5], [6, 6]], [[9, 9], [10, 10]], [[9, 9], [10, 10]]]) + expect_obs1 = torch.tensor( + [ + [[1, 1], [2, 2]], + [[3, 3], [4, 4]], + [[5, 5], [6, 6]], + [[7, 7], [8, 8]], + [[9, 9], [10, 10]], + [[11, 11], [12, 12]], + ] + ) + expect_obs2 = torch.tensor( + [[[1, 1], [2, 2]], [[1, 1], [2, 2]], [[5, 5], [6, 6]], [[5, 5], [6, 6]], [[9, 9], [10, 10]], [[9, 9], [10, 10]]] + ) expect_labels = ["a", "a", "b", "b", "c", "c"] assert torch.all(torch.eq(ret.batch["obs1"], expect_obs1)) assert torch.all(torch.eq(ret.batch["obs2"], expect_obs2)) assert (ret.non_tensor_batch["labels"] == expect_labels).all() assert ret.meta_info == {"name": "abc"} + def test_dataproto_chunk_after_index(): data_len = 4 obs = torch.randn(data_len, 4) @@ -468,33 +490,33 @@ def test_dataproto_chunk_after_index(): selected = data[bool_mask] assert isinstance(selected.batch.batch_size, torch.Size) assert all(isinstance(d, int) for d in selected.batch.batch_size) # int or List[int] - + # Test with integer numpy array int_mask = np.array([0, 2]) selected = data[int_mask] assert isinstance(selected.batch.batch_size, torch.Size) assert all(isinstance(d, int) for d in selected.batch.batch_size) - + # Test with boolean list list_mask = [True, False, True, False] selected = data[list_mask] assert isinstance(selected.batch.batch_size, torch.Size) assert all(isinstance(d, int) for d in selected.batch.batch_size) - + # Test with list list_mask = [0, 2] selected = data[list_mask] assert isinstance(selected.batch.batch_size, torch.Size) assert all(isinstance(d, int) for d in selected.batch.batch_size) - + # Test with torch tensor (bool) torch_bool_mask = torch.tensor([True, False, True, False]) selected = data[torch_bool_mask] assert isinstance(selected.batch.batch_size, torch.Size) assert all(isinstance(d, int) for d in selected.batch.batch_size) - + # Test with torch tensor (int) torch_int_mask = torch.tensor([0, 2]) selected = data[torch_int_mask] assert isinstance(selected.batch.batch_size, torch.Size) - assert all(isinstance(d, int) for d in selected.batch.batch_size) \ No newline at end of file + assert all(isinstance(d, int) for d in selected.batch.batch_size) diff --git a/tests/tools/test_base_tool_on_cpu.py b/tests/tools/test_base_tool_on_cpu.py new file mode 100644 index 000000000..63a2bbb37 --- /dev/null +++ b/tests/tools/test_base_tool_on_cpu.py @@ -0,0 +1,160 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Unit Tests for `initialize_tools_from_config` +import json +import os +from typing import Any + +import pytest +from transformers.utils import get_json_schema + +from verl.tools.base_tool import BaseTool, OpenAIFunctionToolSchema +from verl.tools.utils.tool_registry import initialize_tools_from_config + + +class WeatherToolForTest(BaseTool): + def get_current_temperature(self, location: str, unit: str = "celsius"): + """Get current temperature at a location. + + Args: + location: The location to get the temperature for, in the format "City, State, Country". + unit: The unit to return the temperature in. Defaults to "celsius". (choices: ["celsius", "fahrenheit"]) + + Returns: + the temperature, the location, and the unit in a dict + """ + return { + "temperature": 26.1, + "location": location, + "unit": unit, + } + + def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema: + schema = get_json_schema(self.get_current_temperature) + return OpenAIFunctionToolSchema(**schema) + + async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]: + try: + result = self.get_current_temperature(**parameters) + return json.dumps(result), 0, {} + except Exception as e: + return str(e), 0, {} + + +class WeatherToolWithDataForTest(BaseTool): + def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema: + schema = get_json_schema(self.get_temperature_date) + return OpenAIFunctionToolSchema(**schema) + + def get_temperature_date(self, location: str, date: str, unit: str = "celsius"): + """Get temperature at a location and date. + + Args: + location: The location to get the temperature for, in the format "City, State, Country". + date: The date to get the temperature for, in the format "Year-Month-Day". + unit: The unit to return the temperature in. Defaults to "celsius". (choices: ["celsius", "fahrenheit"]) + + Returns: + the temperature, the location, the date and the unit in a dict + """ + return { + "temperature": 25.9, + "location": location, + "date": date, + "unit": unit, + } + + async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]: + try: + result = self.get_temperature_date(**parameters) + return json.dumps(result), 0, {} + except Exception as e: + return str(e), 0, {} + + +@pytest.fixture +def create_local_tool_config(): + tool_config = { + "tools": [ + { + "class_name": "tests.tools.test_base_tool_on_cpu.WeatherToolForTest", + "config": {"type": "native"}, + }, + { + "class_name": "tests.tools.test_base_tool_on_cpu.WeatherToolWithDataForTest", + "config": {"type": "native"}, + }, + ] + } + tool_config_path = "/tmp/tool_config.json" + with open(tool_config_path, "w") as f: + json.dump(tool_config, f) + yield tool_config_path + if os.path.exists(tool_config_path): + os.remove(tool_config_path) + + +@pytest.fixture +def create_fake_tool_config(): + tool_config = { + "tools": [ + { + "class_name": "tests.workers.rollout.fake_path.test_vllm_chat_scheduler.WeatherTool", + "config": {"type": "native"}, + }, + { + "class_name": "tests.workers.rollout.fake_path.test_vllm_chat_scheduler.WeatherToolWithData", + "config": {"type": "native"}, + }, + ] + } + tool_config_path = "/tmp/tool_config.json" + with open(tool_config_path, "w") as f: + json.dump(tool_config, f) + yield tool_config_path + if os.path.exists(tool_config_path): + os.remove(tool_config_path) + + +def test_initialize_tools_from_fake_config(create_fake_tool_config): + tool_config_path = create_fake_tool_config + + # Use pytest.raises to check if an exception is raised when calling initialize_tools_from_config. + # Since the tool configuration uses fake paths, an exception is expected during the tool initialization process. + with pytest.raises(ModuleNotFoundError): + _ = initialize_tools_from_config(tool_config_path) + + +def test_initialize_tools_from_local_config(create_local_tool_config): + """ + Test the `initialize_tools_from_config` function using a local tool configuration. + This test verifies that the function can correctly initialize tools based on a local configuration file. + + Args: + create_local_tool_config: A pytest fixture that creates a local tool configuration file + and returns its path. After the test is completed, the fixture + will clean up the configuration file. + """ + # Retrieve the path of the local tool configuration file generated by the fixture + tool_config_path = create_local_tool_config + + tools = initialize_tools_from_config(tool_config_path) + + assert len(tools) == 2 + from tests.tools.test_base_tool_on_cpu import WeatherToolForTest, WeatherToolWithDataForTest + + assert isinstance(tools[0], WeatherToolForTest) + assert isinstance(tools[1], WeatherToolWithDataForTest) + assert tools[0].config == {"type": "native"} + assert tools[1].config == {"type": "native"} diff --git a/tests/trainer/__init__.py b/tests/trainer/__init__.py index 4c6063623..6f79d474d 100644 --- a/tests/trainer/__init__.py +++ b/tests/trainer/__init__.py @@ -13,4 +13,4 @@ # limitations under the License. """ Tests for the trainer module. -""" \ No newline at end of file +""" diff --git a/verl/third_party/vllm/vllm_v_0_6_3/__init__.py b/tests/trainer/config/__init__.py similarity index 100% rename from verl/third_party/vllm/vllm_v_0_6_3/__init__.py rename to tests/trainer/config/__init__.py diff --git a/verl/recipe/dapo/src/config/dapo_megatron_trainer.yaml b/tests/trainer/config/legacy_ppo_megatron_trainer.yaml similarity index 50% rename from verl/recipe/dapo/src/config/dapo_megatron_trainer.yaml rename to tests/trainer/config/legacy_ppo_megatron_trainer.yaml index b9cc1c2ae..fc146c934 100644 --- a/verl/recipe/dapo/src/config/dapo_megatron_trainer.yaml +++ b/tests/trainer/config/legacy_ppo_megatron_trainer.yaml @@ -6,29 +6,37 @@ data: reward_fn_key: data_source max_prompt_length: 512 max_response_length: 512 - gen_batch_size: ${data.train_batch_size} train_batch_size: 1024 val_batch_size: null # DEPRECATED: Validation datasets are sent to inference engines as a whole batch, which will schedule the memory themselves return_raw_input_ids: False # This should be set to true when the tokenizer between policy and rm differs return_raw_chat: False return_full_prompt: False shuffle: True - filter_overlong_prompts: True # By Reasoning360. Originally False # for large-scale dataset, filtering overlong prompts could be timeconsuming. You cat set the filter_overlong_prompts_workers to use multiprocessing to speed up. + filter_overlong_prompts: False # for large-scale dataset, filtering overlong prompts could be timeconsuming. You cat set the filter_overlong_prompts_workers to use multiprocessing to speed up. filter_overlong_prompts_workers: 1 truncation: error - image_key: images + trust_remote_code: False # main_ppo will check this config to determine whether to use remote code for tokenizer custom_cls: path: null name: null + sampler: + class_path: null + class_name: null + dataloader_num_workers: 8 + return_multi_modal_inputs: True actor_rollout_ref: hybrid_engine: True + nccl_timeout: 600 # seconds, default is 10 minutes for torch, you can set it to a larger value if you have long-running operations like 32B or 72B model using megatron model: path: ~/models/deepseek-llm-7b-chat + custom_chat_template: null external_lib: null - override_config: { } - enable_gradient_checkpointing: True - use_remove_padding: False + override_config: + model_config: {} + moe_config: + freeze_moe_router: False + enable_gradient_checkpointing: False gradient_checkpointing_kwargs: ## Activation Checkpointing activations_checkpoint_method: null # 'uniform', 'block'; not used with 'selective' @@ -45,31 +53,44 @@ actor_rollout_ref: ppo_micro_batch_size_per_gpu: null use_dynamic_bsz: False ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length} - grad_clip: 1.0 + use_torch_compile: True # False to disable torch compile # pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, 1 + cliprange_high) clip_ratio: 0.2 # default value if clip_ratio_low and clip_ratio_high are not specified clip_ratio_low: 0.2 clip_ratio_high: 0.2 clip_ratio_c: 3.0 # lower bound of the value for Dual-clip PPO from https://arxiv.org/pdf/1912.09729 - data_loader_seed: null loss_agg_mode: "token-mean" # / "seq-mean-token-sum" / "seq-mean-token-mean" # NOTE: "token-mean" is the default behavior - entropy_coeff: 0.001 + entropy_coeff: 0 use_kl_loss: False # True for GRPO - use_torch_compile: True # False to disable torch compile kl_loss_coef: 0.001 # for grpo kl_loss_type: low_var_kl # for grpo ppo_epochs: 1 + data_loader_seed: null shuffle: False + policy_loss: # policy loss config + loss_mode: "vanilla" # Loss function mode: vanilla / clip-cov / kl-cov / gpg from https://arxiv.org/abs/2505.22617, + clip_cov_ratio: 0.0002 # Ratio of tokens to be clipped for clip-cov loss + clip_cov_lb: 1.0 # Lower bound for clip-cov loss + clip_cov_ub: 5.0 # Upper bound for clip-cov loss + kl_cov_ratio: 0.0002 # Ratio of tokens to be applied kl penalty for kl-cov loss + ppo_kl_coef: 0.1 # KL divergence penalty coefficient optim: + optimizer: adam lr: 1e-6 clip_grad: 1.0 - lr_warmup_steps: -1 # Prioritized. Negative values mean delegating to lr_warmup_steps_ratio. - lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime - min_lr_ratio: null # only useful for warmup with cosine - warmup_style: constant # select from constant/cosine total_training_steps: -1 # must be override by program + lr_warmup_init: 0.0 # initial learning rate for warmup, default to 0.0 + lr_warmup_steps: null # Prioritized. None, 0 or Negative values mean delegating to lr_warmup_steps_ratio. + lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime + lr_decay_steps: null + lr_decay_style: constant # select from constant/linear/cosine/inverse_square_root + min_lr: 0.0 # minimum learning rate, default to 0.0 weight_decay: 0.01 + weight_decay_incr_style: constant # select from constant/linear/cosine + lr_wsd_decay_style: exponential # select from constant/exponential/cosine + lr_wsd_decay_steps: null + use_checkpoint_opt_param_scheduler: False # use checkpoint optimizer parameter scheduler megatron: param_offload: False grad_offload: False @@ -86,17 +107,23 @@ actor_rollout_ref: dist_checkpointing_path: null seed: 42 override_transformer_config: {} # additional transformer config like: num_layers_in_first(/last)_pipeline_stage - profile: # profile the actor model in `update_policy` + use_mbridge: False + profile: # profile the actor model in `update_policy` use_profile: False # open it when you want to profile the actor model profile_ranks: null # list, you can specify the ranks to profile - step_start: -1 # start step in update_policy - step_end: -1 # end step + step_start: -1 # start step in update_policy + step_end: -1 # end step save_path: null # the path to save the profile result load_weight: True checkpoint: - contents: ['model', 'optimizer', 'extra'] # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space + async_save: False # save checkpoint asynchronously + # What to include in saved checkpoints + # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space + save_contents: ['model', 'optimizer', 'extra'] + # For more flexibility, you can specify the contents to load from the checkpoint. + load_contents: ${actor_rollout_ref.actor.checkpoint.save_contents} ref: - strategy: megatron + strategy: ${actor_rollout_ref.actor.strategy} use_torch_compile: ${actor_rollout_ref.actor.use_torch_compile} megatron: param_offload: False @@ -112,6 +139,7 @@ actor_rollout_ref: dist_checkpointing_path: null seed: ${actor_rollout_ref.actor.megatron.seed} override_transformer_config: ${actor_rollout_ref.actor.megatron.override_transformer_config} + use_mbridge: ${actor_rollout_ref.actor.megatron.use_mbridge} profile: use_profile: False profile_ranks: null @@ -129,8 +157,7 @@ actor_rollout_ref: temperature: 1.0 top_k: -1 # 0 for hf rollout, -1 for vllm rollout top_p: 1 - use_fire_sampling: False # https://arxiv.org/abs/2410.21236 - prompt_length: ${data.max_prompt_length} # not use for opensource + prompt_length: ${data.max_prompt_length} # for xperf_gpt response_length: ${data.max_response_length} # for vllm rollout dtype: bfloat16 # should align with FSDP @@ -139,7 +166,7 @@ actor_rollout_ref: enforce_eager: True free_cache_engine: True load_format: dummy_megatron - tensor_model_parallel_size: 2 + tensor_model_parallel_size: 1 max_num_batched_tokens: 8192 max_model_len: null max_num_seqs: 1024 @@ -148,14 +175,20 @@ actor_rollout_ref: log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} disable_log_stats: True - enable_chunked_prefill: True # may get higher throughput when set to True. When activated, Please increase max_num_batched_tokens or decrease max_model_len. + enable_chunked_prefill: False # could get higher throughput # for hf rollout do_sample: True + layer_name_map: + qkv_layer_name: qkv + gate_proj_layer_name: gate_up # number of responses (i.e. num sample times) - n: 1 # > 1 for grpo + n: 1 engine_kwargs: # inference engine parameters vllm: swap_space: null # null means "use the engine default value" (usually 4 GB), setting it to, e.g., 32 means 32 GB + disable_mm_preprocessor_cache: False # whether to disable the preprocessor cache for multimodel models. + sglang: + attention_backend: null # null means use the engine default value, available options: flashinfer, triton, flashmla val_kwargs: # sampling parameters for validation top_k: -1 # 0 for hf rollout, -1 for vllm rollout @@ -163,31 +196,104 @@ actor_rollout_ref: temperature: 0 n: 1 do_sample: False # default eager for validation - multi_turn: - enable: False # set to True for multi-turn tool interaction tasks; should set rollout.name to sglang as well - max_turns: null # null for no limit (default max_length // 3) - tool_config_path: null # null for no tool - format: chatml # chatml, more formats will be supported in the future -# NOTE: DAPO does not have a critic model. This is just a placeholder. + # Multi-turn interaction config for tools or chat. + multi_turn: + # set to True for multi-turn tool interaction tasks; should set rollout.name to sglang as well + enable: False + + # null for no limit (default max_length // 3) + max_assistant_turns: null + + # null for no tool + tool_config_path: null + + # null for no limit (default max_length // 3) + max_user_turns: null + + # max parallel call for tools in single turn + max_parallel_calls: 1 + + # max length of tool response + max_tool_response_length: 256 + + # truncate side of tool response: left, middle, right + tool_response_truncate_side: middle + + # null for no interaction + interaction_config_path: null + + # null for default callback + completion_callback: null + + # - When set to True, the model's default chat template is used for multi-turn rollout, which typically matches production behavior. + # - When set to False, the token ids recorded for training are used instead; unlike the default chat template, these always include the model's full output, + # which may contain additional content such as reasoning content. This maintains the consistency between training and rollout, but it will lead to longer prompts. + use_inference_chat_template: False + + # Tokenization is performed turn by turn and the resulting token ids are concatenated to form the full conversation. + # To ensure this matches the result of tokenizing the entire conversation at once, a sanity check is run at the end of each multi-turn rollout to compare the two sets of token ids. + # Some models are known to produce different tokenization results when tokenizing turn by turn vs. all at once. aThis behavior has already been validated for them. + # To reduce excessive warnings, you can turn off the sanity check for these models if you are using their default chat template: + # Qwen/QwQ-32B, Qwen/Qwen3-xxB + # - disable: disable tokenization sanity check + # - strict: enable strict tokenization sanity check (default) + # - ignore_strippable: ignore strippable tokens when checking tokenization sanity + tokenization_sanity_check_mode: strict + + # Format of the multi-turn interaction. Options: hermes, llama3_json, ... + format: hermes + + # [Experimental] agent loop based rollout configs + agent: + + # Number of agent loop workers + num_workers: 8 + + custom_async_server: + path: null + name: null + + # support logging rollout prob for debugging purpose + calculate_log_probs: False + # Nsight system profiler configs + profiler: + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint + _target_: verl.utils.profiler.ProfilerConfig + discrete: False + all_ranks: False + ranks: [] + critic: rollout_n: ${actor_rollout_ref.rollout.n} - strategy: megatron + strategy: ${actor_rollout_ref.actor.strategy} + nccl_timeout: 600 # seconds, default is 10 minutes for torch, you can set it to a larger value if you have long-running operations like 32B or 72B model using megatron optim: - lr: 1e-5 + optimizer: adam + lr: 1e-6 clip_grad: 1.0 - lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime - min_lr_ratio: null # only useful for warmup with cosine - warmup_style: constant # select from constant/cosine total_training_steps: -1 # must be override by program + lr_warmup_init: 0.0 # initial learning rate for warmup, default to 0.0 + lr_warmup_steps: null # Prioritized. None, 0 or Negative values mean delegating to lr_warmup_steps_ratio. + lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime + lr_decay_steps: null + lr_decay_style: linear # select from constant/linear/cosine/inverse_square_root + min_lr: 0.0 # minimum learning rate, default to 0.0 weight_decay: 0.01 + weight_decay_incr_style: constant # select from constant/linear/cosine + lr_wsd_decay_style: exponential # select from constant/exponential/cosine + lr_wsd_decay_steps: null + use_checkpoint_opt_param_scheduler: False # use checkpoint optimizer parameter scheduler model: path: ~/models/deepseek-llm-7b-chat tokenizer_path: ${actor_rollout_ref.model.path} - override_config: { } + override_config: + model_config: {} + moe_config: + freeze_moe_router: False external_lib: ${actor_rollout_ref.model.external_lib} - enable_gradient_checkpointing: True - use_remove_padding: False + trust_remote_code: False + enable_gradient_checkpointing: False gradient_checkpointing_kwargs: ## Activation Checkpointing activations_checkpoint_method: null @@ -209,29 +315,39 @@ critic: dist_checkpointing_path: null seed: ${actor_rollout_ref.actor.megatron.seed} override_transformer_config: ${actor_rollout_ref.actor.megatron.override_transformer_config} + use_mbridge: ${actor_rollout_ref.actor.megatron.use_mbridge} + load_weight: True ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size} ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu ppo_micro_batch_size_per_gpu: null - forward_micro_batch_size: ${critic.ppo_micro_batch_size} - forward_micro_batch_size_per_gpu: ${critic.ppo_micro_batch_size_per_gpu} use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} ppo_max_token_len_per_gpu: 32768 # (${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}) * 2 forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu} ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs} data_loader_seed: ${actor_rollout_ref.actor.data_loader_seed} shuffle: ${actor_rollout_ref.actor.shuffle} - grad_clip: 1.0 cliprange_value: 0.5 kl_ctrl: type: fixed kl_coef: 0.001 loss_agg_mode: ${actor_rollout_ref.actor.loss_agg_mode} checkpoint: - contents: ['model', 'optimizer', 'extra'] # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space - + async_save: False # save checkpoint asynchronously + # What to include in saved checkpoints + # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space + save_contents: ['model', 'optimizer', 'extra'] + load_contents: ${critic.checkpoint.save_contents} + # Nsight system profiler configs + profiler: + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint + _target_: verl.utils.profiler.ProfilerConfig + discrete: False + all_ranks: False + ranks: [] reward_model: enable: False - strategy: megatron + strategy: ${actor_rollout_ref.actor.strategy} + nccl_timeout: 600 # seconds, default is 10 minutes for torch, you can set it to a larger value if you have long-running operations like 32B or 72B model using megatron megatron: param_offload: False tensor_model_parallel_size: 1 @@ -246,6 +362,7 @@ reward_model: dist_checkpointing_path: null seed: ${actor_rollout_ref.actor.megatron.seed} override_transformer_config: {} + use_mbridge: ${actor_rollout_ref.actor.megatron.use_mbridge} model: input_tokenizer: ${actor_rollout_ref.model.path} # set this to null if the chat template is identical path: ~/models/FsfairX-LLaMA3-RM-v0.1 @@ -253,61 +370,103 @@ reward_model: external_lib: ${actor_rollout_ref.model.external_lib} load_weight: True micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu - micro_batch_size_per_gpu: null # set a number - max_length: null + micro_batch_size_per_gpu: null use_dynamic_bsz: ${critic.use_dynamic_bsz} forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu} + max_length: null reward_manager: naive launch_reward_fn_async: False # custom reward function executed async on CPU, during log_prob - overlong_buffer: - enable: False # We try to avoid forgetting to set enable - len: 0 - penalty_factor: 0.0 - log: True + sandbox_fusion: + url: null # faas url to run code in cloud sandbox + max_concurrent: 64 # max concurrent requests to sandbox + memory_limit_mb: 1024 # Max memory limit for each sandbox process in MB + # Nsight system profiler configs + profiler: + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint + _target_: verl.utils.profiler.ProfilerConfig + discrete: False + all_ranks: False + ranks: [] custom_reward_function: path: null name: compute_score algorithm: + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint + _target_: verl.trainer.config.AlgoConfig gamma: 1.0 lam: 1.0 adv_estimator: gae + norm_adv_by_std_in_grpo: True use_kl_in_reward: False kl_penalty: kl # how to estimate kl divergence kl_ctrl: + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint + _target_: verl.trainer.config.KLControlConfig type: fixed kl_coef: 0.001 horizon: 10000 target_kl: 0.1 - filter_groups: - enable: False # We try to avoid forgetting to set enable - metric: null # acc / score / seq_reward / seq_final_reward / ... - max_num_gen_batches: 0 # Non-positive values mean no upper limit + use_pf_ppo: False + pf_ppo: + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint + _target_: verl.trainer.config.PFPPOConfig + reweight_method: pow # ["pow", "max_min", "max_random"] + weight_pow: 2.0 trainer: balance_batch: True total_epochs: 30 total_training_steps: null + profile_steps: null # [1,2,5] or [] or null project_name: verl_examples experiment_name: gsm8k - logger: [ 'console', 'wandb' ] + logger: ['console', 'wandb'] log_val_generations: 0 nnodes: 1 n_gpus_per_node: 8 save_freq: -1 + esi_redundant_time: 0 + # auto: find the last ckpt to resume. If can't find, start from scratch resume_mode: auto # or disable or resume_path if resume_from_path is set resume_from_path: null + del_local_ckpt_after_load: False val_before_train: True test_freq: -1 critic_warmup: 0 default_hdfs_dir: null - remove_previous_ckpt_in_save: False - del_local_ckpt_after_load: False default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} max_actor_ckpt_to_keep: null max_critic_ckpt_to_keep: null + # The timeout for ray worker group to wait for the register center to be ready + ray_wait_register_center_timeout: 300 + device: cuda + # see ppo_trainer.yaml for more details + controller_nsight_options: + trace: "cuda,nvtx,cublas,ucx" + cuda-memory-usage: "true" + cuda-graph-trace: "graph" + worker_nsight_options: + trace: "cuda,nvtx,cublas,ucx" + cuda-memory-usage: "true" + cuda-graph-trace: "graph" + capture-range: "cudaProfilerApi" + capture-range-end: null + kill: none + npu_profile: + options: + save_path: ./profiler_data + level: level1 + with_memory: False + record_shapes: False + with_npu: True + with_cpu: True + with_module: False + with_stack: False + analysis: True ray_init: num_cpus: null # `None` means using all CPUs, which might cause hang if limited in systems like SLURM. Please set to a number allowed then. + timeline_json_file: null diff --git a/tests/trainer/config/legacy_ppo_trainer.yaml b/tests/trainer/config/legacy_ppo_trainer.yaml new file mode 100644 index 000000000..8ba94e204 --- /dev/null +++ b/tests/trainer/config/legacy_ppo_trainer.yaml @@ -0,0 +1,1111 @@ +# Format checks enforced on CI: +# 1. Comments must appear above each field. +# 2. There must be a blank line between each field. +# 3. Inline comments (after a field on the same line) are not allowed. +# 4. Indentation level is respected for nested fields. + +# dataset config +data: + + # Tokenizer class or path. If null, it will be inferred from the model. + tokenizer: null + + # Whether to use shared memory for data loading. + use_shm: False + + # Training set parquet. Can be a list or a single file. + # The program will read all files into memory, so it can't be too large (< 100GB). + # The path can be either a local path or an HDFS path. + # For HDFS path, we provide utils to download it to DRAM and convert it to a local path. + train_files: ~/data/rlhf/gsm8k/train.parquet + + # Validation parquet. Can be a list or a single file. + val_files: ~/data/rlhf/gsm8k/test.parquet + + # The field in the dataset where the prompt is located. Default is 'prompt'. + prompt_key: prompt + + # The field used to select the reward function (if using different ones per example). + reward_fn_key: data_source + + # Maximum prompt length. All prompts will be left-padded to this length. + # An error will be reported if the length is too long. + max_prompt_length: 512 + + # Maximum response length. Rollout in RL algorithms (e.g. PPO) generates up to this length. + max_response_length: 512 + + # Batch size sampled for one training iteration of different RL algorithms. + train_batch_size: 1024 + + # Batch size used during validation. Can be null. + val_batch_size: null + + # Whether to return the original input_ids without adding chat template. + # This is used when the reward model's chat template differs from the policy. + # If using a model-based RM with different templates, this should be True. + return_raw_input_ids: False + + # Whether to return the original chat (prompt) without applying chat template. + return_raw_chat: False + + # Whether to return the full prompt with chat template. + return_full_prompt: False + + # Whether to shuffle the data in the dataloader. + shuffle: True + + # num dataloader workers + dataloader_num_workers: 8 + + # Whether to shuffle the validation set. + validation_shuffle: False + + # Whether to filter overlong prompts. + filter_overlong_prompts: False + + # Number of workers for filtering overlong prompts. + # For large-scale datasets, filtering can be time-consuming. + # Use multiprocessing to speed up. Default is 1. + filter_overlong_prompts_workers: 1 + + # Truncate the input_ids or prompt if they exceed max_prompt_length. + # Options: 'error', 'left', or 'right'. Default is 'error'. + truncation: error + + # The field in the multi-modal dataset where the image is located. Default is 'images'. + image_key: images + + # The field in the multi-modal dataset where the video is located. + video_key: videos + + # If the remote tokenizer has a Python file, this flag determines whether to allow using it. + trust_remote_code: False + + # Optional: specify a custom dataset class path and name if overriding default loading behavior. + custom_cls: + + # The path to the file containing your customized dataset class. If not specified, pre-implemented dataset will be used. + path: null + + # The name of the dataset class within the specified file. + name: null + + # Whether to return multi-modal inputs in the dataset. Set to False if rollout generates new multi-modal inputs. + return_multi_modal_inputs: True + + # Data generation configuration for augmenting the dataset. + datagen: + + # The path to the file containing your customized data generation class. + # E.g. 'pkg://verl.experimental.dynamic_dataset.dynamicgen_dataset' + path: null + + # The class name of the data generation class within the specified file. + # E.g. 'MockDataGenerator' + name: null + + # settings related to data sampler + sampler: + + # the path to the module containing a curriculum class which implements the + # AbstractSampler interface + class_path: null + + # the name of the curriculum class like `MySampler` + class_name: null + +# config for actor, rollout and reference model +actor_rollout_ref: + + # Whether it's a hybrid engine, currently only supports hybrid engine + hybrid_engine: true + + # common configs for the model + model: + + # Huggingface model path. This can be either local path or HDFS path. + path: ~/models/deepseek-llm-7b-chat + + # Custom chat template for the model. + custom_chat_template: null + + # Whether to use shared memory (SHM) for accelerating the loading of model weights + use_shm: false + + # Additional Python packages to register huggingface models/tokenizers. + external_lib: null + + # Used to override model's original configurations, mainly dropout + override_config: {} + + # Enable gradient checkpointing for actor + enable_gradient_checkpointing: true + + # Enable activation offloading for actor + enable_activation_offload: false + + # Whether to remove padding tokens in inputs during training + use_remove_padding: false + + # Set to positive value to enable LoRA (e.g., 32) + lora_rank: 0 + + # LoRA scaling factor + lora_alpha: 16 + + # Target modules to apply LoRA. Options: "all-linear" (not recommended for VLMs) or + # [q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj] + target_modules: all-linear + + # Exclude modules from applying Lora. Similar usage to target_modules and Peft. + # Example: '.*visual.*' for excluding the ViT in Qwen2.5-VL, as currently vllm does not support ViT Lora. + exclude_modules: null + + # Whether to use Liger for linear layer fusion + use_liger: false + + # Whether to use custom fused kernels (e.g., FlashAttention, fused MLP) + use_fused_kernels: false + + # Options for fused kernels. If use_fused_kernels is true, this will be used. + fused_kernel_options: + + # Implementation backend for fused kernels. Options: "triton" or "torch". + impl_backend: torch + + # Whether to enable loading a remote code model + trust_remote_code: false + + # actor configs + actor: + + # fsdp, fsdp2 or megatron. fsdp backend used here. + strategy: fsdp + + # Split each sample into sub-batches of this size for PPO + ppo_mini_batch_size: 256 + + # [Deprecated] Global micro batch size + ppo_micro_batch_size: null + + # Local per-GPU micro batch size + ppo_micro_batch_size_per_gpu: null + + # Whether to automatically adjust batch size at runtime + use_dynamic_bsz: false + + # Max tokens per GPU in one PPO batch; affects gradient accumulation + # Typically it should be: n * ${data.max_prompt_length} + ${data.max_response_length} + ppo_max_token_len_per_gpu: 16384 + + # Gradient clipping for actor updates + grad_clip: 1.0 + + # PPO clip ratio + clip_ratio: 0.2 + + # Lower bound for asymmetric clipping (used in dual-clip PPO) + clip_ratio_low: 0.2 + + # Upper bound for asymmetric clipping (used in dual-clip PPO) + clip_ratio_high: 0.2 + + # policy loss config + policy_loss: + + # Loss function mode: vanilla / clip-cov / kl-cov /gpg from https://arxiv.org/abs/2505.22617 + loss_mode: "vanilla" + + # Ratio of tokens to be clipped for clip-cov loss + clip_cov_ratio: 0.0002 + + # Lower bound for clip-cov loss + clip_cov_lb: 1.0 + + # Upper bound for clip-cov loss + clip_cov_ub: 5.0 + + # Ratio of tokens to be applied kl penalty for kl-cov loss + kl_cov_ratio: 0.0002 + + # KL divergence penalty coefficient + ppo_kl_coef: 0.1 + + # Constant C in Dual-clip PPO; clips when advantage < 0 and ratio > C + clip_ratio_c: 3.0 + + # Loss aggregation mode: "token-mean", "seq-mean-token-sum", or "seq-mean-token-mean" + loss_agg_mode: token-mean + + # Entropy regularization coefficient in PPO loss + entropy_coeff: 0 + + # Whether to use KL loss instead of KL reward penalty. True for GRPO + use_kl_loss: false + + # Whether to use torch.compile() + use_torch_compile: true + + # KL loss coefficient when use_kl_loss is enabled. For GRPO + kl_loss_coef: 0.001 + + # Type of KL divergence loss. Options: "kl"(k1), "abs", "mse"(k2), "low_var_kl"(k3), "full" + kl_loss_type: low_var_kl + + # Number of PPO epochs per batch + ppo_epochs: 1 + + # Shuffle training data across PPO epochs + shuffle: false + + # Sequence parallelism size for Ulysses-style model parallelism + ulysses_sequence_parallel_size: 1 + + # calculate entropy with chunking to reduce memory peak + entropy_from_logits_with_chunking: False + + # recompute entropy + entropy_checkpointing: False + + # checkpoint configs + checkpoint: + + # What to include in saved checkpoints + # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space + save_contents: ['model', 'optimizer', 'extra'] + + # For more flexibility, you can specify the contents to load from the checkpoint. + load_contents: ${actor_rollout_ref.actor.checkpoint.save_contents} + + # optimizer configs + optim: + + # Learning rate + lr: 1e-6 + + # Warmup steps; negative value delegates to lr_warmup_steps_ratio + lr_warmup_steps: -1 + + # Warmup steps ratio (used if lr_warmup_steps is negative) + lr_warmup_steps_ratio: 0.0 + + # Minimum LR ratio for cosine schedule + min_lr_ratio: 0.0 + + # Number of cosine cycles in LR schedule + num_cycles: 0.5 + + # LR warmup style: "constant" or "cosine" + warmup_style: constant + + # Total training steps (must be overridden at runtime) + total_training_steps: -1 + + # Weight decay + weight_decay: 0.01 + + # configs for FSDP + fsdp_config: + + # policy for wrapping the model + wrap_policy: + + # Minimum number of parameters to trigger wrapping a layer with FSDP + min_num_params: 0 + + # Whether to offload model parameters to CPU (trades speed for memory) + param_offload: false + + # Whether to offload optimizer state to CPU + optimizer_offload: false + + # Only for FSDP2: offload param/grad/optimizer during train + offload_policy: false + + # Only for FSDP2: Reshard after forward pass to reduce memory footprint + reshard_after_forward: true + + # Number of GPUs in each FSDP shard group; -1 means auto + fsdp_size: -1 + + # Only for FSDP1: FSDP1 configuration, prefetch the next forward-pass all-gather + # before the current forward computation. + forward_prefetch: False + + # Reference model config. + # Reference model will be enabled when actor.use_kl_loss or/and algorithm.use_kl_in_reward is/are True. + ref: + + # actor_rollout_ref.ref: FSDP config same as actor. For models larger than 7B, it’s recommended to turn on offload for ref by default + strategy: ${actor_rollout_ref.actor.strategy} + + # config for FSDP strategy + fsdp_config: + + # whether to offload parameters in FSDP + param_offload: False + + # whether to perform reshard after model forward to save memory. + # only for fsdp2, [True, False, int between 1 and fsdp_size] + reshard_after_forward: True + + # Only for FSDP1: FSDP1 configuration, prefetch the next forward-pass all-gather + # before the current forward computation. + forward_prefetch: False + + # the wrap policy for FSDP model + wrap_policy: + + # minimum number of params in a wrapped module + min_num_params: 0 + + # whether to enable torch.compile + use_torch_compile: ${actor_rollout_ref.actor.use_torch_compile} + + # [Will be deprecated, use log_prob_micro_batch_size_per_gpu] + # The batch size for one forward pass in the computation of log_prob. Global batch size. + log_prob_micro_batch_size: null + + # The batch size for one forward pass in the computation of log_prob. Local batch size per GPU. + log_prob_micro_batch_size_per_gpu: null + + # enable dynamic batch size (sequence packing) for log_prob computation + log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + + # the max token length per GPU + log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} + + # sequence parallel size + ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} + + # calculate entropy with chunking to reduce memory peak + entropy_from_logits_with_chunking: False + + # recompute entropy + entropy_checkpointing: False + + # Rollout model config. + rollout: + + # actor_rollout_ref.rollout.name: hf/vllm/sglang. + name: vllm + + # sync: LLM, async: AsyncLLM + mode: sync + + # Sampling temperature for rollout. + temperature: 1.0 + + # Top-k sampling parameter. -1 for vLLM rollout, 0 for HF rollout. + top_k: -1 + + # Top-p sampling parameter. Default 1.0. + top_p: 1 + + + # typically the same as data max prompt length + prompt_length: ${data.max_prompt_length} + + # typically the same as data max response length + response_length: ${data.max_response_length} + + # for vllm rollout + # Rollout model parameters type. Align with actor model's FSDP/Megatron type. + dtype: bfloat16 + + # Fraction of GPU memory used by vLLM/SGLang for KV cache. + gpu_memory_utilization: 0.5 + + # Whether to ignore EOS and continue generating after EOS is hit. + ignore_eos: False + + # Whether to disable CUDA graph. Default True to allow cache freeing. + enforce_eager: True + + # Whether to free engine KVCache after generation. Set enforce_eager=True when enabled. + free_cache_engine: True + + # Which loader to use for rollout model weights: dummy_dtensor, hf, megatron, etc. + # safetensors (for huge model, and set use_shm=True); dummy_dtensor: randomly init model weight + load_format: dummy_dtensor + + # for huge model, layered summon can save memory (prevent OOM) but make it slower + layered_summon: False + + # TP size for rollout. Only effective for vLLM. + tensor_model_parallel_size: 2 + + # max number of tokens in a batch + max_num_batched_tokens: 8192 + + # max length for rollout + max_model_len: null + + # max length of sequences + max_num_seqs: 1024 + + # [Will be deprecated, use log_prob_micro_batch_size_per_gpu] The batch size for one forward pass in the computation of log_prob. Global batch size. + log_prob_micro_batch_size: null + + # The batch size for one forward pass in the computation of log_prob. Local batch size per GPU. + log_prob_micro_batch_size_per_gpu: null + + # enable dynamic batch size (sequence packing) for log_prob computation + log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + + # max token length for log_prob computation + log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} + + # disable logging statistics + disable_log_stats: True + + # may get higher throughput when set to True. When activated, Please increase max_num_batched_tokens or decrease max_model_len. + enable_chunked_prefill: True + + # for hf rollout + # Whether to sample during training rollout. False uses greedy sampling. + do_sample: True + + # number of responses (i.e. num sample times). > 1 for grpo + n: 1 + + # Whether to wake up inference engine in multi-stage. (Wake up model weights first, then resume kv cache) + multi_stage_wake_up: false + + # Extra inference engine arguments (vllm, sglang). + engine_kwargs: + + # for vllm + vllm: + + # Swap space (in GB) used by inference engine. null uses default (e.g., 4 GB). + swap_space: null + + # Whether to disable the preprocessor cache for multimodel models. + disable_mm_preprocessor_cache: False + + # for sglang + sglang: + + # The attention backend for sglang engine. Options: flashinfer, triton, flashmla, null for default. + attention_backend: null + + # Sampling parameters used during validation. + val_kwargs: + + # sampling parameters for validation + # Top-k sampling parameter. -1 for vLLM rollout, 0 for HF rollout. + top_k: -1 + + # Top-p sampling parameter. Default 1.0. + top_p: 1.0 + + # Sampling temperature for rollout. + temperature: 0 + + # whether to repeat n times for validation + n: 1 + + # Whether to sample during training rollout. False uses greedy sampling. + do_sample: False + + # Multi-turn interaction config for tools or chat. + multi_turn: + + # set to True for multi-turn tool interaction tasks; should set rollout.name to sglang as well + enable: False + + # null for no limit (default max_length // 3) + max_assistant_turns: null + + # null for no tool + tool_config_path: null + + # null for no limit (default max_length // 3) + max_user_turns: null + + # max parallel call for tools in single turn + max_parallel_calls: 1 + + # max length of tool response + max_tool_response_length: 256 + + # truncate side of tool response: left, middle, right + tool_response_truncate_side: middle + + # null for no interaction + interaction_config_path: null + + # null for default callback + completion_callback: null + + # - When set to True, the model's default chat template is used for multi-turn rollout, which typically matches production behavior. + # - When set to False, the token ids recorded for training are used instead; unlike the default chat template, these always include the model's full output, + # which may contain additional content such as reasoning content. This maintains the consistency between training and rollout, but it will lead to longer prompts. + use_inference_chat_template: False + + # Tokenization is performed turn by turn and the resulting token ids are concatenated to form the full conversation. + # To ensure this matches the result of tokenizing the entire conversation at once, a sanity check is run at the end of each multi-turn rollout to compare the two sets of token ids. + # Some models are known to produce different tokenization results when tokenizing turn by turn vs. all at once. aThis behavior has already been validated for them. + # To reduce excessive warnings, you can turn off the sanity check for these models if you are using their default chat template: + # Qwen/QwQ-32B, Qwen/Qwen3-xxB + # - disable: disable tokenization sanity check + # - strict: enable strict tokenization sanity check (default) + # - ignore_strippable: ignore strippable tokens when checking tokenization sanity + tokenization_sanity_check_mode: strict + + # Format of the multi-turn interaction. Options: hermes, llama3_json, ... + format: hermes + + # support logging rollout prob for debugging purpose + calculate_log_probs: False + + # [Experimental] agent loop based rollout configs + agent: + + # Number of agent loop workers + num_workers: 8 + + # custom async server configs + custom_async_server: + + # Path to the custom async server implementation + path: null + + # Class name of the custom async server class (e.g. AsyncvLLMServer) + name: null + + # profiler configs + profiler: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint + _target_: verl.utils.profiler.ProfilerConfig + + # True for each task has its own database, False for all tasks in one training step share one database. + discrete: False + + # Whether to profile all ranks. + all_ranks: False + + # The ranks that will be profiled. [] or [0,1,...] + ranks: [] + +# configs for the critic +critic: + + # Number of rollouts per update (mirrors actor rollout_n) + rollout_n: ${actor_rollout_ref.rollout.n} + + # fsdp or fsdp2 strategy used for critic model training + strategy: ${actor_rollout_ref.actor.strategy} + + # optimizer configs + optim: + + # Learning rate + lr: 1e-5 + + # Warmup steps ratio; total steps will be injected at runtime + lr_warmup_steps_ratio: 0. + + # Minimum LR ratio for cosine schedule + min_lr_ratio: null + + # LR warmup style: "constant" or "cosine" + warmup_style: constant + + # Total training steps (must be overridden at runtime) + total_training_steps: -1 + + # Weight decay + weight_decay: 0.01 + + # model config for the critic + model: + + # Path to pretrained model weights + path: ~/models/deepseek-llm-7b-chat + + # Whether to use shared memory for loading the model + use_shm: False + + # Tokenizer path (defaults to actor's model path) + tokenizer_path: ${actor_rollout_ref.model.path} + + # Hugging Face config override + override_config: { } + + # External model implementation (optional) + external_lib: ${actor_rollout_ref.model.external_lib} + + # Enable gradient checkpointing to save memory + enable_gradient_checkpointing: True + + # Offload activations to CPU to reduce GPU memory usage + enable_activation_offload: False + + # Use remove padding optimization (saves compute) + use_remove_padding: False + + # Whether to trust remote code from Hugging Face models + trust_remote_code: ${actor_rollout_ref.model.trust_remote_code} + + # FSDP-specific config + fsdp_config: + + # Whether to offload model parameters to CPU + param_offload: False + + # Whether to offload optimizer state to CPU + optimizer_offload: False + + # Only for FSDP2: offload param/grad/optimizer during train + offload_policy: False + + # Only for FSDP2: Reshard after forward pass to reduce memory footprint + reshard_after_forward: True + + # Policy for wrapping layers with FSDP + wrap_policy: + + # Minimum number of parameters to trigger wrapping + min_num_params: 0 + + # Number of GPUs in each FSDP shard group; -1 means auto + fsdp_size: -1 + + # Only for FSDP1: FSDP1 configuration, prefetch the next forward-pass all-gather + # before the current forward computation. + forward_prefetch: False + + # Set to positive value to enable LoRA (e.g., 32) + lora_rank: 0 + + # LoRA scaling factor + lora_alpha: 16 + + # LoRA target modules: "all-linear" or list of linear projection layers + target_modules: all-linear + + # PPO mini-batch size per update + ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size} + + # [Deprecated] Global micro batch size + ppo_micro_batch_size: null + + # Local per-GPU micro batch size + ppo_micro_batch_size_per_gpu: null + + # Forward-only batch size (global) + forward_micro_batch_size: ${critic.ppo_micro_batch_size} + + # Forward-only batch size (per GPU) + forward_micro_batch_size_per_gpu: ${critic.ppo_micro_batch_size_per_gpu} + + # Whether to automatically adjust batch size at runtime + use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + + # Max tokens per GPU in one PPO batch (doubled for critic) + ppo_max_token_len_per_gpu: 32768 + + # Max token length per GPU in forward pass + forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu} + + # Sequence parallelism size for Ulysses-style model parallelism + ulysses_sequence_parallel_size: 1 + + # Number of PPO epochs per batch + ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs} + + # Shuffle training data across PPO epochs + shuffle: ${actor_rollout_ref.actor.shuffle} + + # Gradient clipping for critic updates + grad_clip: 1.0 + + # PPO value function clipping range + cliprange_value: 0.5 + + # Loss aggregation mode: "token-mean", "seq-mean-token-sum", or "seq-mean-token-mean" + loss_agg_mode: ${actor_rollout_ref.actor.loss_agg_mode} + + # checkpoint configs + checkpoint: + + # What to include in saved checkpoints + # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space + save_contents: ['model', 'optimizer', 'extra'] + + # What to include when loading checkpoints + load_contents: ${critic.checkpoint.save_contents} + + # profiler configs + # the corresponding dataclass is verl.utils.profiler.ProfilerConfig. + profiler: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint + _target_: verl.utils.profiler.ProfilerConfig + + # True for each task has its own database, False for all tasks in one training step share one database. + discrete: False + + # Whether to profile all ranks. + all_ranks: False + + # The ranks that will be profiled. [] or [0,1,...] + ranks: [] + +# configs for the reward model +reward_model: + + # Whether to enable reward model. If False, we compute the reward only with the user-defined reward functions. + # In GSM8K and Math examples, we disable reward model. + # For RLHF alignment example using full_hh_rlhf, we utilize reward model to assess the responses. + # If False, the following parameters are not effective + enable: False + + # FSDP strategy: "fsdp" or "fsdp2" + strategy: ${actor_rollout_ref.actor.strategy} + + # model config for reward scoring + model: + + # Input tokenizer. If the reward model’s chat template is inconsistent with the policy, + # we need to first decode to plaintext, then apply the rm’s chat_template. + # Then score with RM. If chat_templates are consistent, it can be set to null. + input_tokenizer: ${actor_rollout_ref.model.path} + + # RM’s HDFS path or local path. Note that RM only supports AutoModelForSequenceClassification. + # Other model types need to define their own RewardModelWorker and pass it from the code. + path: ~/models/FsfairX-LLaMA3-RM-v0.1 + + # Whether to use shared memory for loading the model + use_shm: False + + # External model implementation (optional) + external_lib: ${actor_rollout_ref.model.external_lib} + + # Use remove padding optimization (saves compute) + use_remove_padding: False + + # Whether to use fused reward kernels for speedup + use_fused_kernels: ${actor_rollout_ref.model.use_fused_kernels} + + # Whether to enable loading a remote code model, default to False + trust_remote_code: False + + # FSDP-specific config + fsdp_config: + + # Policy for wrapping layers with FSDP + wrap_policy: + + # Minimum number of parameters to trigger wrapping + min_num_params: 0 + + # Whether to offload model parameters to CPU + param_offload: False + + # Only for FSDP2: Reshard after forward pass to reduce memory footprint + reshard_after_forward: True + + # Number of GPUs in each FSDP shard group; -1 means auto + fsdp_size: -1 + + # Only for FSDP1: FSDP1 configuration, prefetch the next forward-pass all-gather + # before the current forward computation. + forward_prefetch: False + + # [Deprecated] Global micro batch size + micro_batch_size: null + + # Local per-GPU micro batch size + micro_batch_size_per_gpu: null + + # Maximum sequence length to process for scoring + max_length: null + + # Sequence parallelism size for Ulysses-style model parallelism + ulysses_sequence_parallel_size: 1 + + # Whether to dynamically adjust batch size at runtime + use_dynamic_bsz: ${critic.use_dynamic_bsz} + + # Maximum number of tokens per GPU in one forward pass + forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu} + + # Reward Manager. This defines the mechanism of computing rule-based reward and handling different reward sources. + # Default is naive. If all verification functions are multiprocessing-safe, + # the reward manager can be set to prime for parallel verification. + reward_manager: naive + + # Whether to launch custom reward function asynchronously during log_prob + launch_reward_fn_async: False + + # Cloud/local sandbox fusion configuration for custom reward logic + sandbox_fusion: + + # Cloud/local function URL for sandbox execution + url: null + + # Max concurrent requests allowed to sandbox + max_concurrent: 64 + + # Max memory limit for each sandbox process in MB + memory_limit_mb: 1024 + + # profiler configs + profiler: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint + _target_: verl.utils.profiler.ProfilerConfig + + # True for each task has its own database, False for all tasks in one training step share one database. + discrete: False + + # Whether to profile all ranks. + all_ranks: False + + # The ranks that will be profiled. [] or [0,1,...] + ranks: [] + +# custom reward function definition +custom_reward_function: + + # The path to the file containing your customized reward function. + # If not specified, pre-implemented reward functions will be used. + path: null + + # The name of the reward function within the specified file. Default is 'compute_score'. + name: compute_score + +# config for the algorithm +algorithm: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint + _target_: verl.trainer.config.AlgoConfig + + # Discount factor for future rewards + gamma: 1.0 + + # Trade-off between bias and variance in the GAE estimator + lam: 1.0 + + # Advantage estimator type: "gae", "grpo", "reinforce_plus_plus", etc. + adv_estimator: gae + + # Whether to normalize advantages by std (specific to GRPO) + norm_adv_by_std_in_grpo: True + + # Whether to enable in-reward KL penalty + use_kl_in_reward: False + + # How to estimate KL divergence: "kl", "abs", "mse", "low_var_kl", or "full" + kl_penalty: kl + + # KL control configuration + kl_ctrl: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint + _target_: verl.trainer.config.KLControlConfig + + # KL control type: "fixed" or "adaptive" + type: fixed + + # Initial coefficient for KL penalty + kl_coef: 0.001 + + # Horizon value for adaptive controller (if enabled) + horizon: 10000 + + # Target KL divergence (used for adaptive controller) + target_kl: 0.1 + + # Whether to enable preference feedback PPO + use_pf_ppo: False + + # Preference feedback PPO settings + pf_ppo: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint + _target_: verl.trainer.config.PFPPOConfig + + # Method for reweighting samples: "pow", "max_min", or "max_random" + reweight_method: pow + + # Power used for weight scaling in "pow" method + weight_pow: 2.0 + +# config for the trainer +trainer: + + # Whether to balance batch sizes across distributed workers + balance_batch: True + + # Number of epochs in training + total_epochs: 30 + + # Total training steps (can be set explicitly or derived from epochs) + total_training_steps: null + + # The steps that will be profiled. null means no profiling. null or [1,2,5,...] + profile_steps: null + + # controller Nvidia Nsight Systems Options. Must set when profile_steps is not None. + ## reference https://docs.nvidia.com/nsight-systems/UserGuide/index.html + ## reference https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html + controller_nsight_options: + + # Select the API(s) to be traced. + trace: "cuda,nvtx,cublas,ucx" + + # Track the GPU memory usage by CUDA kernels. Must be string type "true" or "false". + cuda-memory-usage: "true" + + # CUDA graphs will be traced as a whole + cuda-graph-trace: "graph" + + # worker Nvidia Nsight Systems Options. Must set when profile_steps is not None. + worker_nsight_options: + + # Select the API(s) to be traced. + trace: "cuda,nvtx,cublas,ucx" + + # Track the GPU memory usage by CUDA kernels. Must be string type "true" or "false". + cuda-memory-usage: "true" + + # CUDA graphs will be traced as a whole + cuda-graph-trace: "graph" + + # Profiling only in a range of torch.cuda.profiler.start and stop. Do not change this config. + capture-range: "cudaProfilerApi" + + # Specify the desired behavior when a capture range ends. + # In verl we need the orch.cuda.profiler.start/stop pair to repeats n times. + # valid values are "repeat-shutdown:n" or null. + # For normal whole step profiling, n = len(profile_steps); + # but for discrete profiling, n = len(profile_steps) * Number(subtasks). + # Or you can just leave it null and the program will use n = len(profile_steps) * 6; + capture-range-end: null + + # Send signal to the target application's process group. We let the program to exit by itself. + kill: none + + # Config for npu profiler. Must set when profile_steps is not None and torch_npu is available. + npu_profile: + + # Options for the npu profiler + options: + + # Storage path of collected data. + save_path: ./profiler_data + + # Collection level, optional values: level_none, level0, level1, level2. + level: level1 + + # Whether to enable memory analysis. + with_memory: False + + # Whether to record tensor shape. + record_shapes: False + + # Whether to record Device-side performance data. + with_npu: True + + # Whether to record Host-side performance data. + with_cpu: True + + # Whether to record Python call stack information. + with_module: False + + # Whether to record operator call stack information. + with_stack: False + + # Whether to automatically parse the data. + analysis: True + + # Project name for experiment tracking (e.g., wandb) + project_name: verl_examples + + # Experiment name for run identification in tracking tools + experiment_name: gsm8k + + # Logging backends to use: "console", "wandb", etc. + logger: [ 'console', 'wandb' ] + + # Number of generations to log during validation + log_val_generations: 0 + + # Directory for logging rollout data; no dump if null + rollout_data_dir: null + + # Directory for logging validation data; no dump if null + validation_data_dir: null + + # Number of nodes used in the training + nnodes: 1 + + # Number of GPUs per node + n_gpus_per_node: 8 + + # Save frequency (by iteration) for model checkpoints + save_freq: -1 + + # ESI refers to the elastic server instance used during training, similar to the training plan. For example, + # if you purchase 10 hours of computing power, the ESI will automatically shut down after 10 hours of training. + # To ensure a checkpoint is saved before ESI shuts down, the system will start saving a checkpoint in advance. + # The advance time is calculated as: Advance Time = Longest historical step duration + Checkpoint save duration + esi_redundant_time. + # Here, esi_redundant_time is a user-defined value that further extends the advance time for added safety. + esi_redundant_time: 0 + + # Resume mode: "auto", "disable", or "resume_path" + # "auto": resume from last checkpoint if available + # "disable": start from scratch + # "resume_path": resume from a user-defined path + resume_mode: auto + + # Path to resume training from (only used when resume_mode is "resume_path") + resume_from_path: null + + # Whether to run validation before training begins + val_before_train: True + + # Whether to run validation only + val_only: False + + # Validation frequency (in training iterations) + test_freq: -1 + + # Number of iterations to warm up the critic before updating policy + critic_warmup: 0 + + # Default path to distributed filesystem for saving checkpoints + default_hdfs_dir: null + + # Whether to delete local checkpoints after loading + del_local_ckpt_after_load: False + + # Default local directory for saving checkpoints + default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} + + # Maximum number of actor checkpoints to keep + max_actor_ckpt_to_keep: null + + # Maximum number of critic checkpoints to keep + max_critic_ckpt_to_keep: null + + # Timeout (in seconds) for Ray worker to wait for registration + ray_wait_register_center_timeout: 300 + + # Device to run training on (e.g., "cuda", "cpu") + device: cuda + +# configs related to ray initialization +ray_init: + + # Number of CPUs for Ray. Use a fixed number instead of null when using SLURM. + num_cpus: null + + # Path to save Ray timeline JSON for performance profiling + timeline_json_file: null diff --git a/tests/trainer/config/test_algo_config_on_cpu.py b/tests/trainer/config/test_algo_config_on_cpu.py new file mode 100644 index 000000000..848a3ffe1 --- /dev/null +++ b/tests/trainer/config/test_algo_config_on_cpu.py @@ -0,0 +1,206 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from omegaconf import OmegaConf + +from verl.trainer.config import AlgoConfig, KLControlConfig, PFPPOConfig +from verl.trainer.ppo.core_algos import ( + compute_gae_advantage_return, + compute_grpo_outcome_advantage, + get_adv_estimator_fn, +) +from verl.utils.config import omega_conf_to_dataclass + + +class TestAlgoConfig(unittest.TestCase): + """Test the AlgoConfig dataclass and its integration with core algorithms.""" + + def setUp(self): + """Set up test fixtures.""" + # Create a sample algorithm config as DictConfig (similar to what comes from YAML) + self.config_dict = { + "_target_": "verl.trainer.config.AlgoConfig", + "gamma": 0.99, + "lam": 0.95, + "adv_estimator": "gae", + "norm_adv_by_std_in_grpo": True, + "use_kl_in_reward": True, + "kl_penalty": "kl", + "kl_ctrl": { + "_target_": "verl.trainer.config.KLControlConfig", + "type": "adaptive", + "kl_coef": 0.002, + "horizon": 5000, + "target_kl": 0.05, + }, + "use_pf_ppo": True, + "pf_ppo": {"_target_": "verl.trainer.config.PFPPOConfig", "reweight_method": "max_min", "weight_pow": 3.0}, + } + self.omega_config = OmegaConf.create(self.config_dict) + + def test_dataclass_creation_from_dict(self): + """Test creating AlgoConfig from dictionary.""" + config = omega_conf_to_dataclass(self.config_dict) + + self.assertIsInstance(config, AlgoConfig) + self.assertEqual(config.gamma, 0.99) + self.assertEqual(config.lam, 0.95) + self.assertEqual(config.adv_estimator, "gae") + self.assertTrue(config.norm_adv_by_std_in_grpo) + self.assertTrue(config.use_kl_in_reward) + self.assertEqual(config.kl_penalty, "kl") + self.assertTrue(config.use_pf_ppo) + + def test_dataclass_creation_from_omega_config(self): + """Test creating AlgoConfig from OmegaConf DictConfig.""" + config = omega_conf_to_dataclass(self.omega_config) + + self.assertIsInstance(config, AlgoConfig) + self.assertEqual(config.gamma, 0.99) + self.assertEqual(config.lam, 0.95) + + def test_nested_configs(self): + """Test that nested configurations are properly converted.""" + config = omega_conf_to_dataclass(self.omega_config) + + # Test KL control config + self.assertIsInstance(config.kl_ctrl, KLControlConfig) + self.assertEqual(config.kl_ctrl.type, "adaptive") + self.assertEqual(config.kl_ctrl.kl_coef, 0.002) + self.assertEqual(config.kl_ctrl.horizon, 5000) + self.assertEqual(config.kl_ctrl.target_kl, 0.05) + + # Test PF PPO config + self.assertIsInstance(config.pf_ppo, PFPPOConfig) + self.assertEqual(config.pf_ppo.reweight_method, "max_min") + self.assertEqual(config.pf_ppo.weight_pow, 3.0) + + def test_default_values(self): + """Test that default values are properly set.""" + minimal_config = {"gamma": 0.8} + config = omega_conf_to_dataclass(minimal_config, AlgoConfig) + + self.assertEqual(config.gamma, 0.8) + self.assertEqual(config.lam, 1.0) # default value + self.assertEqual(config.adv_estimator, "gae") # default value + self.assertTrue(config.norm_adv_by_std_in_grpo) # default value + self.assertFalse(config.use_kl_in_reward) # default value + self.assertEqual(config.kl_penalty, "kl") # default value + self.assertFalse(config.use_pf_ppo) # default value + + def test_get_method_backward_compatibility(self): + """Test the get method for backward compatibility.""" + config = omega_conf_to_dataclass(self.omega_config) + + # Test existing attribute + self.assertEqual(config.get("gamma"), 0.99) + self.assertEqual(config.get("gamma", 1.0), 0.99) + + # Test non-existing attribute + self.assertIsNone(config.get("non_existing")) + self.assertEqual(config.get("non_existing", "default"), "default") + + def test_post_init_nested_configs(self): + """Test that __post_init__ properly initializes nested configs when None.""" + # Create config without nested configs + minimal_config = AlgoConfig(gamma=0.9) + + # Check that nested configs are initialized + self.assertIsNotNone(minimal_config.kl_ctrl) + self.assertIsInstance(minimal_config.kl_ctrl, KLControlConfig) + self.assertIsNone(minimal_config.pf_ppo) + + def test_config_init_from_yaml(self): + import os + + from hydra import compose, initialize_config_dir + + with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")): + cfg = compose(config_name="ppo_trainer") + algo_config = omega_conf_to_dataclass(cfg.algorithm) + from verl.trainer.config import AlgoConfig, PFPPOConfig + + assert isinstance(algo_config, AlgoConfig) + assert isinstance(algo_config.pf_ppo, PFPPOConfig) + + +class TestAlgoCompute(unittest.TestCase): + """Test the AlgoConfig dataclass and its integration with core algorithms.""" + + def setUp(self): + """Set up test fixtures.""" + self.algo_config = AlgoConfig( + gamma=0.99, + lam=0.95, + adv_estimator="gae", + norm_adv_by_std_in_grpo=True, + use_kl_in_reward=True, + kl_penalty="kl", + kl_ctrl=KLControlConfig(type="adaptive", kl_coef=0.002, horizon=5000, target_kl=0.05), + use_pf_ppo=True, + pf_ppo=PFPPOConfig(reweight_method="max_min", weight_pow=3.0), + ) + + def test_advantage_estimator_with_cfg(self): + """Test integration with advantage estimators from core_algos.""" + config = self.algo_config + + # Test GAE advantage estimator + adv_fn = get_adv_estimator_fn(config.adv_estimator) + self.assertIsNotNone(adv_fn) + + # Test with actual GAE computation + batch_size, seq_len = 2, 5 + token_level_rewards = torch.randn(batch_size, seq_len) + values = torch.randn(batch_size, seq_len) + response_mask = torch.ones(batch_size, seq_len) + + advantages, returns = compute_gae_advantage_return( + token_level_rewards=token_level_rewards, + values=values, + response_mask=response_mask, + gamma=config.gamma, + lam=config.lam, + ) + + self.assertEqual(advantages.shape, (batch_size, seq_len)) + self.assertEqual(returns.shape, (batch_size, seq_len)) + + def test_grpo_advantage_estimator_with_cfg(self): + """Test integration with GRPO advantage estimator.""" + grpo_config = AlgoConfig(adv_estimator="grpo", norm_adv_by_std_in_grpo=True) + + # Test GRPO advantage computation + batch_size, seq_len = 4, 3 + token_level_rewards = torch.tensor([[1.0, 0.5, 0.0], [2.0, 1.0, 0.0], [0.5, 0.2, 0.0], [1.5, 0.8, 0.0]]) + response_mask = torch.ones(batch_size, seq_len) + index = np.array([0, 0, 1, 1]) # Two groups + + advantages, returns = compute_grpo_outcome_advantage( + token_level_rewards=token_level_rewards, + response_mask=response_mask, + index=index, + norm_adv_by_std_in_grpo=grpo_config.norm_adv_by_std_in_grpo, + ) + + self.assertEqual(advantages.shape, (batch_size, seq_len)) + self.assertEqual(returns.shape, (batch_size, seq_len)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/trainer/config/test_legacy_config_on_cpu.py b/tests/trainer/config/test_legacy_config_on_cpu.py new file mode 100644 index 000000000..39862aa22 --- /dev/null +++ b/tests/trainer/config/test_legacy_config_on_cpu.py @@ -0,0 +1,133 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest + +from hydra import compose, initialize_config_dir +from hydra.core.global_hydra import GlobalHydra +from omegaconf import OmegaConf + + +class TestConfigComparison(unittest.TestCase): + """Test that current configs match their legacy counterparts exactly.""" + + def _compare_configs_recursively(self, current_config, legacy_config, path="", legacy_allow_missing=True): + """Recursively compare two OmegaConf configs and assert they are identical. + + Args: + legacy_allow_missing (bool): sometimes the legacy megatron config contains fewer keys and + we allow that to happen + """ + if isinstance(current_config, dict) and isinstance(legacy_config, dict): + current_keys = set(current_config.keys()) + legacy_keys = set(legacy_config.keys()) + + missing_in_current = legacy_keys - current_keys + missing_in_legacy = current_keys - legacy_keys + + if missing_in_current: + self.fail(f"Keys missing in current config at {path}: {missing_in_current}") + if missing_in_legacy: + # if the legacy + msg = f"Keys missing in legacy config at {path}: {missing_in_legacy}" + if legacy_allow_missing: + print(msg) + else: + self.fail(msg) + + for key in current_keys: + current_path = f"{path}.{key}" if path else key + if key in legacy_config: + self._compare_configs_recursively(current_config[key], legacy_config[key], current_path) + elif isinstance(current_config, list) and isinstance(legacy_config, list): + self.assertEqual( + len(current_config), + len(legacy_config), + f"List lengths differ at {path}: current={len(current_config)}, legacy={len(legacy_config)}", + ) + for i, (current_item, legacy_item) in enumerate(zip(current_config, legacy_config, strict=True)): + self._compare_configs_recursively(current_item, legacy_item, f"{path}[{i}]") + else: + self.assertEqual( + current_config, + legacy_config, + f"Values differ at {path}: current={current_config}, legacy={legacy_config}", + ) + + def test_ppo_trainer_config_matches_legacy(self): + """Test that ppo_trainer.yaml matches legacy_ppo_trainer.yaml exactly.""" + import os + + from hydra import compose, initialize_config_dir + from hydra.core.global_hydra import GlobalHydra + + GlobalHydra.instance().clear() + + try: + with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")): + current_config = compose(config_name="ppo_trainer") + + legacy_config = OmegaConf.load("tests/trainer/config/legacy_ppo_trainer.yaml") + current_dict = OmegaConf.to_container(current_config, resolve=True) + legacy_dict = OmegaConf.to_container(legacy_config, resolve=True) + + if "defaults" in current_dict: + del current_dict["defaults"] + + self._compare_configs_recursively(current_dict, legacy_dict) + finally: + GlobalHydra.instance().clear() + + def test_ppo_megatron_trainer_config_matches_legacy(self): + """Test that ppo_megatron_trainer.yaml matches legacy_ppo_megatron_trainer.yaml exactly.""" + + GlobalHydra.instance().clear() + + try: + with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")): + current_config = compose(config_name="ppo_megatron_trainer") + + legacy_config = OmegaConf.load("tests/trainer/config/legacy_ppo_megatron_trainer.yaml") + current_dict = OmegaConf.to_container(current_config, resolve=True) + legacy_dict = OmegaConf.to_container(legacy_config, resolve=True) + + if "defaults" in current_dict: + del current_dict["defaults"] + + self._compare_configs_recursively(current_dict, legacy_dict, legacy_allow_missing=True) + finally: + GlobalHydra.instance().clear() + + def test_load_component(self): + """Test that ppo_megatron_trainer.yaml matches legacy_ppo_megatron_trainer.yaml exactly.""" + + GlobalHydra.instance().clear() + configs_to_load = [ + ("verl/trainer/config/actor", "dp_actor"), + ("verl/trainer/config/actor", "megatron_actor"), + ("verl/trainer/config/ref", "dp_ref"), + ("verl/trainer/config/ref", "megatron_ref"), + ("verl/trainer/config/rollout", "rollout"), + ] + for config_dir, config_file in configs_to_load: + try: + with initialize_config_dir(config_dir=os.path.abspath(config_dir)): + compose(config_name=config_file) + finally: + GlobalHydra.instance().clear() + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/trainer/ppo/test_core_algos_on_cpu.py b/tests/trainer/ppo/test_core_algos_on_cpu.py new file mode 100644 index 000000000..087a0d2f1 --- /dev/null +++ b/tests/trainer/ppo/test_core_algos_on_cpu.py @@ -0,0 +1,192 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import unittest + +import pytest +import torch + +import verl.trainer.ppo.core_algos +from verl.trainer.ppo.core_algos import compute_gae_advantage_return, get_adv_estimator_fn, register_adv_est + + +def mock_test_fn(): + pass + + +class TestRegisterAdvEst(unittest.TestCase): + def setUp(self): + """Clear the registry before each test""" + verl.trainer.ppo.core_algos.ADV_ESTIMATOR_REGISTRY.clear() + verl.trainer.ppo.core_algos.ADV_ESTIMATOR_REGISTRY = { + "gae": lambda x: x * 2, + "vtrace": lambda x: x + 1, + } + self.ADV_ESTIMATOR_REGISTRY = verl.trainer.ppo.core_algos.ADV_ESTIMATOR_REGISTRY + + def tearDown(self) -> None: + verl.trainer.ppo.core_algos.ADV_ESTIMATOR_REGISTRY.clear() + return super().tearDown() + + def test_register_new_function(self): + """Test registering a new function with a string name""" + + @register_adv_est("test_estimator") + def test_fn(): + pass + + self.assertIn("test_estimator", self.ADV_ESTIMATOR_REGISTRY) + self.assertEqual(self.ADV_ESTIMATOR_REGISTRY["test_estimator"], test_fn) + + def test_register_with_enum(self): + """Test registering with an enum value (assuming AdvantageEstimator exists)""" + from enum import Enum + + class AdvantageEstimator(Enum): + TEST = "test_enum_estimator" + + @register_adv_est(AdvantageEstimator.TEST) + def test_fn(): + pass + + self.assertIn("test_enum_estimator", self.ADV_ESTIMATOR_REGISTRY) + self.assertEqual(self.ADV_ESTIMATOR_REGISTRY["test_enum_estimator"], test_fn) + + def test_duplicate_registration_same_function(self): + """Test that registering the same function twice doesn't raise an error""" + register_adv_est("duplicate_test")(mock_test_fn) + register_adv_est("duplicate_test")(mock_test_fn) + + self.assertEqual(self.ADV_ESTIMATOR_REGISTRY["duplicate_test"], mock_test_fn) + + def test_duplicate_registration_different_function(self): + """Test that registering different functions with same name raises ValueError""" + + @register_adv_est("conflict_test") + def test_fn1(): + pass + + with self.assertRaises(ValueError): + + @register_adv_est("conflict_test") + def test_fn2(): + pass + + def test_decorator_preserves_function(self): + """Test that the decorator returns the original function""" + + def test_fn(): + return "original" + + decorated = register_adv_est("preserve_test")(test_fn) + self.assertEqual(decorated(), "original") + + def test_multiple_registrations(self): + """Test registering multiple different functions""" + init_adv_count = len(self.ADV_ESTIMATOR_REGISTRY) + + @register_adv_est("estimator1") + def fn1(): + pass + + @register_adv_est("estimator2") + def fn2(): + pass + + self.assertEqual(len(self.ADV_ESTIMATOR_REGISTRY), 2 + init_adv_count) + self.assertEqual(self.ADV_ESTIMATOR_REGISTRY["estimator1"], fn1) + self.assertEqual(self.ADV_ESTIMATOR_REGISTRY["estimator2"], fn2) + + def test_get_adv_estimator_fn_valid_names(self): + """Test that valid names return the correct function from registry.""" + # Test GAE + gae_fn = get_adv_estimator_fn("gae") + assert gae_fn(5) == 10 # 5 * 2 = 10 + + # Test Vtrace + vtrace_fn = get_adv_estimator_fn("vtrace") + assert vtrace_fn(5) == 6 # 5 + 1 = 6 + + def test_get_adv_estimator_fn_invalid_name(self): + """Test that invalid names raise ValueError.""" + with pytest.raises(ValueError) as excinfo: + get_adv_estimator_fn("invalid_name") + assert "Unknown advantage estimator simply: invalid_name" in str(excinfo.value) + + def test_get_adv_estimator_fn_case_sensitive(self): + """Test that name lookup is case-sensitive.""" + with pytest.raises(ValueError): + get_adv_estimator_fn("GAE") # Different case + + +def test_multi_turn_compute_gae_advantage_return(): + """Test multi-turn GAE skip observation tokens.""" + gamma = random.uniform(0.0, 1.0) + lam = random.uniform(0.0, 1.0) + + rewards = torch.tensor([[0.0, 0.0, 0.1, 0.1, 0.1, 0.0, 0.0, 0.1, 1.0, 0.0, 0.0]], dtype=torch.float) + + values1 = torch.tensor( + [ + [ + random.uniform(-100.0, 100.0), + random.random(), + 4.0, + 5.0, + 6.0, + random.uniform(-100.0, 0), + random.random(), + 7.0, + 9.0, + 0.0, + 0.0, + ] + ], + dtype=torch.float, + ) + + values2 = torch.tensor( + [ + [ + random.random(), + random.uniform(-100.0, 100.0), + 4.0, + 5.0, + 6.0, + random.random(), + random.uniform(0.0, 100.0), + 7.0, + 9.0, + 0.0, + 0.0, + ] + ], + dtype=torch.float, + ) + + response_mask = torch.tensor([[0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0]], dtype=torch.float) + + adv1, ret1 = compute_gae_advantage_return(rewards, values1, response_mask, gamma, lam) + adv2, ret2 = compute_gae_advantage_return(rewards, values2, response_mask, gamma, lam) + + ret1 *= response_mask + ret2 *= response_mask + assert torch.equal(adv1, adv2), f"{adv1=}, {adv2=}" + assert torch.equal(ret1, ret2), f"{ret1=}, {ret2=}" + print(f" [CORRECT] \n\n{adv1=}, \n\n{ret1=}") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/trainer/ppo/test_metric_utils.py b/tests/trainer/ppo/test_metric_utils_on_cpu.py similarity index 92% rename from tests/trainer/ppo/test_metric_utils.py rename to tests/trainer/ppo/test_metric_utils_on_cpu.py index 5e685e5a9..50fe952c0 100644 --- a/tests/trainer/ppo/test_metric_utils.py +++ b/tests/trainer/ppo/test_metric_utils_on_cpu.py @@ -44,32 +44,32 @@ def test_reduce_metrics_basic(self): "accuracy": [0.0, 0.5, 1.0], } result = reduce_metrics(metrics) - + self.assertEqual(result["loss"], 2.0) self.assertEqual(result["accuracy"], 0.5) - + def test_reduce_metrics_empty(self): """Test that reduce_metrics handles empty lists.""" metrics = { "empty": [], } result = reduce_metrics(metrics) - + self.assertTrue(np.isnan(result["empty"])) - + def test_reduce_metrics_single_value(self): """Test that reduce_metrics works with single values.""" metrics = { "single": [5.0], } result = reduce_metrics(metrics) - + self.assertEqual(result["single"], 5.0) class TestComputeDataMetrics(unittest.TestCase): """Tests for the compute_data_metrics function.""" - + def setUp(self): """Set up common test data.""" # Create a mock DataProto object @@ -80,17 +80,25 @@ def setUp(self): "advantages": torch.tensor([[0.1, 0.2], [0.3, 0.4]]), "returns": torch.tensor([[1.1, 1.2], [1.3, 1.4]]), "responses": torch.zeros((2, 2)), # 2 samples, 2 tokens each - "attention_mask": torch.tensor([ - [1, 1, 1, 1], # 2 prompt tokens, 2 response tokens - [1, 1, 1, 1], - ]), + "attention_mask": torch.tensor( + [ + [1, 1, 1, 1], # 2 prompt tokens, 2 response tokens + [1, 1, 1, 1], + ] + ), + "response_mask": torch.tensor( + [ + [1, 1], # 2 response tokens + [1, 1], + ] + ), "values": torch.tensor([[0.9, 1.0], [1.1, 1.2]]), } - + def test_compute_data_metrics_with_critic(self): """Test compute_data_metrics with critic enabled.""" metrics = compute_data_metrics(self.batch, use_critic=True) - + # Check that all expected metrics are present self.assertIn("critic/score/mean", metrics) self.assertIn("critic/rewards/mean", metrics) @@ -100,19 +108,19 @@ def test_compute_data_metrics_with_critic(self): self.assertIn("critic/vf_explained_var", metrics) self.assertIn("response_length/mean", metrics) self.assertIn("prompt_length/mean", metrics) - + # Check some specific values self.assertAlmostEqual(metrics["critic/score/mean"], 5.0) # Sum of token_level_scores self.assertAlmostEqual(metrics["critic/rewards/mean"], 2.5) # Sum of token_level_rewards - + def test_compute_data_metrics_without_critic(self): """Test compute_data_metrics with critic disabled.""" metrics = compute_data_metrics(self.batch, use_critic=False) - + # Check that critic-specific metrics are not present self.assertNotIn("critic/values/mean", metrics) self.assertNotIn("critic/vf_explained_var", metrics) - + # Check that other metrics are still present self.assertIn("critic/score/mean", metrics) self.assertIn("critic/rewards/mean", metrics) @@ -121,48 +129,50 @@ def test_compute_data_metrics_without_critic(self): class TestComputeTimingMetrics(unittest.TestCase): """Tests for the compute_timing_metrics function.""" - + def setUp(self): """Set up common test data.""" # Create a mock DataProto object self.batch = MagicMock() self.batch.batch = { "responses": torch.zeros((2, 3)), # 2 samples, 3 response tokens each - "attention_mask": torch.tensor([ - [1, 1, 1, 1, 1, 1], # 3 prompt tokens, 3 response tokens - [1, 1, 1, 1, 1, 1], - ]), + "attention_mask": torch.tensor( + [ + [1, 1, 1, 1, 1, 1], # 3 prompt tokens, 3 response tokens + [1, 1, 1, 1, 1, 1], + ] + ), } - + # Mock the _compute_response_info function to return known values self.response_info = { "prompt_length": torch.tensor([3.0, 3.0]), "response_length": torch.tensor([3.0, 3.0]), "response_mask": torch.ones((2, 3)), } - + @patch("verl.trainer.ppo.metric_utils._compute_response_info") def test_compute_timing_metrics(self, mock_compute_response_info): """Test compute_timing_metrics with various timing data.""" mock_compute_response_info.return_value = self.response_info - + timing_raw = { "gen": 0.5, # 500ms "ref": 0.3, # 300ms "values": 0.2, # 200ms } - + metrics = compute_timing_metrics(self.batch, timing_raw) - + # Check raw timing metrics self.assertEqual(metrics["timing_s/gen"], 0.5) self.assertEqual(metrics["timing_s/ref"], 0.3) self.assertEqual(metrics["timing_s/values"], 0.2) - + # Check per-token timing metrics # gen uses only response tokens (6 tokens) self.assertAlmostEqual(metrics["timing_per_token_ms/gen"], 0.5 * 1000 / 6, places=5) - + # ref and values use all tokens (12 tokens) self.assertAlmostEqual(metrics["timing_per_token_ms/ref"], 0.3 * 1000 / 12, places=5) self.assertAlmostEqual(metrics["timing_per_token_ms/values"], 0.2 * 1000 / 12, places=5) @@ -170,7 +180,7 @@ def test_compute_timing_metrics(self, mock_compute_response_info): class TestComputeThroughputMetrics(unittest.TestCase): """Tests for the compute_throughout_metrics function.""" - + def setUp(self): """Set up common test data.""" # Create a mock DataProto object @@ -178,23 +188,23 @@ def setUp(self): self.batch.meta_info = { "global_token_num": [100, 200, 300], # 600 tokens total } - + def test_compute_throughout_metrics(self): """Test compute_throughout_metrics with various timing data.""" timing_raw = { "step": 2.0, # 2 seconds per step } - + # Test with 1 GPU metrics = compute_throughout_metrics(self.batch, timing_raw, n_gpus=1) - + self.assertEqual(metrics["perf/total_num_tokens"], 600) self.assertEqual(metrics["perf/time_per_step"], 2.0) self.assertEqual(metrics["perf/throughput"], 600 / 2.0) # 300 tokens/sec - + # Test with 2 GPUs metrics = compute_throughout_metrics(self.batch, timing_raw, n_gpus=2) - + self.assertEqual(metrics["perf/total_num_tokens"], 600) self.assertEqual(metrics["perf/time_per_step"], 2.0) self.assertEqual(metrics["perf/throughput"], 600 / (2.0 * 2)) # 150 tokens/sec/GPU @@ -202,31 +212,31 @@ def test_compute_throughout_metrics(self): class TestBootstrapMetric(unittest.TestCase): """Tests for the bootstrap_metric function.""" - + def test_bootstrap_metric_basic(self): """Test bootstrap_metric with simple data and functions.""" data = [1, 2, 3, 4, 5] reduce_fns = [np.mean, np.max] - + # Use a fixed seed for reproducibility result = bootstrap_metric(data, subset_size=3, reduce_fns=reduce_fns, n_bootstrap=100, seed=42) - + # Check that we get two results (one for each reduce_fn) self.assertEqual(len(result), 2) - + # Each result should be a tuple of (mean, std) mean_result, max_result = result self.assertEqual(len(mean_result), 2) self.assertEqual(len(max_result), 2) - + # The mean of means should be close to the true mean (3.0) self.assertAlmostEqual(mean_result[0], 3.0, delta=0.3) - + # The mean of maxes should be close to the expected value for samples of size 3 # For samples of size 3 from [1,2,3,4,5], the expected max is around 4.0-4.5 self.assertGreater(max_result[0], 3.5) self.assertLess(max_result[0], 5.0) - + def test_bootstrap_metric_empty(self): """Test bootstrap_metric with empty data.""" with self.assertRaises(ValueError): @@ -235,7 +245,7 @@ def test_bootstrap_metric_empty(self): class TestCalcMajVal(unittest.TestCase): """Tests for the calc_maj_val function.""" - + def test_calc_maj_val_basic(self): """Test calc_maj_val with simple data.""" data = [ @@ -243,12 +253,12 @@ def test_calc_maj_val_basic(self): {"pred": "B", "val": 0.8}, {"pred": "A", "val": 0.7}, ] - + result = calc_maj_val(data, vote_key="pred", val_key="val") - + # "A" is the majority vote, so we should get the first "val" for "A" self.assertEqual(result, 0.9) - + def test_calc_maj_val_tie(self): """Test calc_maj_val with tied votes.""" data = [ @@ -257,18 +267,18 @@ def test_calc_maj_val_tie(self): {"pred": "B", "val": 0.7}, {"pred": "A", "val": 0.6}, ] - + # In case of a tie, the first key in sorted order wins # This depends on Python's dict implementation, but for this test # we just verify that one of the valid values is returned result = calc_maj_val(data, vote_key="pred", val_key="val") - + self.assertTrue(result in [0.9, 0.8]) class TestProcessValidationMetrics(unittest.TestCase): """Tests for the process_validation_metrics function.""" - + def test_process_validation_metrics_basic(self): """Test process_validation_metrics with simple data.""" data_sources = ["source1", "source1", "source2"] @@ -276,24 +286,22 @@ def test_process_validation_metrics_basic(self): infos_dict = { "score": [0.8, 0.9, 0.7], } - - result = process_validation_metrics( - data_sources, sample_inputs, infos_dict, seed=42 - ) - + + result = process_validation_metrics(data_sources, sample_inputs, infos_dict, seed=42) + # Check the structure of the result self.assertIn("source1", result) self.assertIn("source2", result) - + # Check that source1 has metrics for score self.assertIn("score", result["source1"]) - + # Check that mean@2 is present for source1/score self.assertIn("mean@2", result["source1"]["score"]) - + # Check the value of mean@2 for source1/score self.assertAlmostEqual(result["source1"]["score"]["mean@2"], 0.85) - + def test_process_validation_metrics_with_pred(self): """Test process_validation_metrics with prediction data.""" data_sources = ["source1", "source1", "source1"] @@ -302,14 +310,12 @@ def test_process_validation_metrics_with_pred(self): "score": [0.8, 0.9, 0.7], "pred": ["A", "B", "A"], } - - result = process_validation_metrics( - data_sources, sample_inputs, infos_dict, seed=42 - ) - + + result = process_validation_metrics(data_sources, sample_inputs, infos_dict, seed=42) + # Check that majority voting metrics are present self.assertIn("maj@2/mean", result["source1"]["score"]) - + # For bootstrap with n=2, the majority vote could be either A or B # depending on the random sampling, so we don't check the exact value diff --git a/tests/utils/cpu_tests/_test_module.py b/tests/utils/_test_module.py similarity index 100% rename from tests/utils/cpu_tests/_test_module.py rename to tests/utils/_test_module.py diff --git a/tests/utils/ckpt/test_esi_save_ckpt_on_cpu.py b/tests/utils/ckpt/test_esi_save_ckpt_on_cpu.py new file mode 100644 index 000000000..203494bd9 --- /dev/null +++ b/tests/utils/ckpt/test_esi_save_ckpt_on_cpu.py @@ -0,0 +1,70 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import time +from datetime import datetime, timedelta +from unittest import TestCase + +from verl.utils.checkpoint.checkpoint_manager import should_save_ckpt_esi + + +class TestShouldSaveCkptEsi(TestCase): + def test_no_expiration_timestamp(self): + """Test case when no expiration timestamp is set""" + os.environ.pop("MLP_CURRENT_CAPACITY_BLOCK_EXPIRATION_TIMESTAMP", None) + os.environ.pop("SAGEMAKER_CURRENT_CAPACITY_BLOCK_EXPIRATION_TIMESTAMP", None) + self.assertFalse(should_save_ckpt_esi(100)) + + def test_mlp_expiration_valid(self): + """Test valid MLP expiration timestamp requiring save""" + current_time = time.time() + os.environ["MLP_CURRENT_CAPACITY_BLOCK_EXPIRATION_TIMESTAMP"] = str(current_time + 90) + self.assertTrue(should_save_ckpt_esi(30)) # max_steps_duration=30 seconds + + def test_mlp_expiration_passed(self): + """Test expired MLP timestamp""" + current_time = time.time() + os.environ["MLP_CURRENT_CAPACITY_BLOCK_EXPIRATION_TIMESTAMP"] = str(current_time - 10) + self.assertFalse(should_save_ckpt_esi(30)) + + def test_mlp_invalid_timestamp(self): + """Test invalid MLP timestamp format""" + os.environ["MLP_CURRENT_CAPACITY_BLOCK_EXPIRATION_TIMESTAMP"] = "invalid" + self.assertFalse(should_save_ckpt_esi(30)) + + def test_mlp_expiration_not_reached(self): + """Test MLP expiration timestamp with insufficient remaining time""" + current_time = time.time() + os.environ["MLP_CURRENT_CAPACITY_BLOCK_EXPIRATION_TIMESTAMP"] = str(current_time + 200) + self.assertFalse(should_save_ckpt_esi(30)) # max_steps_duration=30 + + def test_aws_expiration_not_reached(self): + """Test AWS expiration timestamp with sufficient remaining time""" + now = datetime.now() + expiration = now + timedelta(minutes=100) # Exceeds 90-minute threshold + os.environ["SAGEMAKER_CURRENT_CAPACITY_BLOCK_EXPIRATION_TIMESTAMP"] = str(int(expiration.timestamp())) + self.assertFalse(should_save_ckpt_esi(30 * 60)) + + def test_redundant_time(self): + """Test redundant_time parameter effect""" + current_time = time.time() + # Total required: 60+30+30=120 seconds + os.environ["MLP_CURRENT_CAPACITY_BLOCK_EXPIRATION_TIMESTAMP"] = str(current_time + 120) + self.assertTrue(should_save_ckpt_esi(30, redundant_time=30)) + + def test_zero_max_steps_duration(self): + """Test zero max_steps_duration""" + current_time = time.time() + os.environ["MLP_CURRENT_CAPACITY_BLOCK_EXPIRATION_TIMESTAMP"] = str(current_time + 60) + self.assertFalse(should_save_ckpt_esi(0)) diff --git a/tests/utils/dataset/test_create_rl_sampler_on_cpu.py b/tests/utils/dataset/test_create_rl_sampler_on_cpu.py new file mode 100644 index 000000000..35bf5a3ab --- /dev/null +++ b/tests/utils/dataset/test_create_rl_sampler_on_cpu.py @@ -0,0 +1,108 @@ +# Copyright 2025 Amazon.com Inc and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +test create_rl_sampler +""" + +from collections.abc import Sized + +import pytest +import torch +from omegaconf import DictConfig, OmegaConf +from torch.utils.data import Dataset, RandomSampler + +from verl.experimental.dataset.sampler import AbstractCurriculumSampler +from verl.trainer.main_ppo import create_rl_sampler + + +class RandomCurriculumSampler(AbstractCurriculumSampler): + def __init__( + self, + data_source: Sized, + data_config: DictConfig, + ): + train_dataloader_generator = torch.Generator() + train_dataloader_generator.manual_seed(1) + sampler = RandomSampler(data_source=data_source) + self.sampler = sampler + + def __iter__(self): + return self.sampler.__iter__() + + def __len__(self) -> int: + return len(self.sampler) + + def update(self, batch) -> None: + return + + +class MockIncorrectSampler: + """A fake sampler class that does not adhere to the AbstractCurriculumSampler interface.""" + + def __init__(self, data_source, data_config): + pass + + +class MockChatDataset(Dataset): + def __init__(self): + self.data = [ + {"prompt": "What's your name?", "response": "My name is Assistant."}, + {"prompt": "How are you?", "response": "I'm doing well, thank you."}, + {"prompt": "What is the capital of France?", "response": "Paris."}, + { + "prompt": "Tell me a joke.", + "response": "Why did the chicken cross the road? To get to the other side!", + }, + {"prompt": "What is 2+2?", "response": "4"}, + ] + + def __getitem__(self, index): + return self.data[index] + + def __len__(self): + return len(self.data) + + +def test_create_custom_curriculum_samper(): + data_config = OmegaConf.create( + { + "dataloader_num_workers": 0, + "sampler": { + "class_path": "pkg://tests.utils.dataset.test_create_rl_sampler_on_cpu", + "class_name": "RandomCurriculumSampler", + }, + } + ) + + dataset = MockChatDataset() + + # doesn't raise + create_rl_sampler(data_config, dataset) + + +def test_create_custom_curriculum_samper_wrong_class(): + data_config = OmegaConf.create( + { + "sampler": { + "class_path": "pkg://tests.utils.dataset.test_create_rl_sampler_on_cpu", + "class_name": "MockIncorrectSampler", + } + } + ) + + dataset = MockChatDataset() + + # MockIncorrectSampler is not an instance of AbstractCurriculumSampler, so raises + with pytest.raises(AssertionError): + create_rl_sampler(data_config, dataset) diff --git a/tests/utils/gpu_tests/dataset/test_multiturn_sft_dataset.py b/tests/utils/dataset/test_multiturn_sft_dataset_on_cpu.py similarity index 90% rename from tests/utils/gpu_tests/dataset/test_multiturn_sft_dataset.py rename to tests/utils/dataset/test_multiturn_sft_dataset_on_cpu.py index 64ee8f532..8028d44e5 100644 --- a/tests/utils/gpu_tests/dataset/test_multiturn_sft_dataset.py +++ b/tests/utils/dataset/test_multiturn_sft_dataset_on_cpu.py @@ -75,7 +75,9 @@ def test_multiturn_sft_dataset(): # Test 3: Shape Consistency assert item0["loss_mask"].shape == item0["input_ids"].shape, "Loss mask shape doesn't match input_ids shape" - assert item0["attention_mask"].shape == item0["input_ids"].shape, "Attention mask shape doesn't match input_ids shape" + assert item0["attention_mask"].shape == item0["input_ids"].shape, ( + "Attention mask shape doesn't match input_ids shape" + ) assert item0["position_ids"].shape == item0["input_ids"].shape, "Position IDs shape doesn't match input_ids shape" # Test 4: Loss Mask Pattern - Math Conversation @@ -116,7 +118,9 @@ def test_multiturn_sft_dataset(): # Test 7: Position IDs Pattern position_ids0 = item0["position_ids"] - assert torch.equal(position_ids0[:sequence_length], torch.arange(sequence_length)), "Position IDs not sequential for non-padded tokens" + assert torch.equal(position_ids0[:sequence_length], torch.arange(sequence_length)), ( + "Position IDs not sequential for non-padded tokens" + ) if sequence_length < len(position_ids0): assert torch.all(position_ids0[sequence_length:] == 0), "Padding position IDs not zero" @@ -137,7 +141,9 @@ def test_multiturn_sft_dataset(): # The content should NOT appear in the non-masked text non_assistant_text = tokenizer.decode(input_ids0[loss_mask0 == 0]) - assert msg["content"] not in non_assistant_text, f"Assistant message '{msg['content']}' found in non-assistant text" + assert msg["content"] not in non_assistant_text, ( + f"Assistant message '{msg['content']}' found in non-assistant text" + ) # Test 9: Verify non-assistant parts have loss_mask=0 # Get non-assistant text @@ -147,10 +153,14 @@ def test_multiturn_sft_dataset(): # Verify that system and user messages are in the non-assistant text for msg in test_data["messages"][0]: # First conversation if msg["role"] in ["system", "user"]: - assert msg["content"] in non_assistant_text, f"{msg['role'].title()} message '{msg['content']}' not found in non-assistant text" + assert msg["content"] in non_assistant_text, ( + f"{msg['role'].title()} message '{msg['content']}' not found in non-assistant text" + ) # And verify they're NOT in the assistant text - assert msg["content"] not in assistant_text, f"{msg['role'].title()} message '{msg['content']}' found in assistant text" + assert msg["content"] not in assistant_text, ( + f"{msg['role'].title()} message '{msg['content']}' found in assistant text" + ) # Test 10: Verify padding behavior padding_config = {"max_length": 1024, "truncation": "error", "multiturn": {"messages_key": "messages"}} @@ -161,7 +171,9 @@ def test_multiturn_sft_dataset(): actual_length = torch.sum(padded_item["attention_mask"]) # Verify padding tokens - assert torch.all(padded_item["input_ids"][actual_length:] == tokenizer.pad_token_id), "Padding tokens not set correctly" + assert torch.all(padded_item["input_ids"][actual_length:] == tokenizer.pad_token_id), ( + "Padding tokens not set correctly" + ) assert torch.all(padded_item["attention_mask"][actual_length:] == 0), "Attention mask not set correctly for padding" assert torch.all(padded_item["loss_mask"][actual_length:] == 0), "Loss mask not set correctly for padding" diff --git a/tests/utils/gpu_tests/dataset/test_rl_dataset.py b/tests/utils/dataset/test_rl_dataset_on_cpu.py similarity index 95% rename from tests/utils/gpu_tests/dataset/test_rl_dataset.py rename to tests/utils/dataset/test_rl_dataset_on_cpu.py index 4f87cd510..2afc3ef49 100644 --- a/tests/utils/gpu_tests/dataset/test_rl_dataset.py +++ b/tests/utils/dataset/test_rl_dataset_on_cpu.py @@ -104,8 +104,8 @@ def test_image_rl_data(): data_proto = DataProto.from_dict(tensors=tensors, non_tensors=non_tensors) - assert "multi_modal_data" in data_proto.non_tensor_batch - assert "multi_modal_inputs" in data_proto.non_tensor_batch + assert "multi_modal_data" in data_proto.non_tensor_batch, data_proto + assert "multi_modal_inputs" in data_proto.non_tensor_batch, data_proto data = dataset[0]["input_ids"] output = tokenizer.batch_decode([data])[0] diff --git a/tests/utils/gpu_tests/dataset/test_sft_dataset.py b/tests/utils/dataset/test_sft_dataset_on_cpu.py similarity index 100% rename from tests/utils/gpu_tests/dataset/test_sft_dataset.py rename to tests/utils/dataset/test_sft_dataset_on_cpu.py diff --git a/tests/utils/gpu_tests/checkpoint/run_deepseek_megatron_ckpt.sh b/tests/utils/gpu_tests/checkpoint/run_deepseek_megatron_ckpt.sh deleted file mode 100644 index 9e35b58b5..000000000 --- a/tests/utils/gpu_tests/checkpoint/run_deepseek_megatron_ckpt.sh +++ /dev/null @@ -1,91 +0,0 @@ -set -x - -# the config file used: verl/trainer/main_ppo/config/ppo_megatron_trainer.yaml - -huggingface-cli download deepseek-ai/deepseek-coder-1.3b-instruct - -export VLLM_ATTENTION_BACKEND=XFORMERS - -python3 -m verl.trainer.main_ppo --config-path=config \ - --config-name='ppo_megatron_trainer.yaml'\ - data.train_files=$HOME/data/gsm8k/train.parquet \ - data.val_files=$HOME/data/gsm8k/test.parquet \ - data.train_batch_size=1024 \ - data.max_prompt_length=512 \ - data.max_response_length=512 \ - actor_rollout_ref.model.path=deepseek-ai/deepseek-coder-1.3b-instruct \ - actor_rollout_ref.actor.optim.lr=2e-6 \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \ - actor_rollout_ref.actor.megatron.virtual_pipeline_model_parallel_size=2 \ - actor_rollout_ref.actor.megatron.tensor_model_parallel_size=4 \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \ - actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=2 \ - actor_rollout_ref.ref.megatron.virtual_pipeline_model_parallel_size=2 \ - actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \ - critic.optim.lr=2e-5 \ - critic.model.path=deepseek-ai/deepseek-coder-1.3b-instruct \ - critic.model.enable_gradient_checkpointing=False \ - critic.ppo_micro_batch_size_per_gpu=4 \ - critic.megatron.pipeline_model_parallel_size=2 \ - critic.megatron.virtual_pipeline_model_parallel_size=2 \ - critic.megatron.tensor_model_parallel_size=2 \ - algorithm.kl_ctrl.kl_coef=0.001 \ - trainer.critic_warmup=0 \ - trainer.logger=['console'] \ - trainer.project_name='verl_megatron_gsm8k_examples' \ - trainer.experiment_name='deepseek_megatron_checkpoint_saveload' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=1 \ - trainer.test_freq=1 \ - trainer.total_epochs=15 \ - trainer.total_training_steps=1 $@ - - -python3 -m verl.trainer.main_ppo --config-path=config \ - --config-name='ppo_megatron_trainer.yaml'\ - data.train_files=$HOME/data/gsm8k/train.parquet \ - data.val_files=$HOME/data/gsm8k/test.parquet \ - data.train_batch_size=1024 \ - data.max_prompt_length=512 \ - data.max_response_length=512 \ - actor_rollout_ref.model.path=deepseek-ai/deepseek-coder-1.3b-instruct \ - actor_rollout_ref.actor.optim.lr=2e-6 \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \ - actor_rollout_ref.actor.megatron.virtual_pipeline_model_parallel_size=2 \ - actor_rollout_ref.actor.megatron.tensor_model_parallel_size=4 \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \ - actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=2 \ - actor_rollout_ref.ref.megatron.virtual_pipeline_model_parallel_size=2 \ - actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \ - critic.optim.lr=2e-5 \ - critic.model.path=deepseek-ai/deepseek-coder-1.3b-instruct \ - critic.model.enable_gradient_checkpointing=False \ - critic.ppo_micro_batch_size_per_gpu=4 \ - critic.megatron.pipeline_model_parallel_size=2 \ - critic.megatron.virtual_pipeline_model_parallel_size=2 \ - critic.megatron.tensor_model_parallel_size=2 \ - algorithm.kl_ctrl.kl_coef=0.001 \ - trainer.critic_warmup=0 \ - trainer.logger=['console'] \ - trainer.project_name='verl_megatron_gsm8k_examples' \ - trainer.experiment_name='deepseek_megatron_checkpoint_saveload' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.resume_mode=auto \ - trainer.save_freq=-1 \ - trainer.test_freq=1 \ - trainer.total_epochs=15 \ - trainer.total_training_steps=1 $@ \ No newline at end of file diff --git a/tests/utils/gpu_tests/checkpoint/run_qwen_megatron_ckpt.sh b/tests/utils/gpu_tests/checkpoint/run_qwen_megatron_ckpt.sh deleted file mode 100644 index 17a69b710..000000000 --- a/tests/utils/gpu_tests/checkpoint/run_qwen_megatron_ckpt.sh +++ /dev/null @@ -1,91 +0,0 @@ -set -x - -# the config file used: verl/trainer/main_ppo/config/ppo_megatron_trainer.yaml - -huggingface-cli download Qwen/Qwen2.5-0.5B - -export VLLM_ATTENTION_BACKEND=XFORMERS - -python3 -m verl.trainer.main_ppo --config-path=config \ - --config-name='ppo_megatron_trainer.yaml'\ - data.train_files=$HOME/data/gsm8k/train.parquet \ - data.val_files=$HOME/data/gsm8k/test.parquet \ - data.train_batch_size=1024 \ - data.max_prompt_length=512 \ - data.max_response_length=512 \ - actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B \ - actor_rollout_ref.actor.optim.lr=2e-6 \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \ - actor_rollout_ref.actor.megatron.virtual_pipeline_model_parallel_size=2 \ - actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \ - actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=2 \ - actor_rollout_ref.ref.megatron.virtual_pipeline_model_parallel_size=2 \ - actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \ - critic.optim.lr=2e-5 \ - critic.model.path=Qwen/Qwen2.5-0.5B \ - critic.model.enable_gradient_checkpointing=False \ - critic.ppo_micro_batch_size_per_gpu=4 \ - critic.megatron.pipeline_model_parallel_size=2 \ - critic.megatron.virtual_pipeline_model_parallel_size=2 \ - critic.megatron.tensor_model_parallel_size=2 \ - algorithm.kl_ctrl.kl_coef=0.001 \ - trainer.critic_warmup=0 \ - trainer.logger=['console'] \ - trainer.project_name='verl_megatron_gsm8k_examples' \ - trainer.experiment_name='qwen2_5_0b5_megatron_saveload' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=1 \ - trainer.test_freq=1 \ - trainer.total_epochs=15 \ - trainer.total_training_steps=1 $@ - - -python3 -m verl.trainer.main_ppo --config-path=config \ - --config-name='ppo_megatron_trainer.yaml'\ - data.train_files=$HOME/data/gsm8k/train.parquet \ - data.val_files=$HOME/data/gsm8k/test.parquet \ - data.train_batch_size=1024 \ - data.max_prompt_length=512 \ - data.max_response_length=512 \ - actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B \ - actor_rollout_ref.actor.optim.lr=2e-6 \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \ - actor_rollout_ref.actor.megatron.virtual_pipeline_model_parallel_size=2 \ - actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \ - actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=2 \ - actor_rollout_ref.ref.megatron.virtual_pipeline_model_parallel_size=2 \ - actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \ - critic.optim.lr=2e-5 \ - critic.model.path=Qwen/Qwen2.5-0.5B \ - critic.model.enable_gradient_checkpointing=False \ - critic.ppo_micro_batch_size_per_gpu=4 \ - critic.megatron.pipeline_model_parallel_size=2 \ - critic.megatron.virtual_pipeline_model_parallel_size=2 \ - critic.megatron.tensor_model_parallel_size=2 \ - algorithm.kl_ctrl.kl_coef=0.001 \ - trainer.critic_warmup=0 \ - trainer.logger=['console'] \ - trainer.project_name='verl_megatron_gsm8k_examples' \ - trainer.experiment_name='qwen2_5_0b5_megatron_saveload' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.resume_mode=auto \ - trainer.save_freq=-1 \ - trainer.test_freq=1 \ - trainer.total_epochs=15 \ - trainer.total_training_steps=1 $@ \ No newline at end of file diff --git a/tests/utils/gpu_tests/checkpoint/test_megatron_ckpt.py b/tests/utils/gpu_tests/checkpoint/test_megatron_ckpt.py deleted file mode 100644 index 54734f60d..000000000 --- a/tests/utils/gpu_tests/checkpoint/test_megatron_ckpt.py +++ /dev/null @@ -1,196 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Using FSDPTrainer -""" -import os - -import hydra -import ray -from transformers import AutoTokenizer - -from verl.trainer.ppo.ray_trainer import RayPPOTrainer -from verl.utils.fs import copy_local_path_from_hdfs - -MODEL_PATH = 'Qwen/Qwen2.5-0.5B' -DATA_PATH = 'data/gsm8k/' -SAVE_PATH = '/tmp/checkpoint' - - -def make_reward_function(tokenizer, num_examine): - return None - - -additional_config = { - 'data': { - 'train_files': f'{DATA_PATH}/train.parquet', - 'val_files': f'{DATA_PATH}/test.parquet', - 'train_batch_size': 1024, - 'val_batch_size': 1312, - 'max_prompt_length': 512, - 'max_response_length': 512 - }, - 'actor_rollout_ref': { - 'model': { - 'path': MODEL_PATH - }, - 'actor': { - 'optim': { - 'lr': 2e-6 - }, - 'ppo_mini_batch_size': 32, - 'ppo_micro_batch_size_per_gpu': 1, - 'megatron': { - 'tensor_model_parallel_size': 2, - 'pipeline_model_parallel_size': 4, - } - }, - 'rollout': { - 'log_prob_micro_batch_size_per_gpu': 8, - 'tensor_model_parallel_size': 2, - 'name': 'vllm', - 'gpu_memory_utilization': 0.5 - }, - 'ref': { - 'log_prob_micro_batch_size_per_gpu': 16, - 'megatron': { - 'tensor_model_parallel_size': 2 - } - } - }, - 'critic': { - 'optim': { - 'lr': 2e-5 - }, - 'model': { - 'path': MODEL_PATH, - 'enable_gradient_checkpointing': False - }, - 'ppo_micro_batch_size_per_gpu': 4, - 'megatron': { - 'tensor_model_parallel_size': 2 - } - }, - 'algorithm': { - 'kl_ctrl': { - 'kl_coef': 0.001 - }, - 'adv_estimator': 'grpo', - }, - 'trainer': { - 'critic_warmup': 0, - 'logger': ['console'], - 'project_name': 'verl_megatron_gsm8k_examples', - 'experiment_name': 'qwen2_5_0b5_function_rm', - 'n_gpus_per_node': 8, - 'nnodes': 1, - 'save_freq': 1, - 'test_freq': 1, - 'total_epochs': 15, - 'total_training_steps': 3, - } -} - - -def check_result(origin_path, megatron_path, input_text): - from transformers import AutoModelForCausalLM - import torch - print("check result") - torch_dtype = torch.float16 - origin_model = AutoModelForCausalLM.from_pretrained( - origin_path, - torch_dtype=torch_dtype, - ).eval() - - origin_model = origin_model.to('cuda') - tokenizer = AutoTokenizer.from_pretrained(origin_path) - - inputs = tokenizer(input_text, return_tensors="pt").to('cuda') - origin_outputs = origin_model.generate(**inputs, max_new_tokens=8, do_sample=False) - origin_text = tokenizer.decode(origin_outputs[0], skip_special_tokens=True) - print(f"origin_text: {origin_text}") - - megatron_model = AutoModelForCausalLM.from_pretrained( - megatron_path, - torch_dtype=torch_dtype, - ).eval() - megatron_model = megatron_model.to('cuda') - megatron_outputs = megatron_model.generate(**inputs, max_new_tokens=8, do_sample=False) - megatron_text = tokenizer.decode(megatron_outputs[0], skip_special_tokens=True) - print(f"megatron_text: {megatron_text}") - - assert origin_text == megatron_text, "megatron ckpt is diff from origin ckpt" - - -@hydra.main(config_path='../../../verl/verl/trainer/config', config_name='ppo_megatron_trainer', version_base=None) -def main(config): - ray.init() - - from omegaconf import OmegaConf - from pprint import pprint - - additional_omegaconf = OmegaConf.create(additional_config) - config = OmegaConf.merge(config, additional_omegaconf) - - # print initial config - pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values - - # print the config - print('Config after normalizing batch_size') - pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values - - config.trainer.logger = ['console'] - # download the checkpoint from hdfs - local_path = copy_local_path_from_hdfs(config.actor_rollout_ref.model.path) - local_path = os.path.expanduser(local_path) - # instantiate tokenizern - tokenizer = AutoTokenizer.from_pretrained(local_path) - print(f'Tokenizer vocab_size: {tokenizer.vocab_size}') - - # define worker classes - from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker - from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role - - role_worker_mapping = { - Role.ActorRollout: ray.remote(ActorRolloutRefWorker), - Role.Critic: ray.remote(CriticWorker), - } - - global_pool_id = 'global_pool' - resource_pool_spec = { - global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, - } - mapping = { - Role.ActorRollout: global_pool_id, - Role.Critic: global_pool_id, - } - - reward_fn = make_reward_function(tokenizer=tokenizer, num_examine=1) - - resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) - - trainer = RayPPOTrainer(config=config, - tokenizer=tokenizer, - role_worker_mapping=role_worker_mapping, - resource_pool_manager=resource_pool_manager, - reward_fn=reward_fn, - val_reward_fn=reward_fn) - trainer.init_workers() - print(f"actor model : {trainer.actor_rollout_wg}") - trainer.actor_rollout_wg.save_checkpoint(SAVE_PATH) - - -if __name__ == '__main__': - main() - check_result(MODEL_PATH, SAVE_PATH, "who are you?") diff --git a/tests/utils/gpu_tests/dataset/test_rm_dataset.py b/tests/utils/gpu_tests/dataset/test_rm_dataset.py deleted file mode 100644 index 72b4b3a23..000000000 --- a/tests/utils/gpu_tests/dataset/test_rm_dataset.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os - -from verl.utils import hf_tokenizer -from verl.utils.dataset.rm_dataset import RMDataset - - -def get_rm_data(): - # prepare test dataset - local_folder = os.path.expanduser("~/verl-data/full_hh_rlhf/rm/") - local_path = os.path.join(local_folder, "test.parquet") - os.makedirs(local_folder, exist_ok=True) - return local_path - - -def test_rm_dataset(): - tokenizer = hf_tokenizer("facebook/opt-1.3b") - local_path = get_rm_data() - dataset = RMDataset(parquet_files=local_path, tokenizer=tokenizer, max_length=512) - data = dataset[0]["input_ids"] - output = tokenizer.batch_decode(data) - assert len(output) > 1 - assert isinstance(output[0], str) diff --git a/tests/utils/gpu_tests/megatron/test_pipeline_parallel.py b/tests/utils/megatron/test_pipeline_parallel.py similarity index 64% rename from tests/utils/gpu_tests/megatron/test_pipeline_parallel.py rename to tests/utils/megatron/test_pipeline_parallel.py index cf442a03b..24a416987 100644 --- a/tests/utils/gpu_tests/megatron/test_pipeline_parallel.py +++ b/tests/utils/megatron/test_pipeline_parallel.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pytest + +from verl.model_merger.megatron_model_merger import get_dynamic_pipeline_shards from verl.utils.megatron.pipeline_parallel import make_batch_generator @@ -45,3 +48,23 @@ def test_make_batch_generator_empty(): assert len(generators) == vpp_size for gen in generators: assert list(gen) == [] + + +@pytest.mark.parametrize( + "layer_num,pp_size,gt", + [ + (61, 8, [6, 8, 8, 8, 8, 8, 8, 7]), + (61, 7, [8, 9, 9, 9, 9, 9, 8]), + (61, 1, [61]), + (61, 0, ValueError), + (10, 16, ValueError), + ], +) +def test_get_dynamic_pipeline_shards(layer_num, pp_size, gt): + if isinstance(gt, list): + shards = get_dynamic_pipeline_shards(layer_num, pp_size) + assert len(shards) == len(gt) == pp_size, f"Expected {pp_size} shards, got {len(shards)}" + assert all([shard == gt[i] for i, shard in enumerate(shards)]), f"Expected shards {gt}, got {shards}" + elif issubclass(gt, Exception): + with pytest.raises(gt): + shards = get_dynamic_pipeline_shards(layer_num, pp_size) diff --git a/tests/reward_score/test_sandbox_fusion.py b/tests/utils/reward_score/reward_score/test_sandbox_fusion_on_cpu.py similarity index 81% rename from tests/reward_score/test_sandbox_fusion.py rename to tests/utils/reward_score/reward_score/test_sandbox_fusion_on_cpu.py index 4c4a383fc..aaa427183 100644 --- a/tests/reward_score/test_sandbox_fusion.py +++ b/tests/utils/reward_score/reward_score/test_sandbox_fusion_on_cpu.py @@ -125,7 +125,8 @@ def test_integration_runtime_timeout(): @pytest.mark.skipif(skip_condition, reason=skip_reason) def test_integration_concurrency_high_load(): - """Integration test: High concurrency (100 cases) against real API with mixed results (success, wrong answer, timeout)""" + """Integration test: High concurrency (100 cases) against real API with mixed results (success, wrong + answer, timeout)""" concurrency_level = 100 # Indices for different expected outcomes wrong_answer_indices = {10, 25, 50} @@ -181,7 +182,10 @@ def test_integration_concurrency_high_load(): ) end_time = time.time() duration = end_time - start_time - print(f"\nHigh concurrency test ({concurrency_level} cases with {len(wrong_answer_indices)} wrong answers, {len(timeout_indices)} timeouts) duration: {duration:.2f} seconds") + print( + f"\nHigh concurrency test ({concurrency_level} cases with {len(wrong_answer_indices)} wrong answers, " + f"{len(timeout_indices)} timeouts) duration: {duration:.2f} seconds" + ) # Verify results against the expected map assert len(results) == concurrency_level, f"Expected {concurrency_level} results, got {len(results)}" @@ -202,7 +206,10 @@ def test_integration_concurrency_high_load(): else: unexpected_results.append((i, r, f"Expected {expected}")) - print(f"Correct results (True): {correct_count}/{concurrency_level - len(wrong_answer_indices) - len(timeout_indices)}") + print( + f"Correct results (True): {correct_count}/" + f"{concurrency_level - len(wrong_answer_indices) - len(timeout_indices)}" + ) print(f"Expected wrong answers (False, correctly identified): {wrong_count}/{len(wrong_answer_indices)}") print(f"Expected timeouts (-3, correctly identified): {timeout_count}/{len(timeout_indices)}") @@ -212,14 +219,18 @@ def test_integration_concurrency_high_load(): print(f" Index {idx}: Got {res}, {expected_str}. Metadata: {metadata_list[idx]}") raise AssertionError(f"Found {len(unexpected_results)} unexpected results.") - assert correct_count == concurrency_level - len(wrong_answer_indices) - len(timeout_indices), "Incorrect number of successful results" + assert correct_count == concurrency_level - len(wrong_answer_indices) - len(timeout_indices), ( + "Incorrect number of successful results" + ) assert wrong_count == len(wrong_answer_indices), "Incorrect number of identified wrong answers" assert timeout_count == len(timeout_indices), "Incorrect number of identified timeouts" # Verify metadata count and basic status of one of each type assert len(metadata_list) == concurrency_level # Find the first correct index - first_correct_index = next(i for i in range(concurrency_level) if i not in wrong_answer_indices and i not in timeout_indices) + first_correct_index = next( + i for i in range(concurrency_level) if i not in wrong_answer_indices and i not in timeout_indices + ) assert metadata_list[first_correct_index]["status"] == "success" assert metadata_list[first_correct_index]["stdout"] == f"output_{first_correct_index}\n" @@ -250,12 +261,21 @@ def test_unit_concurrency_order(mock_call_sandbox_api): def side_effect(*args, **kwargs): stdin = kwargs.get("stdin") if stdin == "input1": - return ({"status": "Success", "run_result": {"status": "Finished", "stdout": "output1", "return_code": 0}}, None) + return ( + {"status": "Success", "run_result": {"status": "Finished", "stdout": "output1", "return_code": 0}}, + None, + ) elif stdin == "input2": time.sleep(0.1) - return ({"status": "Success", "run_result": {"status": "Finished", "stdout": "output2", "return_code": 0}}, None) + return ( + {"status": "Success", "run_result": {"status": "Finished", "stdout": "output2", "return_code": 0}}, + None, + ) elif stdin == "input3": - return ({"status": "Success", "run_result": {"status": "Finished", "stdout": "output3", "return_code": 0}}, None) + return ( + {"status": "Success", "run_result": {"status": "Finished", "stdout": "output3", "return_code": 0}}, + None, + ) else: return (None, "Unknown input in mock") @@ -287,11 +307,17 @@ def test_unit_api_timeout_error_concurrent(mock_call_sandbox_api): def side_effect(*args, **kwargs): stdin = kwargs.get("stdin") if stdin == "input1": - return ({"status": "Success", "run_result": {"status": "Finished", "stdout": "output1", "return_code": 0}}, None) + return ( + {"status": "Success", "run_result": {"status": "Finished", "stdout": "output1", "return_code": 0}}, + None, + ) elif stdin == "input2_timeout": return (None, api_error_message) elif stdin == "input3": - return ({"status": "Success", "run_result": {"status": "Finished", "stdout": "output3", "return_code": 0}}, None) + return ( + {"status": "Success", "run_result": {"status": "Finished", "stdout": "output3", "return_code": 0}}, + None, + ) else: return (None, "Unknown input in mock") @@ -313,7 +339,8 @@ def side_effect(*args, **kwargs): MAX_GLOBAL_CONCURRENCY_LIMIT_TEST = 5 # Define the number of processes used in the test NUM_PROCESSES_TEST = 4 -# Define the number of tasks processed by check_correctness in each process (i.e., internal ThreadPoolExecutor's concurrency potential) +# Define the number of tasks processed by check_correctness in each process (i.e., internal +# ThreadPoolExecutor's concurrency potential) NUM_TASKS_PER_PROCESS_TEST = 3 # Simulate API call duration to ensure calls can overlap SIMULATED_API_CALL_DURATION_TEST = 0.2 # seconds @@ -331,6 +358,7 @@ def _mock_api_call_for_concurrency_tracking( stdin, compile_timeout, run_timeout, + memory_limit_mb, language, ): # entry_time = time.time() # For detailed logging @@ -339,7 +367,8 @@ def _mock_api_call_for_concurrency_tracking( if active_calls_counter.value > max_calls_tracker.value: max_calls_tracker.value = active_calls_counter.value # Optional debug log: - # print(f"[PID:{os.getpid()}-TID:{threading.get_ident()}] API Call Start. Active: {active_calls_counter.value}, Max Observed: {max_calls_tracker.value}, Input: {stdin}") + # print(f"[PID:{os.getpid()}-TID:{threading.get_ident()}] API Call Start. Active: " + # f"{active_calls_counter.value}, Max Observed: {max_calls_tracker.value}, Input: {stdin}") time.sleep(SIMULATED_API_CALL_DURATION_TEST) # Simulate actual work duration @@ -347,29 +376,69 @@ def _mock_api_call_for_concurrency_tracking( with call_lock: active_calls_counter.value -= 1 # Optional debug log: - # print(f"[PID:{os.getpid()}-TID:{threading.get_ident()}] API Call End. Active: {active_calls_counter.value}, Input: {stdin}, Duration: {exit_time - entry_time:.2f}s") + # print(f"[PID:{os.getpid()}-TID:{threading.get_ident()}] API Call End. Active: " + # f"{active_calls_counter.value}, Input: {stdin}, Duration: {exit_time - entry_time:.2f}s") # Return a simulated successful API response - return {"status": "Success", "run_result": {"status": "Finished", "stdout": f"mock_output_for_{stdin}", "return_code": 0}}, None + return { + "status": "Success", + "run_result": {"status": "Finished", "stdout": f"mock_output_for_{stdin}", "return_code": 0}, + }, None # --- Worker function for ProcessPoolExecutor --- # This function runs in each child process of ProcessPoolExecutor -def _process_pool_worker_for_concurrency_test(sandbox_url, in_outs, generation, language, timeout, mp_semaphore_for_check_correctness, active_calls_counter, max_calls_tracker, call_lock): +def _process_pool_worker_for_concurrency_test( + sandbox_url, + in_outs, + generation, + memory_limit_mb, + language, + timeout, + mp_semaphore_for_check_correctness, + active_calls_counter, + max_calls_tracker, + call_lock, +): # Corrected lambda to accept keyword arguments matching call_sandbox_api's usage - curried_mock_api_call = lambda sandbox_fusion_url, code, stdin, compile_timeout, run_timeout, language: _mock_api_call_for_concurrency_tracking(active_calls_counter, max_calls_tracker, call_lock, sandbox_fusion_url, code, stdin, compile_timeout, run_timeout, language) + curried_mock_api_call = ( + lambda sandbox_fusion_url, code, stdin, compile_timeout, run_timeout, memory_limit_mb, language: ( + _mock_api_call_for_concurrency_tracking( + active_calls_counter, + max_calls_tracker, + call_lock, + sandbox_fusion_url, + code, + stdin, + compile_timeout, + run_timeout, + memory_limit_mb, + language, + ) + ) + ) # ---- START DEBUG PRINTS ---- import os import verl.utils.reward_score.sandbox_fusion.utils - print(f"[Worker PID:{os.getpid()}] Original call_sandbox_api: {verl.utils.reward_score.sandbox_fusion.utils.call_sandbox_api}", flush=True) + print( + f"[Worker PID:{os.getpid()}] Original call_sandbox_api: " + f"{verl.utils.reward_score.sandbox_fusion.utils.call_sandbox_api}", + flush=True, + ) # ---- END DEBUG PRINTS ---- - with patch("verl.utils.reward_score.sandbox_fusion.utils.call_sandbox_api", side_effect=curried_mock_api_call) as mock_obj: + with patch( + "verl.utils.reward_score.sandbox_fusion.utils.call_sandbox_api", side_effect=curried_mock_api_call + ) as mock_obj: # ---- START DEBUG PRINTS ---- - print(f"[Worker PID:{os.getpid()}] Patched call_sandbox_api: {verl.utils.reward_score.sandbox_fusion.utils.call_sandbox_api}", flush=True) + print( + f"[Worker PID:{os.getpid()}] Patched call_sandbox_api: " + f"{verl.utils.reward_score.sandbox_fusion.utils.call_sandbox_api}", + flush=True, + ) print(f"[Worker PID:{os.getpid()}] Mock object: {mock_obj}", flush=True) # ---- END DEBUG PRINTS ---- results, metadata_list = check_correctness( @@ -377,6 +446,7 @@ def _process_pool_worker_for_concurrency_test(sandbox_url, in_outs, generation, in_outs=in_outs, generation=generation, timeout=timeout, + memory_limit_mb=memory_limit_mb, language=language, concurrent_semaphore=mp_semaphore_for_check_correctness, # Pass multiprocessing.Semaphore ) @@ -403,12 +473,16 @@ def test_multiprocess_global_concurrency_limit_with_semaphore(): mock_sandbox_url = "mock_url_for_concurrency_test" mock_generation = "pass" # Specific code content is not important as API call is mocked + mock_memory_limit_mb = 1024 mock_language = "python" mock_timeout = 5 # Timeout setting, not critical for mock calls # Input/output data for each process # NUM_TASKS_PER_PROCESS_TEST tasks will be handled by check_correctness's internal ThreadPoolExecutor - process_in_outs = {"inputs": [f"task_input_{i}" for i in range(NUM_TASKS_PER_PROCESS_TEST)], "outputs": [f"task_output_{i}" for i in range(NUM_TASKS_PER_PROCESS_TEST)]} + process_in_outs = { + "inputs": [f"task_input_{i}" for i in range(NUM_TASKS_PER_PROCESS_TEST)], + "outputs": [f"task_output_{i}" for i in range(NUM_TASKS_PER_PROCESS_TEST)], + } futures = [] total_tasks_expected_to_run = NUM_PROCESSES_TEST * NUM_TASKS_PER_PROCESS_TEST @@ -422,6 +496,7 @@ def test_multiprocess_global_concurrency_limit_with_semaphore(): mock_sandbox_url, process_in_outs, mock_generation, + mock_memory_limit_mb, mock_language, mock_timeout, global_mp_semaphore, # Global semaphore to test @@ -448,21 +523,32 @@ def test_multiprocess_global_concurrency_limit_with_semaphore(): # print(f"Tasks processed per worker: {num_tasks_processed_per_worker}") # Verify that all submitted tasks have been processed - assert sum(num_tasks_processed_per_worker) == total_tasks_expected_to_run, "Mismatch in the number of tasks processed." + assert sum(num_tasks_processed_per_worker) == total_tasks_expected_to_run, ( + "Mismatch in the number of tasks processed." + ) # Verify that the mock API was called at least once assert max_calls_tracker.value > 0, "The mocked API call_sandbox_api was not called." # Core assertion: Observed maximum concurrent calls should not exceed the semaphore's limit - assert max_calls_tracker.value <= MAX_GLOBAL_CONCURRENCY_LIMIT_TEST, f"Observed concurrency ({max_calls_tracker.value}) exceeded semaphore limit ({MAX_GLOBAL_CONCURRENCY_LIMIT_TEST})." + assert max_calls_tracker.value <= MAX_GLOBAL_CONCURRENCY_LIMIT_TEST, ( + f"Observed concurrency ({max_calls_tracker.value}) exceeded semaphore limit " + f"({MAX_GLOBAL_CONCURRENCY_LIMIT_TEST})." + ) # Optional: Rough check on execution time to verify semaphore is working to limit concurrency # Theoretical minimum execution time = (Total tasks / Concurrency limit) * Single task duration # Actual time will be longer due to various overheads - min_expected_duration = (total_tasks_expected_to_run * SIMULATED_API_CALL_DURATION_TEST) / MAX_GLOBAL_CONCURRENCY_LIMIT_TEST + min_expected_duration = ( + total_tasks_expected_to_run * SIMULATED_API_CALL_DURATION_TEST + ) / MAX_GLOBAL_CONCURRENCY_LIMIT_TEST # print(f"Minimum Expected Execution Time (approx): {min_expected_duration:.2f}s") # Allow some margin, e.g., 80% of theoretical minimum time - assert total_execution_time >= min_expected_duration * 0.8, f"Total execution time ({total_execution_time:.2f}s) was unexpectedly short, suggesting the semaphore might not be effectively limiting concurrency as expected (min expected: {min_expected_duration * 0.8:.2f}s)." + assert total_execution_time >= min_expected_duration * 0.8, ( + f"Total execution time ({total_execution_time:.2f}s) was unexpectedly short, suggesting the " + f"semaphore might not be effectively limiting concurrency as expected " + f"(min expected: {min_expected_duration * 0.8:.2f}s)." + ) # Ensure there is no more code after this point if these were the last functions. @@ -573,9 +659,34 @@ def occurrencesOfElement(self, nums: List[int], queries: List[int], x: int) -> L # Use a short timeout for fast tests results, metadata_list = check_correctness(SANDBOX_URL, in_outs, generation_code, timeout=5) # from verl.utils.reward_score.prime_code import apps_check_correctness - # results, metadata_list = apps_check_correctness(in_outs=in_outs, generation=generation_code, timeout=50000, debug=True) + # results, metadata_list = apps_check_correctness(in_outs=in_outs, generation=generation_code, + # timeout=50000, debug=True) assert results == [True, True] assert "error" not in metadata_list[0] assert metadata_list[0].get("status") != "compilation error" assert metadata_list[0].get("status") != "runtime error" + + +@pytest.mark.skipif(skip_condition, reason=skip_reason) +def test_none_and_empty_stdin_passed_correctly(): + """ + Tests that when stdin data is set to an empty string or None, it is still + is passed correctly to Sandbox Fusion as an empty string. + """ + echo_code = """ +import sys +print(f"You said '{sys.stdin.readline().strip()}'") +""" + in_outs = { + "inputs": [None, "", "hello"], + "outputs": ["You said ''", "You said ''", "You said 'hello'"], + } + + # Use a short timeout for fast tests + results, metadata_list = check_correctness(SANDBOX_URL, in_outs, echo_code, timeout=5) + + assert results == [True, True, True] + assert "error" not in metadata_list[0] + assert metadata_list[0].get("status") != "compilation error" + assert metadata_list[0].get("status") != "runtime error" diff --git a/tests/sandbox/test_sandbox.py b/tests/utils/reward_score/test_sandbox_on_cpu.py similarity index 93% rename from tests/sandbox/test_sandbox.py rename to tests/utils/reward_score/test_sandbox_on_cpu.py index e3e0b10db..ff4073232 100644 --- a/tests/sandbox/test_sandbox.py +++ b/tests/utils/reward_score/test_sandbox_on_cpu.py @@ -109,7 +109,9 @@ def test_parallelism(): ground_truth.extend(prime_math_gts) data_sources.extend(["numina_aops_forum"] * len(prime_math_answers)) - scores = asyncio.run(parallel_compute_score_async(default_compute_score, sequences_str, ground_truth, data_sources, num_processes=16)) + scores = asyncio.run( + parallel_compute_score_async(default_compute_score, sequences_str, ground_truth, data_sources, num_processes=16) + ) print(scores) @@ -118,7 +120,7 @@ def test_prime_code(): Test PRIME code sandbox. """ data_source = "codecontests" - for completion, ground_truth, score_ in zip(prime_code_answers, prime_code_gts, prime_code_scores): + for completion, ground_truth, score_ in zip(prime_code_answers, prime_code_gts, prime_code_scores, strict=True): score = default_compute_score(data_source, completion, ground_truth) assert float(score) == score_ @@ -134,8 +136,10 @@ def test_prime_code_sandbox_fusion(): sandbox_fusion_url = os.environ.get("SANDBOX_FUSION_URL") # Removed the previous 'if not sandbox_url' check block - for completion, ground_truth, score_ in zip(prime_code_answers, prime_code_gts, prime_code_scores): - score = default_compute_score(data_source, completion, ground_truth, extra_info={"sandbox_fusion_url": sandbox_fusion_url}) # <-- Use the URL obtained from the environment variable + for completion, ground_truth, score_ in zip(prime_code_answers, prime_code_gts, prime_code_scores, strict=True): + score = default_compute_score( + data_source, completion, ground_truth, extra_info={"sandbox_fusion_url": sandbox_fusion_url} + ) # <-- Use the URL obtained from the environment variable assert float(score) == score_ @@ -150,7 +154,9 @@ def test_continuous_score_consistency(): expected_continuous_score = 0.9 # 1. Calculate score using prime_code (default) with continuous=True - prime_score, _ = sandbox_fusion.compute_score(os.environ.get("SANDBOX_FUSION_URL"), None, completion, ground_truth, continuous=True) + prime_score, _ = sandbox_fusion.compute_score( + os.environ.get("SANDBOX_FUSION_URL"), None, completion, ground_truth, continuous=True + ) # 2. Calculate score using sandbox_fusion with continuous=True # Ensure the extra_info key triggers the sandbox_fusion path in default_compute_score @@ -174,6 +180,6 @@ def test_check_correctness(): def test_prime_math(): data_source = "numina_aops_forum" - for completion, ground_truth in zip(prime_math_answers, prime_math_gts): + for completion, ground_truth in zip(prime_math_answers, prime_math_gts, strict=True): score = default_compute_score(data_source, completion, ground_truth) assert float(score) == 1.0 diff --git a/tests/utils/gpu_tests/test_activation_offload.py b/tests/utils/test_activation_offload.py similarity index 86% rename from tests/utils/gpu_tests/test_activation_offload.py rename to tests/utils/test_activation_offload.py index c46690630..2393d7962 100644 --- a/tests/utils/gpu_tests/test_activation_offload.py +++ b/tests/utils/test_activation_offload.py @@ -43,16 +43,28 @@ def _fsdp_activation_offloading_test(rank, world_size, rendezvous_file, strategy config = Qwen2Config(num_hidden_layers=4) with torch.device("cuda"): - model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2") + model = AutoModelForCausalLM.from_config( + config=config, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" + ) model = model.to(device="cuda") # Wrap model with FSDP mixed_precision = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32) if strategy == "fsdp": - model = FSDP(model, use_orig_params=False, device_id=torch.cuda.current_device(), sharding_strategy=ShardingStrategy.FULL_SHARD, mixed_precision=mixed_precision, device_mesh=device_mesh, auto_wrap_policy=get_fsdp_wrap_policy(module=model)) + model = FSDP( + model, + use_orig_params=False, + device_id=torch.cuda.current_device(), + sharding_strategy=ShardingStrategy.FULL_SHARD, + mixed_precision=mixed_precision, + device_mesh=device_mesh, + auto_wrap_policy=get_fsdp_wrap_policy(module=model), + ) else: - mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.float32, cast_forward_inputs=True) + mp_policy = MixedPrecisionPolicy( + param_dtype=torch.bfloat16, reduce_dtype=torch.float32, cast_forward_inputs=True + ) fsdp_kwargs = { "mesh": device_mesh, "mp_policy": mp_policy, @@ -64,7 +76,9 @@ def _fsdp_activation_offloading_test(rank, world_size, rendezvous_file, strategy # Create checkpoint manager tokenizer = AutoTokenizer.from_pretrained(model_name) - checkpoint_manager = FSDPCheckpointManager(model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, tokenizer=tokenizer) + checkpoint_manager = FSDPCheckpointManager( + model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, tokenizer=tokenizer + ) # Generate sample input batch_size = 2 diff --git a/tests/utils/test_config_on_cpu.py b/tests/utils/test_config_on_cpu.py new file mode 100644 index 000000000..42dc8e1f2 --- /dev/null +++ b/tests/utils/test_config_on_cpu.py @@ -0,0 +1,95 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from dataclasses import dataclass + +from omegaconf import OmegaConf + +from verl.utils import omega_conf_to_dataclass + + +@dataclass +class TestDataclass: + hidden_size: int + activation: str + + +@dataclass +class TestTrainConfig: + batch_size: int + model: TestDataclass + + +_cfg_str = """train_config: + batch_size: 32 + model: + hidden_size: 768 + activation: relu""" + + +class TestConfigOnCPU(unittest.TestCase): + """Test cases for configuration utilities on CPU. + + Test Plan: + 1. Test basic OmegaConf to dataclass conversion for simple nested structures + 2. Test nested OmegaConf to dataclass conversion for complex hierarchical configurations + 3. Verify all configuration values are correctly converted and accessible + """ + + def setUp(self): + self.config = OmegaConf.create(_cfg_str) + + def test_omega_conf_to_dataclass(self): + sub_cfg = self.config.train_config.model + cfg = omega_conf_to_dataclass(sub_cfg, TestDataclass) + self.assertEqual(cfg.hidden_size, 768) + self.assertEqual(cfg.activation, "relu") + assert isinstance(cfg, TestDataclass) + + def test_nested_omega_conf_to_dataclass(self): + cfg = omega_conf_to_dataclass(self.config.train_config, TestTrainConfig) + self.assertEqual(cfg.batch_size, 32) + self.assertEqual(cfg.model.hidden_size, 768) + self.assertEqual(cfg.model.activation, "relu") + assert isinstance(cfg, TestTrainConfig) + assert isinstance(cfg.model, TestDataclass) + + +class TestPrintCfgCommand(unittest.TestCase): + """Test suite for the print_cfg.py command-line tool.""" + + def test_command_with_override(self): + """Test that the command runs without error when overriding config values.""" + import subprocess + + # Run the command + result = subprocess.run( + ["python3", "scripts/print_cfg.py", "critic.profiler.discrete=True", "+critic.profiler.extra.any_key=val"], + capture_output=True, + text=True, + ) + + # Verify the command exited successfully + self.assertEqual(result.returncode, 0, f"Command failed with stderr: {result.stderr}") + + # Verify the output contains expected config information + self.assertIn("critic", result.stdout) + self.assertIn("profiler", result.stdout) + self.assertIn("discrete=True", result.stdout) + self.assertIn("extra={'any_key': 'val'}", result.stdout) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/utils/gpu_tests/test_flops_counter.py b/tests/utils/test_flops_counter.py similarity index 77% rename from tests/utils/gpu_tests/test_flops_counter.py rename to tests/utils/test_flops_counter.py index fa05c93c1..0b8889b3a 100644 --- a/tests/utils/gpu_tests/test_flops_counter.py +++ b/tests/utils/test_flops_counter.py @@ -39,9 +39,12 @@ def __init__(self, config_dict): "num_key_value_heads": 32, }, "batch_seqlens_tuple": ([512, 1024, 2048], [4096, 4096, 4096]), - # 6*(vocab*hidden*2+layer*(hidden*(q+k+v+head*head_dim)+ hidden*inter*3))*token_sum + 12*sum(seqlen^2)*layer*head*head_dim - # 6*(32000*4096*2+32*(4096*4096*4+4096*11008*3))*(512+1024+2048) + 12*(512*512+1024*1024+2048*2048)*32*4096 - # 6*(32000*4096*2+32*(4096*4096*4+4096*11008*3))*(4096+4096+4096) + 12*(4096*4096+4096*4096+4096*4096)*32*4096 + # 6*(vocab*hidden*2+layer*(hidden*(q+k+v+head*head_dim)+ hidden*inter*3))*token_sum + + # 12*sum(seqlen^2)*layer*head*head_dim + # 6*(32000*4096*2+32*(4096*4096*4+4096*11008*3))*(512+1024+2048) + + # 12*(512*512+1024*1024+2048*2048)*32*4096 + # 6*(32000*4096*2+32*(4096*4096*4+4096*11008*3))*(4096+4096+4096) + + # 12*(4096*4096+4096*4096+4096*4096)*32*4096 "expected_flops_tuple": (153555818250240 / 1e12, 575955114393600 / 1e12), }, "qwen2": { @@ -55,9 +58,12 @@ def __init__(self, config_dict): "num_key_value_heads": 4, }, "batch_seqlens_tuple": ([512, 1024, 2048], [4096, 4096, 4096]), - # 6*(vocab*hidden*2+layer*(hidden*(q+k+v+head*head_dim)+ hidden*inter*3))*token_sum + 12*sum(seqlen^2)*layer*head*head_dim - # 6*(152064*3584*2+28*(3584*(3584+512+512+3584)+3584*18944*3))*(512+1024+2048) + 12*(512*512+1024*1024+2048*2048)*28*3584 - # 6*(152064*3584*2+28*(3584*(3584+512+512+3584)+3584*18944*3))*(4096+4096+4096) + 12*(4096*4096+4096*4096+4096*4096)*28*3584 + # 6*(vocab*hidden*2+layer*(hidden*(q+k+v+head*head_dim)+ hidden*inter*3))*token_sum + + # 12*sum(seqlen^2)*layer*head*head_dim + # 6*(152064*3584*2+28*(3584*(3584+512+512+3584)+3584*18944*3))*(512+1024+2048) + + # 12*(512*512+1024*1024+2048*2048)*28*3584 + # 6*(152064*3584*2+28*(3584*(3584+512+512+3584)+3584*18944*3))*(4096+4096+4096) + + # 12*(4096*4096+4096*4096+4096*4096)*28*3584 "expected_flops_tuple": (170388331954176 / 1e12, 622070178250752 / 1e12), }, "qwen3": { @@ -72,9 +78,12 @@ def __init__(self, config_dict): "head_dim": 128, }, "batch_seqlens_tuple": ([512, 1024, 2048], [4096, 4096, 4096]), - # 6*(vocab*hidden*2+layer*(hidden*(q+k+v+head*head_dim)+ hidden*inter*3))*token_sum + 12*sum(seqlen^2)*layer*head*head_dim - # 6*(151936*4096*2+36*(4096*(128*32+128*8*2+128*32)+4096*12288*3))*(512+1024+2048) + 12*(512*512+1024*1024+2048*2048)*36*128*32 - # 6*(151936*4096*2+36*(4096*(128*32+128*8*2+128*32)+4096*12288*3))*(4096+4096+4096) + 12*(4096*4096+4096*4096+4096*4096)*36*128*32 + # 6*(vocab*hidden*2+layer*(hidden*(q+k+v+head*head_dim)+ hidden*inter*3))*token_sum + + # 12*sum(seqlen^2)*layer*head*head_dim + # 6*(151936*4096*2+36*(4096*(128*32+128*8*2+128*32)+4096*12288*3))*(512+1024+2048) + + # 12*(512*512+1024*1024+2048*2048)*36*128*32 + # 6*(151936*4096*2+36*(4096*(128*32+128*8*2+128*32)+4096*12288*3))*(4096+4096+4096) + + # 12*(4096*4096+4096*4096+4096*4096)*36*128*32 "expected_flops_tuple": (185867930959872 / 1e12, 692924253732864 / 1e12), }, "qwen3_moe": { @@ -91,9 +100,12 @@ def __init__(self, config_dict): "num_experts": 128, }, "batch_seqlens_tuple": ([512, 1024, 2048], [4096, 4096, 4096]), - # 6*(vocab*hidden*2+layer*(hidden*(q+k+v+head*head_dim)+hidden*inter*top_k_exp*3 + hidden*num_experts))*token_sum + 12*sum(seqlen^2)*layer*head*head_dim - # 6*(151936*2048*2+48*(2048*(128*32+128*4*2+128*32)+2048*768*8*3+2048*128))*(512+1024+2048) + 12*(512*512+1024*1024+2048*2048)*48*128*32 - # 6*(151936*2048*2+48*(2048*(128*32+128*4*2+128*32)+2048*768*8*3+2048*128))*(4096+4096+4096) + 12*(4096*4096+4096*4096+4096*4096)*48*128*32 + # 6*(vocab*hidden*2+layer*(hidden*(q+k+v+head*head_dim)+hidden*inter*top_k_exp*3 + + # hidden*num_experts))*token_sum + 12*sum(seqlen^2)*layer*head*head_dim + # 6*(151936*2048*2+48*(2048*(128*32+128*4*2+128*32)+2048*768*8*3+2048*128))*(512+1024+2048) + + # 12*(512*512+1024*1024+2048*2048)*48*128*32 + # 6*(151936*2048*2+48*(2048*(128*32+128*4*2+128*32)+2048*768*8*3+2048*128))*(4096+4096+4096) + + # 12*(4096*4096+4096*4096+4096*4096)*48*128*32 "expected_flops_tuple": (85087060230144 / 1e12, 365944098521088 / 1e12), }, "deepseek_v3": { @@ -117,8 +129,10 @@ def __init__(self, config_dict): }, "batch_seqlens_tuple": ([512, 1024, 2048], [4096, 4096, 4096]), # (1536*7168+128*192*1536+7168*(512+64)+128*(128+128)*512+128*128*7168) = 187105280 - # 6*(129280*7168*2+ 3*(7168*18432*3+187105280)+ 58*(187105280+7168*256+7168*2048*9*3))*(512+1024+2048) + 12*(512*512+1024*1024+2048*2048)*61*192*128 - # 6*(129280*7168*2+ 3*(7168*18432*3+187105280)+ 58*(187105280+7168*256+7168*2048*9*3))*(4096+4096+4096) + 12*(4096*4096+4096*4096+4096*4096)*61*192*128 + # 6*(129280*7168*2+ 3*(7168*18432*3+187105280)+ 58*(187105280+7168*256+7168*2048*9*3))*(512+1024+2048) + + # 12*(512*512+1024*1024+2048*2048)*61*192*128 + # 6*(129280*7168*2+ 3*(7168*18432*3+187105280)+ 58*(187105280+7168*256+7168*2048*9*3))*(4096+4096+4096) + + # 12*(4096*4096+4096*4096+4096*4096)*61*192*128 "expected_flops_tuple": (906535995703296 / 1e12, 3674028304760832 / 1e12), }, } @@ -132,8 +146,12 @@ def test_flops_counter(config_type: str): test_config = CONFIG[config_type] config = Config(test_config["config"]) flops_counter = FlopsCounter(config) - for batch_seqlens, expected_flops in zip(test_config["batch_seqlens_tuple"], test_config["expected_flops_tuple"]): + for batch_seqlens, expected_flops in zip( + test_config["batch_seqlens_tuple"], test_config["expected_flops_tuple"], strict=True + ): # set delta time to 1 to get the flops counted_flops, _ = flops_counter.estimate_flops(batch_seqlens, 1) print(f"Expect flops for {test_config['config']} is {expected_flops}, but get {counted_flops}") - assert math.isclose(counted_flops, expected_flops), f"Expect flops for {test_config['config']} is {expected_flops}, but get {counted_flops}" + assert math.isclose(counted_flops, expected_flops), ( + f"Expect flops for {test_config['config']} is {expected_flops}, but get {counted_flops}" + ) diff --git a/tests/utils/cpu_tests/test_fs.py b/tests/utils/test_fs_on_cpu.py similarity index 100% rename from tests/utils/cpu_tests/test_fs.py rename to tests/utils/test_fs_on_cpu.py diff --git a/tests/utils/cpu_tests/test_import_utils.py b/tests/utils/test_import_utils_on_cpu.py similarity index 100% rename from tests/utils/cpu_tests/test_import_utils.py rename to tests/utils/test_import_utils_on_cpu.py diff --git a/tests/kernels/test_linear_cross_entropy.py b/tests/utils/test_linear_cross_entropy.py similarity index 62% rename from tests/kernels/test_linear_cross_entropy.py rename to tests/utils/test_linear_cross_entropy.py index f0fd0e1a6..0512d1376 100644 --- a/tests/kernels/test_linear_cross_entropy.py +++ b/tests/utils/test_linear_cross_entropy.py @@ -29,23 +29,29 @@ # See the License for the specific language governing permissions and # limitations under the License. -import typing +import os import torch import verl.utils.torch_functional as verl_F from verl.utils.experimental.torch_functional import FusedLinearForPPO +from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy from verl.utils.torch_functional import logprobs_from_logits compute_entropy_from_logits = torch.compile(verl_F.entropy_from_logits, dynamic=True) fused_linear_for_ppo = FusedLinearForPPO() fused_linear_for_ppo.compile(dynamic=True) +MAX_TEST_CASES = os.environ.get("MAX_TEST_CASES", 5) -def run_torch_entropy(hidden: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor, reduction="none") -> typing.List[torch.Tensor]: + +def run_torch_entropy( + hidden: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor, temperature: float, reduction="none" +) -> list[torch.Tensor]: hidden = hidden.squeeze(0).to(torch.float32) weight = weight.transpose(0, 1).to(torch.float32) logits = torch.matmul(hidden, weight) # [num_tokens, vocab_size] + logits /= temperature pd = torch.nn.functional.softmax(logits, dim=-1) # [num_tokens, vocab_size] entropy_a = torch.logsumexp(logits, dim=-1) # [num_tokens] entropy_b = torch.sum(pd * logits, dim=-1) # [num_tokens] @@ -55,10 +61,16 @@ def run_torch_entropy(hidden: torch.Tensor, weight: torch.Tensor, labels: torch. return logprobs, entropy -def run_verl_original_entropy(hidden: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor) -> typing.List[torch.Tensor]: +def run_verl_original_entropy( + hidden: torch.Tensor, + weight: torch.Tensor, + labels: torch.Tensor, + temperature: float, +) -> list[torch.Tensor]: hidden = hidden.squeeze(0).to(torch.float32) weight = weight.transpose(0, 1).to(torch.float32) logits = torch.matmul(hidden, weight) # [num_tokens, vocab_size] + logits /= temperature # compute entropy entropy = compute_entropy_from_logits(logits) # ((total_nnz / sp) + pad) # if use_sp: ((total_nnz / sp) + pad) ; if not use_sp: (batch, seqlen) @@ -67,23 +79,27 @@ def run_verl_original_entropy(hidden: torch.Tensor, weight: torch.Tensor, labels # To be tested -def run_verl_torch_fused_entropy(hidden: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor): +def run_verl_torch_fused_entropy( + hidden: torch.Tensor, + weight: torch.Tensor, + labels: torch.Tensor, + temperature: float, +): hidden = hidden.to(torch.float32) weight = weight.to(torch.float32) logprobs, entropy = fused_linear_for_ppo( hidden, weight, labels, + temperature=temperature, ) return logprobs.squeeze(0), entropy.squeeze(0) -MAX_TEST_CASES = 5 - - class TestLinearCrossEntropy: - def __init__(self, test_case_idx: int) -> None: + def __init__(self, test_case_idx: int, temperature: float = 1.5) -> None: self.test_case_idx = test_case_idx + self.temperature = temperature def cleanup(self): torch.cuda.empty_cache() @@ -94,6 +110,8 @@ def cleanup(self): torch.cuda.synchronize() def generate_hyper(self): + global MAX_TEST_CASES + self.dtype = torch.bfloat16 if self.test_case_idx == 0: self.batch_size = 1 @@ -121,11 +139,20 @@ def generate_hyper(self): self.hidden_size = 4096 self.vocab_size = 102400 else: - raise ValueError(f"Invalid test case index: {test_case_idx}") + raise ValueError(f"Invalid test case index: {self.test_case_idx}") + assert MAX_TEST_CASES <= 5, "MAX_TEST_CASES should be less than or equal to 5." def generate_forward_inputs(self): - hidden = torch.empty((self.batch_size, self.num_tokens, self.hidden_size), dtype=self.dtype, device="cuda").uniform_(-0.5, 0.5).requires_grad_() - weight = torch.empty((self.vocab_size, self.hidden_size), dtype=self.dtype, device="cuda").uniform_(-0.5, 0.5).requires_grad_() + hidden = ( + torch.empty((self.batch_size, self.num_tokens, self.hidden_size), dtype=self.dtype, device="cuda") + .uniform_(-0.5, 0.5) + .requires_grad_() + ) + weight = ( + torch.empty((self.vocab_size, self.hidden_size), dtype=self.dtype, device="cuda") + .uniform_(-0.5, 0.5) + .requires_grad_() + ) labels = torch.randint(0, self.vocab_size, (self.batch_size, self.num_tokens), device="cuda") return hidden, weight, labels @@ -144,6 +171,8 @@ def verify_correctness(self, iterations=5): verl_backward_latency = list() verl_fused_forward_latency = list() verl_fused_backward_latency = list() + kernel_forward_latency = list() + kernel_backward_latency = list() start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) @@ -153,57 +182,97 @@ def verify_correctness(self, iterations=5): hidden, weight, labels = self.generate_forward_inputs() start_event.record() - (torch_logprobs, torch_entropy) = run_torch_entropy(hidden, weight, labels) + (torch_logprobs, torch_entropy) = run_torch_entropy(hidden, weight, labels, self.temperature) end_event.record() torch.cuda.synchronize() torch_forward_latency.append(start_event.elapsed_time(end_event)) start_event.record() - (verl_logprobs, verl_entropy) = run_verl_original_entropy(hidden, weight, labels) + (verl_logprobs, verl_entropy) = run_verl_original_entropy(hidden, weight, labels, self.temperature) end_event.record() torch.cuda.synchronize() verl_forward_latency.append(start_event.elapsed_time(end_event)) start_event.record() - (verl_fused_logprobs, verl_fused_entropy) = run_verl_torch_fused_entropy(hidden, weight, labels) + (verl_fused_logprobs, verl_fused_entropy) = run_verl_torch_fused_entropy( + hidden, weight, labels, self.temperature + ) end_event.record() torch.cuda.synchronize() verl_fused_forward_latency.append(start_event.elapsed_time(end_event)) + start_event.record() + (kernel_logprobs, kernel_entropy) = linear_cross_entropy(hidden, weight, labels, self.temperature) + end_event.record() + torch.cuda.synchronize() + kernel_forward_latency.append(start_event.elapsed_time(end_event)) + torch.testing.assert_close(torch_logprobs, verl_logprobs, atol=1e-4, rtol=1e-4) torch.testing.assert_close(torch_entropy, verl_entropy, atol=1e-4, rtol=1e-4) + torch.testing.assert_close(torch_logprobs, verl_fused_logprobs, atol=1e-4, rtol=1e-4) torch.testing.assert_close(torch_entropy, verl_fused_entropy, atol=1e-4, rtol=1e-4) torch.testing.assert_close(verl_logprobs, verl_fused_logprobs, atol=1e-4, rtol=1e-4) torch.testing.assert_close(verl_entropy, verl_fused_entropy, atol=1e-4, rtol=1e-4) + torch.testing.assert_close(torch_logprobs, kernel_logprobs, atol=1e-3, rtol=2e-4) + torch.testing.assert_close(torch_entropy, kernel_entropy, atol=5e-3, rtol=5e-4) + torch.testing.assert_close(verl_logprobs, kernel_logprobs, atol=1e-3, rtol=2e-4) + torch.testing.assert_close(verl_entropy, kernel_entropy, atol=5e-3, rtol=5e-4) + torch.testing.assert_close(verl_fused_logprobs, kernel_logprobs, atol=1e-3, rtol=2e-4) + torch.testing.assert_close(verl_fused_entropy, kernel_entropy, atol=5e-3, rtol=5e-4) + # backward g_entropy, g_logprobs = self.generate_backward_inputs() start_event.record() - (d_torch_hidden, d_torch_weight) = torch.autograd.grad((torch_entropy, torch_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False) + (d_torch_hidden, d_torch_weight) = torch.autograd.grad( + (torch_entropy, torch_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False + ) end_event.record() torch.cuda.synchronize() torch_backward_latency.append(start_event.elapsed_time(end_event)) start_event.record() - (d_verl_hidden, d_verl_weight) = torch.autograd.grad((verl_entropy, verl_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False) + (d_verl_hidden, d_verl_weight) = torch.autograd.grad( + (verl_entropy, verl_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False + ) end_event.record() torch.cuda.synchronize() verl_backward_latency.append(start_event.elapsed_time(end_event)) start_event.record() - (d_verl_fused_hidden, d_verl_fused_weight) = torch.autograd.grad((verl_fused_entropy, verl_fused_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False) + (d_verl_fused_hidden, d_verl_fused_weight) = torch.autograd.grad( + (verl_fused_entropy, verl_fused_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False + ) end_event.record() torch.cuda.synchronize() verl_fused_backward_latency.append(start_event.elapsed_time(end_event)) + start_event.record() + (d_kernel_hidden, d_kernel_weight) = torch.autograd.grad( + (kernel_entropy, kernel_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False + ) + end_event.record() + torch.cuda.synchronize() + kernel_backward_latency.append(start_event.elapsed_time(end_event)) + torch.testing.assert_close(d_torch_hidden, d_verl_hidden, atol=1e-2, rtol=1e-4) torch.testing.assert_close(d_torch_weight, d_verl_weight, atol=1e-2, rtol=1e-4) + torch.testing.assert_close(d_torch_hidden, d_verl_fused_hidden, atol=1e-2, rtol=1e-4) torch.testing.assert_close(d_torch_weight, d_verl_fused_weight, atol=1e-2, rtol=1e-4) torch.testing.assert_close(d_verl_hidden, d_verl_fused_hidden, atol=1e-2, rtol=1e-4) torch.testing.assert_close(d_verl_weight, d_verl_fused_weight, atol=1e-2, rtol=1e-4) + torch.testing.assert_close(d_torch_hidden, d_verl_hidden, atol=1e-2, rtol=1e-4) + torch.testing.assert_close(d_torch_weight, d_verl_weight, atol=1e-2, rtol=1e-4) + + torch.testing.assert_close(d_torch_hidden, d_kernel_hidden, atol=2e-2, rtol=4e-2) + torch.testing.assert_close(d_torch_weight, d_kernel_weight, atol=2e-2, rtol=4e-2) + torch.testing.assert_close(d_verl_hidden, d_kernel_hidden, atol=2e-2, rtol=4e-2) + torch.testing.assert_close(d_verl_weight, d_kernel_weight, atol=2e-2, rtol=4e-2) + torch.testing.assert_close(d_verl_fused_hidden, d_kernel_hidden, atol=2e-2, rtol=4e-2) + torch.testing.assert_close(d_verl_fused_weight, d_kernel_weight, atol=2e-2, rtol=4e-2) # remove first latency torch_forward_latency = torch_forward_latency[1:] @@ -212,15 +281,43 @@ def verify_correctness(self, iterations=5): verl_backward_latency = verl_backward_latency[1:] verl_fused_forward_latency = verl_fused_forward_latency[1:] verl_fused_backward_latency = verl_fused_backward_latency[1:] + kernel_forward_latency = kernel_forward_latency[1:] + kernel_backward_latency = kernel_backward_latency[1:] print("\n[INFO]: Verified forward & backward correctness.") - print(f"[INFO]: Forward pass: Torch implementation average time: {sum(torch_forward_latency) / len(torch_forward_latency):.2f} ms") - print(f"[INFO]: Backward pass: torch implementation average time: {sum(torch_backward_latency) / len(torch_backward_latency):.2f} ms") - print(f"[INFO]: Forward pass: VeRL implementation average time: {sum(verl_forward_latency) / len(verl_forward_latency):.2f} ms") - print(f"[INFO]: Backward pass: VeRL implementation average time: {sum(verl_backward_latency) / len(verl_backward_latency):.2f} ms") - print(f"[INFO]: Forward pass: VeRL Fused Entropy implementation average time: {sum(verl_fused_forward_latency) / len(verl_fused_forward_latency):.2f} ms") - print(f"[INFO]: Backward pass: VeRL Fused Entropy implementation average time: {sum(verl_fused_backward_latency) / len(verl_fused_backward_latency):.2f} ms") + print( + f"[INFO]: Forward pass: Torch implementation average time: " + f"{sum(torch_forward_latency) / len(torch_forward_latency):.2f} ms" + ) + print( + f"[INFO]: Backward pass: torch implementation average time: " + f"{sum(torch_backward_latency) / len(torch_backward_latency):.2f} ms" + ) + print( + f"[INFO]: Forward pass: VeRL implementation average time: " + f"{sum(verl_forward_latency) / len(verl_forward_latency):.2f} ms" + ) + print( + f"[INFO]: Backward pass: VeRL implementation average time: " + f"{sum(verl_backward_latency) / len(verl_backward_latency):.2f} ms" + ) + print( + f"[INFO]: Forward pass: VeRL Fused Entropy implementation average time: " + f"{sum(verl_fused_forward_latency) / len(verl_fused_forward_latency):.2f} ms" + ) + print( + f"[INFO]: Backward pass: VeRL Fused Entropy implementation average time: " + f"{sum(verl_fused_backward_latency) / len(verl_fused_backward_latency):.2f} ms" + ) + print( + f"[INFO]: Forward pass: Kernel implementation average time: " + f"{sum(kernel_forward_latency) / len(kernel_forward_latency):.2f} ms" + ) + print( + f"[INFO]: Backward pass: kernel implementation average time: " + f"{sum(kernel_backward_latency) / len(kernel_backward_latency):.2f} ms" + ) def check_storage(self, method_name, run_forward): self.cleanup() @@ -229,7 +326,7 @@ def check_storage(self, method_name, run_forward): hidden, weight, labels = self.generate_forward_inputs() torch.cuda.reset_peak_memory_stats() - (logprobs, entropy) = run_forward(hidden, weight, labels) + (logprobs, entropy) = run_forward(hidden, weight, labels, self.temperature) torch.cuda.synchronize() torch_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 print(f"[INFO]: {method_name} Forward pass peak memory: {torch_max_memory:.2f} MB") @@ -237,7 +334,9 @@ def check_storage(self, method_name, run_forward): g_entropy, g_logprobs = self.generate_backward_inputs() torch.cuda.reset_peak_memory_stats() - (d_torch_hidden, d_torch_weight) = torch.autograd.grad((entropy, logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False) + (d_torch_hidden, d_torch_weight) = torch.autograd.grad( + (entropy, logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False + ) torch.cuda.synchronize() torch_backward_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 print(f"[INFO]: {method_name} Backward pass peak memory: {torch_backward_max_memory:.2f} MB") @@ -246,6 +345,7 @@ def check_storage_all(self): self.check_storage("Torch", run_torch_entropy) self.check_storage("VeRL", run_verl_original_entropy) self.check_storage("VeRL Torch Fused", run_verl_torch_fused_entropy) + self.check_storage("Kernel", linear_cross_entropy) if __name__ == "__main__": diff --git a/tests/utils/test_linear_cross_entropy_tp.py b/tests/utils/test_linear_cross_entropy_tp.py new file mode 100644 index 000000000..9c1f868a9 --- /dev/null +++ b/tests/utils/test_linear_cross_entropy_tp.py @@ -0,0 +1,514 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import torch +import torch.distributed as dist + +try: + from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy +except ImportError: + # FIXME: remove these manually included paths + import sys + + sys.path.append(os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../"))) +finally: + from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy + +import verl.utils.torch_functional as verl_F + +compute_entropy_from_logits = torch.compile(verl_F.entropy_from_logits, dynamic=True) + +MAX_TEST_CASES = os.environ.get("MAX_TEST_CASES", 5) +VERIFY_TORCH_SELF = os.environ.get("VERIFY_TORCH_SELF", False) +LOW_MEMORY = os.environ.get("LOW_MEMORY", False) +LOW_MEMORY_DIV_FACTOR = os.environ.get("LOW_MEMORY_DIV_FACTOR", 16) + + +def run_torch_entropy( + hidden: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor, temperature: float, reduction="none" +) -> list[torch.Tensor]: + # [num_tokens, vocab_size] + if len(hidden.shape) > 2: + hidden = hidden.view(-1, hidden.shape[-1]) # [num_tokens, hidden_size] + if len(labels.shape) > 1: + labels = labels.view(-1) + logits = torch.matmul( + hidden.to(torch.float32), + weight.to(torch.float32) if weight.size(0) == hidden.size(1) else weight.T.to(torch.float32), + ) + logits /= temperature + pd = torch.nn.functional.softmax(logits, dim=-1) # [num_tokens, vocab_size] + entropy_a = torch.logsumexp(logits, dim=-1) # [num_tokens] + entropy_b = torch.sum(pd * logits, dim=-1) # [num_tokens] + entropy = entropy_a - entropy_b + logprobs = torch.nn.functional.cross_entropy(logits, labels, reduction=reduction) # [num_tokens] + logprobs = torch.neg(logprobs) + return logprobs, entropy + + +class TorchEntropyTP(torch.autograd.Function): + """ + it is used for testing the correctness of the kernel + it is not efficient and is not recommended to use in practice + """ + + @staticmethod + def forward( + ctx, + hidden: torch.Tensor, + weight: torch.Tensor, + labels: torch.Tensor, + temperature: float, + dist_process_group: torch.distributed.ProcessGroup, + ): + # weight has shape [vocab_size, hidden_size], hidden has shape [num_tokens, hidden_size] + ctx.original_hidden_shape = hidden.shape + if len(hidden.shape) > 2: + hidden = hidden.view(-1, hidden.shape[-1]) # [num_tokens, hidden_size] + if len(labels.shape) > 1: + labels = labels.view(-1) + + logits = torch.matmul(hidden.to(torch.float32), weight.to(torch.float32).T) # [num_tokens, vocab_size] + logits /= temperature + whole_logits = torch.empty( + (logits.shape[0], logits.shape[1] * dist.get_world_size(dist_process_group)), + dtype=logits.dtype, + device=logits.device, + ) + whole_logits_ref = [ + whole_logits[:, i * logits.shape[1] : (i + 1) * logits.shape[1]] + for i in range(dist.get_world_size(dist_process_group)) + ] + dist.all_gather(whole_logits_ref, logits, group=dist_process_group) + + pd = torch.nn.functional.softmax(whole_logits, dim=-1) + entropy_a = torch.logsumexp(whole_logits, dim=-1) # [num_tokens] + entropy_b = torch.sum(pd * whole_logits, dim=-1) # [num_tokens] + entropy = entropy_a - entropy_b + + logprobs = torch.nn.functional.cross_entropy(whole_logits, labels, reduction="none") + logprobs = torch.neg(logprobs) + + ctx.save_for_backward(hidden, weight, labels, whole_logits, entropy_b) + ctx.dist_process_group = dist_process_group + ctx.temperature = temperature + return logprobs, entropy + + @staticmethod + def backward(ctx, g_logprobs: torch.Tensor, g_entropy: torch.Tensor): + hidden, weight, labels, whole_logits, entropy_b = ctx.saved_tensors + dist_process_group = ctx.dist_process_group + temperature = ctx.temperature + batch_size, hidden_size = hidden.shape + vocab_size, hidden_size = weight.shape + rank = dist.get_rank(dist_process_group) + + # Compute softmax probabilities + maximum, _ = torch.max(whole_logits, dim=-1, keepdim=True) + exp_logits = torch.exp(whole_logits - maximum) + accumulate = exp_logits.sum(dim=-1, keepdim=True) + pd = exp_logits / accumulate + + # Gradient for entropy + # entropy = entropy_a - entropy_b + # entropy_a = log(sum(exp(logits))) + # entropy_b = sum(pd * logits) + # d_entropy_a/d_logits = pd + # d_entropy_b/d_logits = pd * (logits - b.unsqueeze(1) + 1) + # d_entropy/d_logits = d_entropy_a - d_entropy_b + # d_entropy/d_logits = pd - pd * (logits - b.unsqueeze(1) + 1) + # d_entropy/d_logits = -pd * (logits - b.unsqueeze(1)) + d_logits_entropy = g_entropy.unsqueeze(1) * (-pd * (whole_logits - entropy_b.unsqueeze(1))) + + # Gradient for logprobs + # logprobs = -cross_entropy = -log(pd[labels]) + # d_logprobs/d_logits = (pd - one_hot(labels)) + one_hot = torch.zeros_like(whole_logits) + one_hot.scatter_(1, labels.unsqueeze(1), 1) + g_logprobs = torch.neg(g_logprobs) + d_logits_logprobs = g_logprobs.unsqueeze(1) * (pd - one_hot) + # NOTE: This will lead to wrong result + # d_logits_logprobs = g_logprobs.unsqueeze(1) * (pd - 1) * one_hot + + # Combine gradients + d_logits = d_logits_entropy + d_logits_logprobs + d_logits /= temperature + + # Get local slice of gradients + local_d_logits = d_logits[:, rank * vocab_size : (rank + 1) * vocab_size] + + # Compute gradients for hidden and weight + d_hidden = torch.matmul(local_d_logits, weight.to(torch.float32)) + d_weight = torch.matmul(local_d_logits.T, hidden.to(torch.float32)) + d_hidden = d_hidden.view(ctx.original_hidden_shape) + + return d_hidden, d_weight, None, None, None + + +run_torch_entropy_tp = TorchEntropyTP.apply + + +class TestLinearCrossEntropy_TensorParallel: + def __init__(self): + dist.init_process_group(backend="nccl") + self.group = dist.group.WORLD + + self.local_rank = dist.get_rank(self.group) + self.world_size = dist.get_world_size(self.group) + device = torch.device(f"cuda:{self.local_rank}") + torch.cuda.set_device(device) + print(f"[INFO]: Local rank: {self.local_rank}, World size: {self.world_size}") + + def initialize(self, test_case_idx: int, temperature: float = 1.5): + self.test_case_idx = test_case_idx + self.temperature = temperature + + def shutdown(self): + dist.destroy_process_group() + + def cleanup(self): + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + import gc + + gc.collect() + torch.cuda.synchronize() + + def generate_hyper(self): + global LOW_MEMORY, LOW_MEMORY_DIV_FACTOR, MAX_TEST_CASES + + self.dtype = torch.bfloat16 + if self.test_case_idx == 0: + self.batch_size = 1 + self.num_tokens = 1937 + self.hidden_size = 3584 + self.vocab_size = 152064 + elif self.test_case_idx == 1: + self.batch_size = 1 + self.num_tokens = 2169 + self.hidden_size = 896 + self.vocab_size = 151936 + elif self.test_case_idx == 2: + self.batch_size = 1 + self.num_tokens = 1530 + self.hidden_size = 2048 + self.vocab_size = 32256 + elif self.test_case_idx == 3: + self.batch_size = 1 + self.num_tokens = 1388 + self.hidden_size = 4096 + self.vocab_size = 102400 + elif self.test_case_idx == 4: + self.batch_size = 1 + self.num_tokens = 8192 + self.hidden_size = 4096 + self.vocab_size = 102400 + else: + raise ValueError(f"Invalid test case index: {self.test_case_idx}") + if LOW_MEMORY: + self.vocab_size = int(self.vocab_size / LOW_MEMORY_DIV_FACTOR) + assert MAX_TEST_CASES <= 5, "MAX_TEST_CASES should be less than or equal to 5." + + def generate_forward_inputs(self): + hidden = ( + torch.empty((self.batch_size, self.num_tokens, self.hidden_size), dtype=self.dtype, device="cuda") + .uniform_(-0.5, 0.5) + .requires_grad_() + ) + weight = ( + torch.empty((self.vocab_size, self.hidden_size), dtype=self.dtype, device="cuda") + .uniform_(-0.5, 0.5) + .requires_grad_() + ) + labels = torch.randint(0, self.vocab_size, (self.batch_size, self.num_tokens), device="cuda") + return hidden, weight, labels + + def generate_backward_inputs(self): + g_entropy = torch.empty((self.num_tokens,), dtype=self.dtype, device="cuda").uniform_(-0.5, 0.5) + g_logprobs = torch.empty((self.num_tokens,), dtype=self.dtype, device="cuda").uniform_(-1, 1) + return g_entropy, g_logprobs + + def verify_torch_itself(self, iterations: int = 5): + self.cleanup() + self.generate_hyper() + + for i in range(iterations): + hidden, weight, labels = self.generate_forward_inputs() + + # NOTE: we need to manually synchronize hidden and labels among Process Group + dist.broadcast(hidden, src=0, group=self.group) + dist.broadcast(labels, src=0, group=self.group) + + # forward pass + # Create a tensor to hold the gathered weights from all ranks + # weight has shape [vocab_size, hidden_size] + # We want to gather along the first dimension to get [vocab_size * world_size, hidden_size] + + # Create a single contiguous tensor to hold all gathered weights + whole_weight = torch.empty( + (self.vocab_size * self.world_size, self.hidden_size), dtype=weight.dtype, device=weight.device + ) + + # Create views into the tensor for each rank's portion + whole_weight_views = [ + whole_weight[i * self.vocab_size : (i + 1) * self.vocab_size] for i in range(self.world_size) + ] + + # Perform all_gather operation using the views + dist.all_gather(whole_weight_views, weight, group=self.group) + + # Set requires_grad for autograd + whole_weight.requires_grad_() + + (single_logprobs, single_entropy) = run_torch_entropy(hidden, whole_weight, labels, self.temperature) + + (tp_logprobs, tp_entropy) = run_torch_entropy_tp(hidden, weight, labels, self.temperature, self.group) + + torch.testing.assert_close(single_logprobs, tp_logprobs, atol=1e-4, rtol=1e-4) + torch.testing.assert_close(single_entropy, tp_entropy, atol=1e-4, rtol=1e-4) + + # backward pass + g_entropy, g_logprobs = self.generate_backward_inputs() + # NOTE: we need to manually synchronize g_entropy and g_logprobs among Process Group + dist.broadcast(g_entropy, src=0, group=self.group) + dist.broadcast(g_logprobs, src=0, group=self.group) + + (single_d_hidden, single_d_weight) = torch.autograd.grad( + (single_entropy, single_logprobs), (hidden, whole_weight), (g_entropy, g_logprobs), retain_graph=False + ) + + (tp_d_hidden, tp_d_weight) = torch.autograd.grad( + (tp_entropy, tp_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False + ) + # NOTE: all-reduce on hidden is conducted outside the kernel + dist.all_reduce(tp_d_hidden, op=dist.ReduceOp.SUM, group=self.group) + + torch.testing.assert_close(tp_d_hidden, single_d_hidden, atol=1e-2, rtol=1e-4) + # Extract the corresponding slice from single_d_weight for comparison + # tp_d_weight has shape [vocab_size, hidden_size] + # single_d_weight has shape [vocab_size * world_size, hidden_size] + torch.testing.assert_close( + tp_d_weight, + single_d_weight[self.local_rank * self.vocab_size : (self.local_rank + 1) * self.vocab_size], + atol=1e-2, + rtol=1e-4, + ) + + # atol=1e-3, rtol=1e-4) + if self.local_rank == 0: + print("[PASS] torch TP correctness is verified") + + def check_torch_storage(self): + self.cleanup() + self.generate_hyper() + + hidden, weight, labels = self.generate_forward_inputs() + + # NOTE: we need to manually synchronize hidden and labels among Process Group + dist.broadcast(hidden, src=0, group=self.group) + dist.broadcast(labels, src=0, group=self.group) + + torch.cuda.reset_peak_memory_stats() + (tp_logprobs, tp_entropy) = run_torch_entropy_tp(hidden, weight, labels, self.temperature, self.group) + torch.cuda.synchronize() + forward_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 + + g_entropy, g_logprobs = self.generate_backward_inputs() + # NOTE: we need to manually synchronize g_entropy and g_logprobs among Process Group + dist.broadcast(g_entropy, src=0, group=self.group) + dist.broadcast(g_logprobs, src=0, group=self.group) + + torch.cuda.reset_peak_memory_stats() + (d_tp_hidden, d_tp_weight) = torch.autograd.grad( + (tp_entropy, tp_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False + ) + torch.cuda.synchronize() + backward_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 + # NOTE: all-reduce on hidden is conducted outside the kernel + dist.all_reduce(d_tp_hidden, op=dist.ReduceOp.SUM, group=self.group) + + if self.local_rank == 0: + print(f"[INFO]: Torch Forward pass peak memory: {forward_max_memory:.2f} MB") + print(f"[INFO]: Torch Backward pass peak memory: {backward_max_memory:.2f} MB") + + def verify_kernel_correctness(self, iterations: int = 5): + self.cleanup() + self.generate_hyper() + + torch_forward_latency = list() + torch_backward_latency = list() + kernel_forward_latency = list() + kernel_backward_latency = list() + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + for i in range(iterations): + hidden, weight, labels = self.generate_forward_inputs() + + # NOTE: we need to manually synchronize hidden and labels among Process Group + dist.broadcast(hidden, src=0, group=self.group) + dist.broadcast(labels, src=0, group=self.group) + + start_event.record() + (torch_logprobs, torch_entropy) = run_torch_entropy_tp(hidden, weight, labels, self.temperature, self.group) + end_event.record() + torch.cuda.synchronize() + torch_forward_latency.append(start_event.elapsed_time(end_event)) + + start_event.record() + (kernel_logprobs, kernel_entropy) = linear_cross_entropy( + hidden, weight, labels, self.temperature, "none", self.group + ) + end_event.record() + torch.cuda.synchronize() + kernel_forward_latency.append(start_event.elapsed_time(end_event)) + + torch.testing.assert_close(torch_logprobs, kernel_logprobs, atol=1e-1, rtol=1e-2) + torch.testing.assert_close(torch_entropy, kernel_entropy, atol=1e-1, rtol=1e-2) + + # backward pass + g_entropy, g_logprobs = self.generate_backward_inputs() + # NOTE: we need to manually synchronize g_entropy and g_logprobs among Process Group + dist.broadcast(g_entropy, src=0, group=self.group) + dist.broadcast(g_logprobs, src=0, group=self.group) + + start_event.record() + (torch_d_hidden, torch_d_weight) = torch.autograd.grad( + (torch_entropy, torch_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False + ) + end_event.record() + torch.cuda.synchronize() + torch_backward_latency.append(start_event.elapsed_time(end_event)) + # NOTE: all-reduce on hidden is conducted outside the kernel + dist.all_reduce(torch_d_hidden, op=dist.ReduceOp.SUM, group=self.group) + + start_event.record() + (kernel_d_hidden, kernel_d_weight) = torch.autograd.grad( + (kernel_entropy, kernel_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False + ) + end_event.record() + torch.cuda.synchronize() + kernel_backward_latency.append(start_event.elapsed_time(end_event)) + # NOTE: all-reduce on hidden is conducted outside the kernel + dist.all_reduce(kernel_d_hidden, op=dist.ReduceOp.SUM, group=self.group) + + torch.testing.assert_close(torch_d_hidden, kernel_d_hidden, atol=2e-2, rtol=4e-2) + torch.testing.assert_close(torch_d_weight, kernel_d_weight, atol=2e-2, rtol=4e-2) + + # remove first latency + torch_forward_latency = torch_forward_latency[1:] + torch_backward_latency = torch_backward_latency[1:] + kernel_forward_latency = kernel_forward_latency[1:] + kernel_backward_latency = kernel_backward_latency[1:] + + if self.local_rank == 0: + print("\n[PASS]: Verified kernel forward & backward correctness.") + + print( + f"[INFO]: Forward pass: Torch implementation average time: " + f"{sum(torch_forward_latency) / len(torch_forward_latency):.2f} ms" + ) + print( + f"[INFO]: Backward pass: torch implementation average time: " + f"{sum(torch_backward_latency) / len(torch_backward_latency):.2f} ms" + ) + print( + f"[INFO]: Forward pass: Kernel implementation average time: " + f"{sum(kernel_forward_latency) / len(kernel_forward_latency):.2f} ms" + ) + print( + f"[INFO]: Backward pass: kernel implementation average time: " + f"{sum(kernel_backward_latency) / len(kernel_backward_latency):.2f} ms" + ) + + def check_kernel_storage(self): + self.cleanup() + self.generate_hyper() + + hidden, weight, labels = self.generate_forward_inputs() + + # NOTE: we need to manually synchronize hidden and labels among Process Group + dist.broadcast(hidden, src=0, group=self.group) + dist.broadcast(labels, src=0, group=self.group) + + torch.cuda.reset_peak_memory_stats() + (kernel_logprobs, kernel_entropy) = linear_cross_entropy( + hidden, weight, labels, self.temperature, "none", self.group + ) + torch.cuda.synchronize() + kernel_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 + + g_entropy, g_logprobs = self.generate_backward_inputs() + # NOTE: we need to manually synchronize g_entropy and g_logprobs among Process Group + dist.broadcast(g_entropy, src=0, group=self.group) + dist.broadcast(g_logprobs, src=0, group=self.group) + + torch.cuda.reset_peak_memory_stats() + (d_kernel_hidden, d_kernel_weight) = torch.autograd.grad( + (kernel_entropy, kernel_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False + ) + torch.cuda.synchronize() + kernel_backward_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 + # NOTE: all-reduce on hidden is conducted outside the kernel + dist.all_reduce(d_kernel_hidden, op=dist.ReduceOp.SUM, group=self.group) + + if self.local_rank == 0: + print(f"[INFO]: Kernel Forward pass peak memory: {kernel_max_memory:.2f} MB") + print(f"[INFO]: Kernel Backward pass peak memory: {kernel_backward_max_memory:.2f} MB") + + +if __name__ == "__main__": + # TP command: torchrun --standalone --nnodes=1 --nproc-per-node=2 tests/kernels/test_linear_cross_entropy_tp.py + + # Check if running with torchrun (distributed mode) + assert int(os.environ["WORLD_SIZE"]) > 1, ( + "[ERROR]: This test is designed to run in distributed mode with torchrun. Please use torchrun to " + "execute this script." + ) + torch.manual_seed(233376 + int(os.environ.get("RANK", 0))) + + # set_backward_method(BackwardEnum._Total_Fuse_MN) + # set_backward_method(BackwardEnum._Split_Dlogits_N) + + test = TestLinearCrossEntropy_TensorParallel() + for test_case_idx in range(MAX_TEST_CASES): + print(f"[INFO] Running test case {test_case_idx}") + test.initialize(test_case_idx) + if VERIFY_TORCH_SELF: + test.verify_torch_itself() + test.check_torch_storage() + test.verify_kernel_correctness() + test.check_kernel_storage() + + test.shutdown() diff --git a/tests/utils/cpu_tests/test_model.py b/tests/utils/test_model_on_cpu.py similarity index 93% rename from tests/utils/cpu_tests/test_model.py rename to tests/utils/test_model_on_cpu.py index 94d475562..8b1416c8a 100644 --- a/tests/utils/cpu_tests/test_model.py +++ b/tests/utils/test_model_on_cpu.py @@ -33,7 +33,9 @@ def test_update_model_config(override_kwargs): handling both plain and nested overrides via parametrization. """ # Create a fresh mock config object for each test case - mock_config = SimpleNamespace(param_a=1, nested_params=SimpleNamespace(sub_param_x="original_x", sub_param_y=100), other_param="keep_me") + mock_config = SimpleNamespace( + param_a=1, nested_params=SimpleNamespace(sub_param_x="original_x", sub_param_y=100), other_param="keep_me" + ) # Apply the updates using the parametrized override_kwargs update_model_config(mock_config, override_kwargs) diff --git a/tests/utils/test_nvtx_profile.py b/tests/utils/test_nvtx_profile.py new file mode 100644 index 000000000..817d03000 --- /dev/null +++ b/tests/utils/test_nvtx_profile.py @@ -0,0 +1,179 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import MagicMock, patch + +from verl.utils import omega_conf_to_dataclass +from verl.utils.profiler import ProfilerConfig +from verl.utils.profiler.nvtx_profile import NsightSystemsProfiler + + +class TestProfilerConfig(unittest.TestCase): + def test_config_init(self): + import os + + from hydra import compose, initialize_config_dir + + with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")): + cfg = compose(config_name="ppo_trainer") + arr = cfg.actor_rollout_ref + for config in [ + cfg.critic.profiler, + arr.profiler, + cfg.reward_model.profiler, + ]: + profiler_config = omega_conf_to_dataclass(config) + self.assertEqual(profiler_config.discrete, config.discrete) + self.assertEqual(profiler_config.all_ranks, config.all_ranks) + self.assertEqual(profiler_config.ranks, config.ranks) + assert isinstance(profiler_config, ProfilerConfig) + with self.assertRaises(AttributeError): + _ = profiler_config.non_existing_key + assert config.get("non_existing_key") == profiler_config.get("non_existing_key") + assert config.get("non_existing_key", 1) == profiler_config.get("non_existing_key", 1) + assert config["discrete"] == profiler_config["discrete"] + from dataclasses import FrozenInstanceError + + with self.assertRaises(FrozenInstanceError): + profiler_config.discrete = False + + def test_frozen_config(self): + """Test that modifying frozen keys in ProfilerConfig raises exceptions.""" + from dataclasses import FrozenInstanceError + + from verl.utils.profiler.config import ProfilerConfig + + # Create a new ProfilerConfig instance + config = ProfilerConfig(discrete=True, all_ranks=False, ranks=[0]) + + # Test direct attribute assignment + with self.assertRaises(FrozenInstanceError): + config.discrete = False + + with self.assertRaises(FrozenInstanceError): + config.all_ranks = True + + with self.assertRaises(FrozenInstanceError): + config.ranks = [1, 2, 3] + + # Test dictionary-style assignment + with self.assertRaises(TypeError): + config["discrete"] = False + + with self.assertRaises(TypeError): + config["all_ranks"] = True + + with self.assertRaises(TypeError): + config["ranks"] = [1, 2, 3] + + config["extra"]["key"] = "value" + + +class TestNsightSystemsProfiler(unittest.TestCase): + """Test suite for NsightSystemsProfiler functionality. + + Test Plan: + 1. Initialization: Verify profiler state after creation + 2. Basic Profiling: Test start/stop functionality + 3. Discrete Mode: Test discrete profiling behavior + 4. Annotation: Test the annotate decorator in both normal and discrete modes + 5. Config Validation: Verify proper config initialization from OmegaConf + """ + + def setUp(self): + self.config = ProfilerConfig(all_ranks=True) + self.rank = 0 + self.profiler = NsightSystemsProfiler(self.rank, self.config) + + def test_initialization(self): + self.assertEqual(self.profiler.this_rank, True) + self.assertEqual(self.profiler.this_step, False) + self.assertEqual(self.profiler.discrete, False) + + def test_start_stop_profiling(self): + with patch("torch.cuda.profiler.start") as mock_start, patch("torch.cuda.profiler.stop") as mock_stop: + # Test start + self.profiler.start() + self.assertTrue(self.profiler.this_step) + mock_start.assert_called_once() + + # Test stop + self.profiler.stop() + self.assertFalse(self.profiler.this_step) + mock_stop.assert_called_once() + + def test_discrete_profiling(self): + discrete_config = ProfilerConfig(discrete=True, all_ranks=True) + profiler = NsightSystemsProfiler(self.rank, discrete_config) + + with patch("torch.cuda.profiler.start") as mock_start, patch("torch.cuda.profiler.stop") as mock_stop: + profiler.start() + self.assertTrue(profiler.this_step) + mock_start.assert_not_called() # Shouldn't start immediately in discrete mode + + profiler.stop() + self.assertFalse(profiler.this_step) + mock_stop.assert_not_called() # Shouldn't stop immediately in discrete mode + + def test_annotate_decorator(self): + mock_self = MagicMock() + mock_self.profiler = self.profiler + mock_self.profiler.this_step = True + + @NsightSystemsProfiler.annotate(message="test") + def test_func(self, *args, **kwargs): + return "result" + + with ( + patch("torch.cuda.profiler.start") as mock_start, + patch("torch.cuda.profiler.stop") as mock_stop, + patch("verl.utils.profiler.nvtx_profile.mark_start_range") as mock_start_range, + patch("verl.utils.profiler.nvtx_profile.mark_end_range") as mock_end_range, + ): + result = test_func(mock_self) + self.assertEqual(result, "result") + mock_start_range.assert_called_once() + mock_end_range.assert_called_once() + mock_start.assert_not_called() # Not discrete mode + mock_stop.assert_not_called() # Not discrete mode + + def test_annotate_discrete_mode(self): + discrete_config = ProfilerConfig(discrete=True, all_ranks=True) + profiler = NsightSystemsProfiler(self.rank, discrete_config) + mock_self = MagicMock() + mock_self.profiler = profiler + mock_self.profiler.this_step = True + + @NsightSystemsProfiler.annotate(message="test") + def test_func(self, *args, **kwargs): + return "result" + + with ( + patch("torch.cuda.profiler.start") as mock_start, + patch("torch.cuda.profiler.stop") as mock_stop, + patch("verl.utils.profiler.nvtx_profile.mark_start_range") as mock_start_range, + patch("verl.utils.profiler.nvtx_profile.mark_end_range") as mock_end_range, + ): + result = test_func(mock_self) + self.assertEqual(result, "result") + mock_start_range.assert_called_once() + mock_end_range.assert_called_once() + mock_start.assert_called_once() # Should start in discrete mode + mock_stop.assert_called_once() # Should stop in discrete mode + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/utils/test_rollout_trace_on_cpu.py b/tests/utils/test_rollout_trace_on_cpu.py new file mode 100644 index 000000000..04dfbeef8 --- /dev/null +++ b/tests/utils/test_rollout_trace_on_cpu.py @@ -0,0 +1,170 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +from unittest.mock import MagicMock, patch + +import pytest + +from verl.utils.rollout_trace import RolloutTraceConfig, rollout_trace_attr, rollout_trace_op + + +@pytest.fixture(autouse=True) +def reset_rollout_trace_config_singleton(): + """Fixture to reset the RolloutTraceConfig singleton before each test.""" + RolloutTraceConfig.reset() + + +@pytest.fixture +def mock_weave_client(): + """Mocks the weave module and its client, yielding the mock client.""" + mock_weave = MagicMock() + mock_client = MagicMock() + mock_call = MagicMock() + mock_client.create_call.return_value = mock_call + mock_weave.init.return_value = mock_client + + # Also mock the call_context if it's used internally by the decorator + mock_weave.trace.context.call_context.return_value = MagicMock() + + with patch.dict(sys.modules, {"weave": mock_weave, "weave.trace.context": mock_weave.trace.context}): + yield mock_client + + +class TracedClass: + @rollout_trace_op + # @weave.op + # @mlflow.trace + async def my_method(self, a, b="default"): + return f"result: {a}, {b}" + + @rollout_trace_op + # @weave.op + # @mlflow.trace + async def middle_method(self, a, b="default"): + await self.my_method("test_a1", b="test_b1") + return f"result: {a}, {b}" + + @rollout_trace_op + # @mlflow.trace + async def my_method_with_exception(self): + raise ValueError("Test Exception") + + async def upper_method(self): + await self.my_method("test_a0", b="test_b0") + await self.middle_method("test_a2", b="test_b2") + return True + + +class UntracedClass: + @rollout_trace_op + async def my_method(self, x): + return x * 2 + + +async def test_rollout_trace_on_untraced_class(): + """Tests that the decorator works correctly when no backend is configured.""" + instance = UntracedClass() + assert await instance.my_method(10) == 20 + + +async def test_rollout_trace_with_tracer(mock_weave_client): + """Tests that the decorator calls the tracer's methods correctly.""" + RolloutTraceConfig.init(project_name="my-project", experiment_name="my-experiment", backend="weave") + instance = TracedClass() + assert RolloutTraceConfig.get_client() is mock_weave_client + + result = await instance.my_method("test_a", b="test_b") + + assert result == "result: test_a, test_b" + mock_weave_client.create_call.assert_called_once() + call_kwargs = mock_weave_client.create_call.call_args.kwargs + assert call_kwargs["op"] == "TracedClass.my_method" + expected_inputs = {"a": "test_a", "b": "test_b"} + assert call_kwargs["inputs"] == expected_inputs + + mock_call = mock_weave_client.create_call.return_value + mock_weave_client.finish_call.assert_called_once_with(mock_call, output=result) + + +async def test_rollout_trace_with_exception(mock_weave_client): + """Tests that `finish` is called with the exception when one is raised.""" + RolloutTraceConfig.init(project_name="my-project", experiment_name="my-experiment", backend="weave") + instance = TracedClass() + + with pytest.raises(ValueError, match="Test Exception"): + await instance.my_method_with_exception() + + mock_weave_client.create_call.assert_called_once() + mock_call = mock_weave_client.create_call.return_value + mock_weave_client.finish_call.assert_called_once() + + # Check that finish_call was called with the exception + args, kwargs = mock_weave_client.finish_call.call_args + assert args[0] == mock_call + assert "exception" in kwargs + assert isinstance(kwargs["exception"], ValueError) + + +async def test_rollout_trace_with_dummy_backend(mock_weave_client): + """Tests that the tracer is not called when the backend is 'dummy'.""" + RolloutTraceConfig.init(project_name="my-project", experiment_name="my-experiment", backend="dummy") + instance = TracedClass() + + await instance.my_method("test_a") + + mock_weave_client.create_call.assert_not_called() + + +@pytest.mark.skipif( + os.environ.get("RUN_WEAVE_INTEGRATION_TESTS", "false").lower() != "true", + reason="Skipping weave integration test. Set RUN_WEAVE_INTEGRATION_TESTS=true to run.", +) +async def test_rollout_trace_with_real_weave_backend(): + """Integration test with a real weave backend.""" + + # This assumes that the weave environment (e.g., project) is configured + RolloutTraceConfig.init(project_name="my-project", experiment_name="my-experiment", backend="weave") + + instance = TracedClass() + + with rollout_trace_attr(step=1, sample_index=2, rollout_n=3): + await instance.upper_method() + + with pytest.raises(ValueError, match="Test Exception"): + await instance.my_method_with_exception() + + print("\nWeave integration test ran successfully. Check your weave project for the trace.") + + +@pytest.mark.skipif( + os.environ.get("RUN_MLFLOW_INTEGRATION_TESTS", "false").lower() != "true", + reason="Skipping mlflow integration test. Set RUN_MLFLOW_INTEGRATION_TESTS=true to run.", +) +async def test_rollout_trace_with_real_mlflow_backend(): + """Integration test with a real mlflow backend.""" + + # This assumes that the mlflow environment (e.g., project) is configured + RolloutTraceConfig.init(project_name="my-project", experiment_name="my-experiment", backend="mlflow") + + instance = TracedClass() + + with rollout_trace_attr(step=1, sample_index=2, rollout_n=3, name="agent_run"): + assert await instance.upper_method() + + # with pytest.raises(ValueError, match="Test Exception"): + # await instance.my_method_with_exception() + + print("\nWeave integration test ran successfully. Check your weave project for the trace.") diff --git a/tests/utils/gpu_tests/test_seqlen_balancing.py b/tests/utils/test_seqlen_balancing.py similarity index 58% rename from tests/utils/gpu_tests/test_seqlen_balancing.py rename to tests/utils/test_seqlen_balancing.py index 154f0deb9..9de777f1c 100644 --- a/tests/utils/gpu_tests/test_seqlen_balancing.py +++ b/tests/utils/test_seqlen_balancing.py @@ -18,13 +18,21 @@ from verl import DataProto from verl.utils.model import create_random_mask -from verl.utils.seqlen_balancing import ceildiv, get_reverse_idx, rearrange_micro_batches +from verl.utils.seqlen_balancing import ( + ceildiv, + get_reverse_idx, + prepare_dynamic_batch, + rearrange_micro_batches, + restore_dynamic_batch, +) def test_seqlen_balancing(): input_ids = torch.randint(low=0, high=10, size=(20, 100)) - attention_mask = create_random_mask(input_ids=input_ids, max_ratio_of_left_padding=0.1, max_ratio_of_valid_token=0.9, min_ratio_of_valid_token=0.5) + attention_mask = create_random_mask( + input_ids=input_ids, max_ratio_of_left_padding=0.1, max_ratio_of_valid_token=0.9, min_ratio_of_valid_token=0.5 + ) data = {"input_ids": input_ids, "attention_mask": attention_mask} dataproto = DataProto.from_single_dict(data) micro_batches, micro_bsz_idx_lst = rearrange_micro_batches(dataproto.batch, max_token_len=300) @@ -38,6 +46,20 @@ def test_seqlen_balancing(): torch.testing.assert_close(new_batch, dataproto.batch) +def test_dynamic_batch(): + input_ids = torch.randint(low=0, high=10, size=(20, 100)) + + attention_mask = create_random_mask( + input_ids=input_ids, max_ratio_of_left_padding=0.1, max_ratio_of_valid_token=0.9, min_ratio_of_valid_token=0.5 + ) + data = {"input_ids": input_ids, "attention_mask": attention_mask} + dataproto = DataProto.from_single_dict(data) + micro_batches, micro_bsz_idx_lst = prepare_dynamic_batch(dataproto, max_token_len=300) + input_ids = torch.cat([micro_batch.batch["input_ids"] for micro_batch in micro_batches], dim=0) + input_ids = restore_dynamic_batch(input_ids, micro_bsz_idx_lst) + torch.testing.assert_close(input_ids, dataproto.batch["input_ids"]) + + def _worker(rank, world_size, init_method, max_token_len, use_same_dp, min_mb): # 1) init process group & CUDA torch.cuda.set_device(rank) @@ -102,6 +124,60 @@ def _worker(rank, world_size, init_method, max_token_len, use_same_dp, min_mb): dist.destroy_process_group() +def test_dataproto_split_uneven(): + """Test DataProto.split with uneven splits""" + # Create test data with 10 items + input_ids = torch.randint(low=0, high=10, size=(10, 5)) + attention_mask = torch.ones(10, 5) + data = {"input_ids": input_ids, "attention_mask": attention_mask} + dataproto = DataProto.from_single_dict(data) + + # Test split with size 3 (should create chunks of [3, 3, 3, 1]) + splits = dataproto.split(3) + assert len(splits) == 4 + assert len(splits[0]) == 3 + assert len(splits[1]) == 3 + assert len(splits[2]) == 3 + assert len(splits[3]) == 1 + + reconstructed = DataProto.concat(splits) + torch.testing.assert_close(reconstructed.batch["input_ids"], dataproto.batch["input_ids"]) + torch.testing.assert_close(reconstructed.batch["attention_mask"], dataproto.batch["attention_mask"]) + + # Test split with size equal to length (should create one chunk) + splits = dataproto.split(10) + assert len(splits) == 1 + assert len(splits[0]) == 10 + + # Test split with size larger than length (should create one chunk with all data) + splits = dataproto.split(15) + assert len(splits) == 1 + assert len(splits[0]) == 10 + + # Test with non-tensor batch data + import numpy as np + + data_with_non_tensor = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "labels": np.array([f"label_{i}" for i in range(10)], dtype=object), + } + dataproto_with_non_tensor = DataProto.from_single_dict(data_with_non_tensor) + + splits = dataproto_with_non_tensor.split(3) + assert len(splits) == 4 + assert len(splits[0]) == 3 + assert len(splits[1]) == 3 + assert len(splits[2]) == 3 + assert len(splits[3]) == 1 + + # Verify non-tensor data integrity + reconstructed = DataProto.concat(splits) + np.testing.assert_array_equal( + reconstructed.non_tensor_batch["labels"], dataproto_with_non_tensor.non_tensor_batch["labels"] + ) + + def test_seqlen_balancing_distributed_params(tmp_path): world_size = 2 init_file = tmp_path / "dist_init" diff --git a/tests/utils/test_temp_env_on_cpu.py b/tests/utils/test_temp_env_on_cpu.py new file mode 100644 index 000000000..851e4cbe4 --- /dev/null +++ b/tests/utils/test_temp_env_on_cpu.py @@ -0,0 +1,143 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import pytest + +from verl.utils.py_functional import temp_env_var + + +@pytest.fixture(autouse=True) +def clean_env(): + """Fixture to clean up environment variables before and after each test.""" + # Store original environment state + original_env = dict(os.environ) + + # Clean up any test variables that might exist + test_vars = ["TEST_VAR", "TEST_VAR_2", "EXISTING_VAR"] + for var in test_vars: + if var in os.environ: + del os.environ[var] + + # Yield control to the test function + yield + + # Restore original environment state after test + os.environ.clear() + os.environ.update(original_env) + + +def test_set_new_env_var(): + """Test setting a new environment variable that didn't exist before.""" + # Ensure variable doesn't exist + assert "TEST_VAR" not in os.environ + + with temp_env_var("TEST_VAR", "test_value"): + # Variable should be set inside context + assert os.environ["TEST_VAR"] == "test_value" + assert "TEST_VAR" in os.environ + + # Variable should be removed after context + assert "TEST_VAR" not in os.environ + + +def test_restore_existing_env_var(): + """Test restoring an environment variable that already existed.""" + # Set up existing variable + os.environ["EXISTING_VAR"] = "original_value" + + with temp_env_var("EXISTING_VAR", "temporary_value"): + # Variable should be temporarily changed + assert os.environ["EXISTING_VAR"] == "temporary_value" + + # Variable should be restored to original value + assert os.environ["EXISTING_VAR"] == "original_value" + + +def test_env_var_restored_on_exception(): + """Test that environment variables are restored even when exceptions occur.""" + # Set up existing variable + os.environ["EXISTING_VAR"] = "original_value" + + with pytest.raises(ValueError): + with temp_env_var("EXISTING_VAR", "temporary_value"): + # Verify variable is set + assert os.environ["EXISTING_VAR"] == "temporary_value" + # Raise exception + raise ValueError("Test exception") + + # Variable should still be restored despite exception + assert os.environ["EXISTING_VAR"] == "original_value" + + +def test_nested_context_managers(): + """Test nested temp_env_var context managers.""" + # Set up original variable + os.environ["TEST_VAR"] = "original" + + with temp_env_var("TEST_VAR", "level1"): + assert os.environ["TEST_VAR"] == "level1" + + with temp_env_var("TEST_VAR", "level2"): + assert os.environ["TEST_VAR"] == "level2" + + # Should restore to level1 + assert os.environ["TEST_VAR"] == "level1" + + # Should restore to original + assert os.environ["TEST_VAR"] == "original" + + +def test_multiple_different_vars(): + """Test setting multiple different environment variables.""" + # Set up one existing variable + os.environ["EXISTING_VAR"] = "existing_value" + + with temp_env_var("EXISTING_VAR", "modified"): + with temp_env_var("TEST_VAR", "new_value"): + assert os.environ["EXISTING_VAR"] == "modified" + assert os.environ["TEST_VAR"] == "new_value" + + # Check restoration + assert os.environ["EXISTING_VAR"] == "existing_value" + assert "TEST_VAR" not in os.environ + + +def test_empty_string_value(): + """Test setting environment variable to empty string.""" + with temp_env_var("TEST_VAR", ""): + assert os.environ["TEST_VAR"] == "" + assert "TEST_VAR" in os.environ + + # Should be removed after context + assert "TEST_VAR" not in os.environ + + +def test_overwrite_with_empty_string(): + """Test overwriting existing variable with empty string.""" + os.environ["EXISTING_VAR"] = "original" + + with temp_env_var("EXISTING_VAR", ""): + assert os.environ["EXISTING_VAR"] == "" + + # Should restore original value + assert os.environ["EXISTING_VAR"] == "original" + + +def test_context_manager_returns_none(): + """Test that context manager yields None.""" + with temp_env_var("TEST_VAR", "value") as result: + assert result is None + assert os.environ["TEST_VAR"] == "value" diff --git a/tests/utils/cpu_tests/test_timeout_decorator.py b/tests/utils/test_timeout_decorator_cpu.py similarity index 100% rename from tests/utils/cpu_tests/test_timeout_decorator.py rename to tests/utils/test_timeout_decorator_cpu.py diff --git a/tests/utils/gpu_tests/test_torch_functional.py b/tests/utils/test_torch_functional.py similarity index 87% rename from tests/utils/gpu_tests/test_torch_functional.py rename to tests/utils/test_torch_functional.py index 5084e18ea..900cb5d54 100644 --- a/tests/utils/gpu_tests/test_torch_functional.py +++ b/tests/utils/test_torch_functional.py @@ -19,7 +19,7 @@ import torch.distributed as dist import torch.multiprocessing as mp -from verl.utils.torch_functional import distributed_masked_mean, distributed_mean_max_min_std +from verl.utils.torch_functional import distributed_masked_mean, distributed_mean_max_min_std, masked_mean def _worker_mean(rank: int, world_size: int, rendezvous_file: str): @@ -52,6 +52,20 @@ def _worker_mean(rank: int, world_size: int, rendezvous_file: str): dist.destroy_process_group() +@pytest.mark.parametrize( + "value,mask,gt", + [ + ([1.0, 2.0, 3.0, 4.0], [1, 0, 0, 1], 2.5), + ([1.0, 2.0, float("nan"), 4.0], [1, 0, 0, 1], 2.5), + ([1.0, 2.0, float("nan"), 4.0], [1, 0, 1, 0], float("nan")), + ], +) +def test_masked_mean(value, mask, gt): + res = masked_mean(torch.tensor(value), torch.tensor(mask)) + gt = torch.tensor(gt) + assert torch.allclose(res, gt) or (torch.isnan(res) and torch.isnan(gt)) + + @pytest.mark.parametrize("world_size", [2, 4]) def test_distributed_mean_max_min_std(world_size, tmp_path): rendezvous_file = str(tmp_path / "rdzv_mean") diff --git a/tests/workers/reward_manager/test_registry_on_cpu.py b/tests/workers/reward_manager/test_registry_on_cpu.py new file mode 100644 index 000000000..9932ae891 --- /dev/null +++ b/tests/workers/reward_manager/test_registry_on_cpu.py @@ -0,0 +1,94 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +# Assuming REWARD_MANAGER_REGISTRY is defined somewhere in the module +from verl.workers.reward_manager.registry import REWARD_MANAGER_REGISTRY, get_reward_manager_cls, register + + +@pytest.fixture +def setup(): + """Setup test cases with a mock registry.""" + REWARD_MANAGER_REGISTRY.clear() + REWARD_MANAGER_REGISTRY.update({"manager1": "Manager1Class", "manager2": "Manager2Class"}) + return REWARD_MANAGER_REGISTRY + + +def test_get_existing_manager(setup): + """Test getting an existing reward manager class.""" + assert get_reward_manager_cls("manager1") == "Manager1Class" + assert get_reward_manager_cls("manager2") == "Manager2Class" + + +def test_get_nonexistent_manager(setup): + """Test getting a non-existent reward manager raises ValueError.""" + with pytest.raises(ValueError) as excinfo: + get_reward_manager_cls("unknown_manager") + assert "Unknown reward manager: unknown_manager" in str(excinfo.value) + + +def test_case_sensitivity(setup): + """Test that manager names are case-sensitive.""" + with pytest.raises(ValueError): + get_reward_manager_cls("MANAGER1") + with pytest.raises(ValueError): + get_reward_manager_cls("Manager1") + + +def test_empty_registry(setup): + """Test behavior when registry is empty.""" + REWARD_MANAGER_REGISTRY.clear() + with pytest.raises(ValueError) as excinfo: + get_reward_manager_cls("any_manager") + assert "Unknown reward manager: any_manager" in str(excinfo.value) + + +def test_register_new_class(setup): + """Test registering a new class with the decorator.""" + + @register("test_manager") + class TestManager: + pass + + assert "test_manager" in REWARD_MANAGER_REGISTRY + assert REWARD_MANAGER_REGISTRY["test_manager"] == TestManager + + +def test_register_different_classes_same_name(setup): + """Test that registering different classes with same name raises ValueError.""" + + @register("conflict_manager") + class Manager1: + pass + + with pytest.raises(ValueError): + + @register("conflict_manager") + class Manager2: + pass + + assert REWARD_MANAGER_REGISTRY["conflict_manager"] == Manager1 + + +def test_decorator_returns_original_class(setup): + """Test that the decorator returns the original class unchanged.""" + + @register("return_test") + class OriginalClass: + def method(setup): + return 42 + + assert OriginalClass().method() == 42 + assert REWARD_MANAGER_REGISTRY["return_test"] == OriginalClass diff --git a/tests/workers/rollout/async_rollout_utils.py b/tests/workers/rollout/async_rollout_utils.py index bc6186553..22f20291e 100644 --- a/tests/workers/rollout/async_rollout_utils.py +++ b/tests/workers/rollout/async_rollout_utils.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os -from typing import Any, Dict import ray from omegaconf import DictConfig @@ -24,12 +22,7 @@ from verl.workers.rollout.async_server import AsyncLLMServerManager -def init_async_rollout_manager(config: DictConfig, scheduler_kwargs: Dict[str, Any] = None) -> AsyncLLMServerManager: - # make openai client happy - os.environ["no_proxy"] = "" - os.environ["http_proxy"] = "" - os.environ["https_proxy"] = "" - +def init_async_rollout_manager(config: DictConfig) -> AsyncLLMServerManager: # =========================== 1. Create hybrid ActorRollout workers =========================== role_worker_mapping = { Role.ActorRollout: ray.remote(AsyncActorRolloutRefWorker), @@ -47,7 +40,9 @@ def init_async_rollout_manager(config: DictConfig, scheduler_kwargs: Dict[str, A # create actor and rollout resource_pool = resource_pool_manager.get_resource_pool(Role.ActorRollout) - actor_rollout_cls = RayClassWithInitArgs(cls=role_worker_mapping[Role.ActorRollout], config=config.actor_rollout_ref, role="actor_rollout") + actor_rollout_cls = RayClassWithInitArgs( + cls=role_worker_mapping[Role.ActorRollout], config=config.actor_rollout_ref, role="actor_rollout" + ) resource_pool_to_cls[resource_pool]["actor_rollout"] = actor_rollout_cls all_wg = {} @@ -61,9 +56,8 @@ def init_async_rollout_manager(config: DictConfig, scheduler_kwargs: Dict[str, A # =========================== 2. Create AsyncLLMServerManager =========================== async_rollout_manager = AsyncLLMServerManager( - config=config.actor_rollout_ref, + config=config, worker_group=actor_rollout_wg, - scheduler_kwargs=scheduler_kwargs, ) return async_rollout_manager diff --git a/tests/workers/rollout/perf/vllm_async_rollout.py b/tests/workers/rollout/perf/vllm_async_rollout.py new file mode 100644 index 000000000..dbcd255df --- /dev/null +++ b/tests/workers/rollout/perf/vllm_async_rollout.py @@ -0,0 +1,135 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Compare vLLM AsyncLLM backend: ExternalRayDistributedExecutor(remote call) vs RayDistributedExecutor(compiled graph) + +1. Prepare openai/gsm8k dataset +python3 examples/data_preprocess/gsm8k.py + +2. Run perf test +python3 tests/workers/rollout/perf/vllm_async_rollout.py >perf.log 2>&1 + +hardware: Nvidia 8*H20 +packages: +- torch==2.6.0 +- vllm==0.8.5 + +[DEBUG] backend: sync, n_gpus_per_node: 8, batch_size: 2048, step: 0, step_time: 21.27 secs +[DEBUG] backend: zeromq, n_gpus_per_node: 8, batch_size: 2048, step: 0, step_time: 23.40 secs +[DEBUG] backend: ray, n_gpus_per_node: 8, batch_size: 2048, step: 0, step_time: 25.33 secs +""" + +import os +import time + +import ray +from omegaconf import DictConfig +from torch.utils.data import SequentialSampler +from torchdata.stateful_dataloader import StatefulDataLoader + +from tests.experimental.agent_loop.agent_utils import AgentLoopManager, RayWorkerGroup, init_agent_loop_manager +from verl.protocol import DataProto +from verl.utils import hf_tokenizer +from verl.utils.dataset import RLHFDataset +from verl.utils.dataset.rl_dataset import collate_fn as default_collate_fn + + +def init_config(n_gpus_per_node) -> DictConfig: + import os + + from hydra import compose, initialize_config_dir + + with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")): + config = compose(config_name="ppo_trainer") + config.trainer.n_gpus_per_node = n_gpus_per_node + config.data.train_batch_size = 128 + config.data.return_raw_chat = True + config.actor_rollout_ref.model.path = "Qwen/Qwen2.5-7B-Instruct" + config.actor_rollout_ref.rollout.mode = "async" + config.actor_rollout_ref.rollout.tensor_model_parallel_size = 2 + config.actor_rollout_ref.rollout.gpu_memory_utilization = 0.9 + config.actor_rollout_ref.rollout.multi_turn.format = "hermes" + config.actor_rollout_ref.rollout.prompt_length = 4096 + config.actor_rollout_ref.rollout.response_length = 4096 + config.actor_rollout_ref.rollout.n = 16 + + # test sleep/wake_up with fsdp offload + config.actor_rollout_ref.actor.fsdp_config.param_offload = True + config.actor_rollout_ref.actor.fsdp_config.optimizer_offload = True + + return config + + +def initialize(config, backend) -> tuple[AgentLoopManager | RayWorkerGroup, StatefulDataLoader]: + env_vars = { + "NCCL_DEBUG": "WARN", + "VLLM_USE_V1": "1", + "VERL_VLLM_DISTRIBUTED_BACKEND": backend, + } + ray.init(runtime_env={"env_vars": env_vars}) + + # STEP 1: init async llm server + server = init_agent_loop_manager(config) + + # STEP 2: create dataloader + tokenizer = hf_tokenizer(config.actor_rollout_ref.model.path) + dataset = RLHFDataset( + data_files=os.path.expanduser("~/data/gsm8k/train.parquet"), + tokenizer=tokenizer, + config=config.data, + ) + dataloader = StatefulDataLoader( + dataset=dataset, + batch_size=config.data.get("gen_batch_size", config.data.train_batch_size), + num_workers=config.data.get("dataloader_num_workers", 8), + drop_last=True, + collate_fn=default_collate_fn, + sampler=SequentialSampler(dataset), + ) + + return server, dataloader + + +def perf_rollout(mode, backend, n_gpus_per_node, num_steps): + config = init_config(n_gpus_per_node) + config.actor_rollout_ref.rollout.mode = mode + agent_loop_manager, dataloader = initialize(config, backend) + + for step, batch in enumerate(dataloader): + batch: DataProto = DataProto.from_single_dict(batch) + batch = batch.pop( + batch_keys=["input_ids", "attention_mask", "position_ids"], + non_tensor_batch_keys=["raw_prompt_ids", "raw_prompt"], + ) + t_start = time.time() + gen_batch = agent_loop_manager.generate_sequences(batch) + t_end = time.time() + print( + f"[DEBUG] backend: {backend}, n_gpus_per_node: {n_gpus_per_node}, batch_size: {len(gen_batch)}, " + f"step: {step}, step_time: {t_end - t_start:.2f} secs" + ) + if step + 1 >= num_steps: + break + + ray.shutdown() + + +if __name__ == "__main__": + num_steps = 1 + n_gpus_per_node = 8 + + # test_cases = [("sync", "sync"), ("async", "zeromq"), ("async", "ray")] + test_cases = [("async", "zeromq"), ("async", "ray")] + for mode, backend in test_cases: + perf_rollout(mode=mode, backend=backend, n_gpus_per_node=n_gpus_per_node, num_steps=num_steps) diff --git a/tests/workers/rollout/resource/tool_configs/mcp_server.json b/tests/workers/rollout/resource/tool_configs/mcp_server.json new file mode 100644 index 000000000..9ed41f10b --- /dev/null +++ b/tests/workers/rollout/resource/tool_configs/mcp_server.json @@ -0,0 +1,8 @@ +{ + "mcpServers": { + "Tavily Expert": { + "url": "https://tavily.api.tadata.com/mcp/tavily/your_expert", + "auth_token": "your_tavily_token" + } + } +} \ No newline at end of file diff --git a/tests/workers/rollout/resource/tool_configs/mcp_tool_config b/tests/workers/rollout/resource/tool_configs/mcp_tool_config new file mode 100644 index 000000000..a9a45bd0b --- /dev/null +++ b/tests/workers/rollout/resource/tool_configs/mcp_tool_config @@ -0,0 +1,11 @@ +tools: + - class_name: verl.tools.mcp_search_tool.MCPSearchTool + config: + rate_limit: 120 + timeout: 120 + type: mcp + mcp: + mcp_servers_config_path: ./resource/tool_configs/mcp_server.json + # optional + tool_selected_list: + - tavily_search_tool \ No newline at end of file diff --git a/tests/workers/rollout/resource/tool_configs/sandbox_fusion_tool_config b/tests/workers/rollout/resource/tool_configs/sandbox_fusion_tool_config index dcfe6ef26..aa3f1eec5 100644 --- a/tests/workers/rollout/resource/tool_configs/sandbox_fusion_tool_config +++ b/tests/workers/rollout/resource/tool_configs/sandbox_fusion_tool_config @@ -2,6 +2,7 @@ tools: - class_name: "verl.tools.sandbox_fusion_tools.SandboxFusionTool" config: sandbox_fusion_url: "https://xxx.apigateway-cn-beijing.volceapi.com/run_code" + type: native tool_schema: type: "function" function: diff --git a/tests/workers/rollout/resource/tool_configs/search_tool_config b/tests/workers/rollout/resource/tool_configs/search_tool_config index 79b647e62..926b6b832 100644 --- a/tests/workers/rollout/resource/tool_configs/search_tool_config +++ b/tests/workers/rollout/resource/tool_configs/search_tool_config @@ -5,6 +5,7 @@ tools: num_workers: 120 rate_limit: 120 timeout: 30 + type: native tool_schema: type: function function: diff --git a/tests/workers/rollout/run_fsdp_vllm.py b/tests/workers/rollout/rollout_vllm/run_fsdp_vllm.py similarity index 93% rename from tests/workers/rollout/run_fsdp_vllm.py rename to tests/workers/rollout/rollout_vllm/run_fsdp_vllm.py index efd64e304..69223890d 100644 --- a/tests/workers/rollout/run_fsdp_vllm.py +++ b/tests/workers/rollout/rollout_vllm/run_fsdp_vllm.py @@ -100,11 +100,15 @@ def main(): device_mesh=device_mesh, ) - FSDP.set_state_dict_type(fsdp_model, state_dict_type=StateDictType.SHARDED_STATE_DICT, state_dict_config=ShardedStateDictConfig()) + FSDP.set_state_dict_type( + fsdp_model, state_dict_type=StateDictType.SHARDED_STATE_DICT, state_dict_config=ShardedStateDictConfig() + ) state_dict = fsdp_model.state_dict() - sampling_params = SamplingParams(temperature=0, top_p=1, n=1, max_tokens=response_length, logprobs=1, ignore_eos=True, detokenize=False) + sampling_params = SamplingParams( + temperature=0, top_p=1, n=1, max_tokens=response_length, logprobs=1, ignore_eos=True, detokenize=False + ) print(actor_model_config) llm = LLM( @@ -142,7 +146,7 @@ def main(): batch_size = input_ids.shape[0] pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id - from verl.workers.rollout.vllm_rollout.vllm_rollout import _pre_process_inputs + from verl.workers.rollout.vllm_rollout.vllm_rollout_spmd import _pre_process_inputs for i in range(batch_size): idx_list.append(_pre_process_inputs(pad_token_id, input_ids[i])) diff --git a/tests/workers/rollout/rollout_vllm/test_vllm_chat_scheduler.py b/tests/workers/rollout/rollout_vllm/test_vllm_chat_scheduler.py new file mode 100644 index 000000000..93aca6a2d --- /dev/null +++ b/tests/workers/rollout/rollout_vllm/test_vllm_chat_scheduler.py @@ -0,0 +1,242 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +from typing import Any + +import numpy as np +import pytest +import ray +from omegaconf import DictConfig +from transformers.utils import get_json_schema + +from tests.workers.rollout.async_rollout_utils import init_async_rollout_manager +from verl.protocol import DataProto +from verl.tools.base_tool import BaseTool, OpenAIFunctionToolSchema +from verl.utils import hf_tokenizer + + +@pytest.fixture +def init_config() -> DictConfig: + import os + + from hydra import compose, initialize_config_dir + + with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")): + config = compose(config_name="ppo_trainer") + model_path = "Qwen/Qwen2.5-1.5B-Instruct" + config.actor_rollout_ref.model.path = model_path + config.actor_rollout_ref.rollout.mode = "async" + config.actor_rollout_ref.rollout.multi_turn.format = "hermes" + config.actor_rollout_ref.rollout.prompt_length = 4096 + config.actor_rollout_ref.rollout.response_length = 4096 + + # test sleep/wake_up with fsdp offload + config.actor_rollout_ref.actor.fsdp_config.param_offload = True + config.actor_rollout_ref.actor.fsdp_config.optimizer_offload = True + + return config + + +def test_vllm_async_rollout_without_tool_calls(init_config): + ray.init( + runtime_env={ + "env_vars": { + "TOKENIZERS_PARALLELISM": "true", + "NCCL_DEBUG": "WARN", + "VLLM_LOGGING_LEVEL": "INFO", + "VLLM_USE_V1": "1", + } + } + ) + + # =========================== 1. Init rollout manager =========================== + async_rollout_manager = init_async_rollout_manager(init_config) + + # test sleep and wake_up + async_rollout_manager.sleep() + async_rollout_manager.wake_up() + + # =========================== 2. Generate sequences =========================== + raw_prompts = [ + [ + { + "role": "user", + "content": "Let's play a role playing game. Your name is Alice, your favorite color is blue.", + } + ], + [{"role": "user", "content": "Let's play a role playing game. Your name is Bob, your favorite color is red."}], + ] + batch = DataProto( + non_tensor_batch={ + "raw_prompt": np.array(raw_prompts), + }, + ) + result = async_rollout_manager.generate_sequences(prompts=batch) + + # check result + seq_len = result.batch["prompts"].size(1) + result.batch["responses"].size(1) + assert len(result) == 2 + assert result.batch["input_ids"].size(1) == seq_len + assert result.batch["attention_mask"].size(1) == seq_len + assert result.batch["position_ids"].size(1) == seq_len + + # check turns + num_turns = result.non_tensor_batch["__num_turns__"] + assert np.all(num_turns == 2) + + print("Test passed!") + ray.shutdown() + + +class WeatherTool(BaseTool): + def get_current_temperature(self, location: str, unit: str = "celsius"): + """Get current temperature at a location. + + Args: + location: The location to get the temperature for, in the format "City, State, Country". + unit: The unit to return the temperature in. Defaults to "celsius". (choices: ["celsius", "fahrenheit"]) + + Returns: + the temperature, the location, and the unit in a dict + """ + return { + "temperature": 26.1, + "location": location, + "unit": unit, + } + + def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema: + schema = get_json_schema(self.get_current_temperature) + return OpenAIFunctionToolSchema(**schema) + + async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]: + try: + result = self.get_current_temperature(**parameters) + return json.dumps(result), 0, {} + except Exception as e: + return str(e), 0, {} + + +class WeatherToolWithData(BaseTool): + def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema: + schema = get_json_schema(self.get_temperature_date) + return OpenAIFunctionToolSchema(**schema) + + def get_temperature_date(self, location: str, date: str, unit: str = "celsius"): + """Get temperature at a location and date. + + Args: + location: The location to get the temperature for, in the format "City, State, Country". + date: The date to get the temperature for, in the format "Year-Month-Day". + unit: The unit to return the temperature in. Defaults to "celsius". (choices: ["celsius", "fahrenheit"]) + + Returns: + the temperature, the location, the date and the unit in a dict + """ + return { + "temperature": 25.9, + "location": location, + "date": date, + "unit": unit, + } + + async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]: + try: + result = self.get_temperature_date(**parameters) + return json.dumps(result), 0, {} + except Exception as e: + return str(e), 0, {} + + +def test_vllm_async_rollout_with_tool_calls(init_config): + ray.init( + runtime_env={ + "env_vars": { + "TOKENIZERS_PARALLELISM": "true", + "NCCL_DEBUG": "WARN", + "VLLM_LOGGING_LEVEL": "INFO", + "VLLM_USE_V1": "1", + } + } + ) + + # =========================== 1. Init rollout manager =========================== + tool_config = { + "tools": [ + { + "class_name": "tests.workers.rollout.rollout_vllm.test_vllm_chat_scheduler.WeatherTool", + "config": {"type": "native"}, + }, + { + "class_name": "tests.workers.rollout.rollout_vllm.test_vllm_chat_scheduler.WeatherToolWithData", + "config": {"type": "native"}, + }, + ] + } + tool_config_path = "/tmp/tool_config.json" + with open(tool_config_path, "w") as f: + json.dump(tool_config, f) + + init_config.actor_rollout_ref.rollout.multi_turn.tool_config_path = tool_config_path + async_rollout_manager = init_async_rollout_manager(init_config) + + # =========================== 2. Generate sequences =========================== + raw_prompts = [ + [ + {"role": "user", "content": "How are you?"}, + ], + [ + {"role": "user", "content": "What's the temperature in Los Angeles now?"}, + ], + [ + { + "role": "system", + "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant.\n\n" + "Current Date: 2024-09-30", + }, + {"role": "user", "content": "What's the temperature in San Francisco now? How about tomorrow?"}, + ], + ] + batch = DataProto( + non_tensor_batch={ + "raw_prompt": np.array([np.array(prompt) for prompt in raw_prompts], dtype=object), + }, + ) + result = async_rollout_manager.generate_sequences(prompts=batch) + + # Check turns + num_turns = result.non_tensor_batch["__num_turns__"] + # [user, assistant] + assert num_turns[0] == 2 + # [user, assistant, tool, assistant] + assert num_turns[1] == 4 + # [system, user, assistant, tool, tool, assistant] + assert num_turns[2] == 6 + + # Check response_mask + tokenizer = hf_tokenizer(init_config.actor_rollout_ref.model.path) + responses = result.batch["responses"] + response_mask = result.batch["response_mask"] + assert responses.size() == response_mask.size(), f"{responses.size()} != {response_mask.size()}" + + # Decode responses with response_mask + for i in range(len(responses)): + valid_tokens = responses[i][response_mask[i].bool()] + response_str = tokenizer.decode(valid_tokens) + assert "" not in response_str, f"found in response: {response_str}" + assert "" not in response_str, f"found in response: {response_str}" + print(f"response: {response_str}") + + print("Test passed!") + ray.shutdown() diff --git a/tests/workers/rollout/rollout_vllm/test_vllm_model_rope_scaling.py b/tests/workers/rollout/rollout_vllm/test_vllm_model_rope_scaling.py new file mode 100644 index 000000000..30c9ae2bc --- /dev/null +++ b/tests/workers/rollout/rollout_vllm/test_vllm_model_rope_scaling.py @@ -0,0 +1,131 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc + +import torch +import torch.distributed +import torch.distributed as dist +from omegaconf import OmegaConf +from transformers import AutoConfig, AutoTokenizer + +from verl import DataProto +from verl.utils.distributed import initialize_global_process_group +from verl.utils.model import compute_position_id_with_mask +from verl.workers.rollout.vllm_rollout.vllm_rollout_spmd import vLLMRollout + + +def test_vllm_rollout_with_yarn_position_embeddings(): + """ + Test the vLLM rollout with yarn position embeddings. + """ + + local_rank, rank, world_size = initialize_global_process_group() + config = OmegaConf.create( + { + "model_path": "OldKingMeister/Qwen2.5-1.5B-Instruct-YaRN", + "prompt_length": 35000, + "response_length": 512, + "dtype": "bfloat16", + "enforce_eager": True, + "gpu_memory_utilization": 0.4, + "enable_chunked_prefill": False, + "free_cache_engine": False, + "disable_log_stats": True, + "max_model_len": 35000 + 512, + "load_format": "auto", + "val_kwargs": { + "top_k": -1, + "top_p": 1.0, + "temperature": 0, + "n": 1, + "do_sample": False, + }, + "tensor_model_parallel_size": 4, + "trust_remote_code": True, + "calculate_log_probs": False, + "do_sample": False, + "temperature": 0.0, + "max_num_batched_tokens": 35000 + 512, + } + ) + + tokenizer = AutoTokenizer.from_pretrained(config.model_path, trust_remote_code=True, padding_side="left") + tokenizer.pad_token = tokenizer.eos_token + model_hf_config = AutoConfig.from_pretrained(config.model_path) + + # do_sample=False for temperate=0 deterministic + input_dataproto = prepare_input_dataproto(tokenizer, config, validate=True, do_sample=False) + + vllm_rollout = vLLMRollout( + model_path=config.model_path, + config=config, + tokenizer=tokenizer, + model_hf_config=model_hf_config, + ) + # rollout + rollout_response = vllm_rollout.generate_sequences( + prompts=input_dataproto, + ) + if rank == 0: + print("VLLM Rollout Outputs:") + print(tokenizer.batch_decode(rollout_response.batch["responses"][:], skip_special_tokens=False)) + for response in rollout_response.batch["responses"]: + assert "<|im_end|>" in tokenizer.decode(response, skip_special_tokens=False), ( + "Response should contain <|im_end|> token" + ) + print("Checks passed.") + + del vllm_rollout + gc.collect() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + dist.barrier() + torch.distributed.destroy_process_group() + + +def prepare_input_dataproto(tokenizer, config, validate, do_sample=False): + base_phrase = "Roses are red, sky is blue. " * 4096 + preencode_prompts = [ + # 32810 tokens > 32768 tokens + [{"role": "user", "content": base_phrase + "Who won the Champions League in 2019?"}], + [{"role": "user", "content": base_phrase + "The founder of Apple is"}], + [{"role": "user", "content": base_phrase + "What's your name"}], + ] + formatted_prompts = [ + tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True) + for conversation in preencode_prompts + ] + prompts = tokenizer(formatted_prompts, return_tensors="pt", padding="max_length", max_length=config.prompt_length) + input_dataproto = DataProto.from_dict( + { + "input_ids": prompts["input_ids"], + "attention_mask": prompts["attention_mask"], + "position_ids": compute_position_id_with_mask(prompts["attention_mask"]), + }, + meta_info={ + "bos_token_id": tokenizer.bos_token_id, + "eos_token_id": tokenizer.eos_token_id, + "pad_token_id": tokenizer.pad_token_id, + "validate": validate, + "do_sample": do_sample, + "response_length": config.response_length, + "temperature": config.temperature, + }, + ) + return input_dataproto + + +if __name__ == "__main__": + test_vllm_rollout_with_yarn_position_embeddings() diff --git a/tests/workers/rollout/test_vllm_spmd.py b/tests/workers/rollout/rollout_vllm/test_vllm_spmd.py similarity index 91% rename from tests/workers/rollout/test_vllm_spmd.py rename to tests/workers/rollout/rollout_vllm/test_vllm_spmd.py index 5b9607875..c2b8f51cb 100644 --- a/tests/workers/rollout/test_vllm_spmd.py +++ b/tests/workers/rollout/rollout_vllm/test_vllm_spmd.py @@ -14,6 +14,7 @@ import os +import pytest import torch from torch.distributed.fsdp import CPUOffload, MixedPrecision from torch.distributed.fsdp import FullyShardedDataParallel as FSDP @@ -54,7 +55,7 @@ def are_lists_similar(a, b): total_length = 0 total_diff = 0 - for s1, s2 in zip(a, b): + for s1, s2 in zip(a, b, strict=True): max_len = max(len(s1), len(s2)) total_length += max_len diff = levenshtein(s1, s2) @@ -67,6 +68,7 @@ def are_lists_similar(a, b): return percentage_difference <= 15 +@pytest.mark.skip("https://github.com/vllm-project/vllm/issues/16993") def test_vllm_spmd(): assert torch.cuda.device_count() >= 2, "At least 2 GPUs is required to run tp+dp tests." local_rank, rank, world_size = initialize_global_process_group() @@ -105,7 +107,9 @@ def test_vllm_spmd(): temperature = 0 top_p = 1 - kwargs = dict(n=1, temperature=temperature, top_p=top_p, max_tokens=max_response_length, logprobs=1, ignore_eos=True) + kwargs = dict( + n=1, temperature=temperature, top_p=top_p, max_tokens=max_response_length, logprobs=1, ignore_eos=True + ) tensor_parallel_size = 4 @@ -127,7 +131,9 @@ def test_vllm_spmd(): device_mesh=device_mesh, ) - FSDP.set_state_dict_type(fsdp_model, state_dict_type=StateDictType.SHARDED_STATE_DICT, state_dict_config=ShardedStateDictConfig()) + FSDP.set_state_dict_type( + fsdp_model, state_dict_type=StateDictType.SHARDED_STATE_DICT, state_dict_config=ShardedStateDictConfig() + ) state_dict = fsdp_model.state_dict() @@ -141,7 +147,6 @@ def test_vllm_spmd(): enforce_eager=True, gpu_memory_utilization=0.8, disable_custom_all_reduce=True, - disable_mm_preprocessor_cache=True, skip_tokenizer_init=False, enable_prefix_caching=True, trust_remote_code=True, @@ -156,7 +161,9 @@ def test_vllm_spmd(): world_size = torch.distributed.get_world_size() model = llm.llm_engine.model_executor.driver_worker.worker.model_runner.model - model.load_weights(((name, param.full_tensor() if world_size != 1 else param) for name, param in state_dict.items())) + model.load_weights( + ((name, param.full_tensor() if world_size != 1 else param) for name, param in state_dict.items()) + ) outputs = llm.generate(preencode_prompts, sampling_params=sampling_params, use_tqdm=False) verl_vllm_response_tokens = [] diff --git a/tests/workers/rollout/test_async_sglang_server.py b/tests/workers/rollout/test_async_sglang_server.py index 914f527c9..0b4e914f1 100644 --- a/tests/workers/rollout/test_async_sglang_server.py +++ b/tests/workers/rollout/test_async_sglang_server.py @@ -25,14 +25,6 @@ }, ) class TestAsyncSglangServer: - @pytest.fixture - def mock_ray_actor(self): - mock_actor = MagicMock() - mock_actor.execute_method.remote = AsyncMock(return_value={"content": "mocked response"}) - mock_actor.resume.remote = AsyncMock() - mock_actor.offload.remote = AsyncMock() - return mock_actor - @pytest.fixture def server_config(self): return DictConfig({"rollout": {"tensor_model_parallel_size": 2}}) @@ -41,22 +33,72 @@ def server_config(self): @patch("verl.workers.rollout.sglang_rollout.async_sglang_server.ray.util.list_named_actors") @patch("verl.workers.rollout.async_server.AsyncServerBase._start_fastapi_server", new_callable=AsyncMock) @pytest.mark.filterwarnings("ignore:Ray state API is no longer experimental:DeprecationWarning") - async def test_init_engine(self, mock_start_fastapi_server, mock_list_actors, server_config, mock_ray_actor): + async def test_init_engine(self, mock_start_fastapi_server, mock_list_actors, server_config): mock_list_actors.return_value = [ + {"name": "test_prefixWorkerDict_1:0", "namespace": "test"}, + {"name": "test_prefixWorkerDict_1:1", "namespace": "test"}, {"name": "test_prefixWorkerDict_0:0", "namespace": "test"}, {"name": "test_prefixWorkerDict_0:1", "namespace": "test"}, {"name": "test_prefixWorkerDict_1:2", "namespace": "test"}, + {"name": "test_prefixWorkerDict_1:3", "namespace": "test"}, + {"name": "test_prefixWorkerDict_0:2", "namespace": "test"}, + {"name": "test_prefixWorkerDict_0:3", "namespace": "test"}, ] from verl.workers.rollout.sglang_rollout.async_sglang_server import AsyncSglangServer ActualClassToInstantiate = AsyncSglangServer - if hasattr(AsyncSglangServer, "__ray_metadata__") and hasattr(AsyncSglangServer.__ray_metadata__, "modified_class"): + if hasattr(AsyncSglangServer, "__ray_metadata__") and hasattr( + AsyncSglangServer.__ray_metadata__, "modified_class" + ): ActualClassToInstantiate = AsyncSglangServer.__ray_metadata__.modified_class - with patch("verl.workers.rollout.sglang_rollout.async_sglang_server.ray.get_actor", return_value=mock_ray_actor): - instance = ActualClassToInstantiate(server_config, 2, 0, "test_prefix") + def mock_get_actor_side_effect(name, namespace=None): + # Create a new mock actor for each call + actor_mock = MagicMock() + + # Support .name attribute access + actor_mock.name = name # Use 'name' here + + # Support ['name'] item access by mocking __getitem__ + def getitem_mock(key): + if key == "name": + return name # Use 'name' here + # For other keys, return a new MagicMock to mimic default behavior or raise KeyError + # Returning a MagicMock is consistent with the original error's cause for unmocked keys + return MagicMock(name=f"mock.__getitem__('{key}')") + + actor_mock.__getitem__.side_effect = getitem_mock + + return actor_mock + + # Verify instance.workers is correctly populated + with patch( + "verl.workers.rollout.sglang_rollout.async_sglang_server.ray.get_actor", + side_effect=mock_get_actor_side_effect, + ): + # Instance 1 + instance = ActualClassToInstantiate(server_config, 4, 0, "test_prefix") + await instance.init_engine() + + assert len(instance.workers) == 2 + assert instance.master_worker["name"] == "test_prefixWorkerDict_0:0" + assert instance.workers[0].name == "test_prefixWorkerDict_0:0" + assert instance.workers[1].name == "test_prefixWorkerDict_0:1" + + # Instance 2 + instance = ActualClassToInstantiate(server_config, 4, 1, "test_prefix") + await instance.init_engine() + + assert len(instance.workers) == 2 + assert instance.master_worker["name"] == "test_prefixWorkerDict_0:2" + assert instance.workers[0].name == "test_prefixWorkerDict_0:2" + assert instance.workers[1].name == "test_prefixWorkerDict_0:3" + # Instance 3 + instance = ActualClassToInstantiate(server_config, 4, 3, "test_prefix") await instance.init_engine() - # Verify instance.workers is correctly populated assert len(instance.workers) == 2 + assert instance.master_worker["name"] == "test_prefixWorkerDict_1:2" + assert instance.workers[0].name == "test_prefixWorkerDict_1:2" + assert instance.workers[1].name == "test_prefixWorkerDict_1:3" diff --git a/tests/workers/rollout/test_custom_completion_callback.py b/tests/workers/rollout/test_custom_completion_callback.py new file mode 100644 index 000000000..c17d5272c --- /dev/null +++ b/tests/workers/rollout/test_custom_completion_callback.py @@ -0,0 +1,305 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import concurrent.futures +import os +import re +import socket +import sys +import tempfile +from contextlib import asynccontextmanager +from typing import Any + +import fastapi +import numpy as np +import ray +import uvicorn +from datasets import load_dataset +from omegaconf import DictConfig +from openai.types.chat.chat_completion import ChatCompletion +from starlette.requests import Request +from starlette.responses import JSONResponse + +from tests.workers.rollout.async_rollout_utils import init_async_rollout_manager +from verl.protocol import DataProto +from verl.utils import hf_tokenizer +from verl.utils.reward_score.sandbox_fusion.utils import _process_single_case +from verl.workers.rollout.chat_scheduler import ChatCompletionScheduler, ToolCompletionCallback + + +def _get_free_port(): + with socket.socket() as sock: + sock.bind(("", 0)) + return sock.getsockname()[1] + + +@ray.remote(num_cpus=1) +class Sandbox: + """Sandbox to execute python code. + + WARNING: This class is for testing purpose only, do not use it in production. + Please use a sandbox with strong isolation and security restrictions instead. + """ + + def __init__(self): + self.address = ray.util.get_node_ip_address() + self.port = None + self.server_ready = asyncio.Event() + asyncio.create_task(self._start_fastapi_server()) + + async def code_execution(self, request: Request): + request_json = await request.json() + code = request_json["code"] + print(f"execute code:\n{code}") + + _, temp_file = tempfile.mkstemp(suffix=".py", prefix="temp_code", dir=None, text=True) + with open(temp_file, "w") as f: + f.write(code) + + try: + process = await asyncio.create_subprocess_exec( + sys.executable, temp_file, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE + ) + + stdout, stderr = await process.communicate() + + response = { + "status": "Success" if process.returncode == 0 else "Failed", + "run_result": { + "status": "Finished", + "stdout": stdout.decode(), + "stderr": stderr.decode(), + "return_code": process.returncode, + }, + } + return JSONResponse(content=response) + finally: + try: + os.unlink(temp_file) + except: # noqa: E722 + pass + + async def _start_fastapi_server(self): + @asynccontextmanager + async def lifespan(app: fastapi.FastAPI): + print("FastAPI startup") + self.server_ready.set() + yield + + print("FastAPI shutdown, maybe address already in use, exit process immediately.") + os._exit(-1) + + app = fastapi.FastAPI(lifespan=lifespan) + app.router.add_api_route("/run_code", self.code_execution, methods=["POST"]) + + self.port = _get_free_port() + config = uvicorn.Config(app, host=["::", "0.0.0.0"], port=self.port, log_level="warning") + server = uvicorn.Server(config) + await server.serve() + + async def get_server_address(self) -> str: + """Get FastAPI server address.""" + await self.server_ready.wait() + return f"{self.address}:{self.port}" + + +class CustomCompletionCallback(ToolCompletionCallback): + def __init__(self, config: DictConfig, scheduler: ChatCompletionScheduler): + super().__init__(config, scheduler) + + self.max_assistant_turns = 16 + self.answer_pattern = re.compile(r"(.*?)", re.DOTALL) + self.code_pattern = re.compile(r"\s*```python(.*?)```\s*", re.DOTALL) + + self.sandbox_fusion_url = config.reward_model.sandbox_fusion.url + self.default_timeout = 10 + self.memory_limit_mb = config.reward_model.sandbox_fusion.memory_limit_mb + # TODO: support asyncio executor + self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=max(32, os.cpu_count() * 5)) + + async def sandbox_code_execution(self, code: str) -> dict[str, Any]: + loop = asyncio.get_running_loop() + result_status, metadata = await loop.run_in_executor( + self.executor, + _process_single_case, + 0, # case_index, + None, # stdin_data, + None, # expected_output, + self.sandbox_fusion_url, # sandbox_fusion_url + code, # generation + self.default_timeout, # timeout + self.memory_limit_mb, # memory limit + "python", # language + ) + + return metadata + + @property + def extra_body(self): + extra = { + "include_stop_str_in_output": True, + "stop": ["
", ""], + } + return extra + + async def __call__(self, messages: list[dict[str, str]], completions: ChatCompletion, info: dict[str, Any]): + role, content, finish_reason = ( + completions.choices[0].message.role, + completions.choices[0].message.content, + completions.choices[0].finish_reason, + ) + messages.append({"role": role, "content": content}) + turn = len(messages) + + # STEP 0: check if we reach max turns + if len(messages) >= self.max_assistant_turns: + print(f"[id={completions.id},turn={turn},finish_reason={finish_reason}] Reach max turns, done!") + return + + # STEP 1: check if we reach max tokens + if finish_reason == "length": + print(f"[id={completions.id},turn={turn},finish_reason={finish_reason}] Reach max tokens, done!") + return + + # STEP 2: check if we got answer + matches = self.answer_pattern.findall(content) + if matches: + print(f"[id={completions.id},turn={turn},finish_reason={finish_reason}] Got answer: {matches[0]}, done!") + return + + # STEP 3: check if we got code block + matches = self.code_pattern.findall(content) + if not matches: + print(f"[id={completions.id},turn={turn},finish_reason={finish_reason}] No code block found, done!") + return + + # STEP 4: execute code block in sandbox + code = matches[0].strip() + metadata = await self.sandbox_code_execution(code) + if metadata["run_status"] != "Finished": + print( + f"[id={completions.id},turn={turn},finish_reason={finish_reason}] Code block execution failed: " + f"{metadata}, done!" + ) + return + + stdout, stderr = metadata["stdout"], metadata["stderr"] + messages.append({"role": "tool", "content": f"{stdout}{stderr}"}) + print(f"[id={completions.id},turn={turn},finish_reason={finish_reason}] Code block executed, continue...") + + # STEP 5: resubmit chat completions with code block output + self.scheduler.submit_chat_completions( + messages=messages, + request_id=completions.id, + info=info, + ) + + +user_prompt_template = """ +You are a helpful assistant. Let's solve math problem in following steps: +1. Write a python code first and return the code to user, the code must be in following format: + + +```python +import os + +print(...) +``` + + +The code must explictly print necessary output to stdout. Remember stop generation at immediately and +return the code. +2. User will send the python code to a external sandbox to execute and get output from stdout. +3. User will send the output in format output to you, and you should use the +output to answer the question. +The answer format must be: \\boxed{'The final answer goes here.'} + +*user question:* +{question} +""" + + +if __name__ == "__main__": + ray.init( + runtime_env={ + "env_vars": { + "TOKENIZERS_PARALLELISM": "true", + "NCCL_DEBUG": "WARN", + "VLLM_LOGGING_LEVEL": "INFO", + "VLLM_USE_V1": "1", + } + } + ) + + # Load config + import os + + from hydra import compose, initialize_config_dir + + with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")): + config = compose(config_name="ppo_trainer") + model_path = "Qwen/Qwen2.5-1.5B-Instruct" + config.actor_rollout_ref.model.path = model_path + config.actor_rollout_ref.rollout.mode = "async" + config.actor_rollout_ref.rollout.multi_turn.format = "hermes" + config.actor_rollout_ref.rollout.multi_turn.completion_callback = ( + "tests.workers.rollout.test_custom_completion_callback.CustomCompletionCallback" + ) + config.actor_rollout_ref.rollout.prompt_length = 4096 + config.actor_rollout_ref.rollout.response_length = 4096 + config.actor_rollout_ref.rollout.n = 4 + + # Init sandbox and async rollout manager + sandbox = Sandbox.options(num_cpus=1).remote() + sandbox_address = ray.get(sandbox.get_server_address.remote()) + sandbox_fusion_url = f"http://{sandbox_address}/run_code" + config.reward_model.sandbox_fusion.url = sandbox_fusion_url + async_rollout_manager = init_async_rollout_manager(config) + + # Build dataset + dataset = load_dataset("Maxwell-Jia/AIME_2024", split="train") + prompts = DataProto( + non_tensor_batch={ + "raw_prompt": np.array( + [ + [{"role": "user", "content": user_prompt_template.replace("{question}", problem)}] + for problem in dataset["Problem"] + ] + ), + }, + ) + + result = async_rollout_manager.generate_sequences(prompts=prompts) + assert len(result) == len(dataset) * config.actor_rollout_ref.rollout.n + + # Check max turns that sandbox is called + num_turns = result.non_tensor_batch["__num_turns__"] + print(f"num_turns: {num_turns}") + assert np.max(num_turns) > 2, f"max turns: {np.max(num_turns)}" + + # Check response_mask + tokenizer = hf_tokenizer(config.actor_rollout_ref.model.path) + responses = result.batch["responses"] + response_mask = result.batch["response_mask"] + assert responses.size() == response_mask.size(), f"{responses.size()} != {response_mask.size()}" + + # Decode responses with response_mask + for i in range(len(responses)): + valid_tokens = responses[i][response_mask[i].bool()] + response_str = tokenizer.decode(valid_tokens) + assert "" not in response_str, f"found in response: {response_str}" + assert "" not in response_str, f"found in response: {response_str}" + print(f"response: {response_str}") + + print("Test passed!") diff --git a/tests/workers/rollout/test_hf_rollout.py b/tests/workers/rollout/test_hf_rollout.py index bde643abf..3eb6f4bb2 100644 --- a/tests/workers/rollout/test_hf_rollout.py +++ b/tests/workers/rollout/test_hf_rollout.py @@ -51,7 +51,10 @@ def prepare_input_dataproto(tokenizer, config, validate): [{"role": "user", "content": "The founder of Apple is"}], [{"role": "user", "content": "What's your name"}], ] - formatted_prompts = [tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True) for conversation in preencode_prompts] + formatted_prompts = [ + tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True) + for conversation in preencode_prompts + ] prompts = tokenizer(formatted_prompts, return_tensors="pt", padding="max_length", max_length=config.prompt_length) input_dataproto = DataProto.from_dict( { @@ -88,7 +91,9 @@ def prepare_fsdp_model(model, world_size): device_mesh=device_mesh, ) - FSDP.set_state_dict_type(fsdp_model, state_dict_type=StateDictType.SHARDED_STATE_DICT, state_dict_config=ShardedStateDictConfig()) + FSDP.set_state_dict_type( + fsdp_model, state_dict_type=StateDictType.SHARDED_STATE_DICT, state_dict_config=ShardedStateDictConfig() + ) return fsdp_model @@ -147,7 +152,9 @@ def test_hf_rollout(n: int = 1, do_sample: bool = True, validate: bool = False): first_eos_pos = eos_positions[0].item() assert response_attention[: first_eos_pos + 1].all(), "Response attention mask should be 1 until EOS" if first_eos_pos + 1 < response_length: - assert not response_attention[first_eos_pos + 1 :].any(), "Response attention mask should be 0 after EOS" + assert not response_attention[first_eos_pos + 1 :].any(), ( + "Response attention mask should be 0 after EOS" + ) else: assert response_attention.all(), "Response attention mask should be all 1 if no EOS token" diff --git a/tests/workers/rollout/test_sglang_async_rollout_mcp_tools.py b/tests/workers/rollout/test_sglang_async_rollout_mcp_tools.py new file mode 100644 index 000000000..387de1618 --- /dev/null +++ b/tests/workers/rollout/test_sglang_async_rollout_mcp_tools.py @@ -0,0 +1,461 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from tests/workers/rollout/test_sglang_async_rollout_sf_tools.py + + +import asyncio +from copy import deepcopy +from unittest.mock import AsyncMock, MagicMock, patch + +import numpy as np +import pytest +from tensordict import TensorDict +from transformers import AutoConfig, AutoTokenizer +from utils_sglang import get_rollout_config, prepare_inputs + +from verl.protocol import DataProto +from verl.tools.mcp_search_tool import MCPSearchTool +from verl.tools.utils.mcp_clients.McpClientManager import MCPClientManager +from verl.workers.rollout.schemas import AsyncRolloutRequest, AsyncRolloutRequestStateEnum, Message +from verl.workers.rollout.sglang_rollout.sglang_rollout import SGLangRollout + +DEFAULT_USER_CONTENT_PREFIX = ( + "Answer the given question. You must conduct reasoning inside and " + "first every time you get new information. After reasoning, if you find you lack " + "some knowledge, you can call a search engine by query " + "and it will return the top searched results between and " + ". You can search as many times as your want. If you find no " + "further external knowledge needed, you can directly provide the answer inside " + " and , without detailed illustrations. For example, " + " Beijing . Question: " +) +user_content = DEFAULT_USER_CONTENT_PREFIX.rstrip("\n") + "How's the weather lately?" + + +def get_search_messages(): + user_prompt = { + "role": "user", + "content": user_content, + } + + expect_turn_0_msg = { + "role": "assistant", + "content": "Let me search the web.", + "tool_calls": [ + { + "id": "10", + "type": "function", + "function": { + "name": "tavily_search_tool", + "arguments": { + "what_is_your_intent": "Search for the weather lately", + "query": "the weather in Beijing today", + "search_depth": "basic", + "time_range": "day", + "include_domains": ["google.com", "baidu.com"], + "max_results": 2, + }, + }, + } + ], + } + + expect_turn_1_msg = { + "role": "assistant", + "content": "Let me search again.", + "tool_calls": [ + { + "type": "function", + "function": { + "name": "tavily_search_tool", + "arguments": { + "what_is_your_intent": "Search for the weather lately", + "query": "the weather in Beijing tomorrow", + "search_depth": "basic", + "time_range": "day", + "include_domains": ["google.com", "baidu.com"], + "max_results": 2, + }, + }, + } + ], + } + + expect_turn_2_msg = { + "role": "assistant", + "content": "Today is sunny and tomorrow will be cloudy in Beijing.", + } + + # Mock search tool responses + tool_return_0_msg = {"role": "tool", "content": [{"type": "text", "text": "Today's weather in Beijing is sunny."}]} + tool_return_1_msg = { + "role": "tool", + "content": [{"type": "text", "text": "Tomorrow's weather in Beijing is cloudy."}], + } + + user_prompts = [user_prompt] + expect_turn_array = [expect_turn_0_msg, expect_turn_1_msg, expect_turn_2_msg] + tool_return_array = [tool_return_0_msg, tool_return_1_msg] + + return user_prompts, expect_turn_array, tool_return_array + + +class TestRolloutWithMCPSearchTools: + @pytest.fixture + def qwen_tokenizer(self): + local_model_path = "Qwen/Qwen2.5-0.5B" + tokenizer = AutoTokenizer.from_pretrained(local_model_path, padding_side="left") + tokenizer.pad_token = tokenizer.eos_token + return tokenizer + + # we only need this for tokenizer + @pytest.fixture + def qwen_model_config(self): + local_model_path = "Qwen/Qwen2.5-0.5B" + config = AutoConfig.from_pretrained(local_model_path) + return config + + @pytest.fixture + def search_data(self, qwen_tokenizer): + user_prompt, expect_turn_array, tool_return_array = get_search_messages() + prompts = [[message] for message in user_prompt] + preencode_turn_array = [ + qwen_tokenizer.apply_chat_template([turn], tokenize=False, add_generation_prompt=False) + for turn in expect_turn_array + ] + preencode_tool_return_array = [ + qwen_tokenizer.apply_chat_template([turn], tokenize=False, add_generation_prompt=True) + for turn in tool_return_array + ] + return prompts, preencode_turn_array, preencode_tool_return_array + + @pytest.fixture + def search_rollout_config(self): + max_prompt_length = 4096 + max_response_length = 3000 + dtype = "bfloat16" + tensor_parallel_size = 1 + tool_path = "./resource/tool_configs/mcp_tool_config" + rollout_config = get_rollout_config( + max_response_length, max_prompt_length, dtype, tensor_parallel_size, tool_path + ) + return rollout_config + + @pytest.fixture + def search_data_proto(self, search_data, qwen_tokenizer): + preencode_prompts, _, _ = search_data + prompts = [ + qwen_tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True) + for message in preencode_prompts + ] + input_ids, attention_mask, position_ids = prepare_inputs(qwen_tokenizer, prompts, 1000) + prompt_dict = TensorDict( + { + "input_ids": input_ids, + "attention_mask": attention_mask, + "position_ids": position_ids, + }, + batch_size=input_ids.shape[0], + ) + messages = np.asarray(preencode_prompts) + + tools_kwargs = np.array( + [ + { + "tavily_search_tool": { + "create_kwargs": {"ground_truth": "Today is sunny and tomorrow will be cloudy in Beijing."}, + }, + } + ], + dtype=object, + ) + index = np.array([0], dtype=object) + prompts = DataProto( + batch=prompt_dict, non_tensor_batch={"raw_prompt": messages, "tools_kwargs": tools_kwargs, "index": index} + ) + return prompts + + @pytest.fixture + def mock_rollout(self, search_rollout_config, qwen_tokenizer, qwen_model_config): + """Mock the rollout instance with sampling_params initialized.""" + tool_schema = [ + { + "type": "function", + "function": { + "name": "tavily_search_tool", + "description": "A powerful web search tool...", + "parameters": { + "type": "object", + "properties": { + "what_is_your_intent": { + "type": "string", + "description": "Describe your intent for using Tavily", + }, + "query": {"type": "string", "description": "Search query"}, + "search_depth": { + "type": "string", + "description": "The depth of the search ('basic' or 'advanced')", + }, + "topic": { + "type": "string", + "description": "The category of the search ('general' or 'news')", + }, + "days": { + "type": "integer", + "description": "Number of days back to include in search results (only for " + "'news' topic)", + }, + "time_range": { + "type": "string", + "description": "Time range for results ('day', 'week', 'month', 'year', 'd', " + "'w', 'm', 'y')", + }, + "include_domains": { + "type": "array", + "description": "List of domains to specifically include in search results", + }, + "exclude_domains": { + "type": "array", + "description": "List of domains to specifically exclude from search results", + }, + "include_answer": { + "type": "boolean", + "description": "Whether to include an answer summary generated by an LLM", + }, + "include_raw_content": { + "type": "boolean", + "description": "Whether to include the cleaned and parsed HTML content of each result", + }, + "include_images": { + "type": "boolean", + "description": "Whether to include images from search results", + }, + "include_image_descriptions": { + "type": "boolean", + "description": "Whether to include descriptions with images", + }, + "max_results": { + "type": "integer", + "description": "Maximum number of results to return (5-20)", + }, + "async_search": { + "type": "boolean", + "description": "Whether to perform the search asynchronously", + }, + }, + "required": ["what_is_your_intent", "query"], + }, + "strict": False, + }, + } + ] + with ( + patch.object(MCPClientManager, "fetch_tool_schemas", return_value=tool_schema), + patch.object(SGLangRollout, "_init_distributed_env", return_value=None), + patch.object(SGLangRollout, "_init_inference_engine", return_value=None), + patch.object(SGLangRollout, "_init_sampling_params", return_value=None), + ): + rollout = SGLangRollout( + actor_module="", + config=search_rollout_config, + processing_class=qwen_tokenizer, + model_hf_config=qwen_model_config, + ) + rollout.sampling_params = { + "n": 1, + "max_new_tokens": search_rollout_config.response_length, + "presence_penalty": 0.0, + "frequency_penalty": 0.0, + "repetition_penalty": 1.0, + } + return rollout + + def test_tools_registration(self, mock_rollout): + assert len(mock_rollout._tool_schemas) != 0 + assert "tavily_search_tool" in mock_rollout._tool_map.keys() + from verl.tools.mcp_search_tool import MCPSearchTool + + assert isinstance(mock_rollout._tool_map["tavily_search_tool"], MCPSearchTool) + # depend on the tokenizer + assert mock_rollout._tool_call_parser_type == "qwen25" + + def test_rollout_req_creation(self, mock_rollout, search_data_proto): + req_list = mock_rollout._preprocess_prompt_to_async_rollout_requests(search_data_proto, n=1) + assert len(req_list) == 1 + assert req_list[0].state == AsyncRolloutRequestStateEnum.PENDING + assert len(req_list[0].tool_schemas) == 1 + + def test_over_size_case(self, mock_rollout, search_data_proto, search_data): + mock_rollout.config.multi_turn.max_assistant_turns = 1 + req = mock_rollout._preprocess_prompt_to_async_rollout_requests(search_data_proto, n=1)[0] + req = MagicMock(wraps=req, spec=AsyncRolloutRequest) + req.finalize = MagicMock() + req_list = [req] + + _, expect_turn_array, _ = search_data + # here we mock a meta info with 'length'. indicate the response is truncate + mock_rollout._handle_engine_call = MagicMock() + future = asyncio.Future() + future.set_result( + { + "text": expect_turn_array[0], + "meta_info": { + "id": "d1188d81cba840359df5b352b344bc8e", + "finish_reason": {"type": "length", "length": 3000}, + "prompt_tokens": 132, + "completion_tokens": 100, + "cached_tokens": 0, + "e2e_latency": 2.23543, + }, + } + ) + mock_rollout._handle_engine_call.return_value = future + mock_rollout._tp_rank = 0 + loop = asyncio.get_event_loop() + output_req_list = loop.run_until_complete( + asyncio.gather( + *[mock_rollout._async_rollout_a_request(req, True, False) for req in req_list], + ) + ) + assert len(output_req_list) == 1 + output_req = output_req_list[0] + assert output_req.state == AsyncRolloutRequestStateEnum.COMPLETED + assert output_req.reward_scores.get("tavily_search_tool") == [] + # we should only have two message, one for prompt, second for response. + assert len(output_req.messages) == 2 + assert output_req.messages[1] == Message( + role="assistant", + content=expect_turn_array[0], + tool_calls=None, + ) + + @patch.object(MCPSearchTool, "execute", new_callable=AsyncMock) + def test_tool_call_basic_case(self, mock_execute, mock_rollout, search_data_proto, search_data): + _, expect_turn_array, tool_return_array = search_data + # Mock search tool execution to return predefined responses + mock_execute.side_effect = [(msg, 0.0, {"status": "success"}) for msg in tool_return_array] + + mock_rollout.config.multi_turn.max_assistant_turns = 10 + req = mock_rollout._preprocess_prompt_to_async_rollout_requests(search_data_proto, n=1)[0] + req = MagicMock(wraps=req, spec=AsyncRolloutRequest) + req.finalize = MagicMock() + req_list = [req] + + mock_rollout._handle_engine_call = MagicMock() + futures = [asyncio.Future() for i in expect_turn_array] + for idx, (i, turn) in enumerate(zip(futures, expect_turn_array, strict=True)): + i.set_result( + { + "text": turn, + "meta_info": { + "id": "d1188d81cba840359df5b352b344bc8e", + "finish_reason": {"type": "tool_calls" if idx < len(expect_turn_array) - 1 else "stop"}, + "prompt_tokens": len(turn), + "completion_tokens": 100, + "cached_tokens": 0, + "e2e_latency": 2.23543, + }, + } + ) + if idx < len(expect_turn_array) - 1: + assert mock_rollout._function_call_parser.has_tool_call(turn) + assert mock_rollout._function_call_parser.parse_non_stream(turn) + + mock_rollout._handle_engine_call.side_effect = futures + mock_rollout._tp_rank = 0 + + loop = asyncio.get_event_loop() + output_req_list = loop.run_until_complete( + asyncio.gather(*[mock_rollout._async_rollout_a_request(req, True, False) for req in req_list]) + ) + + # Verify conversation completed successfully with proper tool usage + output_req = output_req_list[0] + assert output_req.state == AsyncRolloutRequestStateEnum.COMPLETED + assert "tavily_search_tool" in output_req.metrics + assert output_req.metrics["tavily_search_tool"][0]["status"] == "success" + assert mock_execute.await_count == 2 + assert len(output_req.messages) == 6 + # Verify tool response messages contain expected content + search_counter = 0 + for msg in output_req.messages: + if msg.role == "tool": + assert msg.content == tool_return_array[search_counter] + search_counter += 1 + assert search_counter == 2 + + @patch.object(MCPSearchTool, "execute", new_callable=AsyncMock) + def test_tool_call_batch_case(self, mock_execute, mock_rollout, search_data_proto, search_data): + _, expect_turn_array, tool_return_array = search_data + # Mock tool execution for large batch (100 requests * 2 calls each) + mock_execute.side_effect = [ + (tool_return_array[0], 0.0, {"status": "success"}), + (tool_return_array[1], 0.0, {"status": "success"}), + ] * 100 + + mock_rollout.config.multi_turn.max_assistant_turns = 10 + base_req = mock_rollout._preprocess_prompt_to_async_rollout_requests(search_data_proto, n=1)[0] + + req_nums = 100 + req_list = [] + req_turns_map = {} + req_turns_counter = {} + + for i in range(req_nums): + tmp_req = deepcopy(base_req) + tmp_req.batch_data_id = i + tmp_req.request_id = i + req_list.append(MagicMock(wraps=tmp_req, spec=AsyncRolloutRequest)) + + futures = [asyncio.Future() for _ in expect_turn_array] + for idx, (fut, turn) in enumerate(zip(futures, expect_turn_array, strict=True)): + fut.set_result( + { + "text": turn, + "meta_info": { + "id": "dummy", + "finish_reason": {"type": "tool_calls" if idx < len(expect_turn_array) - 1 else "stop"}, + "prompt_tokens": len(turn), + "completion_tokens": 100, + }, + } + ) + req_turns_map[i] = futures + req_turns_counter[i] = 0 + + async def hacked_handle_engine_call(self, _req: AsyncRolloutRequest, *_args, **_kwargs): + fut = req_turns_map[_req.batch_data_id][req_turns_counter[_req.batch_data_id]] + req_turns_counter[_req.batch_data_id] += 1 + return await fut + + with patch.object(SGLangRollout, "_handle_engine_call", new=hacked_handle_engine_call): + mock_rollout._tp_rank = 0 + loop = asyncio.get_event_loop() + output_req_list = loop.run_until_complete( + asyncio.gather(*[mock_rollout._async_rollout_a_request(r, True, False) for r in req_list]) + ) + + # Verify all requests completed successfully + assert len(output_req_list) == req_nums + for out_req in output_req_list: + assert out_req.state == AsyncRolloutRequestStateEnum.COMPLETED + assert "tavily_search_tool" in out_req.metrics + for metric in out_req.metrics["tavily_search_tool"]: + assert metric["status"] == "success" + assert len(out_req.messages) == 6 + assert sum(1 for m in out_req.messages if m.role == "tool") == 2 + + assert mock_execute.await_count == 2 * req_nums diff --git a/tests/workers/rollout/test_sglang_async_rollout_multimodal_delta.py b/tests/workers/rollout/test_sglang_async_rollout_multimodal_delta.py new file mode 100644 index 000000000..47fefca8a --- /dev/null +++ b/tests/workers/rollout/test_sglang_async_rollout_multimodal_delta.py @@ -0,0 +1,187 @@ +# Copyright 2025 Amazon.com, Inc. or its affiliates +# Copyright 2023-2024 SGLang Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from verl.utils.dataset.vision_utils import process_image +from verl.utils.tokenizer import hf_processor +from verl.workers.rollout.schemas import ( + AsyncRolloutRequest, + AsyncRolloutRequestStateEnum, + TokenizationSanityCheckModeEnum, +) + + +def _test_add_tool_response_messages_image_delta(processor, image_list, description_list, resize_image=False): + assert len(image_list) == len(description_list) + # Get the smallest dimensions across all images + processed_images = [] + for img_url in image_list: + img = process_image(img_url) + processed_images.append(img) + + min_width = min(img.size[0] for img in processed_images) + min_height = min(img.size[1] for img in processed_images) + min_size = (min_width, min_height) + + if resize_image: + processed_images_resized = [] + for img in processed_images: + img = img.resize(min_size) + processed_images_resized.append(img) + processed_images = processed_images_resized + + # Initial message history + system_prompt = ( + "You will be provided with an image. Describe this image and then generate a new image for the next round" + ) + messages = [ + { + "role": "system", + "content": system_prompt, + }, + { + "role": "user", + "content": [ + {"type": "text", "text": "Here is the first image provided: "}, + {"type": "image", "image": [processed_images[0]]}, + ], + }, + ] + + # Initial multi_modal_data with one image + multi_modal_data = {"image": [processed_images[0]], "video": []} + # Minimal required fields for AsyncRolloutRequest + + req = AsyncRolloutRequest( + batch_data_id=0, + request_id="test-req-1", + state=AsyncRolloutRequestStateEnum.PENDING, + messages=messages, + multi_modal_keys=["image", "video"], + multi_modal_data=multi_modal_data.copy(), + tool_schemas=[], + tools_kwargs={}, + interaction_kwargs={}, + input_ids=None, + prompt_ids=None, + response_ids=None, + attention_mask=None, + prompt_attention_mask=None, + response_attention_mask=None, + position_ids=None, + prompt_position_ids=None, + response_position_ids=None, + loss_mask=None, + prompt_loss_mask=None, + response_loss_mask=None, + reward_scores={}, + max_prompt_len=8192, + max_response_len=8192, + max_model_len=16384, + metrics={}, + use_inference_chat_template=True, + tokenization_sanity_check_mode=TokenizationSanityCheckModeEnum.STRICT, + generation_prompt_ids=None, + base_conv_wo_gen_prompt_end_pos=0, + base_conv_with_gen_prompt_end_pos=0, + processing_class=processor, + ) + + prev_generated_len = 0 + # Add First Assistant Message and first tool response message(image) + for idx, img in enumerate(processed_images): + if idx == 0: + continue + _ = req.get_generation_prompt_ids(processor) + req.add_assistant_message(processor, content=description_list[idx - 1]) + before_tool_call_len = req.input_ids.shape[-1] + req.add_tool_response_messages(processor, [{"image": [img], "text": "Here is the new image you requested: "}]) + after_tool_call_len = req.input_ids.shape[-1] + if prev_generated_len == 0: + prev_generated_len = after_tool_call_len - before_tool_call_len + else: + if resize_image: + assert after_tool_call_len - before_tool_call_len == prev_generated_len + assert req.multi_modal_data["image"] == processed_images[: idx + 1] + + _ = req.get_generation_prompt_ids(processor) + req.add_assistant_message(processor, content=description_list[-1]) + + messages = [msg.model_dump() for msg in req.messages] + tools = [tool.model_dump() for tool in req.tool_schemas] if req.tool_schemas else None + full_prompt_info = req._handle_apply_chat_template( + processor, + messages, + multi_modal_data=req.multi_modal_data, + tools=tools, + add_generation_prompt=False, + tokenize=True, + return_dict=True, + ) + full_prompt_ids = full_prompt_info["input_ids"] + assert full_prompt_ids.eq(req.input_ids).all() + + # We must use dict(full_prompt_info) to convert BatchFeature values to a new dict + # because np.array() only keeps the keys for BatchFeature. + full_prompt_multi_modal_inputs = full_prompt_info.copy() + full_prompt_multi_modal_inputs.pop("input_ids", None) + full_prompt_multi_modal_inputs.pop("attention_mask", None) + + for key in full_prompt_multi_modal_inputs: + assert full_prompt_multi_modal_inputs[key].eq(req.multi_modal_inputs[key]).all() + + +@pytest.mark.skipif( + hf_processor("Qwen/Qwen2.5-VL-3B-Instruct") is None, reason="Processor not available for Qwen/Qwen2.5-VL-B-Instruct" +) +def test_add_tool_response_messages_image_delta(): + processor = hf_processor("Qwen/Qwen2.5-VL-3B-Instruct") + + # From Qwen2.5-VL-3B-Instruct HF example + img_1_url = {"image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg"} + img_1_description = "A woman sits on the beach at sunset, smiling as she shares a high five with her large dog." + # GitHub Logo + img_2_url = {"image": "https://github.githubassets.com/assets/GitHub-Mark-ea2971cee799.png"} + img_2_description = "A GitHub Logo image" + # Octocat + img_3_url = {"image": "https://octodex.github.com/images/orderedlistocat.png"} + img_3_description = "An Octocat image" + + image_list = [img_1_url, img_2_url, img_3_url] + description_list = [img_1_description, img_2_description, img_3_description] + _test_add_tool_response_messages_image_delta(processor, image_list, description_list, resize_image=False) + + +@pytest.mark.skipif( + hf_processor("Qwen/Qwen2.5-VL-3B-Instruct") is None, reason="Processor not available for Qwen/Qwen2.5-VL-B-Instruct" +) +def test_add_tool_response_messages_image_delta_resize_image(): + processor = hf_processor("Qwen/Qwen2.5-VL-3B-Instruct") + + # From Qwen2.5-VL-3B-Instruct HF example + img_1_url = {"image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg"} + img_1_description = "A woman sits on the beach at sunset, smiling as she shares a high five with her large dog." + # GitHub Logo + img_2_url = {"image": "https://github.githubassets.com/assets/GitHub-Mark-ea2971cee799.png"} + img_2_description = "A GitHub Logo image" + # Octocat + img_3_url = {"image": "https://octodex.github.com/images/orderedlistocat.png"} + img_3_description = "An Octocat image" + + image_list = [img_1_url, img_2_url, img_3_url] + description_list = [img_1_description, img_2_description, img_3_description] + _test_add_tool_response_messages_image_delta(processor, image_list, description_list, resize_image=True) diff --git a/tests/workers/rollout/test_sglang_async_rollout_search_tools.py b/tests/workers/rollout/test_sglang_async_rollout_search_tools.py index 667d49271..2400d5c78 100644 --- a/tests/workers/rollout/test_sglang_async_rollout_search_tools.py +++ b/tests/workers/rollout/test_sglang_async_rollout_search_tools.py @@ -23,13 +23,15 @@ import pytest from tensordict import TensorDict from transformers import AutoConfig, AutoTokenizer -from utils_sglang import ( - get_rollout_config, - prepare_inputs, -) +from utils_sglang import get_rollout_config, prepare_inputs from verl.protocol import DataProto -from verl.tools.schemas import OpenAIFunctionParametersSchema, OpenAIFunctionPropertySchema, OpenAIFunctionSchema, OpenAIFunctionToolSchema +from verl.tools.schemas import ( + OpenAIFunctionParametersSchema, + OpenAIFunctionPropertySchema, + OpenAIFunctionSchema, + OpenAIFunctionToolSchema, +) from verl.tools.search_tool import SearchTool from verl.workers.rollout.schemas import AsyncRolloutRequest, AsyncRolloutRequestStateEnum, Message from verl.workers.rollout.sglang_rollout.sglang_rollout import SGLangRollout @@ -62,7 +64,9 @@ def get_search_messages(): expect_turn_1_msg = { "role": "assistant", "content": "Let me search again.", - "tool_calls": [{"type": "function", "function": {"name": "search", "arguments": {"query": "tomorrow's weather"}}}], + "tool_calls": [ + {"type": "function", "function": {"name": "search", "arguments": {"query": "tomorrow's weather"}}} + ], } expect_turn_2_msg = { @@ -100,8 +104,14 @@ def qwen_model_config(self): def search_data(self, qwen_tokenizer): user_prompt, expect_turn_array, tool_return_array = get_search_messages() prompts = [[message] for message in user_prompt] - preencode_turn_array = [qwen_tokenizer.apply_chat_template([turn], tokenize=False, add_generation_prompt=False) for turn in expect_turn_array] - preencode_tool_return_array = [qwen_tokenizer.apply_chat_template([turn], tokenize=False, add_generation_prompt=True) for turn in tool_return_array] + preencode_turn_array = [ + qwen_tokenizer.apply_chat_template([turn], tokenize=False, add_generation_prompt=False) + for turn in expect_turn_array + ] + preencode_tool_return_array = [ + qwen_tokenizer.apply_chat_template([turn], tokenize=False, add_generation_prompt=True) + for turn in tool_return_array + ] return prompts, preencode_turn_array, preencode_tool_return_array @pytest.fixture @@ -111,13 +121,18 @@ def search_rollout_config(self): dtype = "bfloat16" tensor_parallel_size = 1 tool_path = "./resource/tool_configs/search_tool_config" - rollout_config = get_rollout_config(max_response_length, max_prompt_length, dtype, tensor_parallel_size, tool_path) + rollout_config = get_rollout_config( + max_response_length, max_prompt_length, dtype, tensor_parallel_size, tool_path + ) return rollout_config @pytest.fixture def search_data_proto(self, search_data, qwen_tokenizer): preencode_prompts, _, _ = search_data - prompts = [qwen_tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True) for message in preencode_prompts] + prompts = [ + qwen_tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True) + for message in preencode_prompts + ] input_ids, attention_mask, position_ids = prepare_inputs(qwen_tokenizer, prompts, 1000) prompt_dict = TensorDict( { @@ -133,21 +148,56 @@ def search_data_proto(self, search_data, qwen_tokenizer): [ { "search": { - "create_kwargs": {"ground_truth": "Today is sunny and tomorrow will be cloudy in Beijing.", "data_source": "searchR1_nq"}, + "create_kwargs": { + "ground_truth": "Today is sunny and tomorrow will be cloudy in Beijing.", + "data_source": "searchR1_nq", + }, }, } ], dtype=object, ) index = np.array([0], dtype=object) - prompts = DataProto(batch=prompt_dict, non_tensor_batch={"raw_prompt": messages, "tools_kwargs": tools_kwargs, "index": index}) + prompts = DataProto( + batch=prompt_dict, non_tensor_batch={"raw_prompt": messages, "tools_kwargs": tools_kwargs, "index": index} + ) return prompts + @pytest.fixture + def mock_rollout(self, search_rollout_config, qwen_tokenizer, qwen_model_config): + """Mock the rollout instance with sampling_params initialized.""" + with ( + patch.object(SGLangRollout, "_init_distributed_env", return_value=None), + patch.object(SGLangRollout, "_init_inference_engine", return_value=None), + patch.object(SGLangRollout, "_init_sampling_params", return_value=None), + ): + rollout = SGLangRollout( + actor_module="", + config=search_rollout_config, + processing_class=qwen_tokenizer, + model_hf_config=qwen_model_config, + ) + rollout.sampling_params = { + "n": 1, + "max_new_tokens": search_rollout_config.response_length, + "presence_penalty": 0.0, + "frequency_penalty": 0.0, + "repetition_penalty": 1.0, + } + return rollout + @patch.object(SGLangRollout, "_init_distributed_env", return_value=None) @patch.object(SGLangRollout, "_init_inference_engine", return_value=None) @patch.object(SGLangRollout, "_init_sampling_params", return_value=None) - def test_tools_registration(self, mock_env, mock_engine, mock_sampling, search_rollout_config, qwen_tokenizer, qwen_model_config): - rollout = SGLangRollout(actor_module="", config=search_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config) + def test_tools_registration( + self, mock_env, mock_engine, mock_sampling, search_rollout_config, qwen_tokenizer, qwen_model_config + ): + rollout = SGLangRollout( + actor_module="", + config=search_rollout_config, + processing_class=qwen_tokenizer, + model_hf_config=qwen_model_config, + ) assert len(rollout._tool_schemas) == 1 assert "search" in rollout._tool_map.keys() from verl.tools.search_tool import SearchTool @@ -159,14 +209,28 @@ def test_tools_registration(self, mock_env, mock_engine, mock_sampling, search_r @patch.object(SGLangRollout, "_init_distributed_env", return_value=None) @patch.object(SGLangRollout, "_init_inference_engine", return_value=None) @patch.object(SGLangRollout, "_init_sampling_params", return_value=None) - def test_rollout_req_creation(self, mock_env, mock_engine, mock_sampling, search_rollout_config, qwen_tokenizer, qwen_model_config, search_data_proto): - rollout = SGLangRollout(actor_module="", config=search_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config) + def test_rollout_req_creation( + self, + mock_env, + mock_engine, + mock_sampling, + search_rollout_config, + qwen_tokenizer, + qwen_model_config, + search_data_proto, + ): + rollout = SGLangRollout( + actor_module="", + config=search_rollout_config, + processing_class=qwen_tokenizer, + model_hf_config=qwen_model_config, + ) req_list = rollout._preprocess_prompt_to_async_rollout_requests(search_data_proto, n=1) assert len(req_list) == 1 assert req_list[0].state == AsyncRolloutRequestStateEnum.PENDING - assert len(req_list[0].tools) == 1 - print(type(req_list[0].tools[0])) - assert req_list[0].tools[0] == OpenAIFunctionToolSchema( + assert len(req_list[0].tool_schemas) == 1 + print(type(req_list[0].tool_schemas[0])) + assert req_list[0].tool_schemas[0] == OpenAIFunctionToolSchema( type="function", function=OpenAIFunctionSchema( name="search", @@ -176,7 +240,8 @@ def test_rollout_req_creation(self, mock_env, mock_engine, mock_sampling, search properties={ "query_list": OpenAIFunctionPropertySchema( type="array", - description="A list of fully-formed semantic queries. The tool will return search results for each query.", + description="A list of fully-formed semantic queries. The tool will return search " + "results for each query.", items={"type": "string"}, ) }, @@ -186,35 +251,41 @@ def test_rollout_req_creation(self, mock_env, mock_engine, mock_sampling, search ), ) - @patch.object(SGLangRollout, "_init_distributed_env", return_value=None) - @patch.object(SGLangRollout, "_init_inference_engine", return_value=None) - @patch.object(SGLangRollout, "_init_sampling_params", return_value=None) - def test_over_size_case(self, mock_env, mock_engine, mock_sampling, search_rollout_config, qwen_tokenizer, qwen_model_config, search_data_proto, search_data): - search_rollout_config.multi_turn.max_turns = 1 - rollout = SGLangRollout(actor_module="", config=search_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config) - req = rollout._preprocess_prompt_to_async_rollout_requests(search_data_proto, n=1)[0] + def test_over_size_case(self, mock_rollout, search_data_proto, search_data): + mock_rollout.config.multi_turn.max_assistant_turns = 1 + req = mock_rollout._preprocess_prompt_to_async_rollout_requests(search_data_proto, n=1)[0] req = MagicMock(wraps=req, spec=AsyncRolloutRequest) req.finalize = MagicMock() req_list = [req] _, expect_turn_array, _ = search_data - # here we mock a meta info with 'length'. indicate the response is truncate - rollout._handle_engine_call = MagicMock() + mock_rollout._handle_engine_call = MagicMock() future = asyncio.Future() - future.set_result({"text": expect_turn_array[0], "meta_info": {"id": "d1188d81cba840359df5b352b344bc8e", "finish_reason": {"type": "length", "length": 3000}, "prompt_tokens": 132, "completion_tokens": 100, "cached_tokens": 0, "e2e_latency": 2.23543}}) - rollout._handle_engine_call.return_value = future - rollout._tp_rank = 0 + future.set_result( + { + "text": expect_turn_array[0], + "meta_info": { + "id": "d1188d81cba840359df5b352b344bc8e", + "finish_reason": {"type": "length", "length": 3000}, + "prompt_tokens": 132, + "completion_tokens": 100, + "cached_tokens": 0, + "e2e_latency": 2.23543, + }, + } + ) + mock_rollout._handle_engine_call.return_value = future + mock_rollout._tp_rank = 0 loop = asyncio.get_event_loop() output_req_list = loop.run_until_complete( asyncio.gather( - *[rollout._async_rollout_a_request(req, True, False) for req in req_list], + *[mock_rollout._async_rollout_a_request(req, True, False) for req in req_list], ) ) assert len(output_req_list) == 1 output_req = output_req_list[0] assert output_req.state == AsyncRolloutRequestStateEnum.COMPLETED - assert output_req.reward_scores == {"search": []}, f"output_req.reward_scores: {output_req.reward_scores}" - # we should only have two message, one for prompt, second for response. + assert output_req.reward_scores.get("search") == [] assert len(output_req.messages) == 2 assert output_req.messages[1] == Message( role="assistant", @@ -223,38 +294,47 @@ def test_over_size_case(self, mock_env, mock_engine, mock_sampling, search_rollo ) @patch.object(SearchTool, "execute", new_callable=AsyncMock) - @patch.object(SGLangRollout, "_init_distributed_env", return_value=None) - @patch.object(SGLangRollout, "_init_inference_engine", return_value=None) - @patch.object(SGLangRollout, "_init_sampling_params", return_value=None) - def test_tool_call_basic_case(self, mock_sampling, mock_engine, mock_env, mock_execute, search_rollout_config, qwen_tokenizer, qwen_model_config, search_data_proto, search_data): + def test_tool_call_basic_case(self, mock_execute, mock_rollout, search_data_proto, search_data): _, expect_turn_array, tool_return_array = search_data # Mock search tool execution to return predefined responses mock_execute.side_effect = [(msg, 0.0, {"status": "success"}) for msg in tool_return_array] - search_rollout_config.multi_turn.max_turns = 10 - rollout = SGLangRollout(actor_module="", config=search_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config) - - rollout._tool_map["search"].retrieval_service_url = "mock://dummy" + mock_rollout.config.multi_turn.max_assistant_turns = 10 + mock_rollout._tool_map["search"].retrieval_service_url = "mock://dummy" - req = rollout._preprocess_prompt_to_async_rollout_requests(search_data_proto, n=1)[0] + req = mock_rollout._preprocess_prompt_to_async_rollout_requests(search_data_proto, n=1)[0] req = MagicMock(wraps=req, spec=AsyncRolloutRequest) req.finalize = MagicMock() req_list = [req] - rollout._handle_engine_call = MagicMock() + mock_rollout._handle_engine_call = MagicMock() futures = [asyncio.Future() for i in expect_turn_array] - for idx, (i, turn) in enumerate(zip(futures, expect_turn_array)): - i.set_result({"text": turn, "meta_info": {"id": "d1188d81cba840359df5b352b344bc8e", "finish_reason": {"type": "tool_calls" if idx < len(expect_turn_array) - 1 else "stop"}, "prompt_tokens": len(turn), "completion_tokens": 100, "cached_tokens": 0, "e2e_latency": 2.23543}}) + for idx, (i, turn) in enumerate(zip(futures, expect_turn_array, strict=True)): + i.set_result( + { + "text": turn, + "meta_info": { + "id": "d1188d81cba840359df5b352b344bc8e", + "finish_reason": {"type": "tool_calls" if idx < len(expect_turn_array) - 1 else "stop"}, + "prompt_tokens": len(turn), + "completion_tokens": 100, + "cached_tokens": 0, + "e2e_latency": 2.23543, + }, + } + ) if idx < len(expect_turn_array) - 1: - assert rollout._function_call_parser.has_tool_call(turn) - assert rollout._function_call_parser.parse_non_stream(turn) + assert mock_rollout._function_call_parser.has_tool_call(turn) + assert mock_rollout._function_call_parser.parse_non_stream(turn) - rollout._handle_engine_call.side_effect = futures - rollout._tp_rank = 0 + mock_rollout._handle_engine_call.side_effect = futures + mock_rollout._tp_rank = 0 loop = asyncio.get_event_loop() - output_req_list = loop.run_until_complete(asyncio.gather(*[rollout._async_rollout_a_request(req, True, False) for req in req_list])) + output_req_list = loop.run_until_complete( + asyncio.gather(*[mock_rollout._async_rollout_a_request(req, True, False) for req in req_list]) + ) # Verify conversation completed successfully with proper tool usage output_req = output_req_list[0] @@ -272,10 +352,7 @@ def test_tool_call_basic_case(self, mock_sampling, mock_engine, mock_env, mock_e assert search_counter == 2 @patch.object(SearchTool, "execute", new_callable=AsyncMock) - @patch.object(SGLangRollout, "_init_distributed_env", return_value=None) - @patch.object(SGLangRollout, "_init_inference_engine", return_value=None) - @patch.object(SGLangRollout, "_init_sampling_params", return_value=None) - def test_tool_call_batch_case(self, mock_sampling, mock_engine, mock_env, mock_execute, search_rollout_config, qwen_tokenizer, qwen_model_config, search_data_proto, search_data): + def test_tool_call_batch_case(self, mock_execute, mock_rollout, search_data_proto, search_data): _, expect_turn_array, tool_return_array = search_data # Mock tool execution for large batch (100 requests * 2 calls each) @@ -284,16 +361,10 @@ def test_tool_call_batch_case(self, mock_sampling, mock_engine, mock_env, mock_e (tool_return_array[1], 0.0, {"status": "success"}), ] * 100 - search_rollout_config.multi_turn.max_turns = 10 - rollout = SGLangRollout( - actor_module="", - config=search_rollout_config, - tokenizer=qwen_tokenizer, - model_hf_config=qwen_model_config, - ) - rollout._tool_map["search"].retrieval_service_url = "mock://dummy" + mock_rollout.config.multi_turn.max_assistant_turns = 10 + mock_rollout._tool_map["search"].retrieval_service_url = "mock://dummy" - base_req = rollout._preprocess_prompt_to_async_rollout_requests(search_data_proto, n=1)[0] + base_req = mock_rollout._preprocess_prompt_to_async_rollout_requests(search_data_proto, n=1)[0] req_nums = 100 req_list = [] @@ -307,7 +378,7 @@ def test_tool_call_batch_case(self, mock_sampling, mock_engine, mock_env, mock_e req_list.append(MagicMock(wraps=tmp_req, spec=AsyncRolloutRequest)) futures = [asyncio.Future() for _ in expect_turn_array] - for idx, (fut, turn) in enumerate(zip(futures, expect_turn_array)): + for idx, (fut, turn) in enumerate(zip(futures, expect_turn_array, strict=True)): fut.set_result( { "text": turn, @@ -328,9 +399,11 @@ async def hacked_handle_engine_call(self, _req: AsyncRolloutRequest, *_args, **_ return await fut with patch.object(SGLangRollout, "_handle_engine_call", new=hacked_handle_engine_call): - rollout._tp_rank = 0 + mock_rollout._tp_rank = 0 loop = asyncio.get_event_loop() - output_req_list = loop.run_until_complete(asyncio.gather(*[rollout._async_rollout_a_request(r, True, False) for r in req_list])) + output_req_list = loop.run_until_complete( + asyncio.gather(*[mock_rollout._async_rollout_a_request(r, True, False) for r in req_list]) + ) # Verify all requests completed successfully assert len(output_req_list) == req_nums diff --git a/tests/workers/rollout/test_sglang_async_rollout_sf_tools.py b/tests/workers/rollout/test_sglang_async_rollout_sf_tools.py index fe027a60e..3f30929c2 100644 --- a/tests/workers/rollout/test_sglang_async_rollout_sf_tools.py +++ b/tests/workers/rollout/test_sglang_async_rollout_sf_tools.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +# noqa import asyncio import time from copy import deepcopy @@ -30,7 +32,12 @@ from verl.protocol import DataProto from verl.tools.sandbox_fusion_tools import TokenBucketWorker -from verl.tools.schemas import OpenAIFunctionParametersSchema, OpenAIFunctionPropertySchema, OpenAIFunctionSchema, OpenAIFunctionToolSchema +from verl.tools.schemas import ( + OpenAIFunctionParametersSchema, + OpenAIFunctionPropertySchema, + OpenAIFunctionSchema, + OpenAIFunctionToolSchema, +) from verl.workers.rollout.schemas import AsyncRolloutRequest, AsyncRolloutRequestStateEnum, Message from verl.workers.rollout.sglang_rollout.sglang_rollout import SGLangRollout @@ -158,8 +165,14 @@ def qwen_model_config(self): def sandbox_fusion_data(self, qwen_tokenizer): user_prompt, expect_turn_array, tool_return_array = get_sandbox_fusion_messages() prompts = [[message] for message in user_prompt] - preencode_turn_array = [qwen_tokenizer.apply_chat_template([turn], tokenize=False, add_generation_prompt=False) for turn in expect_turn_array] - preencode_tool_return_array = [qwen_tokenizer.apply_chat_template([turn], tokenize=False, add_generation_prompt=True) for turn in tool_return_array] + preencode_turn_array = [ + qwen_tokenizer.apply_chat_template([turn], tokenize=False, add_generation_prompt=False) + for turn in expect_turn_array + ] + preencode_tool_return_array = [ + qwen_tokenizer.apply_chat_template([turn], tokenize=False, add_generation_prompt=True) + for turn in tool_return_array + ] return prompts, preencode_turn_array, preencode_tool_return_array @pytest.fixture @@ -169,13 +182,18 @@ def sandbox_fusion_rollout_config(self): dtype = "bfloat16" tensor_parallel_size = 1 tool_path = "./resource/tool_configs/sandbox_fusion_tool_config" - rollout_config = get_rollout_config(max_response_length, max_prompt_length, dtype, tensor_parallel_size, tool_path) + rollout_config = get_rollout_config( + max_response_length, max_prompt_length, dtype, tensor_parallel_size, tool_path + ) return rollout_config @pytest.fixture def sandbox_data_proto(self, sandbox_fusion_data, qwen_tokenizer): preencode_prompts, _, _ = sandbox_fusion_data - prompts = [qwen_tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True) for message in preencode_prompts] + prompts = [ + qwen_tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True) + for message in preencode_prompts + ] input_ids, attention_mask, position_ids = prepare_inputs(qwen_tokenizer, prompts, 1000) prompt_dict = TensorDict( { @@ -197,32 +215,50 @@ def sandbox_data_proto(self, sandbox_fusion_data, qwen_tokenizer): dtype=object, ) index = np.array([0], dtype=object) - prompts = DataProto(batch=prompt_dict, non_tensor_batch={"raw_prompt": messages, "tools_kwargs": tools_kwargs, "index": index}) + prompts = DataProto( + batch=prompt_dict, non_tensor_batch={"raw_prompt": messages, "tools_kwargs": tools_kwargs, "index": index} + ) return prompts - @patch.object(SGLangRollout, "_init_distributed_env", return_value=None) - @patch.object(SGLangRollout, "_init_inference_engine", return_value=None) - @patch.object(SGLangRollout, "_init_sampling_params", return_value=None) - def test_tools_registration(self, mock_env, mock_engine, mock_sampling, sandbox_fusion_rollout_config, qwen_tokenizer, qwen_model_config): - rollout = SGLangRollout(actor_module="", config=sandbox_fusion_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config) - assert len(rollout._tool_schemas) == 1 - assert "code_interpreter" in rollout._tool_map.keys() + @pytest.fixture + def mock_rollout(self, sandbox_fusion_rollout_config, qwen_tokenizer, qwen_model_config): + """Mock the rollout instance""" + with patch.object(SGLangRollout, "_init_distributed_env", return_value=None), patch.object( + SGLangRollout, "_init_inference_engine", return_value=None + ), patch.object(SGLangRollout, "_init_sampling_params", return_value=None): + rollout = SGLangRollout( + actor_module="", + config=sandbox_fusion_rollout_config, + processing_class=qwen_tokenizer, + model_hf_config=qwen_model_config, + ) + # set default sampling_params + rollout.sampling_params = { + "n": 1, + "max_new_tokens": sandbox_fusion_rollout_config.response_length, + "presence_penalty": 0.0, + "frequency_penalty": 0.0, + "repetition_penalty": 1.0, + } + return rollout + + def test_tools_registration(self, mock_rollout): + """Test tool registration functionality""" + assert len(mock_rollout._tool_schemas) == 1 + assert "code_interpreter" in mock_rollout._tool_map.keys() from verl.tools.sandbox_fusion_tools import SandboxFusionTool - assert isinstance(rollout._tool_map["code_interpreter"], SandboxFusionTool) - assert rollout._tool_call_parser_type == "qwen25" + assert isinstance(mock_rollout._tool_map["code_interpreter"], SandboxFusionTool) + assert mock_rollout._tool_call_parser_type == "qwen25" - @patch.object(SGLangRollout, "_init_distributed_env", return_value=None) - @patch.object(SGLangRollout, "_init_inference_engine", return_value=None) - @patch.object(SGLangRollout, "_init_sampling_params", return_value=None) - def test_rollout_req_creation(self, mock_env, mock_engine, mock_sampling, sandbox_fusion_rollout_config, qwen_tokenizer, qwen_model_config, sandbox_data_proto): - rollout = SGLangRollout(actor_module="", config=sandbox_fusion_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config) - req_list = rollout._preprocess_prompt_to_async_rollout_requests(sandbox_data_proto, n=1) + def test_rollout_req_creation(self, mock_rollout, sandbox_data_proto): + """Test request creation functionality""" + req_list = mock_rollout._preprocess_prompt_to_async_rollout_requests(sandbox_data_proto, n=1) assert len(req_list) == 1 assert req_list[0].state == AsyncRolloutRequestStateEnum.PENDING - assert len(req_list[0].tools) == 1 - print(type(req_list[0].tools[0])) - assert req_list[0].tools[0] == OpenAIFunctionToolSchema( + assert len(req_list[0].tool_schemas) == 1 + print(type(req_list[0].tool_schemas[0])) + assert req_list[0].tool_schemas[0] == OpenAIFunctionToolSchema( type="function", function=OpenAIFunctionSchema( name="code_interpreter", @@ -242,34 +278,43 @@ def test_rollout_req_creation(self, mock_env, mock_engine, mock_sampling, sandbo ), ) - @patch.object(SGLangRollout, "_init_distributed_env", return_value=None) - @patch.object(SGLangRollout, "_init_inference_engine", return_value=None) - @patch.object(SGLangRollout, "_init_sampling_params", return_value=None) - def test_over_size_case(self, mock_env, mock_engine, mock_sampling, sandbox_fusion_rollout_config, qwen_tokenizer, qwen_model_config, sandbox_data_proto, sandbox_fusion_data): - sandbox_fusion_rollout_config.multi_turn.max_turns = 1 - rollout = SGLangRollout(actor_module="", config=sandbox_fusion_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config) - req = rollout._preprocess_prompt_to_async_rollout_requests(sandbox_data_proto, n=1)[0] + def test_over_size_case(self, mock_rollout, sandbox_data_proto, sandbox_fusion_data): + """Test over-size response truncation case""" + mock_rollout.config.multi_turn.max_assistant_turns = 1 + req = mock_rollout._preprocess_prompt_to_async_rollout_requests(sandbox_data_proto, n=1)[0] req = MagicMock(wraps=req, spec=AsyncRolloutRequest) req.finalize = MagicMock() req_list = [req] _, expect_turn_array, tool_return_array = sandbox_fusion_data # here we mock a meta info with 'length'. indicate the response is truncate - rollout._handle_engine_call = MagicMock() + mock_rollout._handle_engine_call = MagicMock() future = asyncio.Future() - future.set_result({"text": expect_turn_array[0], "meta_info": {"id": "d1188d81cba840359df5b352b344bc8e", "finish_reason": {"type": "length", "length": 1024}, "prompt_tokens": 132, "completion_tokens": 100, "cached_tokens": 0, "e2e_latency": 9.9304039478302}}) - rollout._handle_engine_call.return_value = future - rollout._tp_rank = 0 + future.set_result( + { + "text": expect_turn_array[0], + "meta_info": { + "id": "d1188d81cba840359df5b352b344bc8e", + "finish_reason": {"type": "length", "length": 1024}, + "prompt_tokens": 132, + "completion_tokens": 100, + "cached_tokens": 0, + "e2e_latency": 9.9304039478302, + }, + } + ) + mock_rollout._handle_engine_call.return_value = future + mock_rollout._tp_rank = 0 loop = asyncio.get_event_loop() output_req_list = loop.run_until_complete( asyncio.gather( - *[rollout._async_rollout_a_request(req, True, False) for req in req_list], + *[mock_rollout._async_rollout_a_request(req, True, False) for req in req_list], ) ) assert len(output_req_list) == 1 output_req = output_req_list[0] assert output_req.state == AsyncRolloutRequestStateEnum.COMPLETED - assert output_req.reward_scores == {"code_interpreter": []} + assert output_req.reward_scores.get("code_interpreter") == [] # we should only have two message, one for prompt, second for response. assert len(output_req.messages) == 2 assert output_req.messages[1] == Message( @@ -279,33 +324,42 @@ def test_over_size_case(self, mock_env, mock_engine, mock_sampling, sandbox_fusi ) @skip_if_valid_sandbox(sandbox_url) - @patch.object(SGLangRollout, "_init_distributed_env", return_value=None) - @patch.object(SGLangRollout, "_init_inference_engine", return_value=None) - @patch.object(SGLangRollout, "_init_sampling_params", return_value=None) - def test_tool_call_basic_case(self, mock_env, mock_engine, mock_sampling, sandbox_fusion_rollout_config, qwen_tokenizer, qwen_model_config, sandbox_data_proto, sandbox_fusion_data): - sandbox_fusion_rollout_config.multi_turn.max_turns = 10 - rollout = SGLangRollout(actor_module="", config=sandbox_fusion_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config) - self._tool_map["code_interpreter"].sandbox_fusion_url = sandbox_url - req = rollout._preprocess_prompt_to_async_rollout_requests(sandbox_data_proto, n=1)[0] + def test_tool_call_basic_case(self, mock_rollout, sandbox_data_proto, sandbox_fusion_data): + """Test basic tool call case""" + mock_rollout.config.multi_turn.max_assistant_turns = 10 + mock_rollout._tool_map["code_interpreter"].sandbox_fusion_url = sandbox_url + req = mock_rollout._preprocess_prompt_to_async_rollout_requests(sandbox_data_proto, n=1)[0] req = MagicMock(wraps=req, spec=AsyncRolloutRequest) req.finalize = MagicMock() req_list = [req] _, expect_turn_array, tool_return_array = sandbox_fusion_data # here we mock a meta info with 'length'. indicate the response is truncate - rollout._handle_engine_call = MagicMock() + mock_rollout._handle_engine_call = MagicMock() futures = [asyncio.Future() for i in expect_turn_array] for idx, (i, turn) in enumerate(zip(futures, expect_turn_array)): - i.set_result({"text": turn, "meta_info": {"id": "d1188d81cba840359df5b352b344bc8e", "finish_reason": {"type": "tool_calls" if idx < len(expect_turn_array) - 1 else "stop"}, "prompt_tokens": len(turn), "completion_tokens": 100, "cached_tokens": 0, "e2e_latency": 9.9304039478302}}) + i.set_result( + { + "text": turn, + "meta_info": { + "id": "d1188d81cba840359df5b352b344bc8e", + "finish_reason": {"type": "tool_calls" if idx < len(expect_turn_array) - 1 else "stop"}, + "prompt_tokens": len(turn), + "completion_tokens": 100, + "cached_tokens": 0, + "e2e_latency": 9.9304039478302, + }, + } + ) if idx < len(expect_turn_array) - 1: - assert rollout._function_call_parser.has_tool_call(turn) - assert rollout._function_call_parser.parse_non_stream(turn) + assert mock_rollout._function_call_parser.has_tool_call(turn) + assert mock_rollout._function_call_parser.parse_non_stream(turn) - rollout._handle_engine_call.side_effect = futures - rollout._tp_rank = 0 + mock_rollout._handle_engine_call.side_effect = futures + mock_rollout._tp_rank = 0 loop = asyncio.get_event_loop() output_req_list = loop.run_until_complete( asyncio.gather( - *[rollout._async_rollout_a_request(req, True, False) for req in req_list], + *[mock_rollout._async_rollout_a_request(req, True, False) for req in req_list], ) ) assert len(output_req_list) == 1 @@ -313,7 +367,7 @@ def test_tool_call_basic_case(self, mock_env, mock_engine, mock_sampling, sandbo assert output_req.state == AsyncRolloutRequestStateEnum.COMPLETED # here we verify whether the code sandbox is executed correctly assert output_req.metrics == {"code_interpreter": ["3", "149"]} - assert rollout._handle_engine_call.call_count == 3 + assert mock_rollout._handle_engine_call.call_count == 3 assert len(output_req.messages) == 6 # user + 3*assistant + 2*tool_call code_counter = 0 for msg in output_req.messages: @@ -323,14 +377,11 @@ def test_tool_call_basic_case(self, mock_env, mock_engine, mock_sampling, sandbo assert code_counter == 2 @skip_if_valid_sandbox(sandbox_url) - @patch.object(SGLangRollout, "_init_distributed_env", return_value=None) - @patch.object(SGLangRollout, "_init_inference_engine", return_value=None) - @patch.object(SGLangRollout, "_init_sampling_params", return_value=None) - def test_tool_call_batch_case(self, mock_env, mock_engine, mock_sampling, sandbox_fusion_rollout_config, qwen_tokenizer, qwen_model_config, sandbox_data_proto, sandbox_fusion_data): - sandbox_fusion_rollout_config.multi_turn.max_turns = 10 - rollout = SGLangRollout(actor_module="", config=sandbox_fusion_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config) - self._tool_map["code_interpreter"].sandbox_fusion_url = sandbox_url - req = rollout._preprocess_prompt_to_async_rollout_requests(sandbox_data_proto, n=1)[0] + def test_tool_call_batch_case(self, mock_rollout, sandbox_data_proto, sandbox_fusion_data): + """Test batch tool call case""" + mock_rollout.config.multi_turn.max_assistant_turns = 10 + mock_rollout._tool_map["code_interpreter"].sandbox_fusion_url = sandbox_url + req = mock_rollout._preprocess_prompt_to_async_rollout_requests(sandbox_data_proto, n=1)[0] req_nums = 100 req_list = [] req_turns_counter = {} @@ -344,25 +395,39 @@ def test_tool_call_batch_case(self, mock_env, mock_engine, mock_sampling, sandbo req_list.append(MagicMock(wraps=_temp_req, spec=AsyncRolloutRequest)) futures = [asyncio.Future() for i in expect_turn_array] for idx, (i, turn) in enumerate(zip(futures, expect_turn_array)): - i.set_result({"text": turn, "meta_info": {"id": "d1188d81cba840359df5b352b344bc8e", "finish_reason": {"type": "tool_calls" if idx < len(expect_turn_array) - 1 else "stop"}, "prompt_tokens": len(turn), "completion_tokens": 100, "cached_tokens": 0, "e2e_latency": 9.9304039478302}}) + i.set_result( + { + "text": turn, + "meta_info": { + "id": "d1188d81cba840359df5b352b344bc8e", + "finish_reason": {"type": "tool_calls" if idx < len(expect_turn_array) - 1 else "stop"}, + "prompt_tokens": len(turn), + "completion_tokens": 100, + "cached_tokens": 0, + "e2e_latency": 9.9304039478302, + }, + } + ) if idx < len(expect_turn_array) - 1: - assert rollout._function_call_parser.has_tool_call(turn) - assert rollout._function_call_parser.parse_non_stream(turn) + assert mock_rollout._function_call_parser.has_tool_call(turn) + assert mock_rollout._function_call_parser.parse_non_stream(turn) req_turns_map[_temp_req.batch_data_id] = futures req_turns_counter[_temp_req.batch_data_id] = 0 - async def hacked_handle_engine_call(self, _req: AsyncRolloutRequest, do_sample: bool, is_validate: bool, **kwargs): + async def hacked_handle_engine_call( + self, _req: AsyncRolloutRequest, do_sample: bool, is_validate: bool, **kwargs + ): result = req_turns_map[_req.batch_data_id][req_turns_counter[_req.batch_data_id]] req_turns_counter[_req.batch_data_id] += 1 re = await result return re with patch.object(SGLangRollout, "_handle_engine_call", new=hacked_handle_engine_call): - rollout._tp_rank = 0 + mock_rollout._tp_rank = 0 loop = asyncio.get_event_loop() output_req_list = loop.run_until_complete( asyncio.gather( - *[rollout._async_rollout_a_request(req, True, False) for req in req_list], + *[mock_rollout._async_rollout_a_request(req, True, False) for req in req_list], ) ) assert len(output_req_list) == req_nums @@ -379,6 +444,22 @@ async def hacked_handle_engine_call(self, _req: AsyncRolloutRequest, do_sample: code_counter += 1 assert code_counter == 2 + def test_sampling_params_functionality(self, mock_rollout): + """Test sampling_params functionality""" + # test basic copy functionality + copied_params = mock_rollout.sampling_params.copy() + assert copied_params == mock_rollout.sampling_params + assert copied_params is not mock_rollout.sampling_params + + # test parameter update + copied_params.update({"temperature": 0.8, "top_p": 0.9}) + assert copied_params["temperature"] == 0.8 + assert copied_params["top_p"] == 0.9 + + # ensure original parameters are not modified + assert "temperature" not in mock_rollout.sampling_params + assert "top_p" not in mock_rollout.sampling_params + class RayMultiProcessTestCase(MultiProcessTestCase): def setUp(self): @@ -480,7 +561,9 @@ def test_rate_limiter(self): from verl.tools.sandbox_fusion_tools import PoolMode, init_execution_pool # exec_worker = ExecutionWorker.options(max_concurrency=10).remote(enable_global_rate_limit=True, rate_limit=3) - exec_worker = init_execution_pool(num_workers=10, enable_global_rate_limit=True, rate_limit=3, mode=PoolMode.ThreadMode) + exec_worker = init_execution_pool( + num_workers=10, enable_global_rate_limit=True, rate_limit=3, mode=PoolMode.ThreadMode + ) center = TestActor.options(get_if_exists=True, name="test-actor").remote(self.rank, self.world_size) ray.get(exec_worker.ping.remote()) @@ -510,7 +593,9 @@ def test_rotten_execution(self): from verl.tools.sandbox_fusion_tools import PoolMode, init_execution_pool # exec_worker = ExecutionWorker.options(max_concurrency=10).remote(enable_global_rate_limit=True, rate_limit=6) - exec_worker = init_execution_pool(num_workers=10, enable_global_rate_limit=True, rate_limit=6, mode=PoolMode.ThreadMode) + exec_worker = init_execution_pool( + num_workers=10, enable_global_rate_limit=True, rate_limit=6, mode=PoolMode.ThreadMode + ) ray.get(exec_worker.ping.remote()) def fn(i): @@ -540,7 +625,9 @@ def test_rate_limiter(self): from verl.tools.sandbox_fusion_tools import PoolMode, init_execution_pool # exec_worker = ExecutionWorker.options(max_concurrency=10).remote(enable_global_rate_limit=True, rate_limit=6) - exec_worker = init_execution_pool(num_workers=10, enable_global_rate_limit=True, rate_limit=6, mode=PoolMode.ThreadMode) + exec_worker = init_execution_pool( + num_workers=10, enable_global_rate_limit=True, rate_limit=6, mode=PoolMode.ThreadMode + ) center = TestActor.options(get_if_exists=True, name="test-actor").remote(self.rank, self.world_size) ray.get(exec_worker.ping.remote()) diff --git a/tests/workers/rollout/test_sglang_async_rollout_w_interaction.py b/tests/workers/rollout/test_sglang_async_rollout_w_interaction.py new file mode 100644 index 000000000..3ccde1852 --- /dev/null +++ b/tests/workers/rollout/test_sglang_async_rollout_w_interaction.py @@ -0,0 +1,174 @@ +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +usage: torchrun --standalone --nnodes=1 \ + --nproc_per_node=2 $(which pytest) \ + -s test_sglang_async_rollout_w_interaction.py +""" + +import numpy as np +import torch +from tensordict import TensorDict +from torch.distributed.device_mesh import init_device_mesh +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import MixedPrecision, ShardingStrategy +from utils_sglang import ( + are_lists_similar, + clean_torchelastic_env, + generate_hf_output, + get_rollout_config, + initialize_global_process_group, + load_tokenizer_and_model, + prepare_inputs, +) + +from verl import DataProto +from verl.workers.rollout.sglang_rollout.sglang_rollout import SGLangRollout +from verl.workers.sharding_manager.fsdp_sglang import FSDPSGLangShardingManager + + +def test_async_sglang_rollout_w_interaction(): + assert torch.cuda.device_count() >= 2 + initialize_global_process_group() + clean_torchelastic_env() + + max_prompt_length = 32 + max_response_length = 16 + dtype = "bfloat16" + tensor_parallel_size = 2 + local_model_path = "Qwen/Qwen2.5-0.5B" + + tokenizer, actor_model = load_tokenizer_and_model(local_model_path) + + preencode_prompts = [ + [{"role": "user", "content": prompt, "tool_calls": None}] + for prompt in [ + "Who won the Champions League in 2019?", + "The founder of Apple is", + "What's the best way to learn python?", + ] + ] + interaction_kwargs = [ + {"name": "gsm8k", "query": "Who won the Champions League in 2019?", "ground_truth": "Real Madrid"}, + {"name": "gsm8k", "query": "The founder of Apple is", "ground_truth": "Steve Jobs"}, + {"name": "gsm8k", "query": "What's the best way to learn python?", "ground_truth": "Learn python from scratch"}, + ] + prompts = [ + tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True) + for message in preencode_prompts + ] + input_ids, attention_mask, position_ids = prepare_inputs(tokenizer, prompts, max_prompt_length) + + hf_response_tokens = generate_hf_output(actor_model, input_ids, attention_mask, tokenizer, max_response_length) + + fsdp_device_mesh = init_device_mesh("cuda", mesh_shape=(tensor_parallel_size,), mesh_dim_names=("fsdp",)) + inference_device_mesh_cpu = init_device_mesh( + "cpu", mesh_shape=(1, tensor_parallel_size, 1), mesh_dim_names=("dp", "infer_tp", "pp") + ) + + fsdp_model = FSDP( + actor_model, + use_orig_params=True, + device_id=fsdp_device_mesh["fsdp"].get_local_rank(), + mixed_precision=MixedPrecision(param_dtype=getattr(torch, dtype)), + sharding_strategy=ShardingStrategy.FULL_SHARD, + device_mesh=fsdp_device_mesh, + ) + + # Create a temporary interaction config file for testing + import tempfile + + from omegaconf import OmegaConf + + interaction_config = { + "interaction": [ + {"name": "gsm8k", "class_name": "verl.interactions.gsm8k_interaction.Gsm8kInteraction", "config": {}} + ] + } + + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + OmegaConf.save(interaction_config, f.name) + interaction_config_path = f.name + + rollout_config = get_rollout_config( + max_response_length, max_prompt_length, dtype, tensor_parallel_size, None, interaction_config_path + ) + rollout = SGLangRollout( + actor_module=local_model_path, + config=rollout_config, + processing_class=tokenizer, + model_hf_config=actor_model.config, + ) + + rollout_sharding_manager = FSDPSGLangShardingManager( + module=fsdp_model, + inference_engine=rollout._engine, + model_config=actor_model.config, + rollout_config=rollout_config, + full_params=True, + device_mesh=inference_device_mesh_cpu, + ) + + with rollout_sharding_manager: + prompt_dict = TensorDict( + { + "input_ids": input_ids, + "attention_mask": attention_mask, + "position_ids": position_ids, + }, + batch_size=input_ids.shape[0], + ) + print(f"preprocessed {input_ids.shape=}") + + messages = np.asarray(preencode_prompts) + prompts = DataProto( + batch=prompt_dict, + non_tensor_batch={"raw_prompt": messages, "interaction_kwargs": np.asarray(interaction_kwargs)}, + ) + + prompts.meta_info.update( + { + "eos_token_id": tokenizer.eos_token_id, + "pad_token_id": tokenizer.pad_token_id, + } + ) + + prompts = rollout_sharding_manager.preprocess_data(prompts) + # log_gpu_memory_usage("Before generating sequences", logger=None) + output = rollout.generate_sequences(prompts=prompts) + print(f"generated {output.batch['responses'].shape=}") + # log_gpu_memory_usage("After generating sequences", logger=None) + output = rollout_sharding_manager.postprocess_data(output) + print(f"postprocessed {output.batch['responses'].shape=}") + sglang_output = output.to("cpu") + + sglang_response_tokens = tokenizer.batch_decode(sglang_output.batch["responses"]) + + print(f"hf response: {hf_response_tokens}") + print(f"sglang response: {sglang_response_tokens}") + assert are_lists_similar(hf_response_tokens, sglang_response_tokens) + print("SGLang w interaction Test Passed!") + + # Clean up temporary config file + import os + + os.unlink(interaction_config_path) + + torch.distributed.barrier() + torch.distributed.destroy_process_group() + + +if __name__ == "__main__": + test_async_sglang_rollout_w_interaction() diff --git a/tests/workers/rollout/test_sglang_async_rollout_w_tools.py b/tests/workers/rollout/test_sglang_async_rollout_w_tools.py index c9f5ad68a..20faab851 100644 --- a/tests/workers/rollout/test_sglang_async_rollout_w_tools.py +++ b/tests/workers/rollout/test_sglang_async_rollout_w_tools.py @@ -60,13 +60,18 @@ def test_async_sglang_rollout_w_tool(): "What's the best way to learn python?", ] ] - prompts = [tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True) for message in preencode_prompts] + prompts = [ + tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True) + for message in preencode_prompts + ] input_ids, attention_mask, position_ids = prepare_inputs(tokenizer, prompts, max_prompt_length) hf_response_tokens = generate_hf_output(actor_model, input_ids, attention_mask, tokenizer, max_response_length) fsdp_device_mesh = init_device_mesh("cuda", mesh_shape=(tensor_parallel_size,), mesh_dim_names=("fsdp",)) - inference_device_mesh_cpu = init_device_mesh("cpu", mesh_shape=(1, tensor_parallel_size, 1), mesh_dim_names=("dp", "infer_tp", "pp")) + inference_device_mesh_cpu = init_device_mesh( + "cpu", mesh_shape=(1, tensor_parallel_size, 1), mesh_dim_names=("dp", "infer_tp", "pp") + ) fsdp_model = FSDP( actor_model, @@ -77,13 +82,25 @@ def test_async_sglang_rollout_w_tool(): device_mesh=fsdp_device_mesh, ) - rollout_config = get_rollout_config(max_response_length, max_prompt_length, dtype, tensor_parallel_size, None) - rollout = SGLangRollout(actor_module=local_model_path, config=rollout_config, tokenizer=tokenizer, model_hf_config=actor_model.config) + rollout_config = get_rollout_config( + max_response_length, + max_prompt_length, + dtype, + tensor_parallel_size, + "./resource/tool_configs/sandbox_fusion_tool_config", + ) + rollout = SGLangRollout( + actor_module=local_model_path, + config=rollout_config, + processing_class=tokenizer, + model_hf_config=actor_model.config, + ) rollout_sharding_manager = FSDPSGLangShardingManager( module=fsdp_model, inference_engine=rollout._engine, model_config=actor_model.config, + rollout_config=rollout_config, full_params=True, device_mesh=inference_device_mesh_cpu, ) @@ -100,7 +117,13 @@ def test_async_sglang_rollout_w_tool(): print(f"preprocessed {input_ids.shape=}") messages = np.asarray(preencode_prompts) - prompts = DataProto(batch=prompt_dict, non_tensor_batch={"raw_prompt": messages}) + prompts = DataProto( + batch=prompt_dict, + non_tensor_batch={ + "raw_prompt": messages, + "tools_kwargs": np.array([{}] * input_ids.shape[0], dtype=object), + }, + ) prompts.meta_info.update( { diff --git a/tests/workers/rollout/test_sglang_multi_interaction.py b/tests/workers/rollout/test_sglang_multi_interaction.py new file mode 100644 index 000000000..465470fbd --- /dev/null +++ b/tests/workers/rollout/test_sglang_multi_interaction.py @@ -0,0 +1,426 @@ +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +Test for multi-interaction support in SGLangRollout. +usage: torchrun --standalone --nnodes=1 \ + --nproc_per_node=2 $(which pytest) \ + -s test_sglang_multi_interaction.py +""" + +import os +import tempfile +from unittest.mock import MagicMock, patch + +import torch +import torch.distributed as dist +from omegaconf import DictConfig, OmegaConf +from transformers import AutoTokenizer + +from verl.interactions.base import BaseInteraction +from verl.workers.rollout.sglang_rollout.sglang_rollout import SGLangRollout + + +class MockInteraction(BaseInteraction): + """Mock interaction for testing.""" + + def __init__(self, config): + super().__init__(config) + self.started_instances = set() + + async def start_interaction(self, instance_id=None, **kwargs): + if instance_id is None: + instance_id = "mock_instance" + self.started_instances.add(instance_id) + return instance_id + + async def generate_response(self, instance_id, messages, **kwargs): + return False, f"Mock response from {self.name}", 1.0, {} + + +def create_mock_config_with_multi_interactions(): + """Create a mock configuration with multiple interactions.""" + # Create temporary interaction config file + interaction_config = { + "interaction": [ + { + "name": "mock_agent1", + "class_name": "tests.workers.rollout.test_sglang_multi_interaction.MockInteraction", + "config": {"param1": "value1"}, + }, + { + "name": "mock_agent2", + "class_name": "tests.workers.rollout.test_sglang_multi_interaction.MockInteraction", + "config": {"param2": "value2"}, + }, + ] + } + + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + OmegaConf.save(interaction_config, f.name) + interaction_config_path = f.name + + # Create mock SGLangRollout config + config = DictConfig( + { + "multi_turn": { + "interaction_config_path": interaction_config_path, + "tool_config_path": None, + "enable": True, + "max_assistant_turns": 5, + "max_user_turns": 3, + "use_inference_chat_template": True, + "tokenization_sanity_check_mode": "off", + }, + "prompt_length": 32, + "response_length": 16, + "max_model_len": 512, + "dtype": "bfloat16", + "gpu_memory_utilization": 0.8, + "load_format": "dummy", + "enforce_eager": True, + "free_cache_engine": False, + "calculate_log_probs": False, + "tensor_model_parallel_size": 1, + "n": 1, + "val_kwargs": {"top_k": 1, "top_p": 1.0, "temperature": 0.0}, + } + ) + + return config, interaction_config_path + + +def setup_distributed(): + """Initialize distributed environment if not already initialized.""" + if not dist.is_initialized(): + dist.init_process_group(backend="nccl" if torch.cuda.is_available() else "gloo") + + +class TestSGLangMultiInteraction: + def test_initialize_multiple_interactions(self): + """Test that SGLangRollout can initialize multiple interactions.""" + setup_distributed() + config, temp_config_path = create_mock_config_with_multi_interactions() + + try: + # Mock SGLang engine and initialization methods like the reference test + with ( + patch.object(SGLangRollout, "_init_distributed_env", return_value=None), + patch.object(SGLangRollout, "_init_inference_engine", return_value=None), + patch.object(SGLangRollout, "_init_sampling_params", return_value=None), + ): + # Create a real tokenizer like the reference test + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B", padding_side="left") + tokenizer.pad_token = tokenizer.eos_token + + # Mock model config + mock_model_config = MagicMock() + mock_model_config.max_position_embeddings = 2048 + # since this is a mock, we can set any rope scaling config + # to test the rope_scaling logic at the same time of this test + mock_model_config.rope_scaling = { + "factor": 4.0, + "original_max_position_embeddings": 32768, + "type": "yarn", + } + + # Create SGLangRollout instance + rollout = SGLangRollout( + actor_module="mock_model", + config=config, + processing_class=tokenizer, + model_hf_config=mock_model_config, + port=None, + trust_remote_code=False, + device_mesh=None, + ) + + # Check that interactions were initialized + assert len(rollout.interaction_map) == 2 + assert "mock_agent1" in rollout.interaction_map + assert "mock_agent2" in rollout.interaction_map + + # Use class name comparison instead of isinstance for multi-process compatibility + assert rollout.interaction_map["mock_agent1"].__class__.__name__ == "MockInteraction" + assert rollout.interaction_map["mock_agent2"].__class__.__name__ == "MockInteraction" + + # Also check that they are instances of BaseInteraction (which should work across processes) + assert isinstance(rollout.interaction_map["mock_agent1"], BaseInteraction) + assert isinstance(rollout.interaction_map["mock_agent2"], BaseInteraction) + + # Check that names were set correctly + assert rollout.interaction_map["mock_agent1"].name == "mock_agent1" + assert rollout.interaction_map["mock_agent2"].name == "mock_agent2" + + finally: + os.unlink(temp_config_path) + + def test_interaction_selection_by_name(self): + """Test that interactions are selected by name from interaction_kwargs.""" + setup_distributed() + config, temp_config_path = create_mock_config_with_multi_interactions() + + try: + with ( + patch.object(SGLangRollout, "_init_distributed_env", return_value=None), + patch.object(SGLangRollout, "_init_inference_engine", return_value=None), + patch.object(SGLangRollout, "_init_sampling_params", return_value=None), + ): + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B", padding_side="left") + tokenizer.pad_token = tokenizer.eos_token + + mock_model_config = MagicMock() + mock_model_config.max_position_embeddings = 2048 + mock_model_config.rope_scaling = { + "factor": 4.0, + "original_max_position_embeddings": 32768, + "type": "yarn", + } + + rollout = SGLangRollout( + actor_module="mock_model", + config=config, + processing_class=tokenizer, + model_hf_config=mock_model_config, + port=None, + trust_remote_code=False, + device_mesh=None, + ) + + # Test interaction selection logic + from verl.workers.rollout.schemas import AsyncRolloutRequest, AsyncRolloutRequestStateEnum, Message + + # Create a mock request with specific interaction name + req = AsyncRolloutRequest( + request_id="test_req", + state=AsyncRolloutRequestStateEnum.INTERACTING, + messages=[Message(role="user", content="test message")], + interaction_kwargs={"name": "mock_agent2", "test_param": "value"}, + input_ids=None, + prompt_ids=None, + response_ids=None, + attention_mask=None, + prompt_attention_mask=None, + response_attention_mask=None, + position_ids=None, + prompt_position_ids=None, + response_position_ids=None, + loss_mask=None, + prompt_loss_mask=None, + response_loss_mask=None, + reward_scores={}, + max_prompt_len=32, + max_response_len=16, + max_model_len=512, + use_inference_chat_template=True, + tokenization_sanity_check_mode="disable", + processing_class=tokenizer, + ) + + # Test that the correct interaction is selected + interaction_name = req.interaction_kwargs.get("name", "gsm8k") + assert interaction_name == "mock_agent2" + assert interaction_name in rollout.interaction_map + + selected_interaction = rollout.interaction_map[interaction_name] + assert selected_interaction.name == "mock_agent2" + + finally: + os.unlink(temp_config_path) + + def test_fallback_to_default_interaction(self): + """Test fallback to default interaction when name is not specified.""" + setup_distributed() + # Create config with gsm8k interaction + interaction_config = { + "interaction": [ + { + "name": "gsm8k", + "class_name": "tests.workers.rollout.test_sglang_multi_interaction.MockInteraction", + "config": {}, + } + ] + } + + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + OmegaConf.save(interaction_config, f.name) + interaction_config_path = f.name + + config = DictConfig( + { + "multi_turn": { + "interaction_config_path": interaction_config_path, + "tool_config_path": None, + "enable": True, + "max_assistant_turns": 5, + "max_user_turns": 3, + "use_inference_chat_template": True, + "tokenization_sanity_check_mode": "disable", + }, + "prompt_length": 32, + "response_length": 16, + "max_model_len": 512, + "dtype": "bfloat16", + "gpu_memory_utilization": 0.8, + "load_format": "dummy", + "enforce_eager": True, + "free_cache_engine": False, + "calculate_log_probs": False, + "tensor_model_parallel_size": 1, + "n": 1, + "val_kwargs": {"top_k": 1, "top_p": 1.0, "temperature": 0.0}, + } + ) + + try: + with ( + patch.object(SGLangRollout, "_init_distributed_env", return_value=None), + patch.object(SGLangRollout, "_init_inference_engine", return_value=None), + patch.object(SGLangRollout, "_init_sampling_params", return_value=None), + ): + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B", padding_side="left") + tokenizer.pad_token = tokenizer.eos_token + + mock_model_config = MagicMock() + mock_model_config.max_position_embeddings = 2048 + mock_model_config.rope_scaling = { + "factor": 4.0, + "original_max_position_embeddings": 32768, + "type": "yarn", + } + + rollout = SGLangRollout( + actor_module="mock_model", + config=config, + processing_class=tokenizer, + model_hf_config=mock_model_config, + port=None, + trust_remote_code=False, + device_mesh=None, + ) + + # Test that default interaction name works + interaction_kwargs_without_name = {"test_param": "value"} + default_name = interaction_kwargs_without_name.get("name", "gsm8k") + assert default_name == "gsm8k" + assert default_name in rollout.interaction_map + + finally: + os.unlink(interaction_config_path) + + def test_error_on_missing_interaction(self): + """Test that error is raised when requested interaction is not found.""" + setup_distributed() + config, temp_config_path = create_mock_config_with_multi_interactions() + + try: + with ( + patch.object(SGLangRollout, "_init_distributed_env", return_value=None), + patch.object(SGLangRollout, "_init_inference_engine", return_value=None), + patch.object(SGLangRollout, "_init_sampling_params", return_value=None), + ): + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B", padding_side="left") + tokenizer.pad_token = tokenizer.eos_token + + mock_model_config = MagicMock() + mock_model_config.max_position_embeddings = 2048 + mock_model_config.rope_scaling = { + "factor": 4.0, + "original_max_position_embeddings": 32768, + "type": "yarn", + } + + rollout = SGLangRollout( + actor_module="mock_model", + config=config, + processing_class=tokenizer, + model_hf_config=mock_model_config, + port=None, + trust_remote_code=False, + device_mesh=None, + ) + + # Test error when requesting non-existent interaction + non_existent_name = "non_existent_interaction" + assert non_existent_name not in rollout.interaction_map + + # This should raise ValueError in actual usage + available_interactions = list(rollout.interaction_map.keys()) + assert "mock_agent1" in available_interactions + assert "mock_agent2" in available_interactions + assert non_existent_name not in available_interactions + + finally: + os.unlink(temp_config_path) + + def test_backward_compatibility_no_interaction_config(self): + """Test backward compatibility when no interaction config is provided.""" + setup_distributed() + # Create config without interaction config + config = DictConfig( + { + "multi_turn": { + "interaction_config_path": None, + "tool_config_path": None, + "enable": True, + "max_assistant_turns": 5, + "max_user_turns": 3, + "use_inference_chat_template": True, + "tokenization_sanity_check_mode": "disable", + }, + "prompt_length": 32, + "response_length": 16, + "max_model_len": 512, + "dtype": "bfloat16", + "gpu_memory_utilization": 0.8, + "load_format": "dummy", + "enforce_eager": True, + "free_cache_engine": False, + "calculate_log_probs": False, + "tensor_model_parallel_size": 1, + "n": 1, + "val_kwargs": {"top_k": 1, "top_p": 1.0, "temperature": 0.0}, + } + ) + + with ( + patch.object(SGLangRollout, "_init_distributed_env", return_value=None), + patch.object(SGLangRollout, "_init_inference_engine", return_value=None), + patch.object(SGLangRollout, "_init_sampling_params", return_value=None), + ): + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B", padding_side="left") + tokenizer.pad_token = tokenizer.eos_token + + mock_model_config = MagicMock() + mock_model_config.max_position_embeddings = 2048 + mock_model_config.rope_scaling = { + "factor": 4.0, + "original_max_position_embeddings": 32768, + "type": "yarn", + } + + rollout = SGLangRollout( + actor_module="mock_model", + config=config, + processing_class=tokenizer, + model_hf_config=mock_model_config, + port=None, + trust_remote_code=False, + device_mesh=None, + ) + + # Check that no interactions were initialized + assert len(rollout.interaction_map) == 0 diff --git a/tests/workers/rollout/test_sglang_rollout_sharding_manager.py b/tests/workers/rollout/test_sglang_rollout_sharding_manager.py new file mode 100644 index 000000000..0d3c7b5da --- /dev/null +++ b/tests/workers/rollout/test_sglang_rollout_sharding_manager.py @@ -0,0 +1,57 @@ +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch + +from verl.workers.rollout.sglang_rollout.utils import get_named_tensor_buckets + +_TENSOR_1MB = torch.zeros(512, 512) +_BYTES_1MB = 1 << 20 + + +@pytest.mark.parametrize( + "named_tensors, bucket_size_mb, gt_groups", + [ + ( + [("a", _TENSOR_1MB), ("b", _TENSOR_1MB)], + 0.5 * _BYTES_1MB, + [["a"], ["b"]], + ), + ( + [("a", _TENSOR_1MB), ("b", _TENSOR_1MB)], + 1 * _BYTES_1MB, + [["a"], ["b"]], + ), + ( + [("a", _TENSOR_1MB), ("b", _TENSOR_1MB)], + 1.5 * _BYTES_1MB, + [["a"], ["b"]], + ), + ( + [("a", _TENSOR_1MB), ("b", _TENSOR_1MB)], + 2 * _BYTES_1MB, + [["a", "b"]], + ), + ], +) +def test_get_named_tensor_buckets(named_tensors, bucket_size_mb, gt_groups: list[list[str]]): + named_tensors_iter = iter(named_tensors) + groups = list(get_named_tensor_buckets(named_tensors_iter, bucket_size_mb)) + assert len(groups) == len(gt_groups) + for group, gt_group in zip(groups, gt_groups, strict=True): + assert len(group) == len(gt_group) + for (name, _), (gt_name) in zip(group, gt_group, strict=True): + assert name == gt_name diff --git a/tests/workers/rollout/test_sglang_spmd.py b/tests/workers/rollout/test_sglang_spmd.py index 0ad6445a9..0995e2f64 100644 --- a/tests/workers/rollout/test_sglang_spmd.py +++ b/tests/workers/rollout/test_sglang_spmd.py @@ -57,7 +57,9 @@ def test_sglang_spmd(): hf_response_tokens = generate_hf_output(actor_model, input_ids, attention_mask, tokenizer, max_response_length) tensor_parallel_size = 2 - inference_device_mesh_cpu = init_device_mesh("cpu", mesh_shape=(1, tensor_parallel_size, 1), mesh_dim_names=["dp", "tp", "pp"]) + inference_device_mesh_cpu = init_device_mesh( + "cpu", mesh_shape=(1, tensor_parallel_size, 1), mesh_dim_names=["dp", "tp", "pp"] + ) tp_rank = inference_device_mesh_cpu["tp"].get_local_rank() if tp_rank == 0: @@ -67,6 +69,7 @@ def test_sglang_spmd(): mem_fraction_static=0.5, enable_memory_saver=True, tp_size=inference_device_mesh_cpu["tp"].size(), + attention_backend="fa3", ) input_ids = input_ids.cuda() diff --git a/tests/workers/rollout/test_vllm_hf_loader.py b/tests/workers/rollout/test_vllm_hf_loader.py deleted file mode 100644 index 523065bd6..000000000 --- a/tests/workers/rollout/test_vllm_hf_loader.py +++ /dev/null @@ -1,172 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os - -import torch -from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, GenerationConfig -from vllm import SamplingParams - -from verl.third_party.vllm import LLM, vllm_version -from verl.utils.torch_functional import pad_sequence_to_length -from verl.workers.rollout.vllm_rollout.vllm_rollout import _pre_process_inputs - - -def levenshtein(s1, s2): - m, n = len(s1), len(s2) - # Initialize matrix of zeros - dp = [[0] * (n + 1) for _ in range(m + 1)] - # Initialize first column and first row of the matrix - for i in range(m + 1): - dp[i][0] = i # Deletion from s1 to empty string - for j in range(n + 1): - dp[0][j] = j # Insertion to s1 from empty string - # Compute the Levenshtein distance matrix - for i in range(1, m + 1): - for j in range(1, n + 1): - cost = 0 if s1[i - 1] == s2[j - 1] else 1 # No cost if characters match - dp[i][j] = min( - dp[i - 1][j] + 1, # Deletion - dp[i][j - 1] + 1, # Insertion - dp[i - 1][j - 1] + cost, # Substitution - ) - return dp[m][n] - - -def are_lists_similar(a, b): - if len(a) != len(b): - print("The lists are of different lengths.") - return False - - total_length = 0 - total_diff = 0 - - for s1, s2 in zip(a, b): - max_len = max(len(s1), len(s2)) - total_length += max_len - diff = levenshtein(s1, s2) - total_diff += diff - print(f"Comparing strings:\n{s1}\n{s2}\nDifference: {diff} characters\n") - - percentage_difference = (total_diff / total_length) * 100 - print(f"Total difference: {percentage_difference:.2f}%") - - return percentage_difference <= 10 - - -def test_vllm_with_hf(): - assert torch.cuda.device_count() >= 2, "At least 2 GPUs is required to run tp+dp tests." - - # fill rollout config - max_prompt_length = 16 - max_response_length = 16 - - # Initialize model and token - local_cache_path = "~/.cache/verl/rlhf" - local_cache_path = os.path.expanduser(local_cache_path) - hdfs_path = "deepseek-ai/deepseek-llm-7b-chat" - from verl.utils.fs import copy_to_local - - local_model_path = copy_to_local(src=hdfs_path, cache_dir=local_cache_path) - tokenizer = AutoTokenizer.from_pretrained(local_model_path) - - preencode_prompts = [ - "Who won the Champions League in 2019?", - "The founder of Apple is", - "What's your name", - ] - tokenizer.pad_token = tokenizer.eos_token - prompts = tokenizer(preencode_prompts, return_tensors="pt", padding=True) - input_ids = prompts["input_ids"] - attention_mask = prompts["attention_mask"] - - input_ids = pad_sequence_to_length(input_ids, max_prompt_length, tokenizer.pad_token_id, left_pad=True) - attention_mask = pad_sequence_to_length(attention_mask, max_prompt_length, 0, left_pad=True) - - actor_model = AutoModelForCausalLM.from_pretrained(local_model_path) - actor_model.to(torch.bfloat16) - - actor_model_config = AutoConfig.from_pretrained(local_model_path) - - temperature = 0 - top_p = 1 - - kwargs = dict(n=1, temperature=temperature, top_p=top_p, max_tokens=max_response_length, logprobs=1, ignore_eos=True) - - if vllm_version in ( - "0.5.4", - "0.6.3", - ): - kwargs["detokenize"] = False - sampling_params = SamplingParams(**kwargs) - - tensor_parallel_size = 4 - - llm = LLM( - model=actor_model, - tokenizer=tokenizer, - model_hf_config=actor_model_config, - tensor_parallel_size=tensor_parallel_size, - dtype="bfloat16", - gpu_memory_utilization=0.1, - load_format="hf", - ) - - print("start generation") - input_ids = input_ids.cuda() - attention_mask = attention_mask.cuda() - batch_size = input_ids.size(0) - - idx_list = [] - # parse idx from torch.Tensor to List[List[str]] - for i in range(batch_size): - idx_list.append(_pre_process_inputs(tokenizer.pad_token_id, input_ids[i])) - outputs = llm.generate(prompt_token_ids=idx_list, sampling_params=sampling_params, use_tqdm=False) - vllm_output = outputs[0].cuda() - llm.free_cache_engine() - llm = None - import gc - - torch.cuda.empty_cache() - gc.collect() - - generation_config = GenerationConfig(do_sample=False) - actor_model.cuda() - output = actor_model.generate( - input_ids=input_ids, - attention_mask=attention_mask, - max_new_tokens=max_response_length, - # max_length=max_length, - eos_token_id=tokenizer.eos_token_id, - pad_token_id=tokenizer.pad_token_id, - generation_config=generation_config, - # renormalize_logits=True, - output_scores=False, # this is potentially very large - return_dict_in_generate=True, - use_cache=False, - ) # may OOM when use_cache = True - seq = output.sequences - response = seq[:, max_prompt_length:] - - hf_response_tokens = tokenizer.batch_decode(response) - vllm_response_tokens = tokenizer.batch_decode(vllm_output) - - print(f"hf response: {hf_response_tokens}") - print(f"vllm response: {vllm_response_tokens}") - assert are_lists_similar(hf_response_tokens, vllm_response_tokens), "Strings differ more than 10%:\n" - print("Check Pass") - - -# if __name__ == "__main__": -# test_vllm_with_hf() diff --git a/tests/workers/rollout/test_vllm_multi_turn.py b/tests/workers/rollout/test_vllm_multi_turn.py deleted file mode 100644 index b705d86a9..000000000 --- a/tests/workers/rollout/test_vllm_multi_turn.py +++ /dev/null @@ -1,195 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import asyncio -import json -from typing import Any, Dict - -import numpy as np -import ray -from omegaconf import DictConfig, OmegaConf -from openai.types.chat.chat_completion import ChatCompletion -from vllm.entrypoints.openai.protocol import ChatCompletionRequest, ChatCompletionResponse, ChatCompletionStreamResponse, ErrorResponse - -from tests.workers.rollout.async_rollout_utils import init_async_rollout_manager -from verl.protocol import DataProto - - -def init_config() -> DictConfig: - config = OmegaConf.load("verl/trainer/config/ppo_trainer.yaml") - model_path = "Qwen/Qwen2-7B-Instruct" - config.actor_rollout_ref.model.path = model_path - config.actor_rollout_ref.rollout.mode = "async" - config.actor_rollout_ref.rollout.chat_scheduler = "examples.ppo_trainer.naive_chat_scheduler.NaiveChatCompletionScheduler" - config.actor_rollout_ref.rollout.prompt_length = 4096 - config.actor_rollout_ref.rollout.response_length = 4096 - - # test sleep/wake_up with fsdp offload - config.actor_rollout_ref.actor.fsdp_config.param_offload = True - config.actor_rollout_ref.actor.fsdp_config.optimizer_offload = True - - return config - - -def test_vllm_multi_turn(config): - ray.init( - runtime_env={ - "env_vars": { - "TOKENIZERS_PARALLELISM": "true", - "NCCL_DEBUG": "WARN", - "VLLM_LOGGING_LEVEL": "WARN", - "VLLM_USE_V1": "1", - } - } - ) - - # =========================== 1. Init rollout manager =========================== - model_name = "/".join(config.actor_rollout_ref.model.path.split("/")[-2:]) - async_rollout_manager = init_async_rollout_manager(config) - - # test sleep and wake_up - async_rollout_manager.sleep() - async_rollout_manager.wake_up() - - async_chat_scheduler = async_rollout_manager.chat_scheduler - - # =========================== 2. Multi turn rollout =========================== - async def callback(completions: ChatCompletion, info: Dict[str, Any], exception: Exception): - assert exception is None, f"exception: {exception}" - messages, round = info["messages"], info["round"] - message = completions.choices[0].message - messages.append({"role": message.role, "content": message.content}) - print(f"[round={round}] role: {message.role}, content: {message.content}") - - extra_headers = {"x-request-id": completions.id} - if round == 0: - messages.append({"role": "user", "content": "What is your name?"}) - await async_chat_scheduler.submit_chat_completions( - callback=callback, - callback_additional_info={"messages": messages, "round": 1}, - model=model_name, - messages=messages, - extra_headers=extra_headers, - ) - elif round == 1: - messages.append({"role": "user", "content": "What is your favorite color?"}) - await async_chat_scheduler.submit_chat_completions( - callback=callback, - callback_additional_info={"messages": messages, "round": 2}, - model=model_name, - messages=messages, - extra_headers=extra_headers, - ) - else: - print("Done!") - - messages = [{"role": "user", "content": "Let's play a role playing game. Your name is Bob, your favorite color is red."}] - async_rollout_manager.submit_chat_completions( - callback=callback, - callback_additional_info={"messages": messages, "round": 0}, - model=model_name, - messages=messages, - ) - assert len(messages) == 6 - for round, message in enumerate(messages): - if round % 2 == 0: - assert message["role"] == "user" - else: - assert message["role"] == "assistant" - - # =========================== 3. Generate sequences =========================== - raw_prompts = [ - [ - { - "role": "user", - "content": "Let's play a role playing game. Your name is Alice, your favorite color is blue.", - } - ], - [{"role": "user", "content": "Let's play a role playing game. Your name is Bob, your favorite color is red."}], - ] - batch = DataProto( - non_tensor_batch={ - "raw_prompt": np.array(raw_prompts), - }, - ) - result = async_rollout_manager.generate_sequences(prompts=batch) - seq_len = result.batch["prompts"].size(1) + result.batch["responses"].size(1) - assert len(result) == 2 - assert result.batch["input_ids"].size(1) == seq_len - assert result.batch["attention_mask"].size(1) == seq_len - assert result.batch["position_ids"].size(1) == seq_len - - ray.shutdown() - - -async def test_vllm_streaming_response(config): - ray.init( - runtime_env={ - "env_vars": { - "TOKENIZERS_PARALLELISM": "true", - "NCCL_DEBUG": "WARN", - "VLLM_LOGGING_LEVEL": "WARN", - "VLLM_USE_V1": "1", - } - } - ) - - model_name = "/".join(config.actor_rollout_ref.model.path.split("/")[-2:]) - async_rollout_manager = init_async_rollout_manager(config) - async_llm_server = async_rollout_manager.async_llm_servers[0] - - # non-streaming request - request = ChatCompletionRequest( - model=model_name, - messages=[{"role": "user", "content": "What is your name?"}], - stream=False, - ) - generator = async_llm_server.chat_completion_generator.remote(request) - async for ref in generator: - status_code, data = await ref - print(f">>>> status_code: {status_code}, {data}") - data = data[len("data: ") :].rstrip() - if status_code != 200: - response = ErrorResponse(**json.loads(data)) - else: - response = ChatCompletionResponse(**json.loads(data)) - assert response.choices[0].message.role == "assistant" - assert response.choices[0].message.content is not None - - # streaming request - request = ChatCompletionRequest( - model=model_name, - messages=[{"role": "user", "content": "How are you?"}], - stream=True, - ) - generator = async_llm_server.chat_completion_generator.remote(request) - async for ref in generator: - status_code, data = await ref - print(f">>>> status_code: {status_code}, {data}") - data = data[len("data: ") :].rstrip() - if status_code != 200: - response = ErrorResponse(**json.loads(data)) - elif data == "[DONE]": - break - else: - response = ChatCompletionStreamResponse(**json.loads(data)) - assert response.choices[0].delta.role is None or response.choices[0].delta.role == "assistant" - assert response.choices[0].delta.content is not None - - ray.shutdown() - - -if __name__ == "__main__": - config = init_config() - test_vllm_multi_turn(config) - asyncio.run(test_vllm_streaming_response(config)) diff --git a/tests/workers/rollout/test_vllm_tool_calling.py b/tests/workers/rollout/test_vllm_tool_calling.py deleted file mode 100644 index efc8fbf54..000000000 --- a/tests/workers/rollout/test_vllm_tool_calling.py +++ /dev/null @@ -1,278 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import asyncio -import os -import re -import socket -import sys -import tempfile -from contextlib import asynccontextmanager -from typing import Any, Dict - -import aiohttp -import fastapi -import numpy as np -import ray -import uvicorn -from datasets import load_dataset -from omegaconf import OmegaConf -from openai.types.chat.chat_completion import ChatCompletion -from starlette.requests import Request -from starlette.responses import JSONResponse - -from examples.ppo_trainer.naive_chat_scheduler import NaiveChatCompletionScheduler -from tests.workers.rollout.async_rollout_utils import init_async_rollout_manager -from verl.protocol import DataProto - - -def _get_free_port(): - with socket.socket() as sock: - sock.bind(("", 0)) - return sock.getsockname()[1] - - -@ray.remote(num_cpus=1) -class Sandbox: - """Sandbox to execute python code. - - WARNING: This class is for testing purpose only, do not use it in production. - Please use a sandbox with strong isolation and security restrictions instead. - """ - - def __init__(self): - self.address = ray._private.services.get_node_ip_address() - self.port = None - self.server_ready = asyncio.Event() - asyncio.create_task(self._start_fastapi_server()) - - async def code_execution(self, request: Request): - request_json = await request.json() - code = request_json["code"] - print(f"execute code:\n{code}") - - _, temp_file = tempfile.mkstemp(suffix=".py", prefix="temp_code", dir=None, text=True) - with open(temp_file, "w") as f: - f.write(code) - - try: - process = await asyncio.create_subprocess_exec(sys.executable, temp_file, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE) - - stdout, stderr = await process.communicate() - - return JSONResponse(content={"stdout": stdout.decode(), "stderr": stderr.decode(), "returncode": process.returncode}) - finally: - try: - os.unlink(temp_file) - except: # noqa: E722 - pass - - async def _start_fastapi_server(self): - @asynccontextmanager - async def lifespan(app: fastapi.FastAPI): - print("FastAPI startup") - self.server_ready.set() - yield - - print("FastAPI shutdown, maybe address already in use, exit process immediately.") - os._exit(-1) - - app = fastapi.FastAPI(lifespan=lifespan) - app.router.add_api_route("/code/execution", self.code_execution, methods=["POST"]) - - self.port = _get_free_port() - config = uvicorn.Config(app, host=["::", "0.0.0.0"], port=self.port, log_level="warning") - server = uvicorn.Server(config) - await server.serve() - - async def get_server_address(self) -> str: - """Get FastAPI server address.""" - await self.server_ready.wait() - return f"{self.address}:{self.port}" - - -class ToolChatCompletionScheduler(NaiveChatCompletionScheduler): - """This is a demo chat completion scheduler that supports sandbox code execution - described in ReTool paper: https://arxiv.org/pdf/2504.11536 - """ - - def __init__(self, config, model_path, server_addresses, sandbox_address, system_prompt, **kwargs): - super().__init__(config, model_path, server_addresses, **kwargs) - self.sandbox_address = sandbox_address - self.system_prompt = system_prompt - - async def sandbox_code_execution(self, code: str) -> Dict[str, Any]: - """Execute python code in sandbox.""" - try: - session = aiohttp.ClientSession() - async with session.post( - url=f"http://{self.sandbox_address}/code/execution", - json={"code": code}, - ) as resp: - return await resp.json() - finally: - await session.close() - - async def generate_sequences(self, batch: DataProto, **sampling_params) -> DataProto: - kwargs = dict( - n=self.config.n, - max_completion_tokens=self.config.response_length, - temperature=self.config.temperature, - top_p=self.config.top_p, - extra_body={ - "include_stop_str_in_output": True, - "stop": ["
", ""], - }, - ) - - do_sample = batch.meta_info.get("do_sample", True) - is_validate = batch.meta_info.get("validate", False) - if not do_sample or is_validate: - kwargs["n"] = 1 - kwargs["temperature"] = 0 - - kwargs.update(sampling_params) - print(f"[ToolChatCompletionScheduler] generate_sequences sampling params: {kwargs}") - - max_turns = 3 - - async def callback(completions: ChatCompletion, info: Dict[str, Any], exception: Exception): - batch_conversations, batch_index, turn = ( - info["batch_conversations"], - info["batch_index"], - info["turn"], - ) - role, content = completions.choices[0].message.role, completions.choices[0].message.content - batch_conversations[batch_index].append({"role": role, "content": content}) - - # STEP 0: check if we reach max turns - if turn == max_turns: - print(f"[id={completions.id},turn={turn}] Reach max turns {max_turns}, done!") - return - - # STEP 1: check if we got answer - matches = re.findall(r"(.*?)", content, re.DOTALL) - if matches: - print(f"[id={completions.id},turn={turn}] Got answer: {matches[0]}, done!") - return - - # STEP 2: check if we got code block - matches = re.findall(r"\s*```python(.*?)```\s*", content, re.DOTALL) - if not matches: - print(f"[id={completions.id},turn={turn}] No code block found, done!") - return - - # STEP 3: execute code block in sandbox - code = matches[0].strip() - result = await self.sandbox_code_execution(code) - stdout, stderr = result["stdout"], result["stderr"] - batch_conversations[batch_index].append({"role": "tool", "content": f"{stdout}{stderr}"}) - print(f"[id={completions.id},turn={turn}] Code block executed, continue...") - - # STEP 4: resubmit chat completions with code block output - extra_headers = {"x-request-id": completions.id} - await self.submit_chat_completions( - callback=callback, - callback_additional_info={ - "batch_conversations": batch_conversations, - "batch_index": batch_index, - "turn": turn + 1, - }, - model=self.model_name, - messages=batch_conversations[batch_index], - extra_headers=extra_headers, - **kwargs, - ) - - tasks, batch_conversations = [], [None] * len(batch) - for batch_index, conversation in enumerate(batch.non_tensor_batch["raw_prompt"]): - # raw_prompt: [{"role": "user", "content": ""}, ["role": "assistant", "content"], ...] - batch_conversations[batch_index] = [{"role": "system", "content": self.system_prompt}] + list(conversation) - tasks.append( - asyncio.create_task( - self.submit_chat_completions( - callback=callback, - callback_additional_info={ - "batch_conversations": batch_conversations, - "batch_index": batch_index, - "turn": 1, - }, - model=self.model_name, - messages=batch_conversations[batch_index], - **kwargs, - ) - ) - ) - - await asyncio.gather(*tasks) - print("[NaiveChatCompletionScheduler] generate_sequences done") - - # _postprocess assumes n>=1 - batch_conversations = [[conversation] for conversation in batch_conversations] - return self._postprocess(batch, batch_conversations, kwargs["n"]) - - -system_prompt = """ -You are a helpful assistant. Let's solve math problem in following steps: -1. Write a python code first and return the code to user, the code must be in following format: - - -```python -import os - -print(...) -``` - - -The code must explictly print necessary output to stdout. Remember stop generation at immediately and return the code. -2. User will send the python code to a external sandbox to execute and get output from stdout. -3. User will send the output in format output to you, and you should use the output to answer the question. -The answer format must be: \\boxed{'The final answer goes here.'} -""" - - -def test_vllm_tool_calling(): - ray.init( - runtime_env={ - "env_vars": { - "TOKENIZERS_PARALLELISM": "true", - "NCCL_DEBUG": "WARN", - "VLLM_LOGGING_LEVEL": "INFO", - "VLLM_USE_V1": "1", - } - } - ) - - # Load config - config = OmegaConf.load("verl/trainer/config/ppo_trainer.yaml") - config.actor_rollout_ref.model.path = "Qwen/Qwen2-7B-Instruct" - config.actor_rollout_ref.rollout.mode = "async" - config.actor_rollout_ref.rollout.chat_scheduler = "tests.workers.rollout.test_vllm_tool_calling.ToolChatCompletionScheduler" - config.actor_rollout_ref.rollout.prompt_length = 8192 - config.actor_rollout_ref.rollout.response_length = 8192 - - # Init sandbox and async rollout manager - sandbox = Sandbox.options(num_cpus=1).remote() - sandbox_address = ray.get(sandbox.get_server_address.remote()) - async_rollout_manager = init_async_rollout_manager(config, scheduler_kwargs={"sandbox_address": sandbox_address, "system_prompt": system_prompt}) - - # Build dataset - dataset = load_dataset("Maxwell-Jia/AIME_2024", split="train") - prompts = DataProto(non_tensor_batch={"raw_prompt": np.array([[{"role": "user", "content": problem}] for problem in dataset["Problem"]])}) - - result = async_rollout_manager.generate_sequences(prompts=prompts) - assert len(result) == len(dataset) - - -if __name__ == "__main__": - test_vllm_tool_calling() diff --git a/tests/workers/rollout/utils_sglang.py b/tests/workers/rollout/utils_sglang.py index 35c43a83a..d16b09feb 100644 --- a/tests/workers/rollout/utils_sglang.py +++ b/tests/workers/rollout/utils_sglang.py @@ -43,7 +43,7 @@ def are_lists_similar(a, b, threshold=10): return False total_length = 0 total_diff = 0 - for s1, s2 in zip(a, b): + for s1, s2 in zip(a, b, strict=True): max_len = max(len(s1), len(s2)) total_length += max_len total_diff += levenshtein(s1, s2) @@ -96,7 +96,9 @@ def prepare_inputs(tokenizer, prompts, max_prompt_length): pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id tokenized = tokenizer(prompts, return_tensors="pt", padding=True) input_ids = pad_sequence_to_length(tokenized["input_ids"], max_prompt_length, pad_token_id, left_pad=True) - attention_mask = pad_sequence_to_length(tokenized["attention_mask"], max_prompt_length, pad_token_id=0, left_pad=True) + attention_mask = pad_sequence_to_length( + tokenized["attention_mask"], max_prompt_length, pad_token_id=0, left_pad=True + ) position_ids = compute_position_id_with_mask(attention_mask) position_ids = pad_sequence_to_length(position_ids, max_prompt_length, pad_token_id=0, left_pad=True) return input_ids, attention_mask, position_ids @@ -120,7 +122,14 @@ def generate_hf_output(model, input_ids, attention_mask, tokenizer, max_response return tokenizer.batch_decode(response) -def get_rollout_config(max_response_length, max_prompt_length, dtype, tensor_parallel_size, tool_config_path): +def get_rollout_config( + max_response_length, + max_prompt_length, + dtype, + tensor_parallel_size, + tool_config_path=None, + interaction_config_path=None, +): sampling_params = dict( n=1, temperature=0, @@ -138,9 +147,10 @@ def get_rollout_config(max_response_length, max_prompt_length, dtype, tensor_par rollout_config = OmegaConf.create( { "name": "sglang", + "mode": "sync", "load_format": "dummy_dtensor", "enforce_eager": False, - "free_cache_engine": False, + "free_cache_engine": True, "dtype": dtype, "gpu_memory_utilization": 0.5, "ignore_eos": False, @@ -148,11 +158,16 @@ def get_rollout_config(max_response_length, max_prompt_length, dtype, tensor_par "prompt_length": max_prompt_length, "response_length": max_response_length, "tensor_model_parallel_size": tensor_parallel_size, + # set to 128MB only for testing + "update_weights_bucket_megabytes": 128, "multi_turn": { - "max_turns": 4, + "max_assistant_turns": 4, + "max_user_turns": 4, "enable": True, "tool_config_path": tool_config_path, - "format": "chatml", + "interaction_config_path": interaction_config_path, + "use_inference_chat_template": False, + "tokenization_sanity_check_mode": "strict", }, "max_model_len": None, **sampling_params, diff --git a/verl/__init__.py b/verl/__init__.py index 881f9c74f..6dbdd333f 100644 --- a/verl/__init__.py +++ b/verl/__init__.py @@ -12,17 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +import importlib import logging import os +from importlib.metadata import PackageNotFoundError +from importlib.metadata import version as get_version -import pkg_resources from packaging.version import parse as parse_version -from pkg_resources import DistributionNotFound from .protocol import DataProto from .utils.device import is_npu_available from .utils.logging_utils import set_basic_config + version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__))) with open(os.path.join(version_folder, "version/version")) as f: @@ -35,8 +37,6 @@ __all__ = ["DataProto", "__version__"] if os.getenv("VERL_USE_MODELSCOPE", "False").lower() == "true": - import importlib - if importlib.util.find_spec("modelscope") is None: raise ImportError("You are using the modelscope hub, please install modelscope by `pip install modelscope -U`") # Patch hub to download models from modelscope to speed up. @@ -45,15 +45,21 @@ patch_hub() if is_npu_available: - package_name = 'transformers' - required_version_spec = '4.51.0' + from .models.transformers import npu_patch as npu_patch + + package_name = "transformers" + required_version_spec = "4.52.4" try: - installed_version = pkg_resources.get_distribution(package_name).version + installed_version = get_version(package_name) installed = parse_version(installed_version) required = parse_version(required_version_spec) - if not installed >= required: - raise ValueError(f"{package_name} version >= {required_version_spec} is required on ASCEND NPU, current version is {installed}.") - except DistributionNotFound: + if installed < required: + raise ValueError( + f"{package_name} version >= {required_version_spec} is required on ASCEND NPU, current version is " + f"{installed}." + ) + except PackageNotFoundError as e: raise ImportError( - f"package {package_name} is not installed, please run pip install {package_name}=={required_version_spec}") + f"package {package_name} is not installed, please run pip install {package_name}=={required_version_spec}" + ) from e diff --git a/verl/base_config.py b/verl/base_config.py new file mode 100644 index 000000000..0cd117bb6 --- /dev/null +++ b/verl/base_config.py @@ -0,0 +1,91 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +from dataclasses import ( + dataclass, + field, + fields, # Import the fields function to inspect dataclass fields +) +from typing import Any + + +# BaseConfig class inherits from collections.abc.Mapping, which means it can act like a dictionary +@dataclass +class BaseConfig(collections.abc.Mapping): + """The BaseConfig provides omegaconf DictConfig-like interface for a dataclass config. + + The BaseConfig class implements the Mapping Abstract Base Class. + This allows instances of this class to be used like dictionaries. + """ + + extra: dict[str, Any] = field(default_factory=dict) + + def __setattr__(self, name: str, value): + # if the field already exists (i.e. was set in __init__) + # and is in our frozen list, block assignment + if hasattr(self, "_frozen_fields") and name in self._frozen_fields and name in self.__dict__: + from dataclasses import FrozenInstanceError + + raise FrozenInstanceError(f"Field '{name}' is frozen and cannot be modified") + # otherwise do the normal thing + super().__setattr__(name, value) + + def get(self, key: str, default: Any = None) -> Any: + """Get the value associated with the given key. If the key does not exist, return the default value. + + Args: + key (str): The attribute name to retrieve. + default (Any, optional): The value to return if the attribute does not exist. Defaults to None. + + Returns: + Any: The value of the attribute or the default value. + """ + try: + return getattr(self, key) + except AttributeError: + return default + + def __getitem__(self, key: str): + """Implement the [] operator for the class. Allows accessing attributes like dictionary items. + + Args: + key (str): The attribute name to retrieve. + + Returns: + Any: The value of the attribute. + + Raises: + AttributeError: If the attribute does not exist. + TypeError: If the key type is not string + """ + return getattr(self, key) + + def __iter__(self): + """Implement the iterator protocol. Allows iterating over the attribute names of the instance. + + Yields: + str: The name of each field in the dataclass. + """ + for f in fields(self): + yield f.name + + def __len__(self): + """ + Return the number of fields in the dataclass. + + Returns: + int: The number of fields in the dataclass. + """ + return len(fields(self)) diff --git a/verl/experimental/__init__.py b/verl/experimental/__init__.py new file mode 100644 index 000000000..1ce90c5eb --- /dev/null +++ b/verl/experimental/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/verl/experimental/agent_loop/__init__.py b/verl/experimental/agent_loop/__init__.py new file mode 100644 index 000000000..a39171db7 --- /dev/null +++ b/verl/experimental/agent_loop/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .agent_loop import AgentLoopBase, AgentLoopManager +from .single_turn_agent_loop import SingleTurnAgentLoop +from .tool_agent_loop import ToolAgentLoop + +_ = [SingleTurnAgentLoop, ToolAgentLoop] + +__all__ = ["AgentLoopBase", "AgentLoopManager"] diff --git a/verl/experimental/agent_loop/agent_loop.py b/verl/experimental/agent_loop/agent_loop.py new file mode 100644 index 000000000..480f6593d --- /dev/null +++ b/verl/experimental/agent_loop/agent_loop.py @@ -0,0 +1,538 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import heapq +import logging +import os +import random +from abc import ABC, abstractmethod +from typing import Any + +import hydra +import numpy as np +import ray +import torch +from cachetools import LRUCache +from omegaconf import DictConfig, OmegaConf +from pydantic import BaseModel +from tensordict import TensorDict +from transformers import AutoTokenizer + +from verl.protocol import DataProto +from verl.single_controller.ray.base import RayWorkerGroup +from verl.utils import hf_tokenizer +from verl.utils.fs import copy_to_local +from verl.utils.rollout_trace import RolloutTraceConfig, rollout_trace_attr, rollout_trace_op +from verl.workers.rollout.async_server import async_server_class + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +class AsyncLLMServerManager: + """ + A class to manage multiple OpenAI compatible LLM servers. This class provides + - Load balance: least requests load balancing + - Sticky session: send multi-turn chat completions to same server for automatic prefix caching + """ + + def __init__(self, config: DictConfig, server_handles: list[ray.actor.ActorHandle], max_cache_size: int = 10000): + """Initialize the AsyncLLMServerManager. + + Args: + config (DictConfig): YAML config. + server_handles (List[ray.actor.ActorHandle]): OpenAI compatible LLM server actor handles. + max_cache_size (int, optional): max cache size for request_id to server mapping. Defaults to 10000. + """ + self.config = config + self.server_handles = server_handles + random.shuffle(self.server_handles) + + # Least requests load balancing + self.weighted_serveres = [[0, (hash(server), server)] for server in server_handles] + heapq.heapify(self.weighted_serveres) + + # LRU cache to map request_id to server + self.request_id_to_server = LRUCache(maxsize=max_cache_size) + + def _choose_server(self, request_id: str) -> ray.actor.ActorHandle: + # TODO: implement server pressure awareness load balancing + if request_id in self.request_id_to_server: + return self.request_id_to_server[request_id] + + server = self.weighted_serveres[0][1][1] + self.weighted_serveres[0][0] += 1 + heapq.heapreplace(self.weighted_serveres, self.weighted_serveres[0]) + self.request_id_to_server[request_id] = server + return server + + @rollout_trace_op + async def generate( + self, + request_id, + *, + prompt_ids: list[int], + sampling_params: dict[str, Any], + ) -> list[int]: + """Generate tokens from prompt ids. + + Args: + request_id (str): request id for sticky session. + prompt_ids (List[int]): List of prompt token ids. + sampling_params (Dict[str, Any]): Sampling parameters for the chat completion. + + Returns: + List[int]: List of generated token ids. + """ + server = self._choose_server(request_id) + output = await server.generate.remote( + request_id=request_id, + prompt_ids=prompt_ids, + sampling_params=sampling_params, + ) + return output + + +class AgentLoopMetrics(BaseModel): + """Agent loop performance metrics.""" + + generate_sequences: float = 0.0 + tool_calls: float = 0.0 + + +class AgentLoopOutput(BaseModel): + """Agent loop output.""" + + prompt_ids: list[int] + response_ids: list[int] + response_mask: list[int] + num_turns: int = 0 + metrics: AgentLoopMetrics + + +# make hydra.utils.instantiate happy +class _DummyConfig: + def __init__(self, config: DictConfig) -> None: + self.config = config + + +class AgentLoopBase(ABC): + """An agent loop takes a input message, chat with OpenAI compatible LLM server and interact with various + environments.""" + + _class_initialized = False + + def __init__( + self, trainer_config: _DummyConfig, server_manager: AsyncLLMServerManager, tokenizer: AutoTokenizer, **kwargs + ): + """Initialize agent loop, each sample will have its own loop instance. + + Args: + trainer_config (_DummyConfig): trainer config. + server_manager (AsyncLLMServerManager): OpenAI compatible LLM server manager. + tokenizer (AutoTokenizer): Tokenizer for tokenize messages. + """ + self.init_class(trainer_config.config, tokenizer, **kwargs) + self.config = trainer_config.config + self.server_manager = server_manager + self.tokenizer = tokenizer + self.loop = asyncio.get_running_loop() + + @classmethod + def init_class(cls, config: DictConfig, tokenizer: AutoTokenizer, **kwargs): + """This is used to do heavy initialization work that should shared across all instances. It's only called once. + + Args: + config (DictConfig): trainer config. + tokenizer (AutoTokenizer): Tokenizer for tokenize messages. + **kwargs: extra kwargs from config file passed in by `hydra.utils.instantiate`. + """ + if cls._class_initialized: + return + cls._class_initialized = True + + @abstractmethod + async def run(self, messages: list[dict[str, Any]], sampling_params: dict[str, Any]) -> AgentLoopOutput: + """Run agent loop to interact with LLM server and environment. + + Args: + messages (List[Dict[str, Any]]): Input messages. + sampling_params (Dict[str, Any]): LLM sampling params. + + Returns: + AgentLoopOutput: Agent loop output. + """ + raise NotImplementedError + + +"""Agent loop registry: key is agent_name, value is a dict of agent loop config +used by hydra.utils.instantiate to initialize agent loop instance. + +https://hydra.cc/docs/advanced/instantiate_objects/overview/ +""" +_agent_loop_registry: dict[str, dict] = {} + + +def register(agent_name: str): + """Register agent loop class.""" + + def decorator(subclass: type[AgentLoopBase]) -> type[AgentLoopBase]: + fqdn = f"{subclass.__module__}.{subclass.__qualname__}" + _agent_loop_registry[agent_name] = {"_target_": fqdn} + return subclass + + return decorator + + +@ray.remote +class AgentLoopWorker: + """Agent loop worker takes a batch of messages and run each message in an agent loop.""" + + def __init__(self, config: DictConfig, server_handles: list[ray.actor.ActorHandle]): + """Initialize agent loop manager. + + Args: + config (DictConfig): YAML config. + server_handles (List[ray.actor.ActorHandle]): OpenAI compatible LLM server actor handles. + """ + self.config = config + self.server_manager = AsyncLLMServerManager(config, server_handles) + + model_path = config.actor_rollout_ref.model.path + self.model_name = "/".join(model_path.split("/")[-2:]) + local_path = copy_to_local(config.actor_rollout_ref.model.path) + self.tokenizer = hf_tokenizer(local_path, trust_remote_code=True) + + agent_loop_config_path = config.actor_rollout_ref.rollout.agent.agent_loop_config_path + if agent_loop_config_path: + agent_loop_configs = OmegaConf.load(agent_loop_config_path) + for agent_loop_config in agent_loop_configs: + _agent_loop_registry[agent_loop_config.name] = agent_loop_config + + trace_config = config.trainer.get("rollout_trace", {}) + trace_config = self.config.actor_rollout_ref.rollout.get("trace", {}) + RolloutTraceConfig.init( + self.config.trainer.project_name, + self.config.trainer.experiment_name, + trace_config.get("backend"), + trace_config.get("token2text", False), + ) + + async def generate_sequences(self, batch: DataProto) -> DataProto: + """Generate sequences from agent loop. + + Args: + batch (DataProto): Input batch. + + Returns: + DataProto: Output batch. + - prompts: [bsz, prompt_length], prompt token ids from dataset. + - responses: [bsz, response_length], output token ids include response tokens + from LLM generation and observation tokens from tool_calls. + - response_mask: [bsz, response_length], 1 for LLM generated tokens, 0 for observation/padding tokens. + - input_ids: [bsz, prompt_length + response_length], whole sequence token ids, including prompt tokens + and response tokens. + - attention_mask: [bsz, prompt_length + response_length], 0 for padding tokens, 1 for other tokens. + - position_ids: [bsz, prompt_length + response_length], incremental position ids. + + For multi-turn conversations: + responses: |<- LLM generation ->|<- tool_calls ->|<- LLM generation ->|<- padding ->| + response_mask: | 1, 1, 1, ..., 1, 1 | 0, 0, .., 0, 0 | 1, 1, 1, ..., 1, 1 | 0, 0, ..., 0| + """ + config = self.config.actor_rollout_ref.rollout + sampling_params = dict( + temperature=config.temperature, + top_p=config.top_p, + repetition_penalty=1.0, + ) + + # override sampling params for validation + if batch.meta_info.get("validate", False): + sampling_params["top_p"] = config.val_kwargs.top_p + sampling_params["temperature"] = config.val_kwargs.temperature + + # by default, we assume it's a single turn agent + if "agent_name" not in batch.non_tensor_batch: + batch.non_tensor_batch["agent_name"] = np.array(["single_turn_agent"] * len(batch), dtype=object) + + tasks = [] + agent_names = batch.non_tensor_batch["agent_name"] + raw_prompts = batch.non_tensor_batch["raw_prompt"] + if "index" in batch.non_tensor_batch: + index = batch.non_tensor_batch["index"] + else: + index = np.arange(len(raw_prompts)) + + trajectory_info = await get_trajectory_info( + batch.meta_info.get("global_steps", -1), index, batch.meta_info.get("validate", False) + ) + + for agent_name, messages, trajectory in zip(agent_names, raw_prompts, trajectory_info, strict=True): + tasks.append( + asyncio.create_task(self._run_agent_loop(agent_name, messages.tolist(), sampling_params, trajectory)) + ) + outputs = await asyncio.gather(*tasks) + + output = self._postprocess(outputs) + return output + + async def _run_agent_loop( + self, + agent_name: str, + messages: list[dict[str, Any]], + sampling_params: dict[str, Any], + trajectory: dict[str, Any], + ) -> AgentLoopOutput: + with rollout_trace_attr( + step=trajectory["step"], + sample_index=trajectory["sample_index"], + rollout_n=trajectory["rollout_n"], + validate=trajectory["validate"], + name="agent_loop", + ): + assert agent_name in _agent_loop_registry, ( + f"Agent loop {agent_name} not registered, registered agent loops: {_agent_loop_registry.keys()}" + ) + + agent_loop_config = _agent_loop_registry[agent_name] + agent_loop = hydra.utils.instantiate( + config=agent_loop_config, + trainer_config=_DummyConfig(config=self.config), + server_manager=self.server_manager, + tokenizer=self.tokenizer, + ) + output = await agent_loop.run(messages, sampling_params) + return output + + def _postprocess(self, inputs: list[AgentLoopOutput]) -> DataProto: + # NOTE: consistent with batch version of generate_sequences in vllm_rollout_spmd.py + # prompts: left pad + # responses: right pad + # input_ids: prompt + response + # attention_mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0] + # position_ids: [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11] + + # prompts + self.tokenizer.padding_side = "left" + outputs = self.tokenizer.pad( + [{"input_ids": input.prompt_ids} for input in inputs], + padding="max_length", + max_length=self.config.actor_rollout_ref.rollout.prompt_length, + return_tensors="pt", + return_attention_mask=True, + ) + prompt_ids, prompt_attention_mask = outputs["input_ids"], outputs["attention_mask"] + + # responses + self.tokenizer.padding_side = "right" + outputs = self.tokenizer.pad( + [{"input_ids": input.response_ids} for input in inputs], + padding="max_length", + max_length=self.config.actor_rollout_ref.rollout.response_length, + return_tensors="pt", + return_attention_mask=True, + ) + response_ids, response_attention_mask = outputs["input_ids"], outputs["attention_mask"] + + # response_mask + outputs = self.tokenizer.pad( + [{"input_ids": input.response_mask} for input in inputs], + padding="max_length", + max_length=self.config.actor_rollout_ref.rollout.response_length, + return_tensors="pt", + return_attention_mask=False, + ) + response_mask = outputs["input_ids"] + assert response_ids.shape == response_mask.shape, ( + f"mismatch in response_ids and response_mask shape: {response_ids.shape} vs {response_mask.shape}" + ) + response_mask = response_mask * response_attention_mask + + input_ids = torch.cat([prompt_ids, response_ids], dim=1) + attention_mask = torch.cat([prompt_attention_mask, response_attention_mask], dim=1) + position_ids = (attention_mask.cumsum(dim=1) - 1) * attention_mask + + batch = TensorDict( + { + "prompts": prompt_ids, # [bsz, prompt_length] + "responses": response_ids, # [bsz, response_length] + "response_mask": response_mask, # [bsz, response_length] + "input_ids": input_ids, # [bsz, prompt_length + response_length] + "attention_mask": attention_mask, # [bsz, prompt_length + response_length] + "position_ids": position_ids, # [bsz, prompt_length + response_length] + }, + batch_size=len(input_ids), + ) + + num_turns = np.array([input.num_turns for input in inputs], dtype=np.int32) + metrics = [input.metrics.model_dump() for input in inputs] + return DataProto(batch=batch, non_tensor_batch={"__num_turns__": num_turns}, meta_info={"metrics": metrics}) + + +async def get_trajectory_info(step, index, validate): + """Get trajectory info. + + Args: + step (int): global steps in the trainer. + index (list): form datastore extra_info.index column. + validate (bool): whether is a validate step. + + Returns: + list: trajectory. + """ + trajectory_info = [] + rollout_n = 0 + for i in range(len(index)): + if i > 0 and index[i - 1] == index[i]: + rollout_n += 1 + else: + rollout_n = 0 + trajectory_info.append({"step": step, "sample_index": index[i], "rollout_n": rollout_n, "validate": validate}) + return trajectory_info + + +class AgentLoopManager: + """Agent loop manager that manages a group of agent loop workers.""" + + def __init__(self, config: DictConfig, worker_group: RayWorkerGroup): + """Initialize agent loop manager. + + Args: + config (DictConfig): trainer config. + worker_group (RayWorkerGroup): ActorRolloutRef worker group. + """ + self.config = config + self.worker_group = worker_group + + self._initialize_llm_servers() + self._init_agent_loop_workers() + + # Initially we're in sleep mode. + self.sleep() + + def _initialize_llm_servers(self): + self.rollout_tp_size = self.config.actor_rollout_ref.rollout.tensor_model_parallel_size + self.rollout_dp_size = self.worker_group.world_size // self.rollout_tp_size + + register_center = ray.get_actor(f"{self.worker_group.name_prefix}_register_center") + workers_info = ray.get(register_center.get_worker_info.remote()) + assert len(workers_info) == self.worker_group.world_size + + self.async_llm_servers = [None] * self.rollout_dp_size + self.server_addresses = [None] * self.rollout_dp_size + + if self.config.actor_rollout_ref.rollout.agent.custom_async_server: + server_class = async_server_class( + rollout_backend=self.config.actor_rollout_ref.rollout.name, + rollout_backend_module=self.config.actor_rollout_ref.rollout.agent.custom_async_server.path, + rollout_backend_class=self.config.actor_rollout_ref.rollout.agent.custom_async_server.name, + ) + else: + server_class = async_server_class(rollout_backend=self.config.actor_rollout_ref.rollout.name) + + # Start all server instances, restart if address already in use. + unready_dp_ranks = set(range(self.rollout_dp_size)) + while len(unready_dp_ranks) > 0: + servers = { + rollout_dp_rank: server_class.options( + # make sure AsyncvLLMServer colocates with its corresponding workers + scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy( + node_id=workers_info[rollout_dp_rank * self.rollout_tp_size], + soft=False, + ), + name=f"async_llm_server_{rollout_dp_rank}", + ).remote(self.config, self.rollout_dp_size, rollout_dp_rank, self.worker_group.name_prefix) + for rollout_dp_rank in unready_dp_ranks + } + + for rollout_dp_rank, server in servers.items(): + try: + address = ray.get(server.get_server_address.remote()) + self.server_addresses[rollout_dp_rank] = address + self.async_llm_servers[rollout_dp_rank] = server + unready_dp_ranks.remove(rollout_dp_rank) + except Exception: + ray.kill(server) + print(f"rollout server {rollout_dp_rank} failed, maybe address already in use, restarting...") + + # All server instances are ready, init AsyncLLM engine. + ray.get([server.init_engine.remote() for server in self.async_llm_servers]) + + def _init_agent_loop_workers(self): + self.agent_loop_workers = [] + for i in range(self.config.actor_rollout_ref.rollout.agent.num_workers): + self.agent_loop_workers.append( + AgentLoopWorker.options( + name=f"agent_loop_worker_{i}", + ).remote(self.config, self.async_llm_servers) + ) + + def generate_sequences(self, prompts: DataProto) -> DataProto: + """Split input batch and dispatch to agent loop workers. + + Args: + prompts (DataProto): Input batch. + + Returns: + DataProto: Output batch. + """ + if self.config.actor_rollout_ref.rollout.free_cache_engine: + self.wake_up() + chunkes = prompts.chunk(len(self.agent_loop_workers)) + outputs = ray.get( + [ + worker.generate_sequences.remote(chunk) + for worker, chunk in zip(self.agent_loop_workers, chunkes, strict=True) + ] + ) + output = DataProto.concat(outputs) + if self.config.actor_rollout_ref.rollout.free_cache_engine: + self.sleep() + + # calculate performance metrics + metrics = [output.meta_info["metrics"] for output in outputs] # List[List[Dict[str, str]]] + timing = self._performance_metrics(metrics, output) + + output.meta_info = {"timing": timing} + return output + + def _performance_metrics(self, metrics: list[list[dict[str, str]]], output: DataProto) -> dict[str, float]: + timing = {} + t_generate_sequences = np.array([metric["generate_sequences"] for chunk in metrics for metric in chunk]) + t_tool_calls = np.array([metric["tool_calls"] for chunk in metrics for metric in chunk]) + timing["agent_loop/generate_sequences/min"] = t_generate_sequences.min() + timing["agent_loop/generate_sequences/max"] = t_generate_sequences.max() + timing["agent_loop/generate_sequences/mean"] = t_generate_sequences.mean() + timing["agent_loop/tool_calls/min"] = t_tool_calls.min() + timing["agent_loop/tool_calls/max"] = t_tool_calls.max() + timing["agent_loop/tool_calls/mean"] = t_tool_calls.mean() + + # batch sequence generation is bounded by the slowest sample + slowest = np.argmax(t_generate_sequences + t_tool_calls) + attention_mask = output.batch["attention_mask"][slowest] + prompt_length = output.batch["prompts"].shape[1] + timing["agent_loop/slowest/generate_sequences"] = t_generate_sequences[slowest] + timing["agent_loop/slowest/tool_calls"] = t_tool_calls[slowest] + timing["agent_loop/slowest/prompt_length"] = attention_mask[:prompt_length].sum().item() + timing["agent_loop/slowest/response_length"] = attention_mask[prompt_length:].sum().item() + + return timing + + def wake_up(self): + """Wake up all rollout server instances.""" + ray.get([server.wake_up.remote() for server in self.async_llm_servers]) + + def sleep(self): + """Sleep all rollout server instances.""" + ray.get([server.sleep.remote() for server in self.async_llm_servers]) diff --git a/verl/experimental/agent_loop/single_turn_agent_loop.py b/verl/experimental/agent_loop/single_turn_agent_loop.py new file mode 100644 index 000000000..411388e73 --- /dev/null +++ b/verl/experimental/agent_loop/single_turn_agent_loop.py @@ -0,0 +1,55 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os +from typing import Any +from uuid import uuid4 + +from verl.experimental.agent_loop.agent_loop import AgentLoopBase, AgentLoopOutput, register +from verl.utils.profiler import simple_timer + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +@register("single_turn_agent") +class SingleTurnAgentLoop(AgentLoopBase): + """Naive agent loop that only do single turn chat completion.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.prompt_length = self.config.actor_rollout_ref.rollout.prompt_length + self.response_length = self.config.actor_rollout_ref.rollout.response_length + + async def run(self, messages: list[dict[str, Any]], sampling_params: dict[str, Any]) -> AgentLoopOutput: + metrics = {} + request_id = uuid4().hex + prompt_ids = await self.loop.run_in_executor( + None, lambda: self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=True) + ) + + with simple_timer("generate_sequences", metrics): + response_ids = await self.server_manager.generate( + request_id=request_id, prompt_ids=prompt_ids, sampling_params=sampling_params + ) + response_mask = [1] * len(response_ids) + + output = AgentLoopOutput( + prompt_ids=prompt_ids, + response_ids=response_ids[: self.response_length], + response_mask=response_mask[: self.response_length], + num_turns=2, + metrics=metrics, + ) + return output diff --git a/verl/experimental/agent_loop/tool_agent_loop.py b/verl/experimental/agent_loop/tool_agent_loop.py new file mode 100644 index 000000000..3437c0be5 --- /dev/null +++ b/verl/experimental/agent_loop/tool_agent_loop.py @@ -0,0 +1,166 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import json +import logging +import os +from typing import Any +from uuid import uuid4 + +from verl.experimental.agent_loop.agent_loop import AgentLoopBase, AgentLoopOutput, register +from verl.experimental.agent_loop.tool_parser import FunctionCall, ToolParser +from verl.tools.utils.tool_registry import initialize_tools_from_config +from verl.utils.profiler import simple_timer +from verl.utils.rollout_trace import rollout_trace_op + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +@register("tool_agent") +class ToolAgentLoop(AgentLoopBase): + @classmethod + def init_class(cls, config, tokenizer, **kwargs): + if cls._class_initialized: + return + cls._class_initialized = True + print("Performing class-level ToolAgentLoop initialization") + + # Initialize tools from config file + cls.tokenizer = tokenizer + cls.max_user_turns = config.actor_rollout_ref.rollout.multi_turn.max_user_turns + cls.max_assistant_turns = config.actor_rollout_ref.rollout.multi_turn.max_assistant_turns + cls.max_parallel_calls = config.actor_rollout_ref.rollout.multi_turn.max_parallel_calls + cls.max_tool_response_length = config.actor_rollout_ref.rollout.multi_turn.max_tool_response_length + cls.tool_response_truncate_side = config.actor_rollout_ref.rollout.multi_turn.tool_response_truncate_side + tool_config_path = config.actor_rollout_ref.rollout.multi_turn.tool_config_path + tool_list = initialize_tools_from_config(tool_config_path) if tool_config_path else [] + cls.tools = {tool.name: tool for tool in tool_list} + cls.tool_schemas = [tool.tool_schema.model_dump(exclude_unset=True, exclude_none=True) for tool in tool_list] + cls.tool_parser = ToolParser.get_tool_parser(config.actor_rollout_ref.rollout.multi_turn.format, cls.tokenizer) + print(f"Initialized tools: {cls.tools}") + + cls.prompt_length = config.actor_rollout_ref.rollout.prompt_length + cls.response_length = config.actor_rollout_ref.rollout.response_length + cls.system_prompt = tokenizer.apply_chat_template([{}], add_generation_prompt=False, tokenize=True) + + @rollout_trace_op + async def run(self, messages: list[dict[str, Any]], sampling_params: dict[str, Any]) -> AgentLoopOutput: + metrics = {} + request_id = uuid4().hex + prompt_ids = await self.loop.run_in_executor( + None, + lambda: self.tokenizer.apply_chat_template( + messages, tools=self.tool_schemas, add_generation_prompt=True, tokenize=True + ), + ) + response_mask = [] + + user_turns, assistant_turns = 0, 0 + while True: + with simple_timer("generate_sequences", metrics): + response_ids = await self.server_manager.generate( + request_id=request_id, prompt_ids=prompt_ids, sampling_params=sampling_params + ) + prompt_ids += response_ids + response_mask += [1] * len(response_ids) + assistant_turns += 1 + + # reach max response length + if len(response_mask) >= self.response_length: + break + + # reach max assistant turns + if self.max_assistant_turns and assistant_turns >= self.max_assistant_turns: + break + + # reach max user turns + if self.max_user_turns and user_turns >= self.max_user_turns: + break + + # no tool calls + _, tool_calls = await self.tool_parser.extract_tool_calls(response_ids) + if not tool_calls: + break + + # call tools + tasks = [] + for tool_call in tool_calls[: self.max_parallel_calls]: + tasks.append(self._call_tool(tool_call)) + with simple_timer("tool_calls", metrics): + tool_responses = await asyncio.gather(*tasks) + if any(isinstance(item, Exception) for item in tool_responses): + break + + # append tool_response_ids + tool_response_ids = await self.loop.run_in_executor( + None, + lambda messages=tool_responses: self.tokenizer.apply_chat_template( + messages, add_generation_prompt=True, tokenize=True + ), + ) + tool_response_ids = tool_response_ids[len(self.system_prompt) :] + + # NOTE: last turn should not be user turn, or the EOS token reward + # can't be propagated to previous token in GAE. + if len(response_mask) + len(tool_response_ids) >= self.response_length: + break + + prompt_ids += tool_response_ids + response_mask += [0] * len(tool_response_ids) + user_turns += 1 + + response_ids = prompt_ids[-len(response_mask) :] + prompt_ids = prompt_ids[: len(prompt_ids) - len(response_mask)] + + output = AgentLoopOutput( + prompt_ids=prompt_ids, + response_ids=response_ids[: self.response_length], + response_mask=response_mask[: self.response_length], + num_turns=user_turns + assistant_turns + 1, + metrics=metrics, + ) + return output + + async def _call_tool(self, tool_call: FunctionCall) -> dict[str, str]: + """Call tool and return tool response.""" + tool, instance_id = None, None + try: + # TODO: append malformed tool_call to the prompt: invalid function name or arguments + tool_name = tool_call.name + tool_args = json.loads(tool_call.arguments) + tool = self.tools[tool_name] + + instance_id = await tool.create() + tool_response, _, _ = await tool.execute(instance_id, tool_args) + except Exception as e: + logger.exception(f"Error when executing tool: {e}") + return e + finally: + if tool and instance_id: + await tool.release(instance_id) + + if len(tool_response) > self.max_tool_response_length: + if self.tool_response_truncate_side == "left": + tool_response = tool_response[: self.max_tool_response_length] + "...(truncated)" + elif self.tool_response_truncate_side == "right": + tool_response = "(truncated)..." + tool_response[-self.max_tool_response_length :] + else: + length = self.max_tool_response_length // 2 + tool_response = tool_response[:length] + "...(truncated)..." + tool_response[-length:] + + return { + "role": "tool", + "content": tool_response, + } diff --git a/verl/experimental/agent_loop/tool_parser.py b/verl/experimental/agent_loop/tool_parser.py new file mode 100644 index 000000000..5b4de4a8e --- /dev/null +++ b/verl/experimental/agent_loop/tool_parser.py @@ -0,0 +1,106 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import json +import logging +import os +from abc import ABC, abstractmethod + +import regex as re +from pydantic import BaseModel + +from verl.utils.rollout_trace import rollout_trace_op + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +class FunctionCall(BaseModel): + arguments: str + """ + The arguments to call the function with, as generated by the model in JSON + format. Note that the model does not always generate valid JSON, and may + hallucinate parameters not defined by your function schema. Validate the + arguments in your code before calling your function. + """ + + name: str + """The name of the function to call.""" + + +class ToolParser(ABC): + _registry: dict[str, type["ToolParser"]] = {} + + def __init__(self, tokenizer) -> None: + self.tokenizer = tokenizer + + @abstractmethod + async def extract_tool_calls(self, responses_ids: list[int]) -> tuple[str, list[FunctionCall]]: + """Extract tool calls from the responses. + + Args: + responses_ids (List[int]): The ids of the responses. + + Returns: + Tuple[str, List[FunctionCall]]: Content and extracted tool calls. + """ + raise NotImplementedError + + @classmethod + def get_tool_parser(cls, name: str, tokenizer): + if name not in cls._registry: + raise ValueError(f"Unknown tool parser: {name}") + return cls._registry[name](tokenizer) + + @classmethod + def register(cls, name: str): + def decorator(subclass: type[ToolParser]) -> type[ToolParser]: + cls._registry[name] = subclass + return subclass + + return decorator + + +@ToolParser.register("hermes") +class HermesToolParser(ToolParser): + """Adapted from https://github.com/vllm-project/vllm/blob/v0.9.1/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py""" + + def __init__(self, tokenizer) -> None: + super().__init__(tokenizer) + + self.tool_call_start_token: str = "" + self.tool_call_end_token: str = "" + self.tool_call_regex = re.compile(r"(.*?)", re.DOTALL) + + @rollout_trace_op + async def extract_tool_calls(self, responses_ids: list[int]) -> tuple[str, list[FunctionCall]]: + loop = asyncio.get_running_loop() + text = await loop.run_in_executor(None, self.tokenizer.decode, responses_ids) + if self.tool_call_start_token not in text or self.tool_call_end_token not in text: + return text, [] + + matches = self.tool_call_regex.findall(text) + function_calls = [] + for match in matches: + try: + function_call = json.loads(match) + name, arguments = function_call["name"], function_call["arguments"] + function_calls.append(FunctionCall(name=name, arguments=json.dumps(arguments, ensure_ascii=False))) + except Exception as e: + logger.error(f"Failed to decode tool call: {e}") + + # remaing text exclude tool call tokens + content = self.tool_call_regex.sub("", text) + + return content, function_calls diff --git a/verl/experimental/dataset/__init__.py b/verl/experimental/dataset/__init__.py new file mode 100644 index 000000000..1ce90c5eb --- /dev/null +++ b/verl/experimental/dataset/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/verl/experimental/dataset/sampler.py b/verl/experimental/dataset/sampler.py new file mode 100644 index 000000000..b7b15b422 --- /dev/null +++ b/verl/experimental/dataset/sampler.py @@ -0,0 +1,40 @@ +# Copyright 2025 Amazon.com Inc and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import abstractmethod +from collections.abc import Sized + +from omegaconf import DictConfig +from torch.utils.data import Sampler + +from verl import DataProto + + +class AbstractSampler(Sampler[int]): + """Abstract interface for custom samplers.""" + + @abstractmethod + def __init__( + self, + data_source: Sized, + data_config: DictConfig, + ): + pass + + +class AbstractCurriculumSampler(AbstractSampler): + """Experimental interface for curriculum learning samplers.""" + + @abstractmethod + def update(self, batch: DataProto) -> None: + pass diff --git a/verl/experimental/dynamic_dataset/__init__.py b/verl/experimental/dynamic_dataset/__init__.py new file mode 100644 index 000000000..1ce90c5eb --- /dev/null +++ b/verl/experimental/dynamic_dataset/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/verl/experimental/dynamic_dataset/dynamicgen_dataset.py b/verl/experimental/dynamic_dataset/dynamicgen_dataset.py new file mode 100644 index 000000000..a9532aa03 --- /dev/null +++ b/verl/experimental/dynamic_dataset/dynamicgen_dataset.py @@ -0,0 +1,112 @@ +# Copyright 2025 Amazon.com Inc and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Dataset class that enables dynamic data generation strategies between iterations of training. +This class extends RLHFDataset and uses an AbstractDataGen instance to generate data. + +This is especially useful in settings where proposer model generates new tasks based +on rollout data. +""" + +import logging +from abc import ABC, abstractmethod +from typing import Optional + +import datasets +from omegaconf import DictConfig +from torch.utils.data import Dataset +from transformers import PreTrainedTokenizer, ProcessorMixin + +from verl import DataProto +from verl.utils.dataset import RLHFDataset +from verl.utils.import_utils import load_extern_type + +logger = logging.getLogger(__name__) + + +class AbstractDataGenerator(ABC): + def __init__(self, config: DictConfig): + self.config = config + + @abstractmethod + def generate(self, dataset: Dataset) -> datasets.Dataset: + """ + Generate method must be implemented by subclasses. + Args: + dataset: The dataset to generate from. + Returns: + Processed data or result as implemented by the subclass. + """ + pass + + +class MockDataGenerator(AbstractDataGenerator): + """ + A noop data gen class that only reappends the first datapoint. + This class is useful as a placeholder and testing. + """ + + def __init__(self, config: DictConfig = None): + super().__init__(config) + + def generate(self, dataset: Dataset) -> datasets.Dataset: + print("MockDataGenerator: No operation performed on the dataset.") + return dataset.dataframe.select([0]) + + +class DynamicGenDataset(RLHFDataset): + """ + A dataset class that uses a data generation strategy to process data. + This class extends RLHFDataset and uses an AbstractDataGen instance to generate data. + """ + + def __init__( + self, + data_files: str | list[str], + tokenizer: PreTrainedTokenizer, + config: DictConfig, + processor: Optional[ProcessorMixin] = None, + ): + super().__init__(data_files, tokenizer, config, processor) + self.datagen: AbstractDataGenerator = config.datagen + assert "datagen" in config and config.datagen.get("path", None) is not None, ( + f"datagen path is not set in config: {config}" + ) + # Dynamically load the custom datagen class + datagen_cls = load_extern_type(config.datagen.path, config.datagen.name) + + # Verify that the custom datagen class inherits from AbstractDataGenerator + abs_cls = AbstractDataGenerator + if not issubclass(datagen_cls, abs_cls): + raise TypeError( + f"The custom datagen class '{config.datagen.name}' from '{config.datagen.path}'" + + " must inherit from {abs_cls}" + ) + + self.data_generator = datagen_cls(config.datagen) + self.on_batch_end() + + def append_dataframe(self, new_dataframe: datasets.Dataset): + new_dataframe = self.maybe_filter_out_long_prompts(new_dataframe) + self.dataframe = datasets.concatenate_datasets([self.dataframe, new_dataframe]) + + logger.info(f"new dataset len: {len(self.dataframe)}") + + def on_batch_end(self, batch: DataProto) -> None: + """ + Generate data using the provided data generation strategy. + Note: This method is intended to change the dataset after each training batch. + """ + new_data = self.data_generator.generate(self) + self.append_dataframe(new_data) diff --git a/verl/interactions/__init__.py b/verl/interactions/__init__.py new file mode 100644 index 000000000..b6db0fcef --- /dev/null +++ b/verl/interactions/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/verl/interactions/base.py b/verl/interactions/base.py new file mode 100644 index 000000000..7c5d200ab --- /dev/null +++ b/verl/interactions/base.py @@ -0,0 +1,72 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Optional +from uuid import uuid4 + + +class BaseInteraction: + def __init__(self, config: dict[str, Any]): + self.config = config + self.name: str = config.get("name", "interaction_agent") # More general agent default role name + + async def start_interaction(self, instance_id: Optional[str] = None, **kwargs) -> str: + """Create a tool instance. + + Args: + instance_id: The instance id of the tool. + + Returns: + The instance id of the tool. + """ + if instance_id is None: + return str(uuid4()) + else: + return instance_id + + async def generate_response( + self, instance_id: str, messages: list[dict[str, Any]], **kwargs + ) -> tuple[bool, str, float, dict[str, Any]]: # More clear response generation method + """ + Generates a response for the current turn of interaction. + Returns a tuple containing: + - should_terminate_sequence (bool): True if the interaction sequence should end. + - response_content (str): The textual content of the response. + - current_turn_score (float): The score for this specific turn/response. + - additional_data (dict): Any extra information or metadata. + """ + should_terminate_sequence: bool = False # if True, end rollout + response_content: str = "Your current result seems acceptable." + current_turn_score: float = 0.8 + additional_data: dict[str, Any] = {} + return should_terminate_sequence, response_content, current_turn_score, additional_data + + async def calculate_score(self) -> float: # More clear score calculation method + """ + Calculates a score for the interaction, + potentially considering aspects like partial exposure & in-context task switching. + should be invoke at turn-level + """ + # ...implement the logic to calculate turn-level score... + score = 0.0 + return score + + async def finalize_interaction(self) -> None: # More clear interaction end and resource release method + """ + Finalizes the interaction session and releases any associated state or resources. + Simulates: release state + """ + # ...implement the logic to release state... + pass diff --git a/verl/interactions/gsm8k_interaction.py b/verl/interactions/gsm8k_interaction.py new file mode 100644 index 000000000..365cbb935 --- /dev/null +++ b/verl/interactions/gsm8k_interaction.py @@ -0,0 +1,90 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +from typing import Any, Optional +from uuid import uuid4 + +from verl.utils.reward_score import gsm8k + +from .base import BaseInteraction + +logger = logging.getLogger(__name__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +class Gsm8kInteraction(BaseInteraction): + """A demo interaction for calculating the reward of gsm8k. + + - `start_interaction`: start a interaction instance for a trajectory. + - `generate_response`: generate the response of the user. + - `calculate_score`: calculate the score of the interaction. + - `finalize_interaction`: finalize the interaction instance. + """ + + def __init__(self, config: dict): + super().__init__(config) + self._instance_dict = {} + + async def start_interaction( + self, instance_id: Optional[str] = None, ground_truth: Optional[str] = None, **kwargs + ) -> str: + if instance_id is None: + instance_id = str(uuid4()) + self._instance_dict[instance_id] = { + "response": "", + "ground_truth": ground_truth, + "reward": 0.0, + } + return instance_id + + async def generate_response( + self, instance_id: str, messages: list[dict[str, Any]], **kwargs + ) -> tuple[bool, str, float, dict]: + content = "" + for i in range(len(messages) - 1, -1, -1): + item = messages[i] + if item.get("role") == "user": + content = item.get("content") + break + + if content and content.startswith("#### "): + self._instance_dict[instance_id]["response"] = content + else: + self._instance_dict[instance_id]["response"] = "#### " + (content or "") + + reward = await self.calculate_score(instance_id) + if reward == 1.0: + response = "Your response is correct!" + should_terminate_sequence = True + else: + response = "Your response is incorrect! You need to reflect on your answer and try again." + should_terminate_sequence = False + + return should_terminate_sequence, response, reward, {} + + async def calculate_score(self, instance_id: str, **kwargs) -> float: + return gsm8k.compute_score( + self._instance_dict[instance_id]["response"], + self._instance_dict[instance_id]["ground_truth"], + method="flexible", + format_score=0.0, + score=1.0, + ) + + async def finalize_interaction(self, instance_id: str, **kwargs) -> None: + del self._instance_dict[instance_id] diff --git a/verl/interactions/utils/__init__.py b/verl/interactions/utils/__init__.py new file mode 100644 index 000000000..c4b932b1a --- /dev/null +++ b/verl/interactions/utils/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/verl/interactions/utils/interaction_registry.py b/verl/interactions/utils/interaction_registry.py new file mode 100644 index 000000000..df747af11 --- /dev/null +++ b/verl/interactions/utils/interaction_registry.py @@ -0,0 +1,85 @@ +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib.util +import logging +import os +import sys + +from omegaconf import OmegaConf + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +def get_interaction_class(cls_name): + """Dynamically import and return the interaction class.""" + module_name, class_name = cls_name.rsplit(".", 1) + if module_name not in sys.modules: + spec = importlib.util.find_spec(module_name) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + else: + module = sys.modules[module_name] + + interaction_cls = getattr(module, class_name) + return interaction_cls + + +def initialize_interactions_from_config(interaction_config_file): + """Initialize interactions from configuration file. + + Args: + interaction_config_file: Path to the interaction configuration file. + + Returns: + dict: A dictionary mapping interaction names to BaseInteraction instances. + """ + interaction_config = OmegaConf.load(interaction_config_file) + interaction_map = {} + + for interaction_item in interaction_config.interaction: + cls_name = interaction_item.class_name + interaction_cls = get_interaction_class(cls_name) + + # Extract config and name + config = OmegaConf.to_container(interaction_item.config, resolve=True) + + # Get the interaction name - either from config or derive from class name + name = interaction_item.get("name", None) + if name is None: + # If no name is specified, use the class name as default + class_simple_name = cls_name.split(".")[-1] + # Remove "Interaction" suffix if present, otherwise use full class name + if class_simple_name.endswith("Interaction"): + name = class_simple_name[:-11].lower() # Remove "Interaction" (11 chars) + else: + name = class_simple_name.lower() + + # Check for duplicate names + if name in interaction_map: + raise ValueError(f"Duplicate interaction name '{name}' found. Each interaction must have a unique name.") + + # Inject the name into the config + config["name"] = name + + # Create the interaction instance + interaction = interaction_cls(config=config) + interaction_map[name] = interaction + + logger.info(f"Initialized interaction '{name}' with class '{cls_name}'") + + return interaction_map diff --git a/verl/model_merger/__init__.py b/verl/model_merger/__init__.py new file mode 100644 index 000000000..1ce90c5eb --- /dev/null +++ b/verl/model_merger/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/verl/model_merger/__main__.py b/verl/model_merger/__main__.py new file mode 100644 index 000000000..f3ab5b9c2 --- /dev/null +++ b/verl/model_merger/__main__.py @@ -0,0 +1,73 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This module is used to merge huggingface model and test verl checkpoints from FSDP and Megatron backends. + +To merge FSDP checkpoints: +```sh +python -m verl.model_merger merge \ + --backend fsdp \ + --local_dir checkpoints/verl_fsdp_gsm8k_examples/qwen2_5_0b5_fsdp_saveload/global_step_1/actor \ + --target_dir /path/to/merged_hf_model +``` + +To merge Megatron checkpoints: +```sh +python -m verl.model_merger merge \ + --backend megatron \ + --tie-word-embedding \ + --local_dir checkpoints/verl_megatron_gsm8k_examples/qwen2_5_0b5_megatron_saveload/global_step_1/actor \ + --target_dir /path/to/merged_hf_model +``` + +or use distribtued merge for large models like dpskv3 671B + +```sh +torchrun --nproc_per_node 1 --nnodes 8 --node_rank ${RANK} -m verl.model_merger merge\ + --backend megatron \ + --local_dir ./checkpoints/global_step_1/actor \ + --target_dir /path/to/merged_hf_model +``` + + +For more details, please refer to documentation: +https://verl.readthedocs.io/en/latest/advance/checkpoint.html#convert-fsdp-and-megatron-checkpoints-to-huggingface-format-model +""" + +from .base_model_merger import generate_config_from_args, parse_args + + +def main(): + args = parse_args() + config = generate_config_from_args(args) + print(f"config: {config}") + + if config.backend == "fsdp": + from .fsdp_model_merger import FSDPModelMerger + + merger = FSDPModelMerger(config) + elif config.backend == "megatron": + from .megatron_model_merger import MegatronModelMerger + + merger = MegatronModelMerger(config) + else: + raise NotImplementedError(f"Unknown backend: {config.backend}") + + merger.merge_and_save() + merger.cleanup() + + +if __name__ == "__main__": + main() diff --git a/verl/model_merger/base_model_merger.py b/verl/model_merger/base_model_merger.py new file mode 100644 index 000000000..73ddeb0e1 --- /dev/null +++ b/verl/model_merger/base_model_merger.py @@ -0,0 +1,325 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Optional + +import torch +from accelerate import init_empty_weights +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoModelForTokenClassification, + AutoModelForVision2Seq, + GenerationConfig, +) + +from verl.utils import hf_processor, hf_tokenizer + + +def parse_args(): + parser = argparse.ArgumentParser(description="verl model merger") + subparsers = parser.add_subparsers(dest="operation", required=True, help="Specify 'merge' or 'test' operation.") + + base_op_parser = argparse.ArgumentParser(add_help=False) + base_op_parser.add_argument( + "--backend", type=str, required=True, choices=["fsdp", "megatron"], help="The backend of the model" + ) + base_op_parser.add_argument("--local_dir", type=str, default=None, help="Path to the saved model checkpoints.") + base_op_parser.add_argument( + "--tie-word-embedding", + action="store_true", + help="Whether to tie word embedding weights (currently only Megatron supported)", + ) + base_op_parser.add_argument("--trust-remote-code", action="store_true", help="Whether to trust remote code") + base_op_parser.add_argument( + "--is-value-model", + action="store_true", + help="Whether the model is a value model (currently only Megatron supported)", + ) + base_op_parser.add_argument( + "--use_cpu_initialization", + action="store_true", + help="Whether to use CPU initialization for the model. This is useful for large models that cannot " + "fit into GPU memory during initialization.", + ) + + merge_parser = subparsers.add_parser("merge", parents=[base_op_parser], help="Merge model checkpoints and save.") + merge_parser.add_argument( + "--target_dir", default="tmp", type=str, help="Directory to save the merged huggingface model" + ) + merge_parser.add_argument( + "--hf_upload_path", default=None, type=str, help="Hugging Face repository ID to upload the model" + ) + merge_parser.add_argument( + "--private", action="store_true", help="Whether to upload the model to a private Hugging Face repository" + ) + + test_parser = subparsers.add_parser( + "test", parents=[base_op_parser], help="Test merged model against a reference Hugging Face model" + ) + test_parser.add_argument( + "--test_hf_dir", type=str, required=True, help="Path to the reference Hugging Face model directory for testing" + ) + + args = parser.parse_args() + return args + + +@dataclass +class ModelMergerConfig: + operation: str # 'merge' or 'test' + backend: str + target_dir: Optional[str] = "tmp" + hf_upload_path: Optional[str] = None + private: bool = False + test_hf_dir: Optional[str] = None + tie_word_embedding: bool = False + trust_remote_code: bool = False + is_value_model: bool = False + local_dir: Optional[str] = None + hf_model_config_path: Optional[str] = None + hf_upload: bool = field(init=False) + use_cpu_initialization: bool = False + + def __post_init__(self): + self.hf_upload = self.operation == "merge" and bool(self.hf_upload_path) + if self.operation == "test": + self.target_dir = None + self.hf_upload_path = None + self.private = False + + +def generate_config_from_args(args: argparse.Namespace) -> ModelMergerConfig: + common_config_args = { + "operation": args.operation, + "backend": args.backend, + "tie_word_embedding": args.tie_word_embedding, + "trust_remote_code": args.trust_remote_code, + "is_value_model": args.is_value_model, + "local_dir": args.local_dir, + "hf_model_config_path": os.path.join(args.local_dir, "huggingface"), + "use_cpu_initialization": args.use_cpu_initialization, + } + + if args.operation == "merge": + config = ModelMergerConfig( + **common_config_args, + target_dir=args.target_dir, + hf_upload_path=args.hf_upload_path, + private=args.private, + test_hf_dir=None, + ) + os.makedirs(config.target_dir, exist_ok=True) + elif args.operation == "test": + config = ModelMergerConfig( + **common_config_args, + test_hf_dir=args.test_hf_dir, + # the following args are not used by test operation + target_dir=None, + hf_upload_path=None, + private=False, + ) + else: + raise NotImplementedError(f"Unknown operation: {args.operation}") + return config + + +class BaseModelMerger(ABC): + """ + Abstract base class for merging distributed model checkpoints into HuggingFace format. + + This class provides common functionality for converting model checkpoints from different + distributed training backends (FSDP, Megatron) into standard HuggingFace format that + can be easily loaded and used for inference or further training. + + The merger supports two main operations: + - merge: Convert and save checkpoints to HuggingFace format + - test: Validate merged checkpoints against a reference model + + Args: + config (ModelMergerConfig): Configuration object containing paths, backend type, + and operation parameters. + + Attributes: + config (ModelMergerConfig): The configuration object passed during initialization. + hf_model_config_path (str): Path to the HuggingFace model configuration files. + model_config (PretrainedConfig): Loaded HuggingFace model configuration. + """ + + def __init__(self, config: ModelMergerConfig): + self.config = config + self.hf_model_config_path = config.hf_model_config_path + self.model_config = AutoConfig.from_pretrained( + self.hf_model_config_path, trust_remote_code=self.config.trust_remote_code + ) + + def get_transformers_auto_model_class(self): + if "ForTokenClassification" in self.model_config.architectures[0]: + return AutoModelForTokenClassification + elif "ForCausalLM" in self.model_config.architectures[0]: + return AutoModelForCausalLM + elif "ForConditionalGeneration" in self.model_config.architectures[0]: + return AutoModelForVision2Seq + + raise NotImplementedError(f"Unknown architecture {self.model_config.architectures}") + + def patch_model_generation_config(self, model): + """ + The generation_config created from model config may be different to the pretrained model, + this may lead to error when generating: https://github.com/volcengine/verl/issues/1246 + + This function patch the generation_config created from model config to the pretrained model. + """ + if model.can_generate(): + try: + model.generation_config = GenerationConfig.from_pretrained(self.hf_model_config_path) + except OSError: + print( + f"Warning: Generation config file not found in {self.hf_model_config_path}, using a " + f"generation config created from the model config." + ) + return model + + def save_lora_adapter(self, state_dict: dict[str, torch.Tensor]): + """ + Save lora adapter to safetensors. + + Returns: + lora_path: str, the path to the lora adapter. None if no lora adapter found. + + Note: + This function change the 'state_dict' in place. + """ + lora_params_names = [name for name in state_dict.keys() if "lora_" in name] + + if len(lora_params_names) == 0: + return None + + import json + from typing import OrderedDict + + import peft + from safetensors.torch import save_file + + lora_params = OrderedDict() + target_modules = set() + lora_key = None + + for name in lora_params_names: + lora_key = name.replace(".default.weight", ".weight") + target_modules.add(lora_key.split(".")[-3]) + lora_params[lora_key] = state_dict.pop(name) + + lora_rank = min(lora_params[lora_key].shape[0], lora_params[lora_key].shape[1]) + peft_dict = { + "r": lora_rank, + "lora_alpha": 0, # lora_alpha is not set. An error should be raised to inform the user to set it manually. + "target_modules": list(target_modules), + } + peft_config = peft.LoraConfig(**peft_dict).to_dict() + peft_config["task_type"] = peft_config["task_type"].value if peft_config["task_type"] else None + peft_config["peft_type"] = peft_config["peft_type"].value if peft_config["peft_type"] else None + peft_config["target_modules"] = list(peft_config["target_modules"]) + + lora_path = os.path.join(self.config.target_dir, "lora_adapter") + os.makedirs(lora_path, exist_ok=True) + with open(os.path.join(lora_path, "adapter_config.json"), "w", encoding="utf-8") as f: + json.dump(peft_config, f, ensure_ascii=False, indent=4) + save_file(lora_params, os.path.join(lora_path, "adapter_model.safetensors")) + + for name in list(state_dict.keys()): + key = ( + name.replace("base_model.model.", "") + .replace(".base_layer.weight", ".weight") + .replace(".base_layer.bias", ".bias") + ) + state_dict[key] = state_dict.pop(name) + + return lora_path + + def save_hf_model_and_tokenizer(self, state_dict: dict[str, torch.Tensor]): + auto_model_class = self.get_transformers_auto_model_class() + with init_empty_weights(): + model = auto_model_class.from_config( + self.model_config, torch_dtype=torch.bfloat16, trust_remote_code=self.config.trust_remote_code + ) + model.to_empty(device="cpu") + model = self.patch_model_generation_config(model) + + lora_path = self.save_lora_adapter(state_dict) + if lora_path: + print(f"Saving lora adapter to {lora_path}") + + print(f"Saving model to {self.config.target_dir}") + model.save_pretrained(self.config.target_dir, state_dict=state_dict) + del state_dict + del model + + processor = hf_processor(self.hf_model_config_path, trust_remote_code=self.config.trust_remote_code) + tokenizer = hf_tokenizer(self.hf_model_config_path, trust_remote_code=self.config.trust_remote_code) + if processor is not None: + print(f"Saving processor to {self.config.target_dir}") + processor.save_pretrained(self.config.target_dir) + if tokenizer is not None: + print(f"Saving tokenizer to {self.config.target_dir}") + tokenizer.save_pretrained(self.config.target_dir) + + def upload_to_huggingface(self): + import requests + from huggingface_hub import HfApi + from huggingface_hub.utils import HfHubHTTPError, RepositoryNotFoundError + + api = HfApi() + try: + # Attempt to create repository + api.create_repo(repo_id=self.config.hf_upload_path, private=self.config.private, exist_ok=True) + except HfHubHTTPError as e: + # Handle authentication/API errors + if e.response.status_code == 401: + raise PermissionError( + "Hugging Face authentication failed. Verify your token is valid and has write permissions." + ) from e + elif e.response.status_code == 404: + raise RepositoryNotFoundError(f"Repository path not found: {self.config.hf_upload_path}") from e + else: + raise ConnectionError(f"Failed to create repository ({e.response.status_code}): {e}") from e + except requests.exceptions.ConnectionError as e: + raise ConnectionError("Network connection failed. Check your internet connection.") from e + + try: + # Attempt folder upload + api.upload_folder(folder_path=self.config.target_dir, repo_id=self.config.hf_upload_path, repo_type="model") + except HfHubHTTPError as e: + if e.response.status_code == 401: + raise PermissionError("Authentication failed during upload. Token may have expired.") from e + else: + raise RuntimeError(f"Upload failed ({e.response.status_code}): {e}") from e + except requests.exceptions.ConnectionError as e: + raise ConnectionError("Network interruption during upload. Try again with stable connection.") from e + except OSError as e: + raise FileNotFoundError(f"Local folder error: {self.config.target_dir} - {str(e)}") from e + except Exception as e: + raise RuntimeError(f"Unexpected error during upload: {str(e)}") from e + + @abstractmethod + def merge_and_save(self): + raise NotImplementedError("Subclasses should implement this method") + + @abstractmethod + def cleanup(self): + raise NotImplementedError("Subclasses should implement this method to clean up resources if needed") diff --git a/verl/model_merger/fsdp_model_merger.py b/verl/model_merger/fsdp_model_merger.py new file mode 100644 index 000000000..7853b2b79 --- /dev/null +++ b/verl/model_merger/fsdp_model_merger.py @@ -0,0 +1,265 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path + +import numpy as np +import torch +from torch.distributed._tensor import Placement, Shard + +try: + # for torch 2.5+ + from torch.distributed.tensor import DTensor +except ImportError: + from torch.distributed._tensor import DTensor + +from tqdm import tqdm + +from .base_model_merger import BaseModelMerger + + +class FSDPModelMerger(BaseModelMerger): + """ + Model merger for FSDP (Fully Sharded Data Parallel) checkpoints. + + This class handles the conversion of FSDP distributed checkpoints into HuggingFace format. + FSDP shards model parameters across multiple processes, and this merger reconstructs + the full model by loading and concatenating the sharded parameters from all ranks. + + The merger supports various FSDP configurations including: + - Pure FSDP (single dimension sharding) + - FSDP + DDP (data parallel + fully sharded data parallel) + - DTensor-based sharding with custom device meshes + + Key features: + - Automatic detection of world size from checkpoint filenames + - Support for DTensor and non-DTensor checkpoints + - Parallel loading of checkpoint shards for efficiency + - Validation against reference HuggingFace models + + Example: + To merge FSDP checkpoints: + ```python + config = ModelMergerConfig( + operation="merge", + backend="fsdp", + local_dir="path/to/fsdp/checkpoints", + target_dir="path/to/output" + ) + merger = FSDPModelMerger(config) + merger.merge_and_save() + ``` + """ + + def _get_world_size(self) -> int: + """_summary_ + From FSDP json config file, extract the world size. + + Returns: + int: world size + """ + config_path = Path(self.config.local_dir) / "fsdp_config.json" + if not config_path.exists(): + raise FileNotFoundError(f"Config file {config_path} does not exist.") + + with open(config_path) as f: + config = json.load(f) + + # Extract world size from the config + world_size = config.get("world_size", None) + if world_size is None: + raise ValueError("World size not found in the config file.") + + return world_size + + def _load_rank_zero_state_dict(self, world_size: int) -> dict: + return torch.load( + Path(self.config.local_dir) / f"model_world_size_{world_size}_rank_0.pt", + map_location="cpu", + weights_only=False, + ) + + def _extract_device_mesh_info(self, state_dict: dict, world_size: int) -> tuple[np.ndarray, tuple[str, ...]]: + """ + Retrieves sharding information (device_mesh, mesh_dim_names) from a DTensor in the state_dict. + If no DTensor is found, infers a simple FSDP mesh based on world_size. + """ + pivot_key = sorted(list(state_dict.keys()))[0] + weight = state_dict[pivot_key] + + if isinstance(weight, DTensor): + # get sharding info + device_mesh = weight.device_mesh + mesh = device_mesh.mesh + mesh_dim_names = device_mesh.mesh_dim_names + else: + # for non-DTensor + mesh = np.array([world_size], dtype=np.int64) + mesh_dim_names = ("fsdp",) + + return mesh, mesh_dim_names + + def _calculate_shard_configuration( + self, mesh: np.ndarray, mesh_dim_names: tuple[str, ...] + ) -> tuple[int, tuple[int, ...]]: + """Calculates the total number of shards and the shape of the device mesh.""" + assert mesh_dim_names in (("fsdp",), ("ddp", "fsdp")), f"Unsupported mesh_dim_names {mesh_dim_names}" + + if "tp" in mesh_dim_names: + # TODO: "tp" is not supported yet due to the above assert + total_shards = mesh.shape[-1] * mesh.shape[-2] + mesh_shape = (mesh.shape[-2], mesh.shape[-1]) + else: + total_shards = mesh.shape[-1] + mesh_shape = (mesh.shape[-1],) + + return total_shards, mesh_shape + + def _merge_by_placement(self, tensors: list[torch.Tensor], placement: Placement) -> torch.Tensor: + """Merges a list of tensors based on their DTensor placement""" + if placement.is_replicate(): + return tensors[0] + elif placement.is_partial(): + raise NotImplementedError("Partial placement is not supported yet") + elif placement.is_shard(): + return torch.cat(tensors, dim=placement.dim).contiguous() + + raise NotImplementedError(f"Unsupported placement: {placement}") + + def _load_and_merge_state_dicts( + self, world_size: int, total_shards: int, mesh_shape: tuple[int, ...], mesh_dim_names: tuple[str, ...] + ) -> dict[str, torch.Tensor]: + model_state_dict_lst = [None] * total_shards + + def process_one_shard(rank: int, model_state_dict_lst: list): + model_path = Path(self.config.local_dir) / f"model_world_size_{world_size}_rank_{rank}.pt" + state_dict = torch.load(model_path, map_location="cpu", weights_only=False) + model_state_dict_lst[rank] = state_dict + return state_dict + + with ThreadPoolExecutor(max_workers=min(32, os.cpu_count())) as executor: + futures = [executor.submit(process_one_shard, rank, model_state_dict_lst) for rank in range(total_shards)] + for future in tqdm(futures, desc=f"Loading {total_shards} FSDP shards", total=total_shards): + future.result() + + # Merge state dicts from all shards + state_dict = {} + param_placements: dict[str, list] = {} + + for key in set(model_state_dict_lst[0].keys()): + state_dict[key] = [] + for model_state_shard in model_state_dict_lst: + # add tensor shard in order of rank to state_dict[key] + tensor = model_state_shard.pop(key) + if isinstance(tensor, DTensor): + state_dict[key].append(tensor._local_tensor.bfloat16()) + + placements = tuple(tensor.placements) + # replicated placement at dp dimension can be discarded + if mesh_dim_names[0] in ("dp", "ddp"): + placements = placements[1:] + + if key not in param_placements: + param_placements[key] = placements + else: + assert param_placements[key] == placements + else: + state_dict[key].append(tensor.bfloat16()) + + del model_state_dict_lst + + # Merge tensors + for key in sorted(state_dict): + if not isinstance(state_dict[key], list): + print(f"No need to merge key {key}") + continue + if key in param_placements: + # merge shards + placements: tuple[Shard] = param_placements[key] + if len(mesh_shape) == 1: + # 1-D list, FSDP without TP + assert len(placements) == 1 + shards = state_dict[key] + state_dict[key] = self._merge_by_placement(shards, placements[0]) + else: + # 2-D list, FSDP + TP + raise NotImplementedError("FSDP + TP is not supported yet") + else: + state_dict[key] = torch.cat(state_dict[key], dim=0) + + return state_dict + + def merge_and_save(self): + world_size = self._get_world_size() + rank_zero_state_dict = self._load_rank_zero_state_dict(world_size) + + mesh, mesh_dim_names = self._extract_device_mesh_info(rank_zero_state_dict, world_size) + print(f"Got device mesh {mesh}, mesh_dim_names {mesh_dim_names}") + + total_shards, mesh_shape = self._calculate_shard_configuration(mesh, mesh_dim_names) + print(f"Processing model shards with {total_shards} {mesh_shape} in total") + + merged_state_dict = self._load_and_merge_state_dicts(world_size, total_shards, mesh_shape, mesh_dim_names) + + if self.config.operation == "test": + if not self.config.test_hf_dir: + raise ValueError("test_hf_dir must be provided for test operation") + self._validate_state_dict(merged_state_dict) + elif self.config.operation == "merge": + self.save_hf_model_and_tokenizer(merged_state_dict) + if self.config.hf_upload: + self.upload_to_huggingface() + else: + raise ValueError(f"Unknown operation: {self.config.operation}") + + def _validate_state_dict(self, state_dict: dict[str, torch.Tensor]): + auto_model_class = self.get_transformers_auto_model_class() + + hf_model = auto_model_class.from_pretrained(self.config.test_hf_dir, torch_dtype=torch.bfloat16) + hf_state_dict = hf_model.state_dict() + del hf_model + + hf_model_keys = set(hf_state_dict.keys()) + collected_keys = set(state_dict.keys()) + + missing_keys = hf_model_keys - collected_keys + assert len(missing_keys) == 0, f"Missing keys in collected state dict: {list(sorted(missing_keys))}" + + extra_keys = collected_keys - hf_model_keys + assert len(extra_keys) == 0, f"Extra keys in collected state dict: {list(sorted(extra_keys))}" + + for key in hf_model_keys: + hf_shape = hf_state_dict[key].shape + collected_shape = state_dict[key].shape + assert hf_shape == collected_shape, ( + f"Shape mismatch for key '{key}': original {hf_shape} vs collected {collected_shape}" + ) + + hf_dtype = hf_state_dict[key].dtype + collected_dtype = state_dict[key].dtype + assert hf_dtype == collected_dtype, ( + f"Dtype mismatch for key '{key}': original {hf_dtype} vs collected {collected_dtype}" + ) + + torch.testing.assert_close(hf_state_dict[key], state_dict[key], atol=1e-6, rtol=1e-6) + + print("FSDP checks passed: The merged state_dict matches the hf model saved by FSDPCheckpointManager.") + + def cleanup(self): + """Cleanup temporary files if needed.""" + # FSDP merger does not create temporary files, so no cleanup is needed. + pass diff --git a/verl/model_merger/megatron_model_merger.py b/verl/model_merger/megatron_model_merger.py new file mode 100644 index 000000000..5be281681 --- /dev/null +++ b/verl/model_merger/megatron_model_merger.py @@ -0,0 +1,537 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import warnings +from contextlib import contextmanager +from pathlib import Path +from typing import Any, Callable, ContextManager + +import numpy as np +import torch +import torch.distributed as dist +from accelerate import init_empty_weights +from megatron.core import mpu +from megatron.core.models.gpt.gpt_model import ModelType +from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed +from safetensors.torch import load_file +from transformers import ( + AutoConfig, + PretrainedConfig, +) + +from verl.models.mcore import hf_to_mcore_config +from verl.utils.device import get_device_name, get_nccl_backend, get_torch_device +from verl.utils.megatron.dist_checkpointing import load_dist_checkpointing +from verl.utils.megatron_utils import get_model +from verl.utils.tokenizer import hf_processor, hf_tokenizer + +from .base_model_merger import BaseModelMerger, ModelMergerConfig + + +@contextmanager +def noop_context() -> Any: + yield + + +def get_dynamic_pipeline_shards(layer_num: int, pp_size: int) -> list[int]: + """Calculate the pipeline sharding configuration for Megatron-LM. + + Args: + layer_num: Total number of layers in the model. + pp_size: Number of pipeline parallel ranks. + + Returns: + layer number of each pp rank. Make the sharding of the pipeline as uniform as possible. + """ + if layer_num < pp_size: + raise ValueError(f"layer_num {layer_num} must be greater than pp_size {pp_size}.") + + if pp_size < 1: + raise ValueError(f"pp_size must be at least 1, got {pp_size}.") + if pp_size == 1: + return [layer_num] + + if pp_size == 2: + return [ + layer_num // 2, + layer_num - layer_num // 2, + ] + + middle_size = pp_size - 2 + shards_strategy = [] + for middle_layer_num in range(layer_num): + first_last_layer_num = layer_num - middle_layer_num * middle_size + first_layer_num = first_last_layer_num // 2 + last_layer_num = first_last_layer_num - first_last_layer_num // 2 + if 0 < first_layer_num <= middle_layer_num and 0 < last_layer_num <= middle_layer_num: + shards_strategy.append( + ( + [first_layer_num] + [middle_layer_num] * middle_size + [last_layer_num], + abs(first_layer_num - middle_layer_num), + ) + ) + + # sort by diff of layer_num, to make it as uniform as possible + res = sorted(shards_strategy, key=lambda x: x[1])[0][0] + assert sum(res) == layer_num, f"sum(res)={sum(res)} != layer_num={layer_num}, pp_size={pp_size}" + return res + + +class MegatronModelMerger(BaseModelMerger): + """ + Model merger for Megatron-LM distributed checkpoints. + + This class handles the conversion of Megatron-LM distributed checkpoints into HuggingFace format. + Megatron-LM uses tensor parallelism, pipeline parallelism, and data parallelism to distribute + large language models across multiple GPUs. This merger reconstructs the full model by + loading distributed checkpoints and applying the necessary transformations. + + Key features: + - Support for tensor parallel, pipeline parallel, and data parallel configurations + - Automatic parameter name mapping from Megatron to HuggingFace conventions + - Handling of QKV and gate-up tensor splitting/merging + - Support for tied word embeddings and value models + - Integration with Megatron's distributed checkpointing system + + The merger handles various model architectures and configurations: + - Standard transformer models (GPT-style) + - Models with tied word embeddings + - Value models for reinforcement learning + - Multi-layer attention (MLA) architectures + - Mixture of Experts (MoE) models + + Args: + config (ModelMergerConfig): Configuration object with Megatron-specific settings + including tie_word_embedding and is_value_model flags. + + Example: + To merge Megatron checkpoints: + ```python + config = ModelMergerConfig( + operation="merge", + backend="megatron", + local_dir="path/to/megatron/checkpoints", + target_dir="path/to/output", + tie_word_embedding=True + ) + merger = MegatronModelMerger(config) + merger.merge_and_save() + ``` + """ + + def __init__(self, config: ModelMergerConfig): + super().__init__(config) + # Currently we use only 1 rank to merge the dist_ckpt, we will move to multi-process save shortly afterwards + if "WORLD_SIZE" not in os.environ: + os.environ["RANK"] = "0" + os.environ["LOCAL_RANK"] = "0" + os.environ["WORLD_SIZE"] = "1" + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + + torch.distributed.init_process_group(get_nccl_backend()) + + self.rank = torch.distributed.get_rank() + self.world_size = torch.distributed.get_world_size() + local_rank = os.environ.get("LOCAL_RANK", 0) + get_torch_device().set_device(f"{get_device_name()}:{local_rank}") + + mpu.initialize_model_parallel( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=self.world_size, + virtual_pipeline_model_parallel_size=None, + context_parallel_size=1, + expert_model_parallel_size=1, + ) + model_parallel_cuda_manual_seed(0) + self.hf_config = AutoConfig.from_pretrained( + self.config.hf_model_config_path, trust_remote_code=self.config.trust_remote_code + ) + print(self.hf_config, flush=True) + + self.params_mapping = { + # megatron core gpt model name, huggingface model name + # NOTICE: It's a little bit tricky, when 2 keys have the same prefix, we need to make sure the + # longer key within the containing relationship is processed first. + "embedding.word_embeddings": "model.embed_tokens", + # input layer norm for dpskv3 + "input_layernorm.weight": "input_layernorm.weight", + "input_layernorm.bias": "input_layernorm.bias", + # attn + "self_attention.linear_qkv.layer_norm_weight": "input_layernorm.weight", + "self_attention.linear_qkv.layer_norm_bias": "input_layernorm.bias", + "self_attention.linear_qkv": "self_attn.qkv_proj", + "self_attention.q_layernorm": "self_attn.q_norm", + "self_attention.k_layernorm": "self_attn.k_norm", + "self_attention.linear_proj": "self_attn.o_proj", + # mla + "self_attention.linear_q_proj": "self_attn.q_proj", + "self_attention.linear_q_down_proj": "self_attn.q_a_proj", + "self_attention.linear_q_up_proj.layer_norm_weight": "self_attn.q_a_layernorm.weight", + "self_attention.linear_q_up_proj": "self_attn.q_b_proj", + "self_attention.linear_kv_down_proj": "self_attn.kv_a_proj_with_mqa", + "self_attention.linear_kv_up_proj.layer_norm_weight": "self_attn.kv_a_layernorm.weight", + "self_attention.linear_kv_up_proj": "self_attn.kv_b_proj", + # mlp + "pre_mlp_layernorm": "post_attention_layernorm", + "mlp.linear_fc1.layer_norm_weight": "post_attention_layernorm.weight", + "mlp.linear_fc1.layer_norm_bias": "post_attention_layernorm.bias", + "mlp.linear_fc1": "mlp.gate_up_proj", + "mlp.linear_fc2": "mlp.down_proj", + # moe + "mlp.router.expert_bias": "mlp.gate.e_score_correction_bias", + "mlp.router": "mlp.gate", + "mlp.shared_experts.linear_fc1": "mlp.shared_experts.gate_up_proj", + "mlp.shared_experts.linear_fc2": "mlp.shared_experts.down_proj", + "linear_fc1": "gate_up_proj", + "linear_fc2": "down_proj", + # output + "final_layernorm": "norm", + "output_layer": "lm_head", + } + + if "Qwen2MoeForCausalLM" in self.hf_config.architectures: + self.params_mapping["mlp.shared_experts.linear_fc1"] = "mlp.shared_expert.gate_up_proj" + self.params_mapping["mlp.shared_experts.linear_fc2"] = "mlp.shared_expert.down_proj" + self.params_mapping["mlp.shared_experts.gate_weight"] = "mlp.shared_expert_gate.weight" + + def _load_state_dicts(self, model_ckpt_path: str) -> dict[str, Any]: + """_summary_ + Use Megatron dist_checkpointing to load the model state dicts from the checkpoint directory. + + Args: + model_ckpt_path (str): Path to the model checkpoint directory. + + Returns: + State dict containing the model parameters. + """ + + # init hf config + self.pipeline_shards = get_dynamic_pipeline_shards(self.hf_config.num_hidden_layers, self.world_size) + print(f"Pipeline shards: {self.pipeline_shards}, total layers: {sum(self.pipeline_shards)}") + + tf_config = hf_to_mcore_config( + self.hf_config, + torch.bfloat16, + num_layers_in_first_pipeline_stage=self.pipeline_shards[0] if len(self.pipeline_shards) > 1 else None, + num_layers_in_last_pipeline_stage=self.pipeline_shards[-1] if len(self.pipeline_shards) > 2 else None, + ) + tf_config.use_cpu_initialization = self.config.use_cpu_initialization + tie_word_embeddings = getattr(self.hf_config, "tie_word_embeddings", False) + + # init megatron model + def megatron_model_provider(pre_process, post_process): + from verl.models.mcore import init_mcore_model + + parallel_model = init_mcore_model( + tf_config, + self.hf_config, + pre_process, + post_process, + share_embeddings_and_output_weights=tie_word_embeddings, + value=False, + ) + return parallel_model + + context: Callable[..., ContextManager] = ( + init_empty_weights if self.config.use_cpu_initialization else noop_context + ) + with context(): + whole_model = get_model( + model_provider_func=megatron_model_provider, + model_type=ModelType.encoder_or_decoder, + wrap_with_ddp=False, + transformer_config=tf_config, + ) + + if self.config.use_cpu_initialization: + # convert meta device to empty tensor so it can use `copy_` function + whole_model[0].module = whole_model[0].module.to_empty(device="cpu") + + # load state dicts + sharded_state_dict = {} + for vpp_rank, model in enumerate(whole_model): + key = f"model{vpp_rank}" if len(whole_model) > 1 else "model" + mpu.set_virtual_pipeline_model_parallel_rank(vpp_rank) + sharded_state_dict[key] = model.sharded_state_dict() + model_state_dict = load_dist_checkpointing(sharded_state_dict, model_ckpt_path) + model_state_dict_list = [] + for vpp_rank, model in enumerate(whole_model): + key = f"model{vpp_rank}" if len(whole_model) > 1 else "model" + mpu.set_virtual_pipeline_model_parallel_rank(vpp_rank) + model_state_dict_list.append(model_state_dict[key]) + + return model_state_dict_list + + def _check_megatron_state_key(self, key: str) -> bool: + """ + Checks if the key is a valid Megatron state key. + + Now the model merger only supports keys that start with "decoder/embedding/output_layer" in TransformerLayer. + Shall not use key starts with "model." + """ + if key.startswith("model."): + raise ValueError( + f"Invalid key {key} in Megatron state_dict. Expected keys to start with " + f"'decoder/embedding/output_layer' in TransformerLayer." + ) + + skip_checking_keys = ["embedding.word_embeddings", "output_layer"] + for skip_key in skip_checking_keys: + if skip_key in key: + print(f"skip checking key {key}") + return + + # Exclude extra state keys + if not key.startswith("decoder"): + raise ValueError( + f"Invalid key {key} in Megatron state_dict. Expected keys to start with 'decoder' in TransformerLayer." + ) + + def _split_tensors( + self, key: str, tensor: torch.Tensor, config: PretrainedConfig, is_value_model: bool = False + ) -> list[torch.Tensor]: + """ + Splits a tensor into multiple tensors based on the name. + This is used to handle qkv and gate_up tensors. + """ + if "linear_fc1.weight" in key: + # if the tensor is gate and proj + gate_lst = [] + up_lst = [] + gate, up = tensor.chunk(2) + gate_lst.append(gate) + up_lst.append(up) + gate = torch.cat(gate_lst, dim=0) + up = torch.cat(up_lst, dim=0) + return [gate, up] + elif "self_attention.linear_qkv." in key and "layer_norm" not in key: + # if the tensor is qkv, for each param on tp, split into q, k, v + # concat q, k, v separately. + q_lst, k_lst, v_lst = [], [], [] + assert config.num_attention_heads % config.num_key_value_heads == 0 + num_q_per_kv = config.num_attention_heads // config.num_key_value_heads + assert tensor.shape[0] % (num_q_per_kv + 2) == 0, ( + f"Tensor shape {tensor.shape} is not divisible by {num_q_per_kv + 2}" + ) + kv_size = tensor.shape[0] // (num_q_per_kv + 2) + split_size = [kv_size * num_q_per_kv, kv_size, kv_size] + + num_query_groups_per_partition = config.num_key_value_heads + for chunk in tensor.chunk(num_query_groups_per_partition): + split_size = [ + kv_size * num_q_per_kv // num_query_groups_per_partition, + kv_size // num_query_groups_per_partition, + kv_size // num_query_groups_per_partition, + ] + q, k, v = chunk.split(split_size) + q_lst.append(q) + k_lst.append(k) + v_lst.append(v) + + return [torch.cat(q_lst, dim=0), torch.cat(k_lst, dim=0), torch.cat(v_lst, dim=0)] + else: + return [tensor] + + def _merge_state_dicts(self, model_state_dict_list: list[dict[str, Any]]) -> dict[str, torch.Tensor]: + state_dict = {} + layers_cum = 0 + if self.world_size > 1: + pipeline_cumsum = np.cumsum(self.pipeline_shards) + layers_cum = 0 if self.rank == 0 else pipeline_cumsum[self.rank - 1] + + print(f"{layers_cum=}") + for model_state_dict in model_state_dict_list: + layers_handled = 0 + keys = model_state_dict.keys() + for key in keys: + if "extra_state" in key: + continue + if self.config.tie_word_embedding and ("output_layer" in key): + print("skip lm_head and reward_head loading because of tie_word_embeddings") + continue + + self._check_megatron_state_key(key) + hf_name = self._replace_name(key, self.params_mapping) + assert hf_name is not None, f"Failed to convert layer name [{key}] from megatron to huggingface." + if "model.layers." in hf_name: + local_layer_no = int(hf_name.split(".")[2]) + layers_handled = max(local_layer_no, layers_handled) + global_layer_no = local_layer_no + layers_cum + new_key_list = hf_name.split(".") + new_key_list[2] = str(global_layer_no) + hf_name = ".".join(new_key_list) + else: + warnings.warn(f"hf_name {hf_name} will not be fixed with layer number", stacklevel=2) + + if "mlp.experts." in hf_name and ".weight" in hf_name: + name_prefix, expert_id = hf_name.split(".weight") + for proj in ["gate_up", "down"]: + if f"{proj}_proj" in hf_name: + hf_name = hf_name.replace( + f"mlp.experts.{proj}_proj.weight{expert_id}", + f"mlp.experts.{expert_id}.{proj}_proj.weight", + ) + + tensor = model_state_dict[key] + split_tensor = self._split_tensors( + key, tensor, self.hf_config, is_value_model=self.config.is_value_model + ) + + if len(split_tensor) == 1: + state_dict[hf_name] = split_tensor[0] + elif len(split_tensor) == 3: + # split qkv + for n, d in zip(["q", "k", "v"], split_tensor, strict=True): + state_dict[hf_name.replace("qkv", n)] = d + elif len(split_tensor) == 2: + # split gate up + state_dict[hf_name.replace("gate_up", "gate")] = split_tensor[0] + state_dict[hf_name.replace("gate_up", "up")] = split_tensor[1] + shape_info = ( + split_tensor.shape if isinstance(split_tensor, torch.Tensor) else [t.shape for t in split_tensor] + ) + print(f"converted {key} to {hf_name} with shape {shape_info}") + + layers_cum += layers_handled + 1 # zero based + + return state_dict + + def save_hf_model_and_tokenizer(self, merged_state_dict): + if self.world_size == 1: + return super().save_hf_model_and_tokenizer(merged_state_dict) + + from safetensors.torch import save_file + + layer_num = self.hf_config.num_hidden_layers + + # FIXME: make configurable + saves_per_layer = 1 if layer_num < 30 else 2 + saves_total = saves_per_layer * layer_num + saves_indexes = {} + + # calculate the layer start index and key chunks + layer_this_rank = self.pipeline_shards[self.rank] + pipeline_cumsum = np.cumsum(self.pipeline_shards) + layer_start = 0 if self.rank == 0 else pipeline_cumsum[self.rank - 1] + keys = list(merged_state_dict.keys()) + keys_chunk = np.array_split(np.array(keys), layer_this_rank * saves_per_layer) + numel = 0 + + assert len(keys_chunk) == layer_this_rank * saves_per_layer, ( + f"Expected {len(keys_chunk)} chunks, but got {layer_this_rank * saves_per_layer} for rank {self.rank}." + ) + + # save to model shards manually + target_dir = Path(self.config.target_dir) + for i, keys in enumerate(keys_chunk): + sd_to_save = {k: merged_state_dict[k] for k in keys} + numel += sum([sd_to_save[i].numel() for i in sd_to_save]) + save_idx = layer_start * saves_per_layer + i + save_path = target_dir / f"model-{save_idx + 1:05d}-of-{saves_total:05d}.safetensors" + + save_file(sd_to_save, save_path) + for k in keys: + saves_indexes[k] = str(save_path.name) + + tensor = torch.tensor([numel]).to(get_device_name()) + dist.all_reduce(tensor, op=dist.ReduceOp.SUM) + numel = tensor.cpu().item() + + all_save_indexes = [{} for _ in range(self.world_size)] + dist.all_gather_object(all_save_indexes, saves_indexes) + saves_indexes = {k: v for i in all_save_indexes for k, v in i.items()} + if self.rank == 0: + with open(target_dir / "model.safetensors.index.json", "w") as f: + json.dump( + { + "metadata": { + "total_size": numel, + }, + "weight_map": saves_indexes, + }, + f, + indent=4, + ) + print(f"model saved to {target_dir} with {numel=}") + + self.model_config.save_pretrained(self.config.target_dir) + + processor = hf_processor(self.hf_model_config_path, trust_remote_code=self.config.trust_remote_code) + tokenizer = hf_tokenizer(self.hf_model_config_path, trust_remote_code=self.config.trust_remote_code) + if processor is not None: + print(f"Saving processor to {self.config.target_dir}") + processor.save_pretrained(self.config.target_dir) + if tokenizer is not None: + print(f"Saving tokenizer to {self.config.target_dir}") + tokenizer.save_pretrained(self.config.target_dir) + + def merge_and_save(self): + from verl.utils.megatron_utils import get_dist_checkpoint_path + + model_ckpt_path = get_dist_checkpoint_path(self.config.local_dir) + + model_state_dict = self._load_state_dicts(model_ckpt_path) + merged_state_dict = self._merge_state_dicts(model_state_dict) + del model_state_dict + + if self.config.operation == "test": + if not self.config.test_hf_dir: + raise ValueError("test_hf_dir must be provided for test operation") + self._validate_state_dict(merged_state_dict) + elif self.config.operation == "merge": + self.save_hf_model_and_tokenizer(merged_state_dict) + if self.config.hf_upload: + self.upload_to_huggingface() + else: + raise ValueError(f"Unknown operation: {self.config.operation}") + + def _validate_state_dict(self, state_dict: dict[str, torch.Tensor]): + """ + Compares the merged Megatron state_dict against a reference safetensors model. + Applies necessary name mappings from Megatron to Hugging Face conventions using _replace_name. + """ + ref_state_dict = load_file(Path(self.config.test_hf_dir) / "model.safetensors") + + for name, loaded_weight in state_dict.items(): + # name = self._replace_name(original_name, self.params_mapping) + if not name or name.endswith(".bias") and name not in ref_state_dict: + continue + if "rotary_emb.inv_freq" in name: + continue + if "lm_head.weight" in name: + if self.config.is_value_model or self.config.tie_word_embedding: + continue + if name not in ref_state_dict: + raise RuntimeError(f"key: {name} not exist in state_dict") + param = ref_state_dict[name] + assert loaded_weight.dtype == param.dtype + torch.testing.assert_close(loaded_weight.to("cpu"), param, atol=1e-2, rtol=5e-2) + + def _replace_name(self, megatron_name: str, name_mapping: dict[str, str]) -> str: + for m_name, v_name in name_mapping.items(): + if m_name not in megatron_name: + continue + + megatron_name = megatron_name.replace("decoder", "model") + param_name = megatron_name.replace(m_name, v_name) + + return param_name + + return None # Return None if no mapping found + + def cleanup(self): + torch.distributed.destroy_process_group() diff --git a/verl/models/llama/megatron/checkpoint_utils/llama_loader.py b/verl/models/llama/megatron/checkpoint_utils/llama_loader.py index 912de7822..dafecfdf0 100644 --- a/verl/models/llama/megatron/checkpoint_utils/llama_loader.py +++ b/verl/models/llama/megatron/checkpoint_utils/llama_loader.py @@ -17,6 +17,8 @@ import torch import torch.distributed as dist +from verl.utils.device import get_device_id, get_torch_device + def _megatron_calc_layer_map(config): """Calculate the mapping of global layer_idx to local layer_idx @@ -38,7 +40,9 @@ def _megatron_calc_layer_map(config): for pp_rank_idx in range(pp_size): for virtual_pp_rank_idx in range(virtual_pp_size): - layer_offset = virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model + layer_offset = ( + virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model + ) for layer_idx in range(num_layers_per_model): layer_map[layer_offset + layer_idx] = ( pp_rank_idx, @@ -48,14 +52,17 @@ def _megatron_calc_layer_map(config): return layer_map -def load_state_dict_to_megatron_llama(state_dict, wrapped_models, config, params_dtype, is_value_model=False, tie_word_embeddings=False): +def load_state_dict_to_megatron_llama( + state_dict, wrapped_models, config, params_dtype, is_value_model=False, tie_word_embeddings=False +): """Load merged state_dict to sharded Megatron module in training.""" from megatron.core import DistributedDataParallel as LocalDDP from megatron.core import mpu from megatron.core.transformer.module import Float16Module from torch.nn.parallel import DistributedDataParallel as torchDDP - from verl.utils.megatron_utils import print_rank_0, unwrap_model + from verl.utils.logger import print_rank_0 + from verl.utils.megatron_utils import unwrap_model start_time = time.time() @@ -64,7 +71,9 @@ def _get_gpt_model(model): def fetch_params(module): for param in module.parameters(): - torch.distributed.fetch(param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group()) + torch.distributed.fetch( + param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group() + ) dp_rank = mpu.get_data_parallel_rank() pp_rank = mpu.get_pipeline_model_parallel_rank() @@ -77,12 +86,15 @@ def fetch_params(module): assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0" assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0" - if not isinstance(wrapped_models, (list, tuple)): + if not isinstance(wrapped_models, list | tuple): wrapped_models = list(wrapped_models) assert len(wrapped_models) == virtual_pp_size num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size - assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers, f"num_layers_per_model: {num_layers_per_model} * pp_size: {pp_size} * virtual_pp_size {virtual_pp_size} != config.num_hidden_layers: {config.num_hidden_layers}" + assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers, ( + f"num_layers_per_model: {num_layers_per_model} * pp_size: {pp_size} * virtual_pp_size " + f"{virtual_pp_size} != config.num_hidden_layers: {config.num_hidden_layers}" + ) models = [None] * len(wrapped_models) @@ -138,12 +150,16 @@ def _fetch_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tensor: if gate_name in state_dict and up_name in state_dict: gate_weight = state_dict[gate_name] up_weight = state_dict[up_name] - new_gate_up_weight = torch.empty(config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device()) + new_gate_up_weight = torch.empty( + config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=get_device_id() + ) for i in range(tp_size): intermediate_size_tp = config.intermediate_size // tp_size gate_weight_tp = gate_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] up_weight_tp = up_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] - new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_(torch.cat([gate_weight_tp, up_weight_tp], dim=0)) + new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_( + torch.cat([gate_weight_tp, up_weight_tp], dim=0) + ) tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0) if tensor is not None: @@ -168,7 +184,9 @@ def _fetch_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name) -> torch.Tensor: q_size_tp = config.hidden_size // tp_size kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size total_size = q_size_tp + 2 * kv_size_tp - new_weight_qkv = torch.empty(total_size * tp_size, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device()) + new_weight_qkv = torch.empty( + total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id() + ) for i in range(tp_size): q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] k_part = full_weight_k[i * kv_size_tp : (i + 1) * kv_size_tp] @@ -179,7 +197,9 @@ def _fetch_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name) -> torch.Tensor: q_size_tp = config.hidden_size // tp_size kv_size_tp = hidden_size_per_head total_size = q_size_tp + 2 * kv_size_tp - new_weight_qkv = torch.empty(total_size * tp_size, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device()) + new_weight_qkv = torch.empty( + total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id() + ) for i in range(tp_size): q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head @@ -215,7 +235,9 @@ def _fetch_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name) -> torch.Tensor: for vpp_rank in range(vpp_size): num_layer_vpp_chunk = num_layer_per_pp // vpp_size num_layer_this_model = num_layer_vpp_chunk - offset = vpp_rank * (config.num_hidden_layers // mpu.get_virtual_pipeline_model_parallel_world_size()) + (mpu.get_pipeline_model_parallel_rank() * num_layer_vpp_chunk) + offset = vpp_rank * (config.num_hidden_layers // mpu.get_virtual_pipeline_model_parallel_world_size()) + ( + mpu.get_pipeline_model_parallel_rank() * num_layer_vpp_chunk + ) layer_list.extend(list(range(offset, offset + num_layer_this_model))) else: num_layer_this_model = num_layer_per_pp @@ -291,5 +313,5 @@ def _fetch_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name) -> torch.Tensor: _fetch_tp_shard_tensor(lm_head_weight, "lm_head.weight") dist.barrier() - torch.cuda.empty_cache() + get_torch_device().empty_cache() print_rank_0(f"loading megatron ckpt done, time elapsed {time.time() - start_time}s") diff --git a/verl/models/llama/megatron/checkpoint_utils/llama_loader_depracated.py b/verl/models/llama/megatron/checkpoint_utils/llama_loader_depracated.py index 18a1cf9ce..2f65bc6b1 100644 --- a/verl/models/llama/megatron/checkpoint_utils/llama_loader_depracated.py +++ b/verl/models/llama/megatron/checkpoint_utils/llama_loader_depracated.py @@ -17,6 +17,8 @@ import torch import torch.distributed as dist +from verl.utils.device import get_device_id, get_torch_device + def _megatron_calc_layer_map(config): """Calculate the mapping of global layer_idx to local layer_idx @@ -38,7 +40,9 @@ def _megatron_calc_layer_map(config): for pp_rank_idx in range(pp_size): for virtual_pp_rank_idx in range(virtual_pp_size): - layer_offset = virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model + layer_offset = ( + virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model + ) for layer_idx in range(num_layers_per_model): layer_map[layer_offset + layer_idx] = ( pp_rank_idx, @@ -48,14 +52,17 @@ def _megatron_calc_layer_map(config): return layer_map -def load_state_dict_to_megatron_llama(state_dict, wrapped_models, config, params_dtype, is_value_model=False, tie_word_embeddings=False): +def load_state_dict_to_megatron_llama( + state_dict, wrapped_models, config, params_dtype, is_value_model=False, tie_word_embeddings=False +): """Load merged state_dict to sharded Megatron module in training.""" from megatron.core import DistributedDataParallel as LocalDDP from megatron.core import mpu from megatron.core.transformer.module import Float16Module from torch.nn.parallel import DistributedDataParallel as torchDDP - from verl.utils.megatron_utils import print_rank_0, unwrap_model + from verl.utils.logger import print_rank_0 + from verl.utils.megatron_utils import unwrap_model start_time = time.time() @@ -64,7 +71,9 @@ def _get_gpt_model(model): def broadcast_params(module): for param in module.parameters(): - torch.distributed.broadcast(param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group()) + torch.distributed.broadcast( + param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group() + ) dp_rank = mpu.get_data_parallel_rank() pp_rank = mpu.get_pipeline_model_parallel_rank() @@ -77,12 +86,15 @@ def broadcast_params(module): assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0" assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0" - if not isinstance(wrapped_models, (list, tuple)): + if not isinstance(wrapped_models, list | tuple): wrapped_models = list(wrapped_models) assert len(wrapped_models) == virtual_pp_size num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size - assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers, f"num_layers_per_model: {num_layers_per_model} * pp_size: {pp_size} * virtual_pp_size {virtual_pp_size} != config.num_hidden_layers: {config.num_hidden_layers}" + assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers, ( + f"num_layers_per_model: {num_layers_per_model} * pp_size: {pp_size} * virtual_pp_size " + f"{virtual_pp_size} != config.num_hidden_layers: {config.num_hidden_layers}" + ) models = [None] * len(wrapped_models) @@ -118,7 +130,7 @@ def _broadcast_tensor(tensor, name) -> torch.Tensor: tensor = torch.empty( tensor_shape, dtype=params_dtype, - device=torch.cuda.current_device(), + device=get_device_id(), requires_grad=False, ) if torch.distributed.get_rank() == 0: @@ -157,12 +169,14 @@ def _broadcast_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None sync_tensor = torch.empty( chunk_shape, dtype=params_dtype, - device=torch.cuda.current_device(), + device=get_device_id(), requires_grad=False, ) else: - assert tensor.shape == chunk_shape, f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" - sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False) + assert tensor.shape == chunk_shape, ( + f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" + ) + sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) for i in range(tp_size): if torch.distributed.get_rank() == 0: @@ -202,12 +216,14 @@ def _broadcast_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> t sync_tensor = torch.empty( chunk_shape, dtype=params_dtype, - device=torch.cuda.current_device(), + device=get_device_id(), requires_grad=False, ) else: - assert tensor.shape == chunk_shape, f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" - sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False) + assert tensor.shape == chunk_shape, ( + f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" + ) + sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) for i in range(tp_size): if torch.distributed.get_rank() == 0: @@ -226,12 +242,16 @@ def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tens if torch.distributed.get_rank() == 0: gate_weight = state_dict[gate_name] up_weight = state_dict[up_name] - new_gate_up_weight = torch.empty(config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device()) + new_gate_up_weight = torch.empty( + config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=get_device_id() + ) for i in range(tp_size): intermediate_size_tp = config.intermediate_size // tp_size gate_weight_tp = gate_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] up_weight_tp = up_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] - new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_(torch.cat([gate_weight_tp, up_weight_tp], dim=0)) + new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_( + torch.cat([gate_weight_tp, up_weight_tp], dim=0) + ) tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0) chunk_shape = tensor_chunk[0].shape @@ -250,12 +270,15 @@ def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tens sync_tensor = torch.empty( chunk_shape, dtype=params_dtype, - device=torch.cuda.current_device(), + device=get_device_id(), requires_grad=False, ) else: - assert tensor.shape == chunk_shape, f"rank #{torch.distributed.get_rank() == 0:} tensor {gate_name, up_name} shape {tensor.shape} != {chunk_shape}" - sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False) + assert tensor.shape == chunk_shape, ( + f"rank #{torch.distributed.get_rank() == 0:} tensor {gate_name, up_name} shape " + f"{tensor.shape} != {chunk_shape}" + ) + sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) for i in range(tp_size): if torch.distributed.get_rank() == 0: @@ -283,25 +306,33 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name) -> torch.Tens q_size_tp = config.hidden_size // tp_size kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size total_size = q_size_tp + 2 * kv_size_tp - new_weight_qkv = torch.empty(total_size * tp_size, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device()) + new_weight_qkv = torch.empty( + total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id() + ) for i in range(tp_size): q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] k_part = full_weight_k[i * kv_size_tp : (i + 1) * kv_size_tp] v_part = full_weight_v[i * kv_size_tp : (i + 1) * kv_size_tp] - new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0)) + new_weight_qkv[i * total_size : (i + 1) * total_size].copy_( + torch.cat([q_part, k_part, v_part], dim=0) + ) else: q_size_tp = config.hidden_size // tp_size kv_size_tp = hidden_size_per_head total_size = q_size_tp + 2 * kv_size_tp - new_weight_qkv = torch.empty(total_size * tp_size, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device()) + new_weight_qkv = torch.empty( + total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id() + ) for i in range(tp_size): q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head k_part = full_weight_k[start_idx:end_idx] v_part = full_weight_v[start_idx:end_idx] - new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0)) + new_weight_qkv[i * total_size : (i + 1) * total_size].copy_( + torch.cat([q_part, k_part, v_part], dim=0) + ) tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0) chunk_shape = tensor_chunk[0].shape @@ -320,12 +351,14 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name) -> torch.Tens sync_tensor = torch.empty( chunk_shape, dtype=params_dtype, - device=torch.cuda.current_device(), + device=get_device_id(), requires_grad=False, ) else: - assert tensor.shape == chunk_shape, f"rank #{torch.distributed.get_rank()} tensor {q_name} shape {tensor.shape} != {chunk_shape}" - sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False) + assert tensor.shape == chunk_shape, ( + f"rank #{torch.distributed.get_rank()} tensor {q_name} shape {tensor.shape} != {chunk_shape}" + ) + sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) for i in range(tp_size): if torch.distributed.get_rank() == 0: @@ -421,5 +454,5 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name) -> torch.Tens for wrapped_model in wrapped_models: broadcast_params(wrapped_model) - torch.cuda.empty_cache() + get_torch_device().empty_cache() print_rank_0(f"loading megatron ckpt done, time elapsed {time.time() - start_time}s") diff --git a/verl/models/llama/megatron/checkpoint_utils/llama_saver.py b/verl/models/llama/megatron/checkpoint_utils/llama_saver.py index f5ffb4a8f..595efcde3 100644 --- a/verl/models/llama/megatron/checkpoint_utils/llama_saver.py +++ b/verl/models/llama/megatron/checkpoint_utils/llama_saver.py @@ -21,7 +21,9 @@ from megatron.core.transformer.module import Float16Module from torch.nn.parallel import DistributedDataParallel as torchDDP -from verl.utils.megatron_utils import print_rank_0, unwrap_model +from verl.utils.device import get_device_id, get_torch_device +from verl.utils.logger import print_rank_0 +from verl.utils.megatron_utils import unwrap_model def _megatron_calc_global_rank(tp_rank: int = 0, dp_rank: int = 0, pp_rank: int = 0): @@ -30,7 +32,9 @@ def _megatron_calc_global_rank(tp_rank: int = 0, dp_rank: int = 0, pp_rank: int tp_size = mpu.get_tensor_model_parallel_world_size() dp_size = mpu.get_data_parallel_world_size() pp_size = mpu.get_pipeline_model_parallel_world_size() - assert tp_size * dp_size * pp_size == torch.distributed.get_world_size(), f"{tp_size} x {dp_size} x {pp_size} != {torch.distributed.get_world_size()}" + assert tp_size * dp_size * pp_size == torch.distributed.get_world_size(), ( + f"{tp_size} x {dp_size} x {pp_size} != {torch.distributed.get_world_size()}" + ) # We only support TP-DP-PP grouping, for correctness when resharding return (pp_rank * dp_size + dp_rank) * tp_size + tp_rank @@ -53,7 +57,9 @@ def _megatron_calc_layer_map(config): for pp_rank_idx in range(pp_size): for virtual_pp_rank_idx in range(virtual_pp_size): - layer_offset = virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model + layer_offset = ( + virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model + ) for layer_idx in range(num_layers_per_model): layer_map[layer_offset + layer_idx] = ( pp_rank_idx, @@ -94,7 +100,7 @@ def _get_gpt_model(model): assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0" assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0" - if not isinstance(wrapped_models, (list, tuple)): + if not isinstance(wrapped_models, list | tuple): wrapped_models = list(wrapped_models) assert len(wrapped_models) == virtual_pp_size @@ -105,7 +111,11 @@ def _get_gpt_model(model): for i, wrapped_model in enumerate(wrapped_models): models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module)) - assert len(models[i].model.layers) == num_layers_per_model, "len model layers {} not equal to num_layers_per_model {}".format(len(models[i].model.layers), num_layers_per_model) + assert len(models[i].model.layers) == num_layers_per_model, ( + "len model layers {} not equal to num_layers_per_model {}".format( + len(models[i].model.layers), num_layers_per_model + ) + ) state_dict = dict() @@ -146,7 +156,7 @@ def _broadcast_tensor(tensor, name, src_pp_rank) -> torch.Tensor: weight = torch.empty( tensor_shape, dtype=dtype, - device=torch.cuda.current_device(), + device=get_device_id(), requires_grad=False, ) @@ -175,7 +185,7 @@ def _broadcast_tp_shard_tensor(tensor, name, src_pp_rank, concat_dim=0, mutate_f buffer_tensor = torch.empty( chunk_shape, dtype=dtype, - device=torch.cuda.current_device(), + device=get_device_id(), requires_grad=False, ) @@ -215,7 +225,7 @@ def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name, src_pp_rank) buffer_tensor = torch.empty( chunk_shape, dtype=dtype, - device=torch.cuda.current_device(), + device=get_device_id(), requires_grad=False, ) @@ -264,7 +274,7 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, src_pp_rank): buffer_tensor = torch.empty( chunk_shape, dtype=dtype, - device=torch.cuda.current_device(), + device=get_device_id(), requires_grad=False, ) @@ -316,7 +326,7 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, src_pp_rank): state_dict[v_name] = torch.cat(v_weight_list, dim=0) # empty cache before collecting weights - torch.cuda.empty_cache() + get_torch_device().empty_cache() # Embeddings # ------------------- if dp_rank == 0: @@ -403,7 +413,9 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, src_pp_rank): src_pp_rank=pp_size - 1, ) _broadcast_tensor( - gpt_model_module.reward_head.weight if pp_rank == pp_size - 1 and getattr(gpt_model_module, "reward_weight", None) is not None else None, + gpt_model_module.reward_head.weight + if pp_rank == pp_size - 1 and getattr(gpt_model_module, "reward_weight", None) is not None + else None, "reward_head.weight", src_pp_rank=pp_size - 1, ) @@ -417,7 +429,7 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, src_pp_rank): dist.barrier() - torch.cuda.empty_cache() + get_torch_device().empty_cache() if torch.distributed.get_rank() == 0: if dtype not in [torch.float16, torch.bfloat16, torch.float32]: print(f'Unknown/unsupported dtype to save: {dtype}"') diff --git a/verl/models/llama/megatron/layers/__init__.py b/verl/models/llama/megatron/layers/__init__.py index c0e396f49..352bc5608 100644 --- a/verl/models/llama/megatron/layers/__init__.py +++ b/verl/models/llama/megatron/layers/__init__.py @@ -22,4 +22,13 @@ from .parallel_mlp import ParallelLlamaMLP from .parallel_rmsnorm import ParallelLlamaRMSNorm -__all__ = ["LinearForLastLayer", "MergedColumnParallelLinear", "QKVParallelLinear", "ParallelLlamaAttention", "ParallelLlamaDecoderLayer", "ParallelLlamaDecoderLayerRmPad", "ParallelLlamaMLP", "ParallelLlamaRMSNorm"] +__all__ = [ + "LinearForLastLayer", + "MergedColumnParallelLinear", + "QKVParallelLinear", + "ParallelLlamaAttention", + "ParallelLlamaDecoderLayer", + "ParallelLlamaDecoderLayerRmPad", + "ParallelLlamaMLP", + "ParallelLlamaRMSNorm", +] diff --git a/verl/models/llama/megatron/layers/parallel_attention.py b/verl/models/llama/megatron/layers/parallel_attention.py index 5d909dbe9..e8aacbdb7 100644 --- a/verl/models/llama/megatron/layers/parallel_attention.py +++ b/verl/models/llama/megatron/layers/parallel_attention.py @@ -19,7 +19,7 @@ # limitations under the License. import math -from typing import Optional, Tuple +from typing import Optional import torch import torch.nn.functional as F @@ -46,7 +46,9 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): self.register_buffer("inv_freq", inv_freq, persistent=False) # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache(seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()) + self._set_cos_sin_cache( + seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len @@ -99,7 +101,9 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len if seq_len > self.max_position_embeddings: - base = self.base * ((self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)) ** (self.dim / (self.dim - 2)) + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) @@ -119,7 +123,9 @@ def __init__(self, dim, config, max_position_embeddings=2048, base=10000, device self.factor = config.rope_scaling["factor"] # `8` in the original implementation self.high_freq_factor = config.rope_scaling["high_freq_factor"] # `1` in the original implementation self.low_freq_factor = config.rope_scaling["low_freq_factor"] # `4` in the original implementation - self.old_context_len = config.rope_scaling["original_max_position_embeddings"] # `8192` in the original implementation + self.old_context_len = config.rope_scaling[ + "original_max_position_embeddings" + ] # `8192` in the original implementation low_freq_wavelen = self.old_context_len / self.low_freq_factor high_freq_wavelen = self.old_context_len / self.high_freq_factor @@ -128,7 +134,9 @@ def __init__(self, dim, config, max_position_embeddings=2048, base=10000, device # wavelen < high_freq_wavelen: do nothing; wavelen > low_freq_wavelen: divide by factor inv_freq_llama = torch.where(wavelen > low_freq_wavelen, self.inv_freq / self.factor, self.inv_freq) # otherwise: interpolate between the two, using a smooth factor - smooth_factor = (self.old_context_len / wavelen - self.low_freq_factor) / (self.high_freq_factor - self.low_freq_factor) + smooth_factor = (self.old_context_len / wavelen - self.low_freq_factor) / ( + self.high_freq_factor - self.low_freq_factor + ) smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / self.factor + smooth_factor * inv_freq_llama is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen) inv_freq = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama) @@ -136,7 +144,9 @@ def __init__(self, dim, config, max_position_embeddings=2048, base=10000, device self.register_buffer("inv_freq", inv_freq, persistent=False) # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache(seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()) + self._set_cos_sin_cache( + seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) def rotate_half(x): @@ -183,15 +193,23 @@ def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig): # assign values after tp tp_size = mpu.get_tensor_model_parallel_world_size() - assert self.num_heads % tp_size == 0, f"num_head must be divisible by tp_size. Got num_head={self.num_heads}, tp_size={tp_size}" - assert self.num_key_value_heads % tp_size == 0, f"num_key_value_heads must be divisible by tp_size. Got num_key_value_heads={self.num_key_value_heads}, tp_size={tp_size}" + assert self.num_heads % tp_size == 0, ( + f"num_head must be divisible by tp_size. Got num_head={self.num_heads}, tp_size={tp_size}" + ) + assert self.num_key_value_heads % tp_size == 0, ( + f"num_key_value_heads must be divisible by tp_size. Got num_key_value_heads=" + f"{self.num_key_value_heads}, tp_size={tp_size}" + ) self.num_heads_per_tp = self.num_heads // tp_size self.num_key_value_heads_per_tp = self.num_key_value_heads // tp_size self.hidden_size_per_tp = self.hidden_size // tp_size if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`: {self.num_heads}).") + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and " + f"`num_heads`: {self.num_heads})." + ) column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() row_kwargs = tp_utils.get_default_kwargs_for_row_parallel_linear() @@ -272,7 +290,7 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() qkv = self.qkv_proj(hidden_states)[0] query_states, key_states, value_states = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1) @@ -291,11 +309,16 @@ def forward( attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if attn_weights.size() != (bsz, self.num_heads_per_tp, q_len, kv_seq_len): - raise ValueError(f"Attention weights should be of size {(bsz, self.num_heads_per_tp, q_len, kv_seq_len)}, but is {attn_weights.size()}") + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads_per_tp, q_len, kv_seq_len)}, " + f"but is {attn_weights.size()}" + ) if attention_mask is not None: if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError(f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}") + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) attn_weights = attn_weights + attention_mask # upcast attention to fp32 @@ -303,7 +326,10 @@ def forward( attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (bsz, self.num_heads_per_tp, q_len, self.head_dim): - raise ValueError(f"`attn_output` should be of size {(bsz, self.num_heads_per_tp, q_len, self.head_dim)}, but is {attn_output.size()}") + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads_per_tp, q_len, self.head_dim)}, " + f"but is {attn_output.size()}" + ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size_per_tp) @@ -342,8 +368,12 @@ def apply_rotary_pos_emb_rmpad(q, k, cos, sin, position_ids, indices, sequence_l # use flash-attn rotary embeddings with rmpad # cos/sin shoudl be: (seq_length, rotary_dim / 2) def apply_rotary_pos_emb_rmpad_flash(q, k, cos, sin, cu_seqlens, max_seqlen): - q_embed = apply_rotary_emb(q, cos, sin, interleaved=False, inplace=False, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) - k_embed = apply_rotary_emb(k, cos, sin, interleaved=False, inplace=False, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + q_embed = apply_rotary_emb( + q, cos, sin, interleaved=False, inplace=False, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen + ) + k_embed = apply_rotary_emb( + k, cos, sin, interleaved=False, inplace=False, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen + ) return q_embed, k_embed @@ -363,7 +393,9 @@ def forward( total_nnz = total_nnz * mpu.get_tensor_model_parallel_world_size() qkv = self.qkv_proj(hidden_states)[0] - query_states, key_states, value_states = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1) # (total_nnz, 1, hidden_size) + query_states, key_states, value_states = qkv.split( + [self.q_size, self.k_size, self.v_size], dim=-1 + ) # (total_nnz, 1, hidden_size) if self.megatron_config.sequence_parallel: sequence_parallel_pad = total_nnz - cu_seqlens[-1] @@ -381,8 +413,11 @@ def forward( cos, sin = self.rotary_emb(value_states, seq_len=sequence_length) cos, sin = cos[:, : cos.shape[1] // 2], sin[:, : sin.shape[1] // 2] # flash attn only needs half - query_states, key_states = apply_rotary_pos_emb_rmpad_flash(query_states, key_states, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen_in_batch) - # query_states, key_states = apply_rotary_pos_emb_rmpad(query_states, key_states, cos, sin, position_ids, indices, + query_states, key_states = apply_rotary_pos_emb_rmpad_flash( + query_states, key_states, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen_in_batch + ) + # query_states, key_states = apply_rotary_pos_emb_rmpad(query_states, key_states, cos, sin, + # position_ids, indices, # TODO: llama does not have dropout in the config?? # It is recommended to use dropout with FA according to the docs diff --git a/verl/models/llama/megatron/layers/parallel_decoder.py b/verl/models/llama/megatron/layers/parallel_decoder.py index 3f74b69c0..f46e9457c 100644 --- a/verl/models/llama/megatron/layers/parallel_decoder.py +++ b/verl/models/llama/megatron/layers/parallel_decoder.py @@ -18,7 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple +from typing import Optional import torch from megatron.core import ModelParallelConfig @@ -49,7 +49,7 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` @@ -119,7 +119,7 @@ def forward( indices: torch.Tensor = None, cu_seqlens: int = None, max_seqlen_in_batch: int = None, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states # (total_nnz // sp, 1, hidden_size) hidden_states = self.input_layernorm(hidden_states) diff --git a/verl/models/llama/megatron/modeling_llama_megatron.py b/verl/models/llama/megatron/modeling_llama_megatron.py index f60112a7e..ed5022e0c 100644 --- a/verl/models/llama/megatron/modeling_llama_megatron.py +++ b/verl/models/llama/megatron/modeling_llama_megatron.py @@ -19,7 +19,7 @@ # limitations under the License. """PyTorch LLaMA model with Megatron-style acceleration.""" -from typing import Optional, Tuple, Union +from typing import Optional import torch import torch.utils.checkpoint @@ -88,9 +88,13 @@ def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig): if megatron_config is not None: assert embedding_kwargs.get("config", False), "must have ModelParallelConfig" tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config) - self.embed_tokens = tensor_parallel.VocabParallelEmbedding(num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs) + self.embed_tokens = tensor_parallel.VocabParallelEmbedding( + num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs + ) - self.layers = nn.ModuleList([ParallelLlamaDecoderLayer(config, megatron_config) for _ in range(config.num_hidden_layers)]) + self.layers = nn.ModuleList( + [ParallelLlamaDecoderLayer(config, megatron_config) for _ in range(config.num_hidden_layers)] + ) self.norm = ParallelLlamaRMSNorm(config, megatron_config) # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask @@ -107,8 +111,12 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_em if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(inputs_embeds.device) - combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( + inputs_embeds.device + ) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) return combined_attention_mask @@ -117,7 +125,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: + ) -> tuple | BaseModelOutputWithPast: """ Args: @@ -176,7 +184,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: + ) -> tuple | CausalLMOutputWithPast: r""" Args: labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -230,9 +238,13 @@ def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig): if megatron_config is not None: assert embedding_kwargs.get("config", False), "must have ModelParallelConfig" tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config) - self.embed_tokens = tensor_parallel.VocabParallelEmbedding(num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs) + self.embed_tokens = tensor_parallel.VocabParallelEmbedding( + num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs + ) - self.layers = nn.ModuleList([ParallelLlamaDecoderLayerRmPad(config, megatron_config) for _ in range(config.num_hidden_layers)]) + self.layers = nn.ModuleList( + [ParallelLlamaDecoderLayerRmPad(config, megatron_config) for _ in range(config.num_hidden_layers)] + ) self.norm = ParallelLlamaRMSNorm(config, megatron_config) def forward( @@ -243,7 +255,7 @@ def forward( indices: torch.Tensor = None, cu_seqlens: int = None, max_seqlen_in_batch: int = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: + ) -> tuple | BaseModelOutputWithPast: """ Args: @@ -313,7 +325,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: + ) -> tuple | CausalLMOutputWithPast: r""" Args: labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -326,7 +338,9 @@ def forward( batch_size, sequence_length = input_ids.shape # remove padding here - input_ids, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input(input_ids.unsqueeze(dim=-1), attention_mask) # (total_nnz, 1) + input_ids, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input( + input_ids.unsqueeze(dim=-1), attention_mask + ) # (total_nnz, 1) # pad input_ids to multiple of tp for all tp ranks # TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap @@ -355,7 +369,9 @@ def forward( logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension # add removed padding back - logits = pad_input(logits, indices, batch_size, seqlen=sequence_length) # (batch_size, sequence_length, vocab_size) + logits = pad_input( + logits, indices, batch_size, seqlen=sequence_length + ) # (batch_size, sequence_length, vocab_size) return CausalLMOutputWithPast( loss=None, @@ -388,7 +404,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: + ) -> tuple | CausalLMOutputWithPast: output = super().forward(input_ids, attention_mask, position_ids) output.logits = torch.squeeze(output.logits, dim=-1) return output @@ -422,7 +438,9 @@ def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig, pr assert embedding_kwargs.get("config", False), "must have ModelParallelConfig" tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config) if pre_process: - self.embed_tokens = tensor_parallel.VocabParallelEmbedding(num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs) + self.embed_tokens = tensor_parallel.VocabParallelEmbedding( + num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs + ) else: self.embed_tokens = None @@ -469,7 +487,7 @@ def forward( indices: torch.Tensor = None, cu_seqlens: int = None, max_seqlen_in_batch: int = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: + ) -> tuple | BaseModelOutputWithPast: """ Args: @@ -524,8 +542,12 @@ def __init__( super().__init__() self.config: TransformerConfig = convert_config(config, megatron_config) self.megatron_config = megatron_config - self.model = ParallelLlamaModelRmPadPP(config, megatron_config=megatron_config, pre_process=pre_process, post_process=post_process) - assert share_embeddings_and_output_weights is False, "Llama Model not supports sharing embedding and output weights" + self.model = ParallelLlamaModelRmPadPP( + config, megatron_config=megatron_config, pre_process=pre_process, post_process=post_process + ) + assert share_embeddings_and_output_weights is False, ( + "Llama Model not supports sharing embedding and output weights" + ) self.share_embeddings_and_output_weights = share_embeddings_and_output_weights self.vocab_size = config.vocab_size self.pre_process = pre_process @@ -573,7 +595,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: + ) -> tuple | CausalLMOutputWithPast: r""" Args: labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -588,7 +610,9 @@ def forward( # In the first pp, input_ids will be used, in other pp layers hidden_states will be used inside self.model batch_size, sequence_length = input_ids.shape # remove padding here - input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input(input_ids.unsqueeze(dim=-1), attention_mask) # (total_nnz, 1) + input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input( + input_ids.unsqueeze(dim=-1), attention_mask + ) # (total_nnz, 1) # pad input_ids to multiple of tp for all tp ranks # TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap @@ -617,7 +641,9 @@ def forward( totol_nnz = cu_seqlens[-1] logits = logits[:totol_nnz] # (total_nnz_padded) # add removed padding back. If input is already rmpad, we let the caller pad_input - logits = pad_input(logits, indices, batch_size, seqlen=sequence_length) # (batch_size, sequence_length, vocab_size) + logits = pad_input( + logits, indices, batch_size, seqlen=sequence_length + ) # (batch_size, sequence_length, vocab_size) return CausalLMOutputWithPast( loss=None, @@ -653,7 +679,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: + ) -> tuple | CausalLMOutputWithPast: output = super().forward(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids) if self.post_process: output.logits = torch.squeeze(output.logits, dim=-1) diff --git a/verl/models/mcore/__init__.py b/verl/models/mcore/__init__.py index a4b47418e..29d053177 100644 --- a/verl/models/mcore/__init__.py +++ b/verl/models/mcore/__init__.py @@ -13,6 +13,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .registry import get_mcore_forward_fn, get_mcore_weight_converter, hf_to_mcore_config, init_mcore_model +from .registry import ( + get_mcore_forward_fn, + get_mcore_forward_fused_fn, + get_mcore_weight_converter, + hf_to_mcore_config, + init_mcore_model, +) -__all__ = ["hf_to_mcore_config", "init_mcore_model", "get_mcore_forward_fn", "get_mcore_weight_converter"] +__all__ = [ + "hf_to_mcore_config", + "init_mcore_model", + "get_mcore_forward_fn", + "get_mcore_weight_converter", + "get_mcore_forward_fused_fn", +] diff --git a/verl/models/mcore/config_converter.py b/verl/models/mcore/config_converter.py index 2e21f3ed5..597afcdd1 100644 --- a/verl/models/mcore/config_converter.py +++ b/verl/models/mcore/config_converter.py @@ -19,11 +19,14 @@ import torch import torch.nn.functional as F +from megatron.core import parallel_state as mpu from megatron.core.transformer import MLATransformerConfig, TransformerConfig from transformers import PretrainedConfig -def _get_base_transformer_config(hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs) -> dict: +def _get_base_transformer_config( + hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs +) -> dict: """ Create a base TransformerConfig with common parameters across different model architectures. TODO: (ycl) use dataclass or converter config? @@ -36,10 +39,12 @@ def _get_base_transformer_config(hf_config: PretrainedConfig, dtype: torch.dtype Returns: TransformerConfig with common parameters """ - from megatron.core import parallel_state as mpu # Common parallel state parameters - overlap_p2p_comm = mpu.get_virtual_pipeline_model_parallel_world_size() is not None and mpu.get_virtual_pipeline_model_parallel_world_size() > 1 + overlap_p2p_comm = ( + mpu.get_virtual_pipeline_model_parallel_world_size() is not None + and mpu.get_virtual_pipeline_model_parallel_world_size() > 1 + ) batch_p2p_comm = False # Base configuration with common parameters @@ -54,6 +59,7 @@ def _get_base_transformer_config(hf_config: PretrainedConfig, dtype: torch.dtype "hidden_dropout": getattr(hf_config, "hidden_dropout", 0.0), "kv_channels": getattr(hf_config, "head_dim", None), "layernorm_epsilon": hf_config.rms_norm_eps, + "add_bias_linear": True, # Activation and normalization "activation_func": F.silu, "normalization": "RMSNorm", @@ -79,23 +85,75 @@ def _get_base_transformer_config(hf_config: PretrainedConfig, dtype: torch.dtype } # Update with any provided overrides + # override_transformer_config_kwargs as kwargs shall never be none base_config.update(override_transformer_config_kwargs) - print(f"Overridden TF init config: {base_config}") return base_config -def hf_to_mcore_config_dense(hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs) -> TransformerConfig: +def _get_mla_transformer_config( + hf_config: PretrainedConfig, mla_rope_config: dict, dtype: torch.dtype, **override_transformer_config_kwargs +) -> dict: + """ + Create a MLATransformerConfig with common parameters across different model architectures. + This is specifically for MLA models like DeepseekV3. + + Args: + hf_config: HuggingFace model configuration + mla_rope_config: MLA specific RoPE configuration + dtype: Data type for the model + override_transformer_config_kwargs: Additional parameters to override defaults + + Returns: + MLATransformerConfig with common parameters + """ + base_config = _get_base_transformer_config(hf_config=hf_config, dtype=dtype, **override_transformer_config_kwargs) + mla_config = { + # MLA specific parameters + "q_lora_rank": hf_config.q_lora_rank, + "kv_lora_rank": hf_config.kv_lora_rank, + "qk_head_dim": hf_config.qk_nope_head_dim, + "qk_pos_emb_head_dim": hf_config.qk_rope_head_dim, + "v_head_dim": hf_config.v_head_dim, + "rotary_base": hf_config.rope_theta, + "rotary_scaling_factor": mla_rope_config["factor"], + "rope_type": mla_rope_config["type"], + "max_position_embeddings": mla_rope_config["original_max_position_embeddings"], + "beta_fast": mla_rope_config["beta_fast"], + "beta_slow": mla_rope_config["beta_slow"], + "mscale": mla_rope_config["mscale"], + "mscale_all_dim": mla_rope_config["mscale_all_dim"], + } + + base_config.update(mla_config) + return base_config + + +def hf_to_mcore_config_dense( + hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs +) -> TransformerConfig: # for LlamaForCausalLM or Qwen2ForCausalLM qkv_bias = True if "Qwen2ForCausalLM" in hf_config.architectures else getattr(hf_config, "attention_bias", False) qk_layernorm = True if "Qwen3ForCausalLM" in hf_config.architectures else False - args = _get_base_transformer_config(hf_config=hf_config, dtype=dtype, use_cpu_initialization=False, add_bias_linear=False, add_qkv_bias=qkv_bias, qk_layernorm=qk_layernorm, **override_transformer_config_kwargs) + args: dict = _get_base_transformer_config( + hf_config=hf_config, + dtype=dtype, + use_cpu_initialization=False, + add_bias_linear=False, + add_qkv_bias=qkv_bias, + qk_layernorm=qk_layernorm, + ) + # override_transformer_config_kwargs as kwargs shall never be none + args.update(override_transformer_config_kwargs) + print(f"Overridden TF init config: {args}") return TransformerConfig(**args) -def hf_to_mcore_config_qwen2moe(hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs) -> TransformerConfig: - args = _get_base_transformer_config( +def hf_to_mcore_config_qwen2moe( + hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs +) -> TransformerConfig: + args: dict = _get_base_transformer_config( hf_config=hf_config, dtype=dtype, use_cpu_initialization=False, @@ -120,13 +178,17 @@ def hf_to_mcore_config_qwen2moe(hf_config: PretrainedConfig, dtype: torch.dtype, # Qwen specific moe_router_pre_softmax=True, add_qkv_bias=True, - **override_transformer_config_kwargs, ) + # override_transformer_config_kwargs as kwargs shall never be none + args.update(override_transformer_config_kwargs) + print(f"Overridden TF init config: {args}") return TransformerConfig(**args) -def hf_to_mcore_config_mixtral(hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs) -> TransformerConfig: - args = _get_base_transformer_config( +def hf_to_mcore_config_mixtral( + hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs +) -> TransformerConfig: + args: dict = _get_base_transformer_config( hf_config=hf_config, dtype=dtype, use_cpu_initialization=False, @@ -150,13 +212,17 @@ def hf_to_mcore_config_mixtral(hf_config: PretrainedConfig, dtype: torch.dtype, apply_rope_fusion=True, bias_activation_fusion=True, bias_dropout_fusion=True, - **override_transformer_config_kwargs, ) + # override_transformer_config_kwargs as kwargs shall never be none + args.update(override_transformer_config_kwargs) + print(f"Overridden TF init config: {args}") return TransformerConfig(**args) -def hf_to_mcore_config_qwen3moe(hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs) -> TransformerConfig: - args = _get_base_transformer_config( +def hf_to_mcore_config_qwen3moe( + hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs +) -> TransformerConfig: + args: dict = _get_base_transformer_config( hf_config=hf_config, dtype=dtype, use_cpu_initialization=False, @@ -179,12 +245,16 @@ def hf_to_mcore_config_qwen3moe(hf_config: PretrainedConfig, dtype: torch.dtype, # Qwen specific moe_router_pre_softmax=False, qk_layernorm=True, - **override_transformer_config_kwargs, ) + # override_transformer_config_kwargs as kwargs shall never be none + args.update(override_transformer_config_kwargs) + print(f"Overridden TF init config: {args}") return TransformerConfig(**args) -def hf_to_mcore_config_dpskv3(hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs) -> MLATransformerConfig: +def hf_to_mcore_config_dpskv3( + hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs +) -> MLATransformerConfig: # DeepseekV3ForCausalLM from megatron.core.transformer.enums import AttnBackend @@ -209,20 +279,23 @@ def hf_to_mcore_config_dpskv3(hf_config: PretrainedConfig, dtype: torch.dtype, * # disable MTP and quantization for now if "num_nextn_predict_layers" in hf_config: - assert hf_config.num_nextn_predict_layers == 0, "MTP is not supported for now, please modify the config.json to set num_nextn_predict_layers to 0" - assert "quantization_config" not in hf_config or not hf_config.quantization_config, "quantization is not supported for now, please modify the config.json to remove quantization_config" + assert hf_config.num_nextn_predict_layers == 0, ( + "MTP is not supported for now, please modify the config.json to set num_nextn_predict_layers to 0" + ) + assert "quantization_config" not in hf_config or not hf_config.quantization_config, ( + "quantization is not supported for now, please modify the config.json to remove quantization_config" + ) - args = _get_base_transformer_config( + args: dict = _get_mla_transformer_config( hf_config=hf_config, + mla_rope_config=mla_rope_config, dtype=dtype, + # Additional parameters use_cpu_initialization=False, add_bias_linear=False, attention_backend=AttnBackend.fused, - bf16=dtype is torch.bfloat16, - layernorm_epsilon=hf_config.rms_norm_eps, - ffn_hidden_size=hf_config.intermediate_size, qk_layernorm=True, - # moe specific + # Standard MoE parameters moe_ffn_hidden_size=hf_config.moe_intermediate_size, moe_token_dispatcher_type="alltoall", moe_router_bias_update_rate=0.001, @@ -239,33 +312,20 @@ def hf_to_mcore_config_dpskv3(hf_config: PretrainedConfig, dtype: torch.dtype, * moe_router_pre_softmax=True, moe_router_topk_scaling_factor=hf_config.routed_scaling_factor, moe_layer_freq=moe_layer_freq, - # MLA - q_lora_rank=hf_config.q_lora_rank, - kv_lora_rank=hf_config.kv_lora_rank, - qk_head_dim=hf_config.qk_nope_head_dim, - qk_pos_emb_head_dim=hf_config.qk_rope_head_dim, - v_head_dim=hf_config.v_head_dim, - rotary_base=hf_config.rope_theta, - rotary_scaling_factor=mla_rope_config["factor"], - rope_type=mla_rope_config["type"], - mscale=mla_rope_config["mscale"], - mscale_all_dim=mla_rope_config["mscale_all_dim"], - max_position_embeddings=mla_rope_config["original_max_position_embeddings"], - beta_fast=mla_rope_config["beta_fast"], - beta_slow=mla_rope_config["beta_slow"], # mcore 0.12 moe moe_router_dtype="fp64", disable_bf16_reduced_precision_matmul=True, - # other + # Other optimizations # deallocate_pipeline_outputs=True, # gradient_accumulation_fusion=True, persist_layer_norm=True, bias_activation_fusion=True, bias_dropout_fusion=True, ) - if override_transformer_config_kwargs: - args.update(override_transformer_config_kwargs) - transformer_config = MLATransformerConfig(**args) + # override_transformer_config_kwargs as kwargs shall never be none + args.update(override_transformer_config_kwargs) + transformer_config: MLATransformerConfig = MLATransformerConfig(**args) + print(f"Overridden MLA TF init config: {transformer_config}") # MTP if "num_nextn_predict_layers" in hf_config: transformer_config.mtp_num_layers = hf_config.num_nextn_predict_layers @@ -274,11 +334,27 @@ def hf_to_mcore_config_dpskv3(hf_config: PretrainedConfig, dtype: torch.dtype, * return transformer_config -def hf_to_mcore_config_qwen2_5_vl(hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs) -> TransformerConfig: +def hf_to_mcore_config_qwen2_5_vl( + hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs +) -> TransformerConfig: # Qwen2_5_VLForConditionalGeneration - raise NotImplementedError("Qwen2_5_VLForConditionalGeneration is not supported yet") + + args = _get_base_transformer_config( + hf_config=hf_config, + dtype=dtype, + add_bias_linear=False, + # qwen specific + add_qkv_bias=True, + mrope_section=hf_config.rope_scaling["mrope_section"], + ) + # override_transformer_config_kwargs as kwargs shall never be none + args.update(override_transformer_config_kwargs) + print(f"Overridden TF init config: {args}") + return TransformerConfig(**args) -def hf_to_mcore_config_llama4(hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs) -> TransformerConfig: +def hf_to_mcore_config_llama4( + hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs +) -> TransformerConfig: # Llama4ForConditionalGeneration raise NotImplementedError("Llama4ForConditionalGeneration is not supported yet") diff --git a/verl/models/mcore/loader.py b/verl/models/mcore/loader.py index 2c8784661..659b4baa2 100644 --- a/verl/models/mcore/loader.py +++ b/verl/models/mcore/loader.py @@ -18,6 +18,8 @@ import torch import torch.distributed as dist +from verl.utils.device import get_device_id, get_torch_device + from .saver import _megatron_calc_global_rank @@ -39,7 +41,9 @@ def _megatron_calc_layer_map(config): for pp_rank_idx in range(pp_size): for virtual_pp_rank_idx in range(virtual_pp_size): - layer_offset = virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model + layer_offset = ( + virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model + ) for layer_idx in range(num_layers_per_model): layer_map[layer_offset + layer_idx] = ( pp_rank_idx, @@ -56,7 +60,8 @@ def load_state_dict_to_megatron_gptmodel(state_dict, wrapped_models, config, par from megatron.core.transformer.module import Float16Module from torch.nn.parallel import DistributedDataParallel as torchDDP - from verl.utils.megatron_utils import print_rank_0, unwrap_model + from verl.utils.logger import print_rank_0 + from verl.utils.megatron_utils import unwrap_model start_time = time.time() @@ -65,7 +70,9 @@ def _get_gpt_model(model): def broadcast_params(module): for param in module.parameters(): - torch.distributed.broadcast(param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group()) + torch.distributed.broadcast( + param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group() + ) dp_rank = mpu.get_data_parallel_rank() pp_rank = mpu.get_pipeline_model_parallel_rank() @@ -80,7 +87,7 @@ def broadcast_params(module): assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0" assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0" - if not isinstance(wrapped_models, (list, tuple)): + if not isinstance(wrapped_models, list | tuple): wrapped_models = list(wrapped_models) assert len(wrapped_models) == virtual_pp_size @@ -121,7 +128,7 @@ def _broadcast_tensor(tensor, name) -> torch.Tensor: tensor = torch.empty( tensor_shape, dtype=params_dtype, - device=torch.cuda.current_device(), + device=get_device_id(), requires_grad=False, ) if torch.distributed.get_rank() == src_rank: @@ -160,12 +167,14 @@ def _broadcast_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None sync_tensor = torch.empty( chunk_shape, dtype=params_dtype, - device=torch.cuda.current_device(), + device=get_device_id(), requires_grad=False, ) else: - assert tensor.shape == chunk_shape, f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" - sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False) + assert tensor.shape == chunk_shape, ( + f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" + ) + sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) for i in range(tp_size): if torch.distributed.get_rank() == src_rank: @@ -205,12 +214,14 @@ def _broadcast_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> t sync_tensor = torch.empty( chunk_shape, dtype=params_dtype, - device=torch.cuda.current_device(), + device=get_device_id(), requires_grad=False, ) else: - assert tensor.shape == chunk_shape, f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" - sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False) + assert tensor.shape == chunk_shape, ( + f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" + ) + sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) for i in range(tp_size): if torch.distributed.get_rank() == src_rank: @@ -229,12 +240,16 @@ def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tens if torch.distributed.get_rank() == src_rank: gate_weight = state_dict[gate_name] up_weight = state_dict[up_name] - new_gate_up_weight = torch.empty(config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device()) + new_gate_up_weight = torch.empty( + config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=get_device_id() + ) for i in range(tp_size): intermediate_size_tp = config.intermediate_size // tp_size gate_weight_tp = gate_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] up_weight_tp = up_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] - new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_(torch.cat([gate_weight_tp, up_weight_tp], dim=0)) + new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_( + torch.cat([gate_weight_tp, up_weight_tp], dim=0) + ) tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0) chunk_shape = tensor_chunk[0].shape @@ -253,12 +268,15 @@ def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tens sync_tensor = torch.empty( chunk_shape, dtype=params_dtype, - device=torch.cuda.current_device(), + device=get_device_id(), requires_grad=False, ) else: - assert tensor.shape == chunk_shape, f"rank #{torch.distributed.get_rank() == src_rank:} tensor {gate_name, up_name} shape {tensor.shape} != {chunk_shape}" - sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False) + assert tensor.shape == chunk_shape, ( + f"rank #{torch.distributed.get_rank() == src_rank:} tensor {gate_name, up_name} shape " + f"{tensor.shape} != {chunk_shape}" + ) + sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) for i in range(tp_size): if torch.distributed.get_rank() == src_rank: @@ -289,7 +307,7 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) - sizes = [total_size * tp_size] if not bias: sizes.append(config.hidden_size) - new_weight_qkv = torch.empty(*sizes, dtype=params_dtype, device=torch.cuda.current_device()) + new_weight_qkv = torch.empty(*sizes, dtype=params_dtype, device=get_device_id()) for i in range(tp_size): q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] k_part = full_weight_k[i * kv_size_tp : (i + 1) * kv_size_tp] @@ -301,7 +319,9 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) - v_part_per_head = torch.chunk(v_part, num_query_groups_per_partition, dim=0) total_size_per_head = total_size // num_query_groups_per_partition for j in range(num_query_groups_per_partition): - new_weight_qkv_this_tp[j * total_size_per_head : (j + 1) * total_size_per_head].copy_(torch.cat([q_part_per_head[j], k_part_per_head[j], v_part_per_head[j]], dim=0)) + new_weight_qkv_this_tp[j * total_size_per_head : (j + 1) * total_size_per_head].copy_( + torch.cat([q_part_per_head[j], k_part_per_head[j], v_part_per_head[j]], dim=0) + ) else: q_size_tp = hidden_size_per_head * config.num_attention_heads // tp_size @@ -310,7 +330,7 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) - sizes = [total_size * tp_size] if not bias: sizes.append(config.hidden_size) - new_weight_qkv = torch.empty(*sizes, dtype=params_dtype, device=torch.cuda.current_device()) + new_weight_qkv = torch.empty(*sizes, dtype=params_dtype, device=get_device_id()) for i in range(tp_size): q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head @@ -323,7 +343,9 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) - v_part_per_head = torch.chunk(v_part, config.num_attention_heads, dim=0) total_size_per_head = total_size // config.num_attention_heads for j in range(config.num_attention_heads): - new_weight_qkv_this_tp[j * total_size_per_head : (j + 1) * total_size_per_head].copy_(torch.cat([q_part_per_head[j], k_part_per_head[j], v_part_per_head[j]], dim=0)) + new_weight_qkv_this_tp[j * total_size_per_head : (j + 1) * total_size_per_head].copy_( + torch.cat([q_part_per_head[j], k_part_per_head[j], v_part_per_head[j]], dim=0) + ) tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0) chunk_shape = tensor_chunk[0].shape @@ -342,12 +364,14 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) - sync_tensor = torch.empty( chunk_shape, dtype=params_dtype, - device=torch.cuda.current_device(), + device=get_device_id(), requires_grad=False, ) else: - assert tensor.shape == chunk_shape, f"rank #{torch.distributed.get_rank()} tensor {q_name} shape {tensor.shape} != {chunk_shape}" - sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False) + assert tensor.shape == chunk_shape, ( + f"rank #{torch.distributed.get_rank()} tensor {q_name} shape {tensor.shape} != {chunk_shape}" + ) + sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) for i in range(tp_size): if torch.distributed.get_rank() == src_rank: @@ -371,8 +395,8 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) - layer_map = _megatron_calc_layer_map(config) for layer in range(config.num_hidden_layers): - print_rank_0(f"loading layer #{layer}...") layer_name = f"model.layers.{layer}" + print_rank_0(f"loading layer #{layer}, with layer_name model.layers.{layer}...") dst_pp_rank, dst_virtual_pp_rank, dst_layer_idx = layer_map[layer] gpt_model_module = _get_gpt_model(models[dst_virtual_pp_rank]) @@ -382,7 +406,7 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) - sync_layer.self_attention.linear_qkv.layer_norm_weight if dst_pp_rank == pp_rank else None, f"{layer_name}.input_layernorm.weight", ) - + if f"{layer_name}.self_attn.q_norm.weight" in state_dict: _broadcast_tensor( sync_layer.self_attention.q_layernorm.weight if dst_pp_rank == pp_rank else None, @@ -464,5 +488,5 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) - for wrapped_model in wrapped_models: broadcast_params(wrapped_model) pass - torch.cuda.empty_cache() + get_torch_device().empty_cache() print_rank_0(f"loading megatron ckpt done, time elapsed {time.time() - start_time}s") diff --git a/verl/models/mcore/mbridge.py b/verl/models/mcore/mbridge.py new file mode 100644 index 000000000..35c32d697 --- /dev/null +++ b/verl/models/mcore/mbridge.py @@ -0,0 +1,23 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +try: + from mbridge import AutoBridge + from mbridge.utils.post_creation_callbacks import freeze_moe_router, make_value_model +except ImportError: + print("mbridge package not found. Please install mbridge with `pip install verl[mcore]` or `pip install mbridge`") + raise + +__all__ = ["AutoBridge", "make_value_model", "freeze_moe_router"] diff --git a/verl/models/mcore/model_forward.py b/verl/models/mcore/model_forward.py index 94b86462b..e70e11f4e 100644 --- a/verl/models/mcore/model_forward.py +++ b/verl/models/mcore/model_forward.py @@ -29,6 +29,7 @@ def gptmodel_forward( pack_seqs=True, logits_processor=None, logits_processor_args: dict = None, + **kwargs, ): """Default forward pass for GPT models with optional sequence packing.""" pre_process = unwrap_model(model).pre_process @@ -44,22 +45,104 @@ def gptmodel_forward( packed_seq_params=packed_seq_params, ) if post_process and logits_processor is not None: - args = {k: preprocess_packed_seqs(v, attention_mask, pre_process=True)[0] for k, v in logits_processor_args.items()} + args = { + k: preprocess_packed_seqs(v, attention_mask, pre_process=True)[0] + for k, v in logits_processor_args.items() + } output_dict = logits_processor(output_orig, **args) - output = {k: postprocess_packed_seqs(v, packed_seq_params, attention_mask, batch_size, seq_len, post_process=post_process) for k, v in output_dict.items()} + output = { + k: postprocess_packed_seqs( + v, packed_seq_params, attention_mask, batch_size, seq_len, post_process=post_process + ) + for k, v in output_dict.items() + } else: - output = postprocess_packed_seqs(output_orig, packed_seq_params, attention_mask, batch_size, seq_len, post_process=post_process) + output = postprocess_packed_seqs( + output_orig, packed_seq_params, attention_mask, batch_size, seq_len, post_process=post_process + ) else: assert logits_processor is None, "logits_processor is not supported for non-packed sequence" batch_size, sequence_length = attention_mask.shape - new_input_ids, new_attention_mask, new_position_ids = remove_left_padding(input_ids, attention_mask, position_ids, sequence_parallel, pre_process=pre_process) + new_input_ids, new_attention_mask, new_position_ids = remove_left_padding( + input_ids, attention_mask, position_ids, sequence_parallel, pre_process=pre_process + ) output = model(input_ids=new_input_ids, attention_mask=new_attention_mask, position_ids=new_position_ids) - output = recover_left_padding(output, new_attention_mask, attention_mask, sequence_length, post_process=post_process) + output = recover_left_padding( + output, new_attention_mask, attention_mask, sequence_length, post_process=post_process + ) if value_model and post_process: output = output[..., 0] return output -def gptmodel_forward_qwen2_5_vl(*args, **kwargs): - """Forward pass for Qwen2.5 VL model (not implemented).""" - raise NotImplementedError("VLM is not supported yet") +def gptmodel_forward_qwen2_5_vl( + model, + input_ids, + attention_mask, + position_ids, + sequence_parallel, + value_model=False, + pack_seqs=True, + multi_modal_inputs=None, + logits_processor=None, + logits_processor_args: dict = None, + **kwargs, +): + from megatron.core import parallel_state as mpu + + assert mpu.get_context_parallel_world_size() == 1, "qwen2_5_vl's context parallel is not accurate yet" + pre_process = unwrap_model(model).pre_process + post_process = unwrap_model(model).post_process + pixel_values = ( + multi_modal_inputs["pixel_values"].to(input_ids.device) if "pixel_values" in multi_modal_inputs else None + ) + image_grid_thw = ( + multi_modal_inputs["image_grid_thw"].to(input_ids.device) if "image_grid_thw" in multi_modal_inputs else None + ) + if pack_seqs: + batch_size, seq_len = attention_mask.shape[:2] + input_ids_rmpad, packed_seq_params = preprocess_packed_seqs(input_ids, attention_mask, pre_process=True) + input_ids_rmpad = input_ids_rmpad.contiguous() + output_orig = model( + input_ids=input_ids_rmpad, + attention_mask=None, + position_ids=position_ids, + packed_seq_params=packed_seq_params, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + ) + + if post_process and logits_processor is not None: + args = { + k: preprocess_packed_seqs(v, attention_mask, pre_process=True)[0] + for k, v in logits_processor_args.items() + } + output_dict = logits_processor(output_orig, **args) + output = { + k: postprocess_packed_seqs( + v, packed_seq_params, attention_mask, batch_size, seq_len, post_process=post_process + ) + for k, v in output_dict.items() + } + else: + output = postprocess_packed_seqs( + output_orig, packed_seq_params, attention_mask, batch_size, seq_len, post_process=post_process + ) + else: + batch_size, sequence_length = attention_mask.shape + new_input_ids, new_attention_mask, new_position_ids = remove_left_padding( + input_ids, attention_mask, position_ids, sequence_parallel, pre_process=pre_process + ) + output = model( + input_ids=new_input_ids, + position_ids=new_position_ids, + attention_mask=new_attention_mask, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + ) + output = recover_left_padding( + output, new_attention_mask, attention_mask, sequence_length, post_process=post_process + ) + if value_model and post_process: + output = output[..., 0] + return output diff --git a/verl/models/mcore/model_forward_fused.py b/verl/models/mcore/model_forward_fused.py new file mode 100644 index 000000000..fc55ef1b0 --- /dev/null +++ b/verl/models/mcore/model_forward_fused.py @@ -0,0 +1,327 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import OrderedDict +from typing import Optional + +import torch +from megatron.core import parallel_state +from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk +from megatron.core.inference.contexts import BaseInferenceContext +from megatron.core.models.gpt.gpt_model import GPTModel +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region +from torch import Tensor + +from verl.models.mcore.util import preprocess_packed_seqs +from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy +from verl.utils.megatron_utils import unwrap_model +from verl.utils.model import CausalLMOutputForPPO + +from .qwen2_5_vl.model import Qwen2_5VLModel +from .util import postprocess_packed_seqs_for_dict_output + + +def patch_fused_forward(model: torch.nn.Module): + model = unwrap_model(model) + if isinstance(model, GPTModel): + model = model + elif isinstance(model, Qwen2_5VLModel): + if not hasattr(model, "language_model"): + # the qwen2.5vl model might only have vision_model + return + model = model.language_model + else: + raise ValueError("Model is not a GPTModel or Qwen2_5VLModel") + model.forward_backup = model.forward + model.forward = _fused_GPTModel_forward.__get__(model, model.__class__) + return + + +def unpatch_fused_forward(model: torch.nn.Module): + model = unwrap_model(model) + if isinstance(model, GPTModel): + model = model + elif isinstance(model, Qwen2_5VLModel): + model = model.language_model + else: + raise ValueError("Model is not a GPTModel or Qwen2_5VLModel") + model.forward = model.forward_backup + return + + +def fused_forward_gptmodel( + model: GPTModel, + input_ids: Tensor, + position_ids: Tensor, + attention_mask: Tensor, + labels: Tensor, + labels_mask: Tensor, + **kwargs, +): + pre_process: bool = unwrap_model(model).pre_process + post_process: bool = unwrap_model(model).post_process + + batch_size, seq_len = attention_mask.shape[:2] + input_ids_rmpad, packed_seq_params = preprocess_packed_seqs(input_ids, attention_mask, pre_process=pre_process) + input_ids_rmpad = input_ids_rmpad.contiguous() + labels_rmpad, _ = preprocess_packed_seqs(labels, attention_mask, pre_process=True) + labels_mask_rmpad, _ = preprocess_packed_seqs(labels_mask, attention_mask, pre_process=True) + labels_rmpad = labels_rmpad.contiguous() + labels_mask_rmpad = labels_mask_rmpad.contiguous() + + output_orig: CausalLMOutputForPPO = model( + input_ids=input_ids_rmpad, + attention_mask=None, + position_ids=position_ids, + labels=labels_rmpad, + packed_seq_params=packed_seq_params, + ) + + if post_process: + # output_orig is in type of CausalLMOutputForPPO + output = postprocess_packed_seqs_for_dict_output( + labels_mask_rmpad, + output_orig, + packed_seq_params, + attention_mask, + batch_size, + seq_len, + post_process=post_process, + ) + else: + output = output_orig + return output + + +def fused_forward_qwen2_5_vl( + model: Qwen2_5VLModel, + input_ids: Tensor, + position_ids: Tensor, + attention_mask: Tensor, + labels: Tensor, + labels_mask: Tensor, + multi_modal_inputs=None, + **kwargs, +): + # pre_process = unwrap_model(model).pre_process + post_process = unwrap_model(model).post_process + + pixel_values = ( + multi_modal_inputs["pixel_values"].to(input_ids.device) if "pixel_values" in multi_modal_inputs else None + ) + image_grid_thw = ( + multi_modal_inputs["image_grid_thw"].to(input_ids.device) if "image_grid_thw" in multi_modal_inputs else None + ) + + batch_size, seq_len = attention_mask.shape[:2] + input_ids_rmpad, packed_seq_params = preprocess_packed_seqs(input_ids, attention_mask, pre_process=True) + labels_rmpad, _ = preprocess_packed_seqs(labels, attention_mask, pre_process=True) + labels_mask_rmpad, _ = preprocess_packed_seqs(labels_mask, attention_mask, pre_process=True) + labels_rmpad = labels_rmpad.contiguous() + labels_mask_rmpad = labels_mask_rmpad.contiguous() + input_ids_rmpad = input_ids_rmpad.contiguous() + output_orig: CausalLMOutputForPPO = model( + input_ids=input_ids_rmpad, + attention_mask=None, + position_ids=position_ids, + packed_seq_params=packed_seq_params, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + labels=labels, + ) + if post_process: + # output_orig is in type of CausalLMOutputForPPO + output = postprocess_packed_seqs_for_dict_output( + labels_mask_rmpad, + output_orig, + packed_seq_params, + attention_mask, + batch_size, + seq_len, + post_process=post_process, + ) + else: + output = output_orig + return output + + +def _fused_GPTModel_forward( + self, + input_ids: Tensor, + position_ids: Tensor, + attention_mask: Tensor, + decoder_input: Tensor = None, + labels: Tensor = None, + inference_context: BaseInferenceContext = None, + packed_seq_params: PackedSeqParams = None, + extra_block_kwargs: dict = None, + runtime_gather_output: Optional[bool] = None, + *, + inference_params: Optional[BaseInferenceContext] = None, + loss_mask: Optional[Tensor] = None, + temperature: float = 1.0, +) -> CausalLMOutputForPPO: + """ + Forward pass for GPT models with fused kernel support. + + Patch https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/models/gpt/gpt_model.py + """ + + # If decoder_input is provided (not None), then input_ids and position_ids are ignored. + # Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input. + + # Decoder embedding. + if decoder_input is not None: + pass + elif self.pre_process: + decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids) + else: + # intermediate stage of pipeline + # decoder will get hidden_states from encoder.input_tensor + decoder_input = None + + # Rotary positional embeddings (embedding is None for PP intermediate devices) + rotary_pos_emb = None + rotary_pos_cos = None + rotary_pos_sin = None + if self.position_embedding_type == "rope" and not self.config.multi_latent_attention: + if not self.training and self.config.flash_decode and inference_context: + assert inference_context.is_static_batching(), "GPTModel currently only supports static inference batching." + # Flash decoding uses precomputed cos and sin for RoPE + rotary_pos_cos, rotary_pos_sin = self.rotary_pos_emb_cache.setdefault( + inference_context.max_sequence_length, + self.rotary_pos_emb.get_cos_sin(inference_context.max_sequence_length), + ) + else: + rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len( + inference_context, self.decoder, decoder_input, self.config, packed_seq_params + ) + rotary_pos_emb = self.rotary_pos_emb( + rotary_seq_len, + packed_seq=packed_seq_params is not None and packed_seq_params.qkv_format == "thd", + ) + elif self.position_embedding_type == "mrope" and not self.config.multi_latent_attention: + if self.training or not self.config.flash_decode: + rotary_pos_emb = self.rotary_pos_emb(position_ids, self.mrope_section) + else: + # Flash decoding uses precomputed cos and sin for RoPE + raise NotImplementedError( + "Flash decoding uses precomputed cos and sin for RoPE, not implmented in MultimodalRotaryEmbedding yet." + ) + + if ( + (self.config.enable_cuda_graph or self.config.flash_decode) + and rotary_pos_cos is not None + and inference_context + and inference_context.is_static_batching() + and not self.training + ): + sequence_len_offset = torch.tensor( + [inference_context.sequence_len_offset] * inference_context.current_batch_size, + dtype=torch.int32, + device=rotary_pos_cos.device, # Co-locate this with the rotary tensors + ) + else: + sequence_len_offset = None + + # Wrap decoder_input to allow the decoder (TransformerBlock) to delete the + # reference held by this caller function, enabling early garbage collection for + # skip inference + + # Run decoder. + hidden_states = self.decoder( + hidden_states=decoder_input, + attention_mask=attention_mask, + inference_context=inference_context, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, + **(extra_block_kwargs or {}), + ) + + # Process inference output. + if inference_context and not inference_context.is_static_batching(): + hidden_states = inference_context.last_token_logits(hidden_states.squeeze(1).unsqueeze(0)).unsqueeze(1) + + # logits and loss + output_weight = None + if self.share_embeddings_and_output_weights: + output_weight = self.shared_embedding_or_output_weight() + + if self.mtp_process: + hidden_states = self.mtp( + input_ids=input_ids, + position_ids=position_ids, + labels=labels, + loss_mask=loss_mask, + hidden_states=hidden_states, + attention_mask=attention_mask, + inference_params=inference_params, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, + embedding=self.embedding, + output_layer=self.output_layer, + output_weight=output_weight, + runtime_gather_output=runtime_gather_output, + compute_language_model_loss=self.compute_language_model_loss, + **(extra_block_kwargs or {}), + ) + + if not self.post_process: + return hidden_states + + output = CausalLMOutputForPPO( + loss=None, + logits=None, + past_key_values=None, + hidden_states=hidden_states, + attentions=None, + ) + + if self.config.sequence_parallel: + hidden_states = gather_from_sequence_parallel_region(hidden_states) + logprobs, entropy = linear_cross_entropy( + hidden_states, + self.output_layer.weight, + labels, + temperature, + "none", + parallel_state.get_tensor_model_parallel_group(), + ) + + if has_config_logger_enabled(self.config): + payload = OrderedDict( + { + "input_ids": input_ids, + "position_ids": position_ids, + "attention_mask": attention_mask, + "decoder_input": decoder_input, + "logprobs": logprobs, + "entropy": entropy, + } + ) + log_config_to_disk(self.config, payload, prefix="input_and_logits") + + output.entropy = entropy + output.log_probs = logprobs + + return output diff --git a/verl/models/mcore/model_initializer.py b/verl/models/mcore/model_initializer.py index 7edbfb482..4c01b124b 100644 --- a/verl/models/mcore/model_initializer.py +++ b/verl/models/mcore/model_initializer.py @@ -85,7 +85,9 @@ def initialize( if post_process and value: from verl.models.llama.megatron.layers.parallel_linear import LinearForLastLayer - model.output_layer = LinearForLastLayer(input_size=self.tfconfig.hidden_size, output_size=1, config=self.tfconfig) + model.output_layer = LinearForLastLayer( + input_size=self.tfconfig.hidden_size, output_size=1, config=self.tfconfig + ) return model @@ -193,4 +195,69 @@ class Qwen25VLModel(BaseModelInitializer): """Initializer for Qwen2.5 VL models.""" def get_transformer_layer_spec(self): - raise NotImplementedError("VLM is not supported yet") + transformer_layer_spec = get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True) + return transformer_layer_spec + + def initialize( + self, + pre_process=None, + post_process=None, + share_embeddings_and_output_weights=False, + value=False, + **extra_kwargs, + ): + tfconfig = self.tfconfig + hf_config = self.hf_config + # Qwen2_5_VLForConditionalGeneration + from copy import deepcopy + + transformer_layer_spec = self.get_transformer_layer_spec() + + from megatron.core.extensions.transformer_engine import TEColumnParallelLinear, TERowParallelLinear + from megatron.core.models.gpt.moe_module_specs import MLPSubmodules + from megatron.core.models.vision.vit_layer_specs import get_vit_layer_with_transformer_engine_spec + + from .qwen2_5_vl import Qwen2_5VLModel, get_vision_model_config, get_vision_projection_config + + vision_transformer_config = get_vision_model_config(deepcopy(tfconfig)) + vision_transformer_config.pipeline_model_parallel_size = 1 + vision_transformer_config.first_pipeline_num_layers = None + + vision_projection_config = get_vision_projection_config( + deepcopy(tfconfig), + vision_transformer_config.hidden_size, + spatial_merge_size=hf_config.vision_config.spatial_merge_size, + ) + vision_projection_layer_spec = MLPSubmodules( + linear_fc1=TEColumnParallelLinear, + linear_fc2=TERowParallelLinear, + ) + vision_transformer_layer_spec = get_vit_layer_with_transformer_engine_spec() + + qwen25_vl_model = Qwen2_5VLModel( + language_transformer_config=tfconfig, + language_transformer_layer_spec=transformer_layer_spec, + language_vocab_size=hf_config.vocab_size, + language_max_sequence_length=hf_config.max_position_embeddings, + vision_transformer_config=vision_transformer_config, + vision_transformer_layer_spec=vision_transformer_layer_spec, + vision_projection_config=vision_projection_config, + vision_projection_layer_spec=vision_projection_layer_spec, + vision_projection_type="mlp", + language_rotary_base=hf_config.rope_theta, + pre_process=pre_process, + post_process=post_process, + add_decoder=True, + add_encoder=True, + parallel_output=True, + language_share_embeddings_and_output_weights=share_embeddings_and_output_weights, + ) + + if post_process and value: + from verl.models.llama.megatron.layers.parallel_linear import LinearForLastLayer + + qwen25_vl_model.language_model.output_layer = LinearForLastLayer( + input_size=tfconfig.hidden_size, output_size=1, config=tfconfig + ) + + return qwen25_vl_model diff --git a/verl/models/mcore/patch_v012.py b/verl/models/mcore/patch_v012.py index c1b1a1ea4..d54a3eb34 100644 --- a/verl/models/mcore/patch_v012.py +++ b/verl/models/mcore/patch_v012.py @@ -19,8 +19,15 @@ def apply_patch(): import torch - from megatron.core.transformer.multi_latent_attention import MLASelfAttention, apply_rotary_pos_emb, deprecate_inference_params, gather_from_sequence_parallel_region, gather_from_tensor_model_parallel_region, scatter_to_sequence_parallel_region from megatron.core import parallel_state, tensor_parallel + from megatron.core.transformer.multi_latent_attention import ( + MLASelfAttention, + apply_rotary_pos_emb, + deprecate_inference_params, + gather_from_sequence_parallel_region, + gather_from_tensor_model_parallel_region, + scatter_to_sequence_parallel_region, + ) def patch_get_query_key_value_tensors( self, @@ -44,7 +51,9 @@ def patch_get_query_key_value_tensors( # ========================================= # Prepare RoPE and seqlen related params # ========================================= - rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len(inference_context, None, hidden_states, self.config, packed_seq_params) + rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len( + inference_context, None, hidden_states, self.config, packed_seq_params + ) # rotary_pos_emb:[s, b, 1, 64] mscale = 1.0 @@ -87,13 +96,17 @@ def patch_get_query_key_value_tensors( # kv_combined: [s, b, (kv_lora_rank + qk_pos_emb_head_dim)] kv_combined = gather_from_tensor_model_parallel_region(kv_combined) # kv_compressed:[s, b, kv_lora_rank], k_pos_emb: [s, b, qk_pos_emb_head_dim] - kv_compressed, k_pos_emb = torch.split(kv_combined, [self.config.kv_lora_rank, self.config.qk_pos_emb_head_dim], dim=-1) + kv_compressed, k_pos_emb = torch.split( + kv_combined, [self.config.kv_lora_rank, self.config.qk_pos_emb_head_dim], dim=-1 + ) if self.config.sequence_parallel: # kv_compressed:[s / TP, b, kv_lora_rank] kv_compressed = scatter_to_sequence_parallel_region(kv_compressed) else: # kv_compressed:[s / TP, b, kv_lora_rank], k_pos_emb: [s / TP, b, qk_pos_emb_head_dim] - kv_compressed, k_pos_emb = torch.split(kv_combined, [self.config.kv_lora_rank, self.config.qk_pos_emb_head_dim], dim=-1) + kv_compressed, k_pos_emb = torch.split( + kv_combined, [self.config.kv_lora_rank, self.config.qk_pos_emb_head_dim], dim=-1 + ) if parallel_state.get_tensor_model_parallel_world_size() > 1: # k_pos_emb: [s, b, qk_pos_emb_head_dim] k_pos_emb = gather_from_sequence_parallel_region(k_pos_emb) @@ -191,7 +204,9 @@ def qkv_up_proj_and_rope_apply(q_compressed, kv_compressed, k_pos_emb, rotary_po if self.recompute_up_proj: self.qkv_up_checkpoint = tensor_parallel.CheckpointWithoutOutput() - query, key, value = self.qkv_up_checkpoint.checkpoint(qkv_up_proj_and_rope_apply, q_compressed, kv_compressed, k_pos_emb, rotary_pos_emb) + query, key, value = self.qkv_up_checkpoint.checkpoint( + qkv_up_proj_and_rope_apply, q_compressed, kv_compressed, k_pos_emb, rotary_pos_emb + ) else: query, key, value = qkv_up_proj_and_rope_apply(q_compressed, kv_compressed, k_pos_emb, rotary_pos_emb) diff --git a/verl/models/mcore/qwen2_5_vl/__init__.py b/verl/models/mcore/qwen2_5_vl/__init__.py new file mode 100644 index 000000000..8842d0249 --- /dev/null +++ b/verl/models/mcore/qwen2_5_vl/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024 Alibaba PAI Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .model import Qwen2_5VLModel +from .vision_config import get_vision_model_config, get_vision_projection_config + +__all__ = ["Qwen2_5VLModel", "get_vision_model_config", "get_vision_projection_config"] diff --git a/verl/models/mcore/qwen2_5_vl/attention.py b/verl/models/mcore/qwen2_5_vl/attention.py new file mode 100644 index 000000000..91a27cc3e --- /dev/null +++ b/verl/models/mcore/qwen2_5_vl/attention.py @@ -0,0 +1,221 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024 Alibaba PAI Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from megatron.core.transformer.attention import * + +from .rope_utils import apply_rotary_pos_emb_absolute + + +class Qwen2_5VLSelfAttention(SelfAttention): + """ + Overrides the SelfAttention class, the difference is that qwen2_5_vl uses apply_rotary_pos_emb_absolute + instead of apply_rotary_pos_emb + """ + + def forward( + self, + hidden_states: Tensor, + attention_mask: Tensor, + key_value_states: Optional[Tensor] = None, + inference_context: Optional[BaseInferenceContext] = None, + rotary_pos_emb: Optional[Union[Tensor, Tuple[Tensor, Tensor]]] = None, + rotary_pos_cos: Optional[Tensor] = None, + rotary_pos_sin: Optional[Tensor] = None, + attention_bias: Optional[Tensor] = None, + packed_seq_params: Optional[PackedSeqParams] = None, + sequence_len_offset: Optional[int] = None, + *, + inference_params: Optional[BaseInferenceContext] = None, + ) -> Tuple[Tensor, Tensor]: + """ + Perform a forward pass through the attention module. + + Args: + hidden_states (Tensor): Hidden states. + attention_mask (Tensor): Attention mask. + key_value_states (Optional[Tensor]): Key/value states (for cross attention). + inference_context (Optional[BaseInferenceContext]): Inference context that manages + KV cache. + rotary_pos_emb (Optional[Union[Tensor, Tuple[Tensor, Tensor]]]): Rotary + embedding tensor(s). + rotary_pos_cos (Optional[Tensor]): Rotary embedding cosine. + rotary_pos_sin (Optional[Tensor]): Rotary embedding sine. + attention_bias (Optional[Tensor]): Attention bias. + packed_seq_params (Optional[PackedSeqparams]): Parameters used for THD format. + sequence_len_offset (Optional[int]): Sequence length offset used for + inference CUDA graphs. + + Return: + (Tuple[Tensor, Tensor]) Attention output and bias. + + """ + + inference_context = deprecate_inference_params(inference_context, inference_params) + + if inference_context and inference_context.is_dynamic_batching(): + assert flash_decode_and_prefill_kernel is not None, ( + "Internal use only: install package `nvidia_chunked_flash_attn`." + ) + + # hidden_states: [sq, b, h] + if self.config.flash_decode and not self.training and inference_context is not None: + rotary_pos_emb = None + else: + assert rotary_pos_cos is None and rotary_pos_sin is None + + # For self attention we just duplicate the rotary_pos_emb if it isn't already + if rotary_pos_emb is not None and not isinstance(rotary_pos_emb, tuple): + rotary_pos_emb = (rotary_pos_emb,) * 2 + + # ===================== + # Query, Key, and Value + # ===================== + # Get the query, key and value tensors based on the type of attention - + # self or cross attn. + query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states) + + # =================================================== + # Adjust key, value, and rotary_pos_emb for inference + # =================================================== + + # This branch only runs in the decode phase of flash decoding and returns after the linear + # projection. This conditional is not used in the prefill phase or non-flash-decoding cases. + if ( + self.config.flash_decode + and inference_context is not None + and inference_context.is_decode_only() + and not self.training + and rotary_pos_cos is not None + ): + assert self.layer_number in inference_context.key_value_memory_dict + assert inference_context.sequence_len_offset is not None + inference_key_memory, inference_value_memory = inference_context.key_value_memory_dict[self.layer_number] + output = self.flash_decode( + sequence_len_offset=sequence_len_offset, + query_layer=query, + key_layer=key, + value_layer=value, + inference_key_memory=inference_key_memory, + inference_value_memory=inference_value_memory, + rotary_cos=rotary_pos_cos, + rotary_sin=rotary_pos_sin, + ) + out = output.transpose(0, 1).contiguous() + context_layer = out.view(out.size(0), out.size(1), -1) + output, bias = self.linear_proj(context_layer) + return output, bias + + query, key, value, rotary_pos_emb, attn_mask_type = self._adjust_key_value_for_inference( + inference_context, + query, + key, + value, + rotary_pos_emb, + rotary_pos_cos, + rotary_pos_sin, + sequence_len_offset, + ) + + if packed_seq_params is not None: + query = query.squeeze(1) + key = key.squeeze(1) + value = value.squeeze(1) + + # ================================================ + # relative positional embedding (rotary embedding) + # ================================================ + if rotary_pos_emb is not None and not self.config.flash_decode: + q_pos_emb, k_pos_emb = rotary_pos_emb + + if packed_seq_params is not None: + if packed_seq_params.cu_seqlens_q_padded is not None: + cu_seqlens_q = packed_seq_params.cu_seqlens_q_padded + else: + cu_seqlens_q = packed_seq_params.cu_seqlens_q + if packed_seq_params.cu_seqlens_kv_padded is not None: + cu_seqlens_kv = packed_seq_params.cu_seqlens_kv_padded + else: + cu_seqlens_kv = packed_seq_params.cu_seqlens_kv + else: + cu_seqlens_q = cu_seqlens_kv = None + + if q_pos_emb is not None: + # TODO VIJAY: simplify + if inference_context is None or inference_context.is_static_batching(): + query = apply_rotary_pos_emb_absolute(query, q_pos_emb, config=self.config, cu_seqlens=cu_seqlens_q) + else: + query = inference_context.apply_rotary_emb_query(query, q_pos_emb, self.config, cu_seqlens_q) + if k_pos_emb is not None: + key = apply_rotary_pos_emb_absolute(key, k_pos_emb, config=self.config, cu_seqlens=cu_seqlens_kv) + + # TODO, can apply positional embedding to value_layer so it has + # absolute positional embedding. + # otherwise, only relative positional embedding takes effect + # value_layer = apply_rotary_pos_emb(value_layer, k_pos_emb) + + # ================================== + # core attention computation + # ================================== + + if self.checkpoint_core_attention and self.training: + core_attn_out = self._checkpointed_attention_forward( + query, + key, + value, + attention_mask, + attn_mask_type=attn_mask_type, + attention_bias=attention_bias, + packed_seq_params=packed_seq_params, + ) + else: + if inference_context is None or inference_context.is_static_batching(): + # Static batching attention kernel. + core_attn_out = self.core_attention( + query, + key, + value, + attention_mask, + attn_mask_type=attn_mask_type, + attention_bias=attention_bias, + packed_seq_params=packed_seq_params, + ) + + else: + # Dynamic batching attention kernel. + q, k, v = (query, key, value) + cu_query_lengths, max_seqlen_q = inference_context.cu_query_lengths() + cu_kv_lengths, max_seqlen_k = inference_context.cu_kv_lengths() + + core_attn_out = self.flash_decode_and_prefill( + q, k, v, max_seqlen_q, max_seqlen_k, cu_query_lengths, cu_kv_lengths + ) + core_attn_out = core_attn_out.squeeze(0).unsqueeze(1) + core_attn_out = rearrange(core_attn_out, "s b h d -> s b (h d)") + + if packed_seq_params is not None and packed_seq_params.qkv_format == "thd": + # reshape to same output shape as unpacked case + # (t, np, hn) -> (t, b=1, h=np*hn) + # t is the pack size = sum (sq_i) + # note that batch is a dummy dimension in the packed case + core_attn_out = core_attn_out.reshape(core_attn_out.size(0), 1, -1) + + # ================= + # Output. [sq, b, h] + # ================= + + output, bias = self.linear_proj(core_attn_out) + + return output, bias diff --git a/verl/models/mcore/qwen2_5_vl/model.py b/verl/models/mcore/qwen2_5_vl/model.py new file mode 100644 index 000000000..74e4406c3 --- /dev/null +++ b/verl/models/mcore/qwen2_5_vl/model.py @@ -0,0 +1,340 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024 Alibaba PAI Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging + +import torch +from megatron.core import InferenceParams, tensor_parallel +from megatron.core.models.gpt.gpt_model import GPTModel + +# from .transformer_config import Qwen2VLTransformerConfig +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.transformer import MegatronModule +from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.transformer_config import TransformerConfig + +from .attention import Qwen2_5VLSelfAttention +from .vision_model import Qwen2_5VisionModel + + +# Note: This is under development and may be missing features. +class Qwen2_5VLModel(MegatronModule): + """Qwen2.5VL multi-modal model. + + Args: + language_transformer_config (TransformerConfig): Transformer config for the language model. + language_transformer_layer_spec (ModuleSpec): Specifies module to use for transformer layers of the + language model. + language_vocab_size (int): Language model vocabulary size. + language_max_sequence_length (int): Language model maximum sequence length. This is used for + positional embedding. + vision_transformer_config (TransformerConfig): Transformer config for the vision model. + vision_transformer_layer_spec (ModuleSpec): Specifies module to use for transformer layers of the + vision model. + vision_projection_config (TransformerConfig): Config for the projection from vision model outputs to + language model inputs. + vision_projection_layer_spec (ModuleSpec): Specifies the module to use for the vision + projection. + vision_projection_type (str): Type of the vision projection to use. Default is a 2-layer MLP. + parallel_output (bool): Do not gather the outputs, keep them split across tensor parallel ranks. This + is typically True for training and False for inference. + language_rotary_percent (float): Percent of rotary dimension to use for rotary position embeddings + in the language model. Defaults to 1.0. + pre_process (bool): Include the embedding layer in the gpt decoder (used with pipeline parallelism). + Defaults to True. + post_process (bool): Include an output layer and a layernorm in the gpt decoder (used with pipeline + parallelism). Defaults to True. + add_encoder (bool): Construct the encoder module (used with pipeline parallelism). Defaults to True. + When we use pipelining, the encoder + will live on only a subset of the pipeline stages (specifically, only the first stage). + add_decoder (bool): Construct the decoder module (used with pipeline parallelism). Defaults to True. + When we use pipelining, the decoder + will live on only a subset of the pipeline stages (specifically, every stage after the first one). + img_h (int): The height of each image that the ViT will see. + img_w (int): The width of each image that the ViT will see. + patch_dim (int): The size of each patch side. + img_embedding_idx (int): Index in the language_embeddings tensor where image_embeddings should be + inserted. Defaults to 0. + """ + + def __init__( + self, + language_transformer_config: TransformerConfig, + language_transformer_layer_spec: ModuleSpec, + language_vocab_size: int, + language_max_sequence_length: int, + vision_transformer_config: TransformerConfig, + vision_transformer_layer_spec: ModuleSpec, + vision_projection_config: TransformerConfig, + vision_projection_layer_spec: ModuleSpec, + vision_projection_type: str = "mlp", + parallel_output: bool = True, + language_rotary_percent: float = 1.0, + pre_process: bool = True, + post_process: bool = True, + add_encoder: bool = True, + add_decoder: bool = True, + language_rotary_base: int = 10000, + fp16_lm_cross_entropy: bool = False, + language_share_embeddings_and_output_weights: bool = False, + image_token_id: int = 151655, + video_token_id: int = 151656, + ) -> None: + super().__init__(config=language_transformer_config) + + # patch self_attention to use qwen2_5_vl attention + vision_transformer_layer_spec.submodules.self_attention.module = Qwen2_5VLSelfAttention + for layer_spec in language_transformer_layer_spec.layer_specs: + layer_spec.submodules.self_attention.module = Qwen2_5VLSelfAttention + + logging.getLogger(__name__).warning("Qwen2VL model is under development and may be missing features.") + + self.pre_process = pre_process + self.post_process = post_process + self.add_encoder = add_encoder + self.add_decoder = add_decoder + + self.encoder_hidden_state = None + self.vision_model = None + self.vision_projection = None + self.language_model = None + self.image_token_id = image_token_id + self.video_token_id = video_token_id + + self.square_merge_size = vision_projection_config.ffn_hidden_size // vision_transformer_config.hidden_size + + # This attribute is needed to check if an all-reduce is required + # on the word embeddings inside `finalize_model_grads._allreduce_word_embedding_grads`. + self.share_embeddings_and_output_weights = False + if self.pre_process: + self.vision_model = Qwen2_5VisionModel( + vision_transformer_config, + vision_transformer_layer_spec, + vision_projection_config, + vision_projection_layer_spec, + projection_type=vision_projection_type, + pre_process=True, + post_process=True, + ) + + self.language_model = GPTModel( + config=language_transformer_config, + transformer_layer_spec=language_transformer_layer_spec, + vocab_size=language_vocab_size, + max_sequence_length=language_max_sequence_length, + parallel_output=parallel_output, + position_embedding_type="mrope", + rotary_percent=language_rotary_percent, + pre_process=self.pre_process, + post_process=self.post_process, + rotary_base=language_rotary_base, + fp16_lm_cross_entropy=fp16_lm_cross_entropy, + share_embeddings_and_output_weights=language_share_embeddings_and_output_weights, + scatter_embedding_sequence_parallel=False, + ) + + self.share_embeddings_and_output_weights = self.language_model.share_embeddings_and_output_weights + + def shared_embedding_or_output_weight(self): + """This is a convenience method to surface the language model's word embeddings, which is + necessary for `finalize_model_grads._allreduce_word_embedding_grads`.""" + if self.add_decoder: + return self.language_model.shared_embedding_or_output_weight() + return None + + def set_input_tensor(self, input_tensor) -> None: + # This is usually handled in schedules.py but some inference code still + # gives us non-lists or None + if not isinstance(input_tensor, list): + input_tensor = [input_tensor] + assert len(input_tensor) == 1, "input_tensor should only be length 1 for Qwen2VL" + + if self.pre_process: + self.encoder_hidden_state = input_tensor[0] + else: + self.language_model.set_input_tensor(input_tensor[0]) + + def freeze(self, freeze_language_model: bool, freeze_vision_model: bool, freeze_vision_projection: bool): + """Freeze model modules. + + Make specific modules non-trainable by setting requires_grad to False for the module's parameters. + + Args: + freeze_language_model (bool): Freeze the language model module. + freeze_vision_model (bool): Freeze the vision model module. + freeze_vision_projection (bool): Freeze the vision projection module. + """ + modules = [] + if freeze_language_model and self.language_model is not None: + modules.append(self.language_model) + if freeze_vision_model and self.vision_model is not None: + modules.append(self.vision_model) + if freeze_vision_projection and self.vision_projection is not None: + modules.append(self.vision_projection) + + for module in modules: + for param in module.parameters(): + param.requires_grad = False + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + attention_mask: torch.Tensor = None, + labels: torch.Tensor = None, + inference_params: InferenceParams = None, + packed_seq_params: PackedSeqParams = None, + extra_block_kwargs: dict = None, + pixel_values: torch.Tensor = None, + pixel_values_videos: torch.Tensor = None, + image_grid_thw: torch.Tensor = None, + video_grid_thw: torch.Tensor = None, + ) -> torch.Tensor: + """Forward function of the Qwen2VL model. + + Args: + image_data (torch.Tensor): input image of shape [total_thw_size, n_features]. + input_ids (torch.Tensor): input text ids [batch, text_seq_len]. + position_ids (torch.Tensor): input text position ids [batch, text_seq_len]. + attention_mask (torch.Tensor): attention mask for the language model [batch, 1, combined_seq_len, + combined_seq_len]. + labels (torch.Tensor): Optional target text labels [batch, combined_seq_len]. + inference_params (InferenceParams): Inference-time parameters including KV cache. + + video_start_index: + 0 -- all video + len(video_seq) -- all image + others -- mixture + *_input_mask: should not be None in the first PP stage + Returns: + output (torch.Tensor): Loss of shape [b, s] if labels are provided, otherwise logits of shape + [b, s, vocab_size]. + """ + video_start_index = 0 + vision_grid_thw = None + vision_data = None + if image_grid_thw is not None: + image_mask = input_ids == self.image_token_id + vision_grid_thw = image_grid_thw + vision_data = pixel_values + video_start_index = image_mask.sum().item() + if video_grid_thw is not None: + video_mask = input_ids == self.video_token_id + vision_grid_thw = torch.cat([vision_grid_thw, video_grid_thw], dim=0) + vision_data = torch.cat([vision_data, pixel_values_videos], dim=0) + video_start_index = image_mask.sum().item() + video_mask.sum().item() + use_inference_kv_cache = ( + inference_params is not None and "image_tokens_count" in inference_params.key_value_memory_dict + ) + use_inference_kv_cache = ( + inference_params is not None and "image_tokens_count" in inference_params.key_value_memory_dict + ) + if use_inference_kv_cache: + raise NotImplementedError() + + if self.pre_process: + vision_embeds = None + if vision_grid_thw is not None and vision_grid_thw.shape[0] > 0: + vision_embeds = self.vision_model( + vision_data=vision_data, # If None, vision model should use intermediate outputs (EPP > 1) + grid_thw=vision_grid_thw, # should provided in each EPP stage + ) + + # If running inference, the language model KV cache will be updated for image token positions. + # Here we store the image tokens sequence length, which can be used as an offset to the KV cache later. + if inference_params is not None: + raise NotImplementedError() + # inference_params.key_value_memory_dict["image_tokens_count"] = ( + # vision_embeddings.shape[0] + # ) + + # If running inference, we can skip image token computation if they were computed already earlier + # for this sample. + if use_inference_kv_cache: + language_embeddings: torch.Tensor = self.language_model.embedding( + input_ids=input_ids, + position_ids=None, # NOTE: disable + ) # [text_seq_len, b, h_language] + # NOTE: why not cat here? is it the combined embeddings useless? + combined_embeddings = language_embeddings + elif vision_embeds is not None: + if video_start_index == 0: + image_embeds = None + video_embeds = vision_embeds + elif video_start_index == vision_embeds.shape[0]: + image_embeds = vision_embeds + video_embeds = None + elif 0 < video_start_index < vision_embeds.shape[0]: + image_embeds = vision_embeds[:video_start_index] + video_embeds = vision_embeds[video_start_index:] + else: + raise ValueError( + f"Expect video token start index in range [0, {vision_embeds.shape[0]}], but got " + f"{video_start_index}" + ) + + combined_embeddings = self.language_model.embedding( + input_ids=input_ids, + position_ids=None, # NOTE: disable + ) # [text_seq_len, b, h_language] + + if image_embeds is not None or video_embeds is not None: + combined_embeddings = combined_embeddings.transpose(0, 1).contiguous() + if image_embeds is not None: + image_mask = (input_ids == self.image_token_id).contiguous() + if image_mask.sum() > 0: + combined_embeddings = combined_embeddings.clone() + combined_embeddings[image_mask] = image_embeds.to( + dtype=combined_embeddings.dtype, device=combined_embeddings.device + ) + if video_embeds is not None: + video_mask = (input_ids == self.video_token_id).contiguous() + if video_mask.sum() > 0: + combined_embeddings = combined_embeddings.clone() + combined_embeddings[video_mask] = video_embeds.to( + dtype=combined_embeddings.dtype, device=combined_embeddings.device + ) + combined_embeddings = combined_embeddings.transpose(0, 1).contiguous() + + else: + combined_embeddings = self.language_model.embedding( + input_ids=input_ids, + position_ids=None, # NOTE: disable + ) # [text_seq_len, b, h_language] + if self.config.sequence_parallel: + combined_embeddings = tensor_parallel.scatter_to_sequence_parallel_region(combined_embeddings) + combined_embeddings = combined_embeddings.contiguous() + else: + combined_embeddings = None + from .rope_utils import get_rope_index + + position_ids, _ = get_rope_index( + input_ids, image_grid_thw=image_grid_thw, video_grid_thw=video_grid_thw, attention_mask=attention_mask + ) + + output = self.language_model( + input_ids=None, + position_ids=position_ids, # None in encoder + attention_mask=attention_mask, # None in encoder + decoder_input=combined_embeddings, # only not None in the first decoder PP stage + labels=labels, # only not None in the last decoder PP stage + # inference_params=inference_params, # currently always None + packed_seq_params=packed_seq_params, # currently always None + **(extra_block_kwargs or {}), + ) + + return output diff --git a/verl/models/mcore/qwen2_5_vl/rope_utils.py b/verl/models/mcore/qwen2_5_vl/rope_utils.py new file mode 100644 index 000000000..fadc74daa --- /dev/null +++ b/verl/models/mcore/qwen2_5_vl/rope_utils.py @@ -0,0 +1,266 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024 Alibaba PAI Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +import logging +from typing import Optional + +import torch +from megatron.core.models.common.embeddings.rope_utils import * +from megatron.core.models.common.embeddings.rope_utils import _apply_rotary_pos_emb_bshd +from torch import Tensor + +logger = logging.getLogger(__name__) + + +# Slightly modified from Qwen2VLForConditionalGeneration.get_rope_index +def get_rope_index( + input_ids: Optional[torch.LongTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, +): + """ + Calculate the 3D rope index based on image and video's temporal, height and width in LLM. + + Explanation: + + Each embedding sequence contains vision embedding and text embedding or just contains text embedding. + + For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs. + + Examples: + + input_ids: [T T T T T], here T is for text. + temporal position_ids: [0, 1, 2, 3, 4] + height position_ids: [0, 1, 2, 3, 4] + width position_ids: [0, 1, 2, 3, 4] + + For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part + and 1D rotary position embedding for text part. + + Examples: + + Temporal (Time): 3 patches, representing different segments of the video in time. + Height: 2 patches, dividing each frame vertically. + Width: 2 patches, dividing each frame horizontally. + We also have some important parameters: + fps (Frames Per Second): The video's frame rate, set to 1. This means one frame is processed each + second. + tokens_per_second: This is a crucial parameter. It dictates how many "time-steps" or "temporal + tokens" are conceptually packed into a one-second interval of the video. + In this case, we have 25 tokens per second. So each second of the video will be + represented with 25 separate time points. It essentially defines the temporal + granularity. + temporal_patch_size: The number of frames that compose one temporal patch. Here, it's 2 frames. + interval: The step size for the temporal position IDs, calculated as tokens_per_second * + temporal_patch_size / fps. In this case, 25 * 2 / 1 = 50. This means that each temporal patch will be + have a difference of 50 in the temporal position IDs. + input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision. + vision temporal position_ids: [0, 0, 0, 0, 50, 50, 50, 50, 100, 100, 100, 100] + vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1] + vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] + text temporal position_ids: [101, 102, 103, 104, 105] + text height position_ids: [101, 102, 103, 104, 105] + text width position_ids: [101, 102, 103, 104, 105] + Here we calculate the text start position_ids as the max vision position_ids plus 1. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*): + The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + Returns: + position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`) + mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`) + """ + spatial_merge_size = 2 + tokens_per_second = 2 + image_token_id = 151655 + video_token_id = 151656 + vision_start_token_id = 151652 + mrope_position_deltas = [] + if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): + total_input_ids = input_ids + if attention_mask is None: + attention_mask = torch.ones_like(total_input_ids) + position_ids = torch.ones( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ) + image_index, video_index = 0, 0 + attention_mask = attention_mask.to(total_input_ids.device) + for i, input_ids in enumerate(total_input_ids): + input_ids = input_ids[attention_mask[i] == 1] + image_nums, video_nums = 0, 0 + vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + second_per_grid_t = 0 + image_index += 1 + remain_images -= 1 + ed = ed_image + + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + if second_per_grid_ts is not None: + second_per_grid_t = second_per_grid_ts[video_index] + else: + second_per_grid_t = 1.0 + video_index += 1 + remain_videos -= 1 + ed = ed_video + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + range_tensor = torch.arange(llm_grid_t).view(-1, 1) + expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w) + + time_tensor = expanded_range * second_per_grid_t * tokens_per_second + + time_tensor_long = time_tensor.long() + t_index = time_tensor_long.flatten() + + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) + mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) + mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) + return position_ids, mrope_position_deltas + else: + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) + max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] + mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] + else: + position_ids = ( + torch.arange(input_ids.shape[1], device=input_ids.device) + .view(1, 1, -1) + .expand(3, input_ids.shape[0], -1) + ) + mrope_position_deltas = torch.zeros( + [input_ids.shape[0], 1], + device=input_ids.device, + dtype=input_ids.dtype, + ) + + return position_ids, mrope_position_deltas + + +def apply_rotary_pos_emb_thd_absolute( + t: Tensor, cu_seqlens: Tensor, freqs: Tensor, rotary_interleaved: bool = False +) -> Tensor: + """A baseline implementation of applying RoPE for `thd` format. + + Args: + t (Tensor): Input tensor T is of shape [t, h, d] + cu_seqlens(Tensor): Cumulative sum of sequence lengths in a batch for `t`, + with shape [b + 1] and dtype torch.int32. + freqs (Tensor): Rotary Positional embedding tensor freq is of shape [max_s, 1, 1, d] + + Returns: + Tensor: Shape [t, h, d]. The input tensor after applying RoPE. + """ + return _apply_rotary_pos_emb_bshd(t[:, None], freqs, rotary_interleaved=rotary_interleaved).squeeze(1) + + +def apply_rotary_pos_emb_absolute( + t: Tensor, + freqs: Tensor, + config: TransformerConfig, + cu_seqlens: Optional[Tensor] = None, +): + """ + Reroute to the appropriate apply_rotary_pos_emb function depending on + bshd (conventional) / thd (packed seq) format + + In Qwen2-VL, the shape of freqs is (seq_length, bs, 1, 2 * dim) instead of [max_seqlen, 1, 1, 2 * dim] + """ + + if config.apply_rope_fusion: + if cu_seqlens is None: + # NOTE: TE backends do not support mRoPE in bshd format when bs > 1 + if freqs.shape[1] > 1: + return _apply_rotary_pos_emb_bshd(t, freqs, rotary_interleaved=config.rotary_interleaved) + else: + return fused_apply_rotary_pos_emb(t, freqs) + else: + # NOTE: as expected, thd format can use bshd + return fused_apply_rotary_pos_emb(t[:, None], freqs).squeeze(1) + else: + if cu_seqlens is None: + return _apply_rotary_pos_emb_bshd(t, freqs, rotary_interleaved=config.rotary_interleaved) + else: + return apply_rotary_pos_emb_thd_absolute(t, cu_seqlens, freqs, rotary_interleaved=config.rotary_interleaved) diff --git a/verl/models/mcore/qwen2_5_vl/vision_config.py b/verl/models/mcore/qwen2_5_vl/vision_config.py new file mode 100644 index 000000000..0631c90f6 --- /dev/null +++ b/verl/models/mcore/qwen2_5_vl/vision_config.py @@ -0,0 +1,85 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024 Alibaba PAI Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from megatron.core import parallel_state +from megatron.core.transformer import TransformerConfig + + +def get_vision_model_config(config: TransformerConfig) -> TransformerConfig: + # Given a Transformer Config from decoder, build vision encoder config + # diff: out_hidden_size & intermediate_size + + # mlp: hidden_size -> intermediate_size -> embed_dim, silu + # NOTE: here we provide a workaround to solve the wrong layer amount when VPP of decoder is on + if config.num_layers in [28, 36]: + config.ffn_hidden_size = 3420 + else: + config.ffn_hidden_size = 3456 + + if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None: + config.num_layers = 32 * parallel_state.get_virtual_pipeline_model_parallel_world_size() # depth + else: + config.num_layers = 32 # depth + config.num_attention_heads = 16 # num_heads + config.add_bias_linear = True # all nn.Linear has bias (MLP, attn) + config.add_qkv_bias = True # qkv_proj in attn has bias + config.hidden_size = 1280 # hidden_size + config.hidden_dropout = 0.0 + config.attention_dropout = 0.0 + + # config.gated_linear_unit = False # no gated + # config.activation_func = quick_gelu # hidden_act + config.kv_channels = config.hidden_size // config.num_attention_heads + config.num_query_groups = config.num_attention_heads # no GQA + config.layernorm_zero_centered_gamma = False # False + config.apply_query_key_layer_scaling = False # factor=math.sqrt(head_dim) + config.bias_activation_fusion = False # no swiglu, set false + config.bias_dropout_fusion = False # no dropout, set false + config.attention_softmax_in_fp32 = True # use True + # config.normalization = 'LayerNorm' # use RMSNorm + config.seq_length = 1 + + config.tp_comm_overlap = False + config.sequence_parallel = False + config.temporal_patch_size = 2 + config.patch_size = 14 + config.in_channels = 3 + config.spatial_merge_size = 2 + + config.fullatt_block_indexes = [7, 15, 23, 31] + config._qwen2_5_vl_window_size = 112 + return config + + +def get_vision_projection_config( + config: TransformerConfig, embed_dim: int, spatial_merge_size: int +) -> TransformerConfig: + # merger: + # context_dim = hidden_size * merge_size**2 + # out_hidden_size = hidden_size + # context_dim -> context_dim -> out_hidden_size + # MLP: + # input_size -> ffn_hidden_size -> hidden_size + # spec: LN -> Linear(bias=True) -> GELU -> Linear(bias=True) + config.gated_linear_unit = False + config.bias_activation_fusion = False + config.add_bias_linear = True + config.ffn_hidden_size = embed_dim * (spatial_merge_size**2) + config.activation_func = torch.nn.functional.gelu + config.tp_comm_overlap = False + config.sequence_parallel = False + return config diff --git a/verl/models/mcore/qwen2_5_vl/vision_model.py b/verl/models/mcore/qwen2_5_vl/vision_model.py new file mode 100644 index 000000000..06b4fd328 --- /dev/null +++ b/verl/models/mcore/qwen2_5_vl/vision_model.py @@ -0,0 +1,309 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024 Alibaba PAI Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch +from megatron.core import InferenceParams +from megatron.core.models.common.vision_module.vision_module import VisionModule +from megatron.core.models.vision.multimodal_projector import MultimodalProjector +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.transformer.enums import ModelType +from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.transformer_config import TransformerConfig +from torch import nn +from torch.nn import functional as F + +from .vision_transformer_block import Qwen2_5VisionTransformerBlock as TransformerBlock + + +# copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +class PatchEmbed(nn.Module): + def __init__( + self, + patch_size: int = 14, + temporal_patch_size: int = 2, + in_channels: int = 3, + embed_dim: int = 1152, + ) -> None: + super().__init__() + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.in_channels = in_channels + self.embed_dim = embed_dim + + kernel_size = [temporal_patch_size, patch_size, patch_size] + self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + target_dtype = self.proj.weight.dtype + hidden_states = hidden_states.view( + -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size + ) + hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim) + return hidden_states + + +# copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +class VisionRotaryEmbedding(nn.Module): + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, seqlen: int) -> torch.Tensor: + seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = torch.outer(seq, self.inv_freq) + return freqs.float() + + +class Qwen2_5VisionModel(VisionModule): + """Qwen2.5 ViT vision model. + + Args: + transformer_config (TransformerConfig): Transformer config. + transformer_layer_spec (ModuleSpec): Specifies module to use for transformer layers. + ln_pre_impl (ModuleSpec or type): Specifies the layer norm type to use for ln_pre. + add_class_token (bool, optional): Include a class token. Defaults to True. + class_token_len (int): Class token length. Defaults to 1 but 8 may be faster. + patch_dim (int): Image patch size. + img_h (int): Input image height. + img_w (int): Input image width. + """ + + def __init__( + self, + transformer_config: TransformerConfig, + transformer_layer_spec: ModuleSpec, + projection_config: TransformerConfig, + projection_layer_spec: ModuleSpec, + projection_type: str = "mlp", + pre_process: bool = True, + post_process: bool = False, + ) -> None: + super().__init__(config=transformer_config) + + self.spatial_merge_size = transformer_config.spatial_merge_size + + embed_dim = transformer_config.hidden_size + num_heads = transformer_config.num_attention_heads + temporal_patch_size = transformer_config.temporal_patch_size + patch_size = transformer_config.patch_size + in_channels = transformer_config.in_channels + + self.patch_size = transformer_config.patch_size + self.fullatt_block_indexes = transformer_config.fullatt_block_indexes + self.window_size = transformer_config._qwen2_5_vl_window_size + self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size + + self.max_sequence_length = transformer_config.seq_length + self.patch_embed = PatchEmbed( + patch_size=patch_size, + temporal_patch_size=temporal_patch_size, + in_channels=in_channels, + embed_dim=embed_dim, + ) + + head_dim = embed_dim // num_heads + self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2) + + self.model_type = ModelType.encoder_or_decoder + self.pre_process = pre_process + self.post_process = post_process + + # Transformer layers. + # TODO: Follow-up changes will make pre and post_process configurable. They are needed for supporting + # pipeline parallelism. + # NOTE: a final layer norm and/or linear layer present in some implementations are omitted here. + self.decoder = TransformerBlock( + config=transformer_config, + spec=transformer_layer_spec, + pre_process=self.pre_process, + post_process=self.post_process, + post_layer_norm=True, + ) + + self.merge_hidden_size = projection_config.ffn_hidden_size + self.square_merge_size = self.merge_hidden_size // embed_dim + + if self.post_process: + self.projection = MultimodalProjector( + projection_config, projection_layer_spec, projection_type, projection_config.ffn_hidden_size + ) + else: + self.projection = None + + self.input_tensor = None + + def set_input_tensor(self, input_tensor: torch.Tensor) -> None: + """Sets input tensor to the model. + + Args: + input_tensor (Tensor): Sets the input tensor for the model. + """ + if self.pre_process: # always True + self.input_tensor = input_tensor + else: + raise NotImplementedError() + + def rot_pos_emb(self, grid_thw): + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0).to(grid_thw.device) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size).to(grid_thw.device) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + def get_window_index(self, grid_thw): + window_index: list = [] + cu_window_seqlens: list = [0] + window_index_id = 0 + vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size + + for grid_t, grid_h, grid_w in grid_thw: + llm_grid_h, llm_grid_w = ( + grid_h // self.spatial_merge_size, + grid_w // self.spatial_merge_size, + ) + index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w) + pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size + pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size + num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size + num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size + index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) + index_padded = index_padded.reshape( + grid_t, + num_windows_h, + vit_merger_window_size, + num_windows_w, + vit_merger_window_size, + ) + index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( + grid_t, + num_windows_h * num_windows_w, + vit_merger_window_size, + vit_merger_window_size, + ) + seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) + index_padded = index_padded.reshape(-1) + index_new = index_padded[index_padded != -100] + window_index.append(index_new + window_index_id) + cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1] + cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) + window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() + window_index = torch.cat(window_index, dim=0) + + return window_index, cu_window_seqlens + + def forward( + self, + vision_data: Optional[torch.Tensor], + grid_thw: torch.Tensor, + inference_params: Optional[InferenceParams] = None, + extra_block_kwargs: dict = None, + ) -> torch.Tensor: + """Forward function of the Qwen2 Vision Model. This function passes the input tensors + through the embedding layer and then the transformer. + + Args: + x (torch.Tensor): input image/video data of shape [n_tokens, n_dims] + grid_thw (torch.Tensor): the size tensor indicates grid size of each image/frame + packed_seq_params (PackedSeqParams): parameters to build attention mask in the backend + + Returns: + x (torch.Tensor): output after final transformer block of shape [b, s, h]. + """ + assert grid_thw is not None + assert self.input_tensor is None + assert inference_params is None + + # Rotary positional embeddings (embedding is None for PP intermediate devices) + vision_data = self.patch_embed(vision_data) + window_index, cu_window_seqlens = self.get_window_index(grid_thw) + cu_window_seqlens = torch.tensor( + cu_window_seqlens, + device=vision_data.device, + dtype=torch.int32, + ) + cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) + + seq_len, _ = vision_data.size() + vision_data = vision_data.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + vision_data = vision_data[window_index, :, :] + vision_data = vision_data.reshape(seq_len, 1, -1) + + rotary_pos_emb = self.rot_pos_emb(grid_thw) + rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + rotary_pos_emb = rotary_pos_emb[window_index, :, :] + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, 1, 1, -1).repeat(1, 1, 1, 2) + + hidden_states = self.decoder( + hidden_states=vision_data, + attention_mask=None, + inference_params=inference_params, + rotary_pos_emb=rotary_pos_emb, + packed_seq_params=self.build_packed_seq_params(None, cu_window_seqlens), + packed_seq_params_full=self.build_packed_seq_params(grid_thw), + fullatt_block_indexes=self.fullatt_block_indexes, + **(extra_block_kwargs or {}), + ) + + hidden_states = self.projection(hidden_states.view(-1, self.merge_hidden_size)) + reverse_indices = torch.argsort(window_index) + return hidden_states[reverse_indices, :] + + def build_packed_seq_params( + self, + grid_thw: Optional[torch.Tensor], + cu_seqlens: Optional[torch.Tensor] = None, + ) -> PackedSeqParams: + # NOTE: each frame is a sequence (rather than each grid) + if grid_thw is not None: + seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]) + cu_seqlens = seqlens.cumsum(dim=0) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0).int() + else: + seqlens = cu_seqlens[1:] - cu_seqlens[:-1] + + max_seqlen_q = seqlens.max() + return PackedSeqParams( + cu_seqlens_q=cu_seqlens, + cu_seqlens_kv=cu_seqlens, + qkv_format="thd", + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_q, + ) diff --git a/verl/models/mcore/qwen2_5_vl/vision_transformer_block.py b/verl/models/mcore/qwen2_5_vl/vision_transformer_block.py new file mode 100644 index 000000000..8f765a0ff --- /dev/null +++ b/verl/models/mcore/qwen2_5_vl/vision_transformer_block.py @@ -0,0 +1,265 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024 Alibaba PAI Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from megatron.core.transformer.transformer_block import * + + +class Qwen2_5VisionTransformerBlock(TransformerBlock): + def _checkpointed_forward( + self, + hidden_states: Tensor, + attention_mask: Tensor, + context: Tensor, + context_mask: Tensor, + rotary_pos_emb: Tensor, + attention_bias: Tensor, + packed_seq_params: PackedSeqParams, + packed_seq_params_full: PackedSeqParams, + fullatt_block_indexes, + ): + """Forward method with activation checkpointing.""" + + def custom(start: int, end: int): + def custom_forward(hidden_states, attention_mask, context, context_mask, rotary_pos_emb): + for index in range(start, end): + if index in fullatt_block_indexes: + packed_seq_params_now = packed_seq_params_full + else: + packed_seq_params_now = packed_seq_params + layer = self._get_layer(index) + hidden_states, context = layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + context=context, + context_mask=context_mask, + rotary_pos_emb=rotary_pos_emb, + attention_bias=attention_bias, + inference_context=None, + packed_seq_params=packed_seq_params_now, + ) + return hidden_states, context + + return custom_forward + + def checkpoint_handler(forward_func): + """Determines whether to use the `te_checkpoint` or `tensor_parallel.checkpoint`""" + if self.config.fp8: + return te_checkpoint( + forward_func, + self.config.distribute_saved_activations, + tensor_parallel.random.get_cuda_rng_tracker, + parallel_state.get_tensor_model_parallel_group(), + hidden_states, + attention_mask, + context, + context_mask, + rotary_pos_emb, + ) + else: + return tensor_parallel.checkpoint( + forward_func, + self.config.distribute_saved_activations, + hidden_states, + attention_mask, + context, + context_mask, + rotary_pos_emb, + ) + + if self.config.recompute_method == "uniform": + # Uniformly divide the total number of Transformer layers and checkpoint + # the input activation of each divided chunk. + # A method to further reduce memory usage reducing checkpoints. + layer_idx = 0 + while layer_idx < self.num_layers_per_pipeline_rank: + hidden_states, context = checkpoint_handler( + custom(layer_idx, layer_idx + self.config.recompute_num_layers) + ) + + layer_idx += self.config.recompute_num_layers + + elif self.config.recompute_method == "block": + # Checkpoint the input activation of only a set number of individual + # Transformer layers and skip the rest. + # A method fully use the device memory removing redundant re-computation. + recompute_skip_num_layers = 0 + for layer_idx in range(self.num_layers_per_pipeline_rank): + # Skip recomputation when input grad computation is not needed. + # Need to have at least one input tensor with gradient computation + # for re-enterant autograd engine. + if self.config.fp8 and not hidden_states.requires_grad: + recompute_skip_num_layers += 1 + if ( + layer_idx >= recompute_skip_num_layers + and layer_idx < self.config.recompute_num_layers + recompute_skip_num_layers + ): + hidden_states, context = checkpoint_handler(custom(layer_idx, layer_idx + 1)) + else: + hidden_states, context = custom(layer_idx, layer_idx + 1)( + hidden_states, attention_mask, context, context_mask, rotary_pos_emb + ) + else: + raise ValueError("Invalid activation recompute method.") + + return hidden_states + + def forward( + self, + hidden_states: Union[Tensor, WrappedTensor], + attention_mask: Optional[Tensor], + context: Optional[Tensor] = None, + context_mask: Optional[Tensor] = None, + rotary_pos_emb: Optional[Tensor] = None, + rotary_pos_cos: Optional[Tensor] = None, + rotary_pos_sin: Optional[Tensor] = None, + attention_bias: Optional[Tensor] = None, + inference_context: Optional[BaseInferenceContext] = None, + packed_seq_params: Optional[PackedSeqParams] = None, + sequence_len_offset: Optional[Tensor] = None, + packed_seq_params_full: PackedSeqParams = None, + fullatt_block_indexes=None, + *, + inference_params: Optional[BaseInferenceContext] = None, + ): + """ + Perform the forward pass through the transformer block. + + This method handles the core computation of the transformer, including + self-attention, optional cross-attention, and feed-forward operations. + + Args: + hidden_states (Union[Tensor, WrappedTensor]): Input tensor of shape [s, b, h] + where s is the sequence length, b is the batch size, and h is the hidden size. + Can be passed as a WrappedTensor during inference to avoid an obsolete + reference in the calling function. + attention_mask (Tensor): Boolean tensor of shape [1, 1, s, s] for masking + self-attention. + context (Tensor, optional): Context tensor for cross-attention. + context_mask (Tensor, optional): Mask for cross-attention context + rotary_pos_emb (Tensor, optional): Rotary positional embeddings. + attention_bias (Tensor): Bias tensor for Q * K.T of shape in shape broadcastable + to [b, num_head, sq, skv], e.g. [1, 1, sq, skv]. + Used as an alternative to apply attention mask for TE cuDNN attention. + inference_context (BaseInferenceContext, optional): Parameters for inference-time + optimizations. + packed_seq_params (PackedSeqParams, optional): Parameters for packed sequence + processing. + + Returns: + Union[Tensor, Tuple[Tensor, Tensor]]: The output hidden states tensor of shape + [s, b, h], and optionally the updated context tensor if cross-attention is used. + """ + + inference_context = deprecate_inference_params(inference_context, inference_params) + + # Delete the obsolete reference to the initial input tensor if necessary + if isinstance(hidden_states, WrappedTensor): + hidden_states = hidden_states.unwrap() + + if not self.pre_process: + # See set_input_tensor() + hidden_states = self.input_tensor + + # Update the inference parameters with the current batch size in case it is variable + if inference_context and not self.training: + inference_context.current_batch_size = hidden_states.size(1) + + # Viewless tensor. + # - We only need to create a viewless tensor in the case of micro batch + # size (mbs) == 1, since in this case, 'hidden_states.transpose()' + # above creates a view tensor, and '.contiguous()' is a pass-through. + # For mbs >= 2, '.contiguous()' creates a new tensor, eliminating + # the need to make it viewless. + # + # However, we don't explicitly check mbs == 1 here because + # make_viewless_tensor() has negligible overhead when its input + # is already viewless. + # + # - For the 'else' case above, calling make_viewless_tensor() here is + # likely redundant, since p2p_communication.py (likely originator) + # already creates viewless tensors. That said, make_viewless_tensor() + # is called here to be future-proof and corner-case-proof. + hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) + + if self.config.sequence_parallel: + rng_context = tensor_parallel.get_cuda_rng_tracker().fork() + else: + rng_context = nullcontext() + + # If fp8_recipe is delayed, wrap the entire pass with get_fp8_context(), + # otherwise do nothing extra at the outer level + # if we are using other fp8 recipes, then the context manager enter&exit are free + # we can wrap fp8_context within the for loop over layers, so that we can fine-grained + # control which layer will be fp8 or bf16 + use_outer_fp8_context = self.config.fp8 and self.config.fp8_recipe == Fp8Recipe.delayed + use_inner_fp8_context = self.config.fp8 and self.config.fp8_recipe != Fp8Recipe.delayed + outer_fp8_context = get_fp8_context(self.config) if use_outer_fp8_context else nullcontext() + + with rng_context, outer_fp8_context: + # Forward pass. + if self.config.recompute_granularity == "full" and self.training: + hidden_states = self._checkpointed_forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + context=context, + context_mask=context_mask, + rotary_pos_emb=rotary_pos_emb, + attention_bias=attention_bias, + packed_seq_params=packed_seq_params, + packed_seq_params_full=packed_seq_params_full, + fullatt_block_indexes=fullatt_block_indexes, + ) + else: + for l_no, layer in enumerate(self.layers): + inner_fp8_context = ( + get_fp8_context(self.config, layer.layer_number - 1) if use_inner_fp8_context else nullcontext() + ) + if l_no in fullatt_block_indexes: + packed_seq_params_now = packed_seq_params_full + else: + packed_seq_params_now = packed_seq_params + with self.offload_context, inner_fp8_context: + hidden_states, context = layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + context=context, + context_mask=context_mask, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + attention_bias=attention_bias, + inference_context=inference_context, + packed_seq_params=packed_seq_params_now, + sequence_len_offset=sequence_len_offset, + ) + + if ( + torch.is_grad_enabled() + and self.config.cpu_offloading + and self.group_prefetch_offload_commit_async is not None + ): + hidden_states = self.group_prefetch_offload_commit_async(hidden_states) + + # Final layer norm. + if self.final_layernorm is not None: + hidden_states = self.final_layernorm(hidden_states) + # TENorm produces a "viewed" tensor. This will result in schedule.py's + # deallocate_output_tensor() throwing an error, so a viewless tensor is + # created to prevent this. + hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) + + return hidden_states diff --git a/verl/models/mcore/readme.md b/verl/models/mcore/readme.md index bcbeacd58..606dcf189 100644 --- a/verl/models/mcore/readme.md +++ b/verl/models/mcore/readme.md @@ -6,7 +6,7 @@ The migration has been successful with the help of the mcore team and the commun 2. migrate `LlamaForCausalLM` and `Qwen2ForCausalLM` to mcore `GPTModel` 3. support sequence packing/thd format. 4. support `tensor parallel`, `pipeline parallel`, `sequence parallel`, `virtual pipeline parallel`, `context parallel`. -5. support the mcore `dist_checkpointing` feature and a basic offline weighs conversion scipt from huggingface to mcore `dist_checkpointing` format. +5. support the mcore `dist_checkpointing` feature and a basic offline weighs conversion script from huggingface to mcore `dist_checkpointing` format. We are working on the following features: - support `Qwen2MoeForCausalLM` @@ -15,7 +15,7 @@ We are working on the following features: - support `expert parallel` Features we invite the community to contribute: -- better scipts for offline weights conversion from huggingface to mcore `dist_checkpointing` format. +- better scripts for offline weights conversion from huggingface to mcore `dist_checkpointing` format. - conversion of large models with multiple GPUs - conversion of large models with single GPU - refactor the `megatron_checkpoint_manager.py` by `dist_checkpointing` format. @@ -33,7 +33,7 @@ main steps: - a. convert the huggingface config to mcore `TransformerConfig` - b. init the mcore `GPTModel` with the converted config - c. load the huggingface model weights to the `GPTModel` -2. online weight conversion from mcore to huggingface (due the the rollout engine `vLLM` is using huggingface format) +2. online weight conversion from mcore to huggingface (due to the rollout engine `vLLM` is using huggingface format) - a. bridge the gap between mcore and huggingface weights format and name mapping - b. online resharding the mcore weights to rollout engine - this part is very complicated with multiple parallel strategies composition between mcore and rollout engine @@ -68,9 +68,9 @@ Most of the features of `GPTModel` is out-of-the-box supported in verl through c Features about parallel strategies should be supported with changes about the online weights conversion(especially the resharding part) and verl work dispatching. ### checkpointing -The existing checkpointing code is in `verl/utils/checkpoint/megatron_checkpoint_manager.py`. And the script to convert checkpoint to huggingface format is in `verl/scripts/model_merger.py`. +The existing checkpointing code is in `verl/utils/checkpoint/megatron_checkpoint_manager.py`. And the script to convert checkpoint to huggingface format is in `verl/scripts/model_merger`. -The existing checkpoint format is simplely save every rank's weights and optimizer states. It should be refactored by `dist_checkpointing` format. +The existing checkpoint format simply saves every rank's weights and optimizer states. It should be refactored by `dist_checkpointing` format. ## How to support new models @@ -82,7 +82,7 @@ The existing checkpoint format is simplely save every rank's weights and optimiz - d. for VLM the interface might be different, it is ok to add a new model class with GPTModel as its module. 3. offline weights conversion from huggingface to mcore `dist_checkpointing` format 4. support online weights conversion from mcore to huggingface - - it is recommended to initilize a vLLM model with the converted mcore weights, and then test if the generating sequence is correct. + - it is recommended to initialize a vLLM model with the converted mcore weights, and then test if the generating sequence is correct. ## How to scale up to larger models like deepseek-v3 or other 100B+ models @@ -96,4 +96,4 @@ The necessary features under development for scaling up are - expert parallel - more efficient and general weight resharding and loading 3. Offline weights conversion - - support weights larger then single GPU memory + - support weights larger than single GPU memory diff --git a/verl/models/mcore/registry.py b/verl/models/mcore/registry.py index cc27cc2bc..23f01e8b7 100644 --- a/verl/models/mcore/registry.py +++ b/verl/models/mcore/registry.py @@ -17,7 +17,7 @@ """ from enum import Enum -from typing import Callable, Dict, Type +from typing import Callable import torch import torch.nn as nn @@ -35,6 +35,11 @@ ) from .model_forward import ( gptmodel_forward, + gptmodel_forward_qwen2_5_vl, +) +from .model_forward_fused import ( + fused_forward_gptmodel, + fused_forward_qwen2_5_vl, ) from .model_initializer import ( BaseModelInitializer, @@ -49,6 +54,7 @@ McoreToHFWeightConverterDense, McoreToHFWeightConverterDpskv3, McoreToHFWeightConverterMixtral, + McoreToHFWeightConverterQwen2_5_VL, McoreToHFWeightConverterQwen2Moe, McoreToHFWeightConverterQwen3Moe, ) @@ -67,7 +73,7 @@ class SupportedModel(Enum): # Registry for model configuration converters -MODEL_CONFIG_CONVERTER_REGISTRY: Dict[SupportedModel, Callable[[PretrainedConfig, torch.dtype], TransformerConfig]] = { +MODEL_CONFIG_CONVERTER_REGISTRY: dict[SupportedModel, Callable[[PretrainedConfig, torch.dtype], TransformerConfig]] = { SupportedModel.LLAMA: hf_to_mcore_config_dense, SupportedModel.QWEN2: hf_to_mcore_config_dense, SupportedModel.QWEN2_MOE: hf_to_mcore_config_qwen2moe, @@ -77,10 +83,11 @@ class SupportedModel(Enum): SupportedModel.LLAMA4: hf_to_mcore_config_llama4, SupportedModel.QWEN3: hf_to_mcore_config_dense, SupportedModel.QWEN3_MOE: hf_to_mcore_config_qwen3moe, + SupportedModel.QWEN2_5_VL: hf_to_mcore_config_qwen2_5_vl, } # Registry for model initializers -MODEL_INITIALIZER_REGISTRY: Dict[SupportedModel, Type[BaseModelInitializer]] = { +MODEL_INITIALIZER_REGISTRY: dict[SupportedModel, type[BaseModelInitializer]] = { SupportedModel.LLAMA: DenseModel, SupportedModel.QWEN2: DenseModel, SupportedModel.QWEN2_MOE: Qwen2MoEModel, @@ -90,10 +97,11 @@ class SupportedModel(Enum): SupportedModel.LLAMA4: DenseModel, SupportedModel.QWEN3: DenseModel, SupportedModel.QWEN3_MOE: Qwen3MoEModel, + SupportedModel.QWEN2_5_VL: Qwen25VLModel, } # Registry for model forward functions -MODEL_FORWARD_REGISTRY: Dict[SupportedModel, Callable] = { +MODEL_FORWARD_REGISTRY: dict[SupportedModel, Callable] = { SupportedModel.LLAMA: gptmodel_forward, SupportedModel.QWEN2: gptmodel_forward, SupportedModel.QWEN2_MOE: gptmodel_forward, @@ -103,11 +111,27 @@ class SupportedModel(Enum): SupportedModel.LLAMA4: gptmodel_forward, SupportedModel.QWEN3: gptmodel_forward, SupportedModel.QWEN3_MOE: gptmodel_forward, + SupportedModel.QWEN2_5_VL: gptmodel_forward_qwen2_5_vl, SupportedModel.DEEPSEEK_V3: gptmodel_forward, } +# Registry for model forward functions +MODEL_FORWARD_FUSED_REGISTRY: dict[SupportedModel, Callable] = { + SupportedModel.LLAMA: fused_forward_gptmodel, + SupportedModel.QWEN2: fused_forward_gptmodel, + SupportedModel.QWEN2_MOE: fused_forward_gptmodel, + SupportedModel.MIXTRAL: fused_forward_gptmodel, + SupportedModel.DEEPSEEK_V3: fused_forward_gptmodel, + SupportedModel.QWEN2_5_VL: fused_forward_qwen2_5_vl, + SupportedModel.LLAMA4: fused_forward_gptmodel, + SupportedModel.QWEN3: fused_forward_gptmodel, + SupportedModel.QWEN3_MOE: fused_forward_gptmodel, + SupportedModel.QWEN2_5_VL: fused_forward_qwen2_5_vl, + SupportedModel.DEEPSEEK_V3: fused_forward_gptmodel, +} + # Registry for model weight converters -MODEL_WEIGHT_CONVERTER_REGISTRY: Dict[SupportedModel, Type] = { +MODEL_WEIGHT_CONVERTER_REGISTRY: dict[SupportedModel, type] = { SupportedModel.LLAMA: McoreToHFWeightConverterDense, SupportedModel.QWEN2: McoreToHFWeightConverterDense, SupportedModel.QWEN2_MOE: McoreToHFWeightConverterQwen2Moe, @@ -115,6 +139,7 @@ class SupportedModel(Enum): SupportedModel.DEEPSEEK_V3: McoreToHFWeightConverterDpskv3, SupportedModel.QWEN3: McoreToHFWeightConverterDense, SupportedModel.QWEN3_MOE: McoreToHFWeightConverterQwen3Moe, + SupportedModel.QWEN2_5_VL: McoreToHFWeightConverterQwen2_5_VL, } @@ -123,10 +148,24 @@ def get_supported_model(model_type: str) -> SupportedModel: return SupportedModel(model_type) except ValueError as err: supported_models = [e.value for e in SupportedModel] - raise NotImplementedError(f"Model Type: {model_type} not supported. Supported models: {supported_models}") from err + raise NotImplementedError( + f"Model Type: {model_type} not supported. Supported models: {supported_models}" + ) from err + +def hf_to_mcore_config( + hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs +) -> TransformerConfig: + """Convert huggingface PretrainedConfig to mcore TransformerConfig. + + Args: + hf_config: The huggingface PretrainedConfig. + dtype: The dtype of the model. + **override_transformer_config_kwargs: The kwargs to override the transformer config. -def hf_to_mcore_config(hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs) -> TransformerConfig: + Returns: + The mcore TransformerConfig. + """ assert len(hf_config.architectures) == 1, "Only one architecture is supported for now" model = get_supported_model(hf_config.architectures[0]) return MODEL_CONFIG_CONVERTER_REGISTRY[model](hf_config, dtype, **override_transformer_config_kwargs) @@ -161,7 +200,13 @@ def init_mcore_model( model = get_supported_model(hf_config.architectures[0]) initializer_cls = MODEL_INITIALIZER_REGISTRY[model] initializer = initializer_cls(tfconfig, hf_config) - return initializer.initialize(pre_process=pre_process, post_process=post_process, share_embeddings_and_output_weights=share_embeddings_and_output_weights, value=value, **extra_kwargs) + return initializer.initialize( + pre_process=pre_process, + post_process=post_process, + share_embeddings_and_output_weights=share_embeddings_and_output_weights, + value=value, + **extra_kwargs, + ) def get_mcore_forward_fn(hf_config: PretrainedConfig) -> Callable: @@ -173,6 +218,15 @@ def get_mcore_forward_fn(hf_config: PretrainedConfig) -> Callable: return MODEL_FORWARD_REGISTRY[model] +def get_mcore_forward_fused_fn(hf_config: PretrainedConfig) -> Callable: + """ + Get the forward function for given model architecture. + """ + assert len(hf_config.architectures) == 1, "Only one architecture is supported for now" + model = get_supported_model(hf_config.architectures[0]) + return MODEL_FORWARD_FUSED_REGISTRY[model] + + def get_mcore_weight_converter(hf_config: PretrainedConfig, dtype: torch.dtype) -> Callable: """ Get the weight converter for given model architecture. diff --git a/verl/models/mcore/saver.py b/verl/models/mcore/saver.py index 153c49ff6..2a954b241 100644 --- a/verl/models/mcore/saver.py +++ b/verl/models/mcore/saver.py @@ -22,10 +22,14 @@ from megatron.core.transformer.module import Float16Module from torch.nn.parallel import DistributedDataParallel as torchDDP -from verl.utils.megatron_utils import print_rank_0, unwrap_model +from verl.utils.device import get_device_id, get_torch_device +from verl.utils.logger import print_rank_0 +from verl.utils.megatron_utils import unwrap_model -def _megatron_calc_global_rank(tp_rank: int = 0, dp_rank: int = 0, pp_rank: int = 0, cp_rank: int = 0, ep_rank: int = 0): +def _megatron_calc_global_rank( + tp_rank: int = 0, dp_rank: int = 0, pp_rank: int = 0, cp_rank: int = 0, ep_rank: int = 0 +): """Calculate global rank with support for CP/EP parallelism""" # Get parallel sizes for each dimension @@ -37,7 +41,9 @@ def _megatron_calc_global_rank(tp_rank: int = 0, dp_rank: int = 0, pp_rank: int # Verify total GPU count matches (must be consistent with parallel_state.py) total_size = tp_size * dp_size * pp_size * cp_size - assert total_size == torch.distributed.get_world_size(), f"{tp_size}x{dp_size}x{pp_size}x{cp_size} != {torch.distributed.get_world_size()}" + assert total_size == torch.distributed.get_world_size(), ( + f"{tp_size}x{dp_size}x{pp_size}x{cp_size} != {torch.distributed.get_world_size()}" + ) # Core calculation logic (corresponds to RankGenerator order parameter) # Assumes default order is "tp-cp-ep-dp-pp" @@ -62,7 +68,9 @@ def _megatron_calc_layer_map(config): for pp_rank_idx in range(pp_size): for virtual_pp_rank_idx in range(virtual_pp_size): - layer_offset = virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model + layer_offset = ( + virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model + ) for layer_idx in range(num_layers_per_model): layer_map[layer_offset + layer_idx] = ( pp_rank_idx, @@ -104,7 +112,7 @@ def _get_gpt_model(model): assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0" assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0" - if not isinstance(wrapped_models, (list, tuple)): + if not isinstance(wrapped_models, list | tuple): wrapped_models = list(wrapped_models) assert len(wrapped_models) == virtual_pp_size @@ -115,7 +123,11 @@ def _get_gpt_model(model): for i, wrapped_model in enumerate(wrapped_models): models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module)) - assert len(models[i].decoder.layers) == num_layers_per_model, "len model layers {} not equal to num_layers_per_model {}".format(len(models[i].decoder.layers), num_layers_per_model) + assert len(models[i].decoder.layers) == num_layers_per_model, ( + "len model layers {} not equal to num_layers_per_model {}".format( + len(models[i].decoder.layers), num_layers_per_model + ) + ) state_dict = dict() @@ -156,7 +168,7 @@ def _broadcast_tensor(tensor, name, src_pp_rank) -> torch.Tensor: weight = torch.empty( tensor_shape, dtype=dtype, - device=torch.cuda.current_device(), + device=get_device_id(), requires_grad=False, ) @@ -186,7 +198,7 @@ def _broadcast_tp_shard_tensor(tensor, name, src_pp_rank, concat_dim=0, mutate_f buffer_tensor = torch.empty( chunk_shape, dtype=dtype, - device=torch.cuda.current_device(), + device=get_device_id(), requires_grad=False, ) @@ -227,7 +239,7 @@ def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name, src_pp_rank) buffer_tensor = torch.empty( chunk_shape, dtype=dtype, - device=torch.cuda.current_device(), + device=get_device_id(), requires_grad=False, ) @@ -277,7 +289,7 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, src_pp_rank): buffer_tensor = torch.empty( chunk_shape, dtype=dtype, - device=torch.cuda.current_device(), + device=get_device_id(), requires_grad=False, ) @@ -337,7 +349,7 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, src_pp_rank): state_dict[v_name] = torch.cat(v_weight_list, dim=0) # empty cache before collecting weights - torch.cuda.empty_cache() + get_torch_device().empty_cache() # Embeddings # ------------------- if dp_rank == 0 and cp_rank == 0: # models are identical across cp ranks @@ -453,7 +465,7 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, src_pp_rank): ) dist.barrier() - torch.cuda.empty_cache() + get_torch_device().empty_cache() if torch.distributed.get_rank() == 0: for k, v in state_dict.items(): if dtype != v.dtype: @@ -463,13 +475,23 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, src_pp_rank): return state_dict -def merge_megatron_ckpt_gptmodel_qwen_moe(wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False): +def merge_megatron_ckpt_gptmodel_qwen_moe( + wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False +): raise NotImplementedError("merge_megatron_ckpt_gptmodel_qwen_moe is not implemented") +def merge_megatron_ckpt_gptmodel_qwen2_5_vl( + wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False +): + raise NotImplementedError("merge_megatron_ckpt_gptmodel_qwen2_5_vl is not implemented") + + def merge_megatron_ckpt_gptmodel_dpskv3(wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False): raise NotImplementedError("merge_megatron_ckpt_gptmodel_dpskv3 is not implemented") -def merge_megatron_ckpt_gptmodel_mixtral(wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False): +def merge_megatron_ckpt_gptmodel_mixtral( + wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False +): raise NotImplementedError("merge_megatron_ckpt_gptmodel_mixtral is not implemented") diff --git a/verl/models/mcore/util.py b/verl/models/mcore/util.py index 59b12374f..c1ef7a211 100644 --- a/verl/models/mcore/util.py +++ b/verl/models/mcore/util.py @@ -17,11 +17,16 @@ from megatron.core import parallel_state as mpu from megatron.core.packed_seq_params import PackedSeqParams +from verl.utils.model import CausalLMOutputForPPO -def preprocess_packed_seqs(input_ids: torch.Tensor, attention_mask: torch.Tensor, pre_process: bool = True) -> tuple[torch.Tensor, PackedSeqParams]: + +def preprocess_packed_seqs( + input_ids: torch.Tensor, attention_mask: torch.Tensor, pre_process: bool = True +) -> tuple[torch.Tensor, PackedSeqParams]: """ Preprocess packed sequences - CP splits sequence into CP*2 chunks, and each GPU gets 2 chunks (GPU0 gets first and last chunks, GPU1 gets second and second last chunks, and so on), this is for load balancing with causal masking. + CP splits sequence into CP*2 chunks, and each GPU gets 2 chunks (GPU0 gets first and last chunks, GPU1 + gets second and second last chunks, and so on), this is for load balancing with causal masking. See https://github.com/NVIDIA/TransformerEngine/issues/1368 """ batch_size = input_ids.shape[0] @@ -54,14 +59,18 @@ def preprocess_packed_seqs(input_ids: torch.Tensor, attention_mask: torch.Tensor start_idx = cu_seqlens_padded[i] // cp_size # split to 2 chunks d = input_ids[i, attention_mask[i]] - input_ids_rmpad[start_idx : start_idx + half_seqlen] = d[half_seqlen * cp_rank : half_seqlen * (cp_rank + 1)] + input_ids_rmpad[start_idx : start_idx + half_seqlen] = d[ + half_seqlen * cp_rank : half_seqlen * (cp_rank + 1) + ] remain_start = seqlens_in_batch_padded[i] - half_seqlen * (cp_rank + 1) remain_end = seqlens_in_batch_padded[i] - half_seqlen * cp_rank remain_end = min(remain_end, d.shape[0]) remain_len = remain_end - remain_start if remain_len > 0: - input_ids_rmpad[start_idx + half_seqlen : start_idx + half_seqlen + remain_len] = d[remain_start:remain_end] + input_ids_rmpad[start_idx + half_seqlen : start_idx + half_seqlen + remain_len] = d[ + remain_start:remain_end + ] packed_seq_params = PackedSeqParams( qkv_format="thd", @@ -107,9 +116,13 @@ def postprocess_packed_seqs( for i in range(batch_size): if cp_size <= 1: s = attention_mask[i].sum().item() - output_new[i, attention_mask[i]] = output[0][packed_seq_params.cu_seqlens_q_padded[i] : packed_seq_params.cu_seqlens_q_padded[i] + s] + output_new[i, attention_mask[i]] = output[0][ + packed_seq_params.cu_seqlens_q_padded[i] : packed_seq_params.cu_seqlens_q_padded[i] + s + ] continue - s_len_padded_chunk = (packed_seq_params.cu_seqlens_q_padded[i + 1] - packed_seq_params.cu_seqlens_q_padded[i]) // cp_size + s_len_padded_chunk = ( + packed_seq_params.cu_seqlens_q_padded[i + 1] - packed_seq_params.cu_seqlens_q_padded[i] + ) // cp_size half_seqlen = s_len_padded_chunk // 2 s_len = attention_mask[i].sum().item() s_len_padded = s_len_padded_chunk * cp_size @@ -155,7 +168,9 @@ def remove_left_padding( shape[1] = seq_len if pre_process: new_input_ids = torch.zeros(dtype=input_ids.dtype, device=input_ids.device, size=shape) - new_attention_mask = torch.zeros(dtype=attention_mask.dtype, device=attention_mask.device, size=(batch_size, seq_len)) + new_attention_mask = torch.zeros( + dtype=attention_mask.dtype, device=attention_mask.device, size=(batch_size, seq_len) + ) new_position_ids = torch.zeros(dtype=position_ids.dtype, device=position_ids.device, size=(batch_size, seq_len)) for i in range(batch_size): if pre_process: @@ -188,3 +203,38 @@ def recover_left_padding( for i in range(batch_size): new_result[i, original_attention_mask[i]] = result[i, attention_mask[i]] return new_result + + +def postprocess_packed_seqs_for_dict_output( + labels_mask: torch.Tensor, + output: CausalLMOutputForPPO, + packed_seq_params: PackedSeqParams, + attention_mask: torch.Tensor, + batch_size: int, + seq_len: int, + post_process: bool = True, +) -> dict[str, torch.Tensor]: + """_summary_ + For fused kernels, the output is a dictionary with keys like 'log_probs', 'entropy', etc. + This function post-processes each tensor in the output dictionary. + Args: + output (CausalLMOutputForPPO): _description_ + packed_seq_params (PackedSeqParams): _description_ + attention_mask (torch.Tensor): _description_ + batch_size (int): _description_ + seq_len (int): _description_ + post_process (bool, optional): _description_. Defaults to True. + Returns: + CausalLMOutputForPPO: _description_ + """ + ret = {} + output.entropy = output.entropy.view(1, -1) + output.log_probs = output.log_probs.view(1, -1) + output.log_probs = output.log_probs.masked_fill(~labels_mask, 0.0) + ret["entropy"] = postprocess_packed_seqs( + output.entropy, packed_seq_params, attention_mask, batch_size, seq_len, post_process=post_process + ) + ret["log_probs"] = postprocess_packed_seqs( + output.log_probs, packed_seq_params, attention_mask, batch_size, seq_len, post_process=post_process + ) + return ret diff --git a/verl/models/mcore/weight_converter.py b/verl/models/mcore/weight_converter.py index cd620f2a8..791513f32 100644 --- a/verl/models/mcore/weight_converter.py +++ b/verl/models/mcore/weight_converter.py @@ -147,6 +147,125 @@ def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[lis return convert_names, params +class McoreToHFWeightConverterQwen2_5_VL(McoreToHFWeightConverterDense): + def convert_param(self, name: str, params_one_group: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: + direct_name_mapping = { + "language_model.embedding.word_embeddings.weight": "model.embed_tokens.weight", + "language_model.decoder.final_layernorm.weight": "model.norm.weight", + "language_model.output_layer.weight": "lm_head.weight", + "vision_model.patch_embed.proj.weight": "visual.patch_embed.proj.weight", + "vision_model.decoder.final_layernorm.weight": "visual.merger.ln_q.weight", + "vision_model.projection.encoder.linear_fc1.weight": "visual.merger.mlp.0.weight", + "vision_model.projection.encoder.linear_fc1.bias": "visual.merger.mlp.0.bias", + "vision_model.projection.encoder.linear_fc2.weight": "visual.merger.mlp.2.weight", + "vision_model.projection.encoder.linear_fc2.bias": "visual.merger.mlp.2.bias", + } + if name in direct_name_mapping: + return [direct_name_mapping[name]], [params_one_group[0]] + + if "self_attention" in name: + return self._convert_attention_param(name, params_one_group) + elif "mlp" in name: + return self._convert_mlp_param(name, params_one_group) + else: + raise NotImplementedError(f"Unsupported parameter name: {name}") + + def _convert_attention_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: + model_type, _, _, layer_number = name.split(".")[:4] + + convert_names = [] + if model_type == "language_model": + name_map_after_layer = { + "self_attention.linear_qkv.bias": [ + "self_attn.q_proj.bias", + "self_attn.k_proj.bias", + "self_attn.v_proj.bias", + ], + "self_attention.linear_qkv.weight": [ + "self_attn.q_proj.weight", + "self_attn.k_proj.weight", + "self_attn.v_proj.weight", + ], + "self_attention.linear_proj.weight": "self_attn.o_proj.weight", + "self_attention.linear_qkv.layer_norm_weight": "input_layernorm.weight", + } + name_after_layer = ".".join(name.split(".")[-3:]) + mapped_name = name_map_after_layer.get(name_after_layer) + if isinstance(mapped_name, list): + assert len(params) == len(mapped_name) + for one in mapped_name: + convert_names.append(f"model.layers.{layer_number}.{one}") + else: + assert len(params) == 1 + convert_names.append(f"model.layers.{layer_number}.{mapped_name}") + elif model_type == "vision_model": + name_map_after_layer = { + "self_attention.linear_proj.weight": "attn.proj.weight", + "self_attention.linear_proj.bias": "attn.proj.bias", + "self_attention.linear_qkv.layer_norm_weight": "norm1.weight", + } + name_after_layer = ".".join(name.split(".")[-3:]) + mapped_name = name_map_after_layer.get(name_after_layer, None) + if mapped_name is None: + assert "linear_qkv" in name_after_layer + assert len(params) == 3 + new_param = torch.cat(params, dim=0) + params = [new_param] + if "bias" in name_after_layer: + convert_names.append(f"visual.blocks.{layer_number}.attn.qkv.bias") + else: + convert_names.append(f"visual.blocks.{layer_number}.attn.qkv.weight") + else: + assert len(params) == 1 + convert_names.append(f"visual.blocks.{layer_number}.{mapped_name}") + else: + raise NotImplementedError(f"Unsupported model type: {model_type}") + return convert_names, params + + def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: + model_type, _, _, layer_number = name.split(".")[:4] + + convert_names = [] + if model_type == "language_model": + name_map_after_layer = { + "mlp.linear_fc1.weight": ["mlp.gate_proj.weight", "mlp.up_proj.weight"], + "mlp.linear_fc1.bias": ["mlp.gate_proj.bias", "mlp.up_proj.bias"], + "mlp.linear_fc2.weight": "mlp.down_proj.weight", + "mlp.linear_fc2.bias": "mlp.down_proj.bias", + "mlp.linear_fc1.layer_norm_weight": "post_attention_layernorm.weight", + } + name_after_layer = ".".join(name.split(".")[-3:]) + mapped_name = name_map_after_layer.get(name_after_layer) + if isinstance(mapped_name, list): + assert len(params) == len(mapped_name) + for one in mapped_name: + convert_names.append(f"model.layers.{layer_number}.{one}") + else: + assert len(params) == 1 + convert_names.append(f"model.layers.{layer_number}.{mapped_name}") + + elif model_type == "vision_model": + name_map_after_layer = { + "mlp.linear_fc1.weight": ["mlp.gate_proj.weight", "mlp.up_proj.weight"], + "mlp.linear_fc1.bias": ["mlp.gate_proj.bias", "mlp.up_proj.bias"], + "mlp.linear_fc2.weight": "mlp.down_proj.weight", + "mlp.linear_fc2.bias": "mlp.down_proj.bias", + "mlp.linear_fc1.layer_norm_weight": "norm2.weight", + } + name_after_layer = ".".join(name.split(".")[-3:]) + mapped_name = name_map_after_layer.get(name_after_layer) + if isinstance(mapped_name, list): + assert len(params) == len(mapped_name) + for one in mapped_name: + convert_names.append(f"visual.blocks.{layer_number}.{one}") + else: + assert len(params) == 1 + convert_names.append(f"visual.blocks.{layer_number}.{mapped_name}") + else: + raise NotImplementedError(f"Unsupported model type: {model_type}") + return convert_names, params + + class McoreToHFWeightConverterDpskv3(McoreToHFWeightConverterBase): def _convert_attention_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: # mcore @@ -225,7 +344,10 @@ def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[lis "mlp.linear_fc2.weight": "mlp.down_proj.weight", "mlp.shared_experts.linear_fc2.weight": "mlp.shared_experts.down_proj.weight", "mlp.linear_fc1.weight": ["mlp.gate_proj.weight", "mlp.up_proj.weight"], - "mlp.shared_experts.linear_fc1.weight": ["mlp.shared_experts.gate_proj.weight", "mlp.shared_experts.up_proj.weight"], + "mlp.shared_experts.linear_fc1.weight": [ + "mlp.shared_experts.gate_proj.weight", + "mlp.shared_experts.up_proj.weight", + ], "pre_mlp_layernorm.weight": "post_attention_layernorm.weight", "mlp.router.weight": "mlp.gate.weight", "mlp.router.expert_bias": "mlp.gate.e_score_correction_bias", @@ -260,7 +382,12 @@ def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[lis def _convert_mtp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: assert self.mcore_config.mtp_num_layers == 1, "only support one mtp layer for now" assert self.mcore_config.num_layers == 61, "only support 61 layers for now" - direct_name_mapping = {"mtp.layers.0.enorm.weight": "model.layers.61.enorm.weight", "mtp.layers.0.hnorm.weight": "model.layers.61.hnorm.weight", "mtp.layers.0.eh_proj.weight": "model.layers.61.eh_proj.weight", "mtp.layers.0.final_layernorm.weight": "model.layers.61.shared_head.norm.weight"} + direct_name_mapping = { + "mtp.layers.0.enorm.weight": "model.layers.61.enorm.weight", + "mtp.layers.0.hnorm.weight": "model.layers.61.hnorm.weight", + "mtp.layers.0.eh_proj.weight": "model.layers.61.eh_proj.weight", + "mtp.layers.0.final_layernorm.weight": "model.layers.61.shared_head.norm.weight", + } if name in direct_name_mapping: return [direct_name_mapping[name]], [params[0]] assert "mtp.layers.0.transformer_layer" in name, "only support transformer layer for now" diff --git a/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader.py b/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader.py index 7d15a28bb..3168635c7 100644 --- a/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader.py +++ b/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader.py @@ -17,6 +17,8 @@ import torch import torch.distributed as dist +from verl.utils.device import get_device_id, get_torch_device + def _megatron_calc_layer_map(config): """Calculate the mapping of global layer_idx to local layer_idx @@ -36,7 +38,9 @@ def _megatron_calc_layer_map(config): for pp_rank_idx in range(pp_size): for virtual_pp_rank_idx in range(virtual_pp_size): - layer_offset = virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model + layer_offset = ( + virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model + ) for layer_idx in range(num_layers_per_model): layer_map[layer_offset + layer_idx] = ( pp_rank_idx, @@ -46,14 +50,17 @@ def _megatron_calc_layer_map(config): return layer_map -def load_state_dict_to_megatron_qwen2(state_dict, wrapped_models, config, params_dtype, is_value_model=False, tie_word_embeddings=False): +def load_state_dict_to_megatron_qwen2( + state_dict, wrapped_models, config, params_dtype, is_value_model=False, tie_word_embeddings=False +): """Load merged state_dict to sharded Megatron module in training.""" from megatron.core import DistributedDataParallel as LocalDDP from megatron.core import mpu from megatron.core.transformer.module import Float16Module from torch.nn.parallel import DistributedDataParallel as torchDDP - from verl.utils.megatron_utils import print_rank_0, unwrap_model + from verl.utils.logger import print_rank_0 + from verl.utils.megatron_utils import unwrap_model start_time = time.time() @@ -62,7 +69,9 @@ def _get_gpt_model(model): def fetch_params(module): for param in module.parameters(): - torch.distributed.fetch(param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group()) + torch.distributed.fetch( + param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group() + ) dp_rank = mpu.get_data_parallel_rank() pp_rank = mpu.get_pipeline_model_parallel_rank() @@ -75,12 +84,15 @@ def fetch_params(module): assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0" assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0" - if not isinstance(wrapped_models, (list, tuple)): + if not isinstance(wrapped_models, list | tuple): wrapped_models = list(wrapped_models) assert len(wrapped_models) == virtual_pp_size num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size - assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers, f"num_layers_per_model: {num_layers_per_model} * pp_size: {pp_size} * virtual_pp_size: {virtual_pp_size} != config.num_hidden_layers: {config.num_hidden_layers}" + assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers, ( + f"num_layers_per_model: {num_layers_per_model} * pp_size: {pp_size} * virtual_pp_size: " + f"{virtual_pp_size} != config.num_hidden_layers: {config.num_hidden_layers}" + ) models = [None] * len(wrapped_models) @@ -136,12 +148,16 @@ def _fetch_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tensor: if gate_name in state_dict and up_name in state_dict: gate_weight = state_dict[gate_name] up_weight = state_dict[up_name] - new_gate_up_weight = torch.empty(config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device()) + new_gate_up_weight = torch.empty( + config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=get_device_id() + ) for i in range(tp_size): intermediate_size_tp = config.intermediate_size // tp_size gate_weight_tp = gate_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] up_weight_tp = up_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] - new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_(torch.cat([gate_weight_tp, up_weight_tp], dim=0)) + new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_( + torch.cat([gate_weight_tp, up_weight_tp], dim=0) + ) tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0) if tensor is not None: @@ -167,9 +183,11 @@ def _fetch_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) -> to kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size total_size = q_size_tp + 2 * kv_size_tp if not bias: - new_weight_qkv = torch.empty(total_size * tp_size, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device()) + new_weight_qkv = torch.empty( + total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id() + ) else: - new_weight_qkv = torch.empty(total_size * tp_size, dtype=params_dtype, device=torch.cuda.current_device()) + new_weight_qkv = torch.empty(total_size * tp_size, dtype=params_dtype, device=get_device_id()) for i in range(tp_size): q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] k_part = full_weight_k[i * kv_size_tp : (i + 1) * kv_size_tp] @@ -181,9 +199,11 @@ def _fetch_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) -> to kv_size_tp = hidden_size_per_head total_size = q_size_tp + 2 * kv_size_tp if not bias: - new_weight_qkv = torch.empty(total_size * tp_size, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device()) + new_weight_qkv = torch.empty( + total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id() + ) else: - new_weight_qkv = torch.empty(total_size * tp_size, dtype=params_dtype, device=torch.cuda.current_device()) + new_weight_qkv = torch.empty(total_size * tp_size, dtype=params_dtype, device=get_device_id()) for i in range(tp_size): q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head @@ -218,7 +238,9 @@ def _fetch_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) -> to for vpp_rank in range(vpp_size): num_layer_vpp_chunk = num_layer_per_pp // vpp_size num_layer_this_model = num_layer_vpp_chunk - offset = vpp_rank * (config.num_hidden_layers // mpu.get_virtual_pipeline_model_parallel_world_size()) + (mpu.get_pipeline_model_parallel_rank() * num_layer_vpp_chunk) + offset = vpp_rank * (config.num_hidden_layers // mpu.get_virtual_pipeline_model_parallel_world_size()) + ( + mpu.get_pipeline_model_parallel_rank() * num_layer_vpp_chunk + ) layer_list.extend(list(range(offset, offset + num_layer_this_model))) else: num_layer_this_model = num_layer_per_pp @@ -230,7 +252,10 @@ def _fetch_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) -> to layer_name = f"model.layers.{layer}" dst_pp_rank, dst_virtual_pp_rank, dst_layer_idx = layer_map[layer] - print(f"{torch.distributed.get_rank()} offset: {offset}, num_layer_this_model: {num_layer_this_model}, layer_name: {layer_name}, layer_map[layer]: {layer_map[layer]}") + print( + f"{torch.distributed.get_rank()} offset: {offset}, num_layer_this_model: {num_layer_this_model}, " + f"layer_name: {layer_name}, layer_map[layer]: {layer_map[layer]}" + ) gpt_model_module = _get_gpt_model(models[dst_virtual_pp_rank]) sync_layer = gpt_model_module.model.layers[dst_layer_idx] @@ -308,5 +333,5 @@ def _fetch_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) -> to _fetch_tp_shard_tensor(lm_head_weight, "lm_head.weight") dist.barrier() - torch.cuda.empty_cache() + get_torch_device().empty_cache() print_rank_0(f"loading megatron ckpt done, time elapsed {time.time() - start_time}s") diff --git a/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader_depracated.py b/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader_depracated.py index 8f581176c..770e36533 100644 --- a/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader_depracated.py +++ b/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader_depracated.py @@ -17,6 +17,8 @@ import torch import torch.distributed as dist +from verl.utils.device import get_device_id, get_torch_device + def _megatron_calc_layer_map(config): """Calculate the mapping of global layer_idx to local layer_idx @@ -36,7 +38,9 @@ def _megatron_calc_layer_map(config): for pp_rank_idx in range(pp_size): for virtual_pp_rank_idx in range(virtual_pp_size): - layer_offset = virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model + layer_offset = ( + virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model + ) for layer_idx in range(num_layers_per_model): layer_map[layer_offset + layer_idx] = ( pp_rank_idx, @@ -46,14 +50,17 @@ def _megatron_calc_layer_map(config): return layer_map -def load_state_dict_to_megatron_qwen2(state_dict, wrapped_models, config, params_dtype, is_value_model=False, tie_word_embeddings=False): +def load_state_dict_to_megatron_qwen2( + state_dict, wrapped_models, config, params_dtype, is_value_model=False, tie_word_embeddings=False +): """Load merged state_dict to sharded Megatron module in training.""" from megatron.core import DistributedDataParallel as LocalDDP from megatron.core import mpu from megatron.core.transformer.module import Float16Module from torch.nn.parallel import DistributedDataParallel as torchDDP - from verl.utils.megatron_utils import print_rank_0, unwrap_model + from verl.utils.logger import print_rank_0 + from verl.utils.megatron_utils import unwrap_model start_time = time.time() @@ -62,7 +69,9 @@ def _get_gpt_model(model): def broadcast_params(module): for param in module.parameters(): - torch.distributed.broadcast(param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group()) + torch.distributed.broadcast( + param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group() + ) dp_rank = mpu.get_data_parallel_rank() pp_rank = mpu.get_pipeline_model_parallel_rank() @@ -75,12 +84,15 @@ def broadcast_params(module): assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0" assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0" - if not isinstance(wrapped_models, (list, tuple)): + if not isinstance(wrapped_models, list | tuple): wrapped_models = list(wrapped_models) assert len(wrapped_models) == virtual_pp_size num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size - assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers, f"num_layers_per_model: {num_layers_per_model} * pp_size: {pp_size} * virtual_pp_size: {virtual_pp_size} != config.num_hidden_layers: {config.num_hidden_layers}" + assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers, ( + f"num_layers_per_model: {num_layers_per_model} * pp_size: {pp_size} * virtual_pp_size: " + f"{virtual_pp_size} != config.num_hidden_layers: {config.num_hidden_layers}" + ) models = [None] * len(wrapped_models) @@ -116,7 +128,7 @@ def _broadcast_tensor(tensor, name) -> torch.Tensor: tensor = torch.empty( tensor_shape, dtype=params_dtype, - device=torch.cuda.current_device(), + device=get_device_id(), requires_grad=False, ) if torch.distributed.get_rank() == 0: @@ -155,12 +167,14 @@ def _broadcast_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None sync_tensor = torch.empty( chunk_shape, dtype=params_dtype, - device=torch.cuda.current_device(), + device=get_device_id(), requires_grad=False, ) else: - assert tensor.shape == chunk_shape, f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" - sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False) + assert tensor.shape == chunk_shape, ( + f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" + ) + sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) for i in range(tp_size): if torch.distributed.get_rank() == 0: @@ -200,12 +214,14 @@ def _broadcast_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> t sync_tensor = torch.empty( chunk_shape, dtype=params_dtype, - device=torch.cuda.current_device(), + device=get_device_id(), requires_grad=False, ) else: - assert tensor.shape == chunk_shape, f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" - sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False) + assert tensor.shape == chunk_shape, ( + f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" + ) + sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) for i in range(tp_size): if torch.distributed.get_rank() == 0: @@ -224,12 +240,16 @@ def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tens if torch.distributed.get_rank() == 0: gate_weight = state_dict[gate_name] up_weight = state_dict[up_name] - new_gate_up_weight = torch.empty(config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device()) + new_gate_up_weight = torch.empty( + config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=get_device_id() + ) for i in range(tp_size): intermediate_size_tp = config.intermediate_size // tp_size gate_weight_tp = gate_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] up_weight_tp = up_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] - new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_(torch.cat([gate_weight_tp, up_weight_tp], dim=0)) + new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_( + torch.cat([gate_weight_tp, up_weight_tp], dim=0) + ) tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0) chunk_shape = tensor_chunk[0].shape @@ -248,12 +268,15 @@ def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tens sync_tensor = torch.empty( chunk_shape, dtype=params_dtype, - device=torch.cuda.current_device(), + device=get_device_id(), requires_grad=False, ) else: - assert tensor.shape == chunk_shape, f"rank #{torch.distributed.get_rank() == 0:} tensor {gate_name, up_name} shape {tensor.shape} != {chunk_shape}" - sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False) + assert tensor.shape == chunk_shape, ( + f"rank #{torch.distributed.get_rank() == 0:} tensor {gate_name, up_name} shape " + f"{tensor.shape} != {chunk_shape}" + ) + sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) for i in range(tp_size): if torch.distributed.get_rank() == 0: @@ -282,30 +305,38 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) - kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size total_size = q_size_tp + 2 * kv_size_tp if not bias: - new_weight_qkv = torch.empty(total_size * tp_size, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device()) + new_weight_qkv = torch.empty( + total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id() + ) else: - new_weight_qkv = torch.empty(total_size * tp_size, dtype=params_dtype, device=torch.cuda.current_device()) + new_weight_qkv = torch.empty(total_size * tp_size, dtype=params_dtype, device=get_device_id()) for i in range(tp_size): q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] k_part = full_weight_k[i * kv_size_tp : (i + 1) * kv_size_tp] v_part = full_weight_v[i * kv_size_tp : (i + 1) * kv_size_tp] - new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0)) + new_weight_qkv[i * total_size : (i + 1) * total_size].copy_( + torch.cat([q_part, k_part, v_part], dim=0) + ) else: q_size_tp = config.hidden_size // tp_size kv_size_tp = hidden_size_per_head total_size = q_size_tp + 2 * kv_size_tp if not bias: - new_weight_qkv = torch.empty(total_size * tp_size, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device()) + new_weight_qkv = torch.empty( + total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id() + ) else: - new_weight_qkv = torch.empty(total_size * tp_size, dtype=params_dtype, device=torch.cuda.current_device()) + new_weight_qkv = torch.empty(total_size * tp_size, dtype=params_dtype, device=get_device_id()) for i in range(tp_size): q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head k_part = full_weight_k[start_idx:end_idx] v_part = full_weight_v[start_idx:end_idx] - new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0)) + new_weight_qkv[i * total_size : (i + 1) * total_size].copy_( + torch.cat([q_part, k_part, v_part], dim=0) + ) tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0) chunk_shape = tensor_chunk[0].shape @@ -324,12 +355,14 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) - sync_tensor = torch.empty( chunk_shape, dtype=params_dtype, - device=torch.cuda.current_device(), + device=get_device_id(), requires_grad=False, ) else: - assert tensor.shape == chunk_shape, f"rank #{torch.distributed.get_rank()} tensor {q_name} shape {tensor.shape} != {chunk_shape}" - sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False) + assert tensor.shape == chunk_shape, ( + f"rank #{torch.distributed.get_rank()} tensor {q_name} shape {tensor.shape} != {chunk_shape}" + ) + sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) for i in range(tp_size): if torch.distributed.get_rank() == 0: @@ -438,5 +471,5 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) - for wrapped_model in wrapped_models: broadcast_params(wrapped_model) - torch.cuda.empty_cache() + get_torch_device().empty_cache() print_rank_0(f"loading megatron ckpt done, time elapsed {time.time() - start_time}s") diff --git a/verl/models/qwen2/megatron/checkpoint_utils/qwen2_saver.py b/verl/models/qwen2/megatron/checkpoint_utils/qwen2_saver.py index 11cba17b1..737f73b4c 100644 --- a/verl/models/qwen2/megatron/checkpoint_utils/qwen2_saver.py +++ b/verl/models/qwen2/megatron/checkpoint_utils/qwen2_saver.py @@ -21,7 +21,9 @@ from megatron.core.transformer.module import Float16Module from torch.nn.parallel import DistributedDataParallel as torchDDP -from verl.utils.megatron_utils import print_rank_0, unwrap_model +from verl.utils.device import get_device_id, get_torch_device +from verl.utils.logger import print_rank_0 +from verl.utils.megatron_utils import unwrap_model def _megatron_calc_global_rank(tp_rank: int = 0, dp_rank: int = 0, pp_rank: int = 0): @@ -30,7 +32,9 @@ def _megatron_calc_global_rank(tp_rank: int = 0, dp_rank: int = 0, pp_rank: int tp_size = mpu.get_tensor_model_parallel_world_size() dp_size = mpu.get_data_parallel_world_size() pp_size = mpu.get_pipeline_model_parallel_world_size() - assert tp_size * dp_size * pp_size == torch.distributed.get_world_size(), f"{tp_size} x {dp_size} x {pp_size} != {torch.distributed.get_world_size()}" + assert tp_size * dp_size * pp_size == torch.distributed.get_world_size(), ( + f"{tp_size} x {dp_size} x {pp_size} != {torch.distributed.get_world_size()}" + ) # We only support TP-DP-PP grouping, for correctness when resharding return (pp_rank * dp_size + dp_rank) * tp_size + tp_rank @@ -53,7 +57,9 @@ def _megatron_calc_layer_map(config): for pp_rank_idx in range(pp_size): for virtual_pp_rank_idx in range(virtual_pp_size): - layer_offset = virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model + layer_offset = ( + virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model + ) for layer_idx in range(num_layers_per_model): layer_map[layer_offset + layer_idx] = ( pp_rank_idx, @@ -94,7 +100,7 @@ def _get_gpt_model(model): assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0" assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0" - if not isinstance(wrapped_models, (list, tuple)): + if not isinstance(wrapped_models, list | tuple): wrapped_models = list(wrapped_models) assert len(wrapped_models) == virtual_pp_size @@ -105,7 +111,11 @@ def _get_gpt_model(model): for i, wrapped_model in enumerate(wrapped_models): models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module)) - assert len(models[i].model.layers) == num_layers_per_model, "len model layers {} not equal to num_layers_per_model {}".format(len(models[i].model.layers), num_layers_per_model) + assert len(models[i].model.layers) == num_layers_per_model, ( + "len model layers {} not equal to num_layers_per_model {}".format( + len(models[i].model.layers), num_layers_per_model + ) + ) state_dict = dict() @@ -146,7 +156,7 @@ def _broadcast_tensor(tensor, name, src_pp_rank) -> torch.Tensor: weight = torch.empty( tensor_shape, dtype=dtype, - device=torch.cuda.current_device(), + device=get_device_id(), requires_grad=False, ) @@ -175,7 +185,7 @@ def _broadcast_tp_shard_tensor(tensor, name, src_pp_rank, concat_dim=0, mutate_f buffer_tensor = torch.empty( chunk_shape, dtype=dtype, - device=torch.cuda.current_device(), + device=get_device_id(), requires_grad=False, ) @@ -215,7 +225,7 @@ def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name, src_pp_rank) buffer_tensor = torch.empty( chunk_shape, dtype=dtype, - device=torch.cuda.current_device(), + device=get_device_id(), requires_grad=False, ) @@ -264,7 +274,7 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, src_pp_rank): buffer_tensor = torch.empty( chunk_shape, dtype=dtype, - device=torch.cuda.current_device(), + device=get_device_id(), requires_grad=False, ) @@ -316,7 +326,7 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, src_pp_rank): state_dict[v_name] = torch.cat(v_weight_list, dim=0) # empty cache before collecting weights - torch.cuda.empty_cache() + get_torch_device().empty_cache() # Embeddings # ------------------- if dp_rank == 0: @@ -412,7 +422,9 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, src_pp_rank): src_pp_rank=pp_size - 1, ) _broadcast_tensor( - gpt_model_module.reward_head.weight if pp_rank == pp_size - 1 and getattr(gpt_model_module, "reward_weight", None) is not None else None, + gpt_model_module.reward_head.weight + if pp_rank == pp_size - 1 and getattr(gpt_model_module, "reward_weight", None) is not None + else None, "reward_head.weight", src_pp_rank=pp_size - 1, ) @@ -426,7 +438,7 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, src_pp_rank): dist.barrier() - torch.cuda.empty_cache() + get_torch_device().empty_cache() if torch.distributed.get_rank() == 0: for k, v in state_dict.items(): if dtype != v.dtype: diff --git a/verl/models/qwen2/megatron/layers/__init__.py b/verl/models/qwen2/megatron/layers/__init__.py index b6972ccf7..263ea596f 100644 --- a/verl/models/qwen2/megatron/layers/__init__.py +++ b/verl/models/qwen2/megatron/layers/__init__.py @@ -17,4 +17,10 @@ from .parallel_mlp import ParallelQwen2MLP from .parallel_rmsnorm import ParallelQwen2RMSNorm -__all__ = ["ParallelQwen2Attention", "ParallelQwen2DecoderLayer", "ParallelQwen2DecoderLayerRmPad", "ParallelQwen2MLP", "ParallelQwen2RMSNorm"] +__all__ = [ + "ParallelQwen2Attention", + "ParallelQwen2DecoderLayer", + "ParallelQwen2DecoderLayerRmPad", + "ParallelQwen2MLP", + "ParallelQwen2RMSNorm", +] diff --git a/verl/models/qwen2/megatron/layers/parallel_attention.py b/verl/models/qwen2/megatron/layers/parallel_attention.py index d59e5a5f1..702c429c2 100644 --- a/verl/models/qwen2/megatron/layers/parallel_attention.py +++ b/verl/models/qwen2/megatron/layers/parallel_attention.py @@ -19,7 +19,7 @@ # limitations under the License. import math -from typing import Optional, Tuple +from typing import Optional import torch.nn.functional as F from einops import rearrange @@ -50,7 +50,9 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): self.register_buffer("inv_freq", inv_freq, persistent=False) # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache(seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()) + self._set_cos_sin_cache( + seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len @@ -103,7 +105,9 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len if seq_len > self.max_position_embeddings: - base = self.base * ((self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)) ** (self.dim / (self.dim - 2)) + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) @@ -160,15 +164,23 @@ def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig): # assign values after tp tp_size = mpu.get_tensor_model_parallel_world_size() - assert self.num_heads % tp_size == 0, f"num_head must be divisible by tp_size. Got num_head={self.num_heads}, tp_size={tp_size}" - assert self.num_key_value_heads % tp_size == 0, f"num_key_value_heads must be divisible by tp_size. Got num_key_value_heads={self.num_key_value_heads}, tp_size={tp_size}" + assert self.num_heads % tp_size == 0, ( + f"num_head must be divisible by tp_size. Got num_head={self.num_heads}, tp_size={tp_size}" + ) + assert self.num_key_value_heads % tp_size == 0, ( + f"num_key_value_heads must be divisible by tp_size. Got num_key_value_heads=" + f"{self.num_key_value_heads}, tp_size={tp_size}" + ) self.num_heads_per_tp = self.num_heads // tp_size self.num_key_value_heads_per_tp = self.num_key_value_heads // tp_size self.hidden_size_per_tp = self.hidden_size // tp_size if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`: {self.num_heads}).") + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and " + f"`num_heads`: {self.num_heads})." + ) column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() row_kwargs = tp_utils.get_default_kwargs_for_row_parallel_linear() @@ -223,7 +235,7 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() qkv = self.qkv_proj(hidden_states)[0] query_states, key_states, value_states = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1) @@ -242,11 +254,16 @@ def forward( attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if attn_weights.size() != (bsz, self.num_heads_per_tp, q_len, kv_seq_len): - raise ValueError(f"Attention weights should be of size {(bsz, self.num_heads_per_tp, q_len, kv_seq_len)}, but is {attn_weights.size()}") + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads_per_tp, q_len, kv_seq_len)}, " + f"but is {attn_weights.size()}" + ) if attention_mask is not None: if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError(f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}") + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) attn_weights = attn_weights + attention_mask # upcast attention to fp32 @@ -254,7 +271,10 @@ def forward( attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (bsz, self.num_heads_per_tp, q_len, self.head_dim): - raise ValueError(f"`attn_output` should be of size {(bsz, self.num_heads_per_tp, q_len, self.head_dim)}, but is {attn_output.size()}") + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads_per_tp, q_len, self.head_dim)}, " + f"but is {attn_output.size()}" + ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size_per_tp) @@ -288,8 +308,12 @@ def apply_rotary_pos_emb_rmpad(q, k, cos, sin, position_ids, indices, sequence_l # use flash-attn rotary embeddings with rmpad # cos/sin shoudl be: (seq_length, rotary_dim / 2) def apply_rotary_pos_emb_rmpad_flash(q, k, cos, sin, cu_seqlens, max_seqlen): - q_embed = apply_rotary_emb(q, cos, sin, interleaved=False, inplace=False, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) - k_embed = apply_rotary_emb(k, cos, sin, interleaved=False, inplace=False, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + q_embed = apply_rotary_emb( + q, cos, sin, interleaved=False, inplace=False, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen + ) + k_embed = apply_rotary_emb( + k, cos, sin, interleaved=False, inplace=False, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen + ) return q_embed, k_embed @@ -309,7 +333,9 @@ def forward( total_nnz = total_nnz * mpu.get_tensor_model_parallel_world_size() qkv = self.qkv_proj(hidden_states)[0] - query_states, key_states, value_states = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1) # (total_nnz, 1, hidden_size) + query_states, key_states, value_states = qkv.split( + [self.q_size, self.k_size, self.v_size], dim=-1 + ) # (total_nnz, 1, hidden_size) if self.megatron_config.sequence_parallel: sequence_parallel_pad = total_nnz - cu_seqlens[-1] @@ -327,8 +353,11 @@ def forward( cos, sin = self.rotary_emb(value_states, seq_len=sequence_length) cos, sin = cos[:, : cos.shape[1] // 2], sin[:, : sin.shape[1] // 2] # flash attn only needs half - query_states, key_states = apply_rotary_pos_emb_rmpad_flash(query_states, key_states, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen_in_batch) - # query_states, key_states = apply_rotary_pos_emb_rmpad(query_states, key_states, cos, sin, position_ids, indices, + query_states, key_states = apply_rotary_pos_emb_rmpad_flash( + query_states, key_states, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen_in_batch + ) + # query_states, key_states = apply_rotary_pos_emb_rmpad(query_states, key_states, cos, sin, + # position_ids, indices, # It is recommended to use dropout with FA according to the docs # when training. diff --git a/verl/models/qwen2/megatron/layers/parallel_decoder.py b/verl/models/qwen2/megatron/layers/parallel_decoder.py index 4217c2897..3c8a2a6ee 100644 --- a/verl/models/qwen2/megatron/layers/parallel_decoder.py +++ b/verl/models/qwen2/megatron/layers/parallel_decoder.py @@ -18,7 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple +from typing import Optional import torch from megatron.core import ModelParallelConfig @@ -49,7 +49,7 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` @@ -119,7 +119,7 @@ def forward( indices: torch.Tensor = None, cu_seqlens: int = None, max_seqlen_in_batch: int = None, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states # (total_nnz // sp, 1, hidden_size) hidden_states = self.input_layernorm(hidden_states) diff --git a/verl/models/qwen2/megatron/modeling_qwen2_megatron.py b/verl/models/qwen2/megatron/modeling_qwen2_megatron.py index cc95bfaca..92e81be8d 100644 --- a/verl/models/qwen2/megatron/modeling_qwen2_megatron.py +++ b/verl/models/qwen2/megatron/modeling_qwen2_megatron.py @@ -19,7 +19,7 @@ # limitations under the License. """PyTorch Qwen2 model.""" -from typing import Optional, Tuple, Union +from typing import Optional import torch import torch.utils.checkpoint @@ -29,6 +29,7 @@ from transformers.models.qwen2.configuration_qwen2 import Qwen2Config from transformers.models.qwen2.modeling_qwen2 import CausalLMOutputWithPast +from verl.utils.device import get_device_name from verl.utils.megatron import sequence_parallel as sp_utils from verl.utils.megatron import tensor_parallel as tp_utils from verl.utils.megatron_utils import TransformerConfig, convert_config @@ -88,9 +89,13 @@ def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig): if megatron_config is not None: assert embedding_kwargs.get("config", False), "must have ModelParallelConfig" tp_utils.update_kwargs_with_config(embedding_kwargs, megatron_config) - self.embed_tokens = tensor_parallel.VocabParallelEmbedding(num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs) + self.embed_tokens = tensor_parallel.VocabParallelEmbedding( + num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs + ) - self.layers = nn.ModuleList([ParallelQwen2DecoderLayer(config, megatron_config) for _ in range(config.num_hidden_layers)]) + self.layers = nn.ModuleList( + [ParallelQwen2DecoderLayer(config, megatron_config) for _ in range(config.num_hidden_layers)] + ) self.norm = ParallelQwen2RMSNorm(config, megatron_config) # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask @@ -107,8 +112,12 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_em if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(inputs_embeds.device) - combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( + inputs_embeds.device + ) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) return combined_attention_mask @@ -117,7 +126,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: + ) -> tuple | BaseModelOutputWithPast: """ Args: @@ -176,7 +185,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: + ) -> tuple | CausalLMOutputWithPast: r""" Args: labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -230,9 +239,13 @@ def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig): if megatron_config is not None: assert embedding_kwargs.get("config", False), "must have ModelParallelConfig" tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config) - self.embed_tokens = tensor_parallel.VocabParallelEmbedding(num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs) + self.embed_tokens = tensor_parallel.VocabParallelEmbedding( + num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs + ) - self.layers = nn.ModuleList([ParallelQwen2DecoderLayerRmPad(config, megatron_config) for _ in range(config.num_hidden_layers)]) + self.layers = nn.ModuleList( + [ParallelQwen2DecoderLayerRmPad(config, megatron_config) for _ in range(config.num_hidden_layers)] + ) self.norm = ParallelQwen2RMSNorm(config, megatron_config) def forward( @@ -243,7 +256,7 @@ def forward( indices: torch.Tensor = None, cu_seqlens: int = None, max_seqlen_in_batch: int = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: + ) -> tuple | BaseModelOutputWithPast: """ Args: @@ -313,7 +326,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: + ) -> tuple | CausalLMOutputWithPast: r""" Args: labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -326,7 +339,9 @@ def forward( batch_size, sequence_length = input_ids.shape # remove padding here - input_ids, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input(input_ids.unsqueeze(dim=-1), attention_mask) # (total_nnz, 1) + input_ids, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input( + input_ids.unsqueeze(dim=-1), attention_mask + ) # (total_nnz, 1) # pad input_ids to multiple of tp for all tp ranks # TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap @@ -355,7 +370,9 @@ def forward( logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension # add removed padding back - logits = pad_input(logits, indices, batch_size, seqlen=sequence_length) # (batch_size, sequence_length, vocab_size) + logits = pad_input( + logits, indices, batch_size, seqlen=sequence_length + ) # (batch_size, sequence_length, vocab_size) return CausalLMOutputWithPast( loss=None, @@ -388,7 +405,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: + ) -> tuple | CausalLMOutputWithPast: output = super().forward(input_ids, attention_mask, position_ids) output.logits = torch.squeeze(output.logits, dim=-1) return output @@ -422,7 +439,9 @@ def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig, pr assert embedding_kwargs.get("config", False), "must have ModelParallelConfig" tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config) if pre_process: - self.embed_tokens = tensor_parallel.VocabParallelEmbedding(num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs) + self.embed_tokens = tensor_parallel.VocabParallelEmbedding( + num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs + ) else: self.embed_tokens = None @@ -468,7 +487,7 @@ def forward( indices: torch.Tensor = None, cu_seqlens: int = None, max_seqlen_in_batch: int = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: + ) -> tuple | BaseModelOutputWithPast: """ Args: @@ -523,7 +542,9 @@ def __init__( super().__init__() self.config: TransformerConfig = convert_config(config, megatron_config) self.megatron_config = megatron_config - self.model = ParallelQwen2ModelRmPadPP(config, megatron_config=megatron_config, pre_process=pre_process, post_process=post_process) + self.model = ParallelQwen2ModelRmPadPP( + config, megatron_config=megatron_config, pre_process=pre_process, post_process=post_process + ) self.share_embeddings_and_output_weights = share_embeddings_and_output_weights self.vocab_size = config.vocab_size self.pre_process = pre_process @@ -595,7 +616,7 @@ def setup_embeddings_and_output_layer(self) -> None: if torch.distributed.is_initialized() and parallel_state.is_rank_in_embedding_group(): weight = self.shared_embedding_or_output_weight() - weight.data = weight.data.cuda() + weight.data = weight.data.to(get_device_name()) torch.distributed.all_reduce(weight.data, group=parallel_state.get_embedding_group()) def shared_embedding_or_output_weight(self) -> torch.Tensor: @@ -607,7 +628,8 @@ def shared_embedding_or_output_weight(self) -> torch.Tensor: def _forward_head(self, hidden_states): # all_gather from sequence parallel region is performed inside lm_head - # print(f'logits shape before forward_head: {hidden_states.shape}, vocab_size = {self.config.vocab_size}') # [4, 32, 4096] + # print(f'logits shape before forward_head: {hidden_states.shape}, vocab_size = ' + # f'{self.config.vocab_size}') # [4, 32, 4096] output_weight = None if self.share_embeddings_and_output_weights: output_weight = self.shared_embedding_or_output_weight() @@ -623,7 +645,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: + ) -> tuple | CausalLMOutputWithPast: r""" Args: labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -638,7 +660,9 @@ def forward( # In the first pp, input_ids will be used, in other pp layers hidden_states will be used inside self.model batch_size, sequence_length = input_ids.shape # remove padding here - input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input(input_ids.unsqueeze(dim=-1), attention_mask) # (total_nnz, 1) + input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input( + input_ids.unsqueeze(dim=-1), attention_mask + ) # (total_nnz, 1) # pad input_ids to multiple of tp for all tp ranks # TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap @@ -666,7 +690,9 @@ def forward( totol_nnz = cu_seqlens[-1] logits = logits[:totol_nnz] # (total_nnz_padded) # add removed padding back. If input is already rmpad, we let the caller pad_input - logits = pad_input(logits, indices, batch_size, seqlen=sequence_length) # (batch_size, sequence_length, vocab_size) + logits = pad_input( + logits, indices, batch_size, seqlen=sequence_length + ) # (batch_size, sequence_length, vocab_size) return CausalLMOutputWithPast( loss=None, @@ -702,7 +728,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: + ) -> tuple | CausalLMOutputWithPast: output = super().forward(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids) if self.post_process: output.logits = torch.squeeze(output.logits, dim=-1) diff --git a/verl/models/registry.py b/verl/models/registry.py index 6fa8effd4..829b9e20c 100644 --- a/verl/models/registry.py +++ b/verl/models/registry.py @@ -13,7 +13,7 @@ # limitations under the License. import importlib -from typing import List, Optional, Type +from typing import Optional import torch.nn as nn @@ -38,7 +38,7 @@ # return model class class ModelRegistry: @staticmethod - def load_model_cls(model_arch: str, value=False) -> Optional[Type[nn.Module]]: + def load_model_cls(model_arch: str, value=False) -> Optional[type[nn.Module]]: if model_arch not in _MODELS: return None @@ -54,5 +54,5 @@ def load_model_cls(model_arch: str, value=False) -> Optional[Type[nn.Module]]: return getattr(module, model_cls_name, None) @staticmethod - def get_supported_archs() -> List[str]: + def get_supported_archs() -> list[str]: return list(_MODELS.keys()) diff --git a/verl/models/transformers/dense_common.py b/verl/models/transformers/dense_common.py new file mode 100644 index 000000000..56fe293f5 --- /dev/null +++ b/verl/models/transformers/dense_common.py @@ -0,0 +1,193 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Optional, Union + +import torch +from transformers.cache_utils import Cache +from transformers.modeling_outputs import CausalLMOutputWithPast + + +@dataclass +class CausalLMOutputForPPO(CausalLMOutputWithPast): + log_probs: Optional[torch.FloatTensor] = None + entropy: Optional[torch.FloatTensor] = None + + +def forward_base_model( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, +) -> CausalLMOutputWithPast: + r""" + Copy paste LLaMa's forward + https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/transformers/model/llama.py + + This function should be generic enough for all pure text models. + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + return outputs + + +def forward_with_torch_backend( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union["Cache", list[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: int | torch.Tensor = 0, + temperature: float = 1.0, + **loss_kwargs, +) -> tuple | CausalLMOutputForPPO: + from verl.utils.experimental.torch_functional import FusedLinearForPPO + + outputs = forward_base_model( + self, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + + if not return_dict: + raise NotImplementedError("forward_with_torch_backend has to return_dict") + + # Loss calculations + if labels is not None: + rolled_labels = torch.roll(labels, shifts=-1, dims=-1) + elif input_ids is not None: + rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1) + else: + raise RuntimeError("To use forward_with_torch_backend, either labels or input_ids must be provided.") + + fused_linear_for_ppo = FusedLinearForPPO() + log_probs, entropy = fused_linear_for_ppo.forward( + hidden_states=hidden_states, + vocab_weights=self.lm_head.weight, + input_ids=rolled_labels, + temperature=temperature, + ) + + return CausalLMOutputForPPO( + log_probs=log_probs, + entropy=entropy, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +def forward_with_triton_backend( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union["Cache", list[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: int | torch.Tensor = 0, + temperature: float = 1.0, + **loss_kwargs, +) -> tuple | CausalLMOutputForPPO: + from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy + + outputs = forward_base_model( + self, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + + if not return_dict: + raise NotImplementedError("forward_with_triton_backend has to return_dict") + + # Loss calculations + if labels is not None: + rolled_labels = torch.roll(labels, shifts=-1, dims=-1) + elif input_ids is not None: + rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1) + else: + raise RuntimeError("To use forward_with_triton_backend, either labels or input_ids must be provided.") + + log_probs, entropy = linear_cross_entropy( + hidden_states, + self.lm_head.weight, + rolled_labels, + temperature, + "none", + ) + + return CausalLMOutputForPPO( + log_probs=log_probs, + entropy=entropy, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/verl/models/transformers/kimi_vl.py b/verl/models/transformers/kimi_vl.py index b2133a9c5..edd79364b 100644 --- a/verl/models/transformers/kimi_vl.py +++ b/verl/models/transformers/kimi_vl.py @@ -12,58 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple +from typing import Optional import torch import torch.nn.functional as F from transformers.cache_utils import Cache from transformers.modeling_flash_attention_utils import _flash_attention_forward -from verl.utils.ulysses import gather_heads_scatter_seq, gather_outpus_and_unpad, gather_seq_scatter_heads, get_ulysses_sequence_parallel_group, get_ulysses_sequence_parallel_rank, get_ulysses_sequence_parallel_world_size, validate_ulysses_config - - -def _merge_with_image_features( - self, - inputs_embeds: torch.Tensor, - input_ids: torch.Tensor, - image_features: torch.Tensor, -): - """ - Args: - inputs_embeds (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length, input_embed_dim)`): - The input embeddings. - input_ids (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`): - The input ids. - image_features (:obj:`torch.Tensor` of shape :obj:`(image_token_nums, image_feature_dim)`): - The image features to merge with the input embeddings. - """ - image_token_index: int = self.config.media_placeholder_token_id - - batch_size, sequence_length, input_embed_dim = inputs_embeds.shape - image_feature_nums, image_feature_dim = image_features.shape - - assert image_feature_dim == input_embed_dim - - image_token_nums = (input_ids == image_token_index).sum() - total_image_token_nums = torch.tensor([image_token_nums], dtype=image_token_nums.dtype, device=input_ids.device) - total_image_token_nums = gather_outpus_and_unpad(total_image_token_nums, gather_dim=0) # [sp_size] - assert image_feature_nums == total_image_token_nums.sum() - - # (batch_size, sequence_length / sp, input_embed_dim) -> (batch_size * sequence_length / sp, input_embed_dim) - inputs_embeds = inputs_embeds.reshape(-1, input_embed_dim) - - # (batch_size, sequence_length / sp) -> (batch_size * sequence_length / sp) - input_ids = input_ids.flatten() - - # split image features and fill in the image token positions if there are image tokens - sp_image_features = image_features.split(total_image_token_nums.tolist(), dim=0) - sp_rank = get_ulysses_sequence_parallel_rank() - image_features = sp_image_features[sp_rank] - inputs_embeds[input_ids == image_token_index] = image_features - - inputs_embeds = inputs_embeds.reshape((batch_size, sequence_length, input_embed_dim)) - - return inputs_embeds +from verl.utils.ulysses import ( + gather_heads_scatter_seq, + gather_seq_scatter_heads, + get_ulysses_sequence_parallel_world_size, + validate_ulysses_config, +) # Copied from transformers.models.llama.modeling_llama.rotate_half @@ -132,7 +93,7 @@ def _ulysses_flash_attn_forward( output_attentions: bool = False, use_cache: bool = False, **kwargs, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: +) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() if self.q_lora_rank is None: @@ -140,7 +101,6 @@ def _ulysses_flash_attn_forward( else: q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) - q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) # Flash attention requires the input to have the shape # batch_size x seq_length x head_dim x hidden_dim @@ -148,50 +108,50 @@ def _ulysses_flash_attn_forward( compressed_kv = self.kv_a_proj_with_mqa(hidden_states) compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) - kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv)).view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim).transpose(1, 2) + kv = ( + self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) + .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + .transpose(1, 2) + ) k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) - kv_seq_len = value_states.shape[-2] - # patch to get all emb + # patch ulysses_sp_size = get_ulysses_sequence_parallel_world_size() - kv_seq_len *= ulysses_sp_size + if ulysses_sp_size > 1: + validate_ulysses_config(self.num_heads, ulysses_sp_size) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + num_key_value_groups = self.config.num_attention_heads // self.config.num_key_value_heads + k_pe = repeat_kv(k_pe, ulysses_sp_size) # to keep heads=1 after a2a + k_nope = repeat_kv(k_nope, num_key_value_groups) + value_states = repeat_kv(value_states, num_key_value_groups) + q = gather_seq_scatter_heads(q, seq_dim=2, head_dim=1) + k_pe = gather_seq_scatter_heads(k_pe, seq_dim=2, head_dim=1) + k_nope = gather_seq_scatter_heads(k_nope, seq_dim=2, head_dim=1) + value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1) + # (batch_size, num_head / sp_size, seq_length, head_size) + full_q_len = q.size(2) # full_q_len = seq_length + + else: + full_q_len = q_len + + q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + cos, sin = self.rotary_emb(value_states, seq_len=full_q_len) q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) - query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + query_states = k_pe.new_empty(bsz, self.num_heads // ulysses_sp_size, full_q_len, self.q_head_dim) query_states[:, :, :, : self.qk_nope_head_dim] = q_nope query_states[:, :, :, self.qk_nope_head_dim :] = q_pe - key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + key_states = k_pe.new_empty(bsz, self.num_heads // ulysses_sp_size, full_q_len, self.q_head_dim) key_states[:, :, :, : self.qk_nope_head_dim] = k_nope key_states[:, :, :, self.qk_nope_head_dim :] = k_pe if self.q_head_dim != self.v_head_dim: value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim]) - # patch - if ulysses_sp_size > 1: - validate_ulysses_config(self.num_heads, ulysses_sp_size) - - num_key_value_groups = self.config.num_attention_heads // self.config.num_key_value_heads - key_states = repeat_kv(key_states, num_key_value_groups) - value_states = repeat_kv(value_states, num_key_value_groups) - query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1) - key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1) - value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1) - # (batch_size, num_head / sp_size, seq_length, head_size) - full_q_len = query_states.size(2) # full_q_len = seq_length - - position_ids_list = [torch.empty_like(position_ids) for _ in range(ulysses_sp_size)] - torch.distributed.all_gather(position_ids_list, position_ids, group=get_ulysses_sequence_parallel_group()) - position_ids = torch.concat(position_ids_list, dim=-1) - - else: - full_q_len = q_len - - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # TODO: These transpose are quite inefficient but Flash Attention requires the layout + # [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache # to be able to avoid many of these transpose/reshape/view. query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) diff --git a/verl/models/transformers/llama.py b/verl/models/transformers/llama.py index f44252bb7..687ceab71 100644 --- a/verl/models/transformers/llama.py +++ b/verl/models/transformers/llama.py @@ -13,8 +13,7 @@ # limitations under the License. import sys -from dataclasses import dataclass -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, Optional import torch @@ -25,7 +24,6 @@ from transformers.cache_utils import Cache from transformers.modeling_flash_attention_utils import _flash_attention_forward -from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.models.llama.modeling_llama import apply_rotary_pos_emb from transformers.utils import logging @@ -48,9 +46,9 @@ def llama_flash_attn_forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 **kwargs, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: +) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: """ Adapted from transformers 4.47.1 to support Ulysses sequence parallelism. @@ -105,7 +103,8 @@ def llama_flash_attn_forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # TODO: These transpose are quite inefficient but Flash Attention requires the layout + # [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache # to be able to avoid many of these transpose/reshape/view. query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) @@ -129,7 +128,11 @@ def llama_flash_attn_forward( else: target_dtype = self.q_proj.weight.dtype - logger.warning_once(f"The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in {target_dtype}.") + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to " + f"the fact you have upcasted embedding or layer norm layers in float32. We will cast back the " + f"input in {target_dtype}." + ) query_states = query_states.to(target_dtype) key_states = key_states.to(target_dtype) @@ -165,12 +168,12 @@ def llama_flash_attn_forward( def llama_attn_forward( self, hidden_states: torch.Tensor, - position_embeddings: Tuple[torch.Tensor, torch.Tensor], + position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: +) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: """ Adapted from transformers 4.49.0 to support Ulysses sequence parallelism for transformers >= 4.48.0. @@ -208,7 +211,11 @@ def llama_attn_forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once('`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.') + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. " + "Falling back to eager attention. This warning can be removed using the argument " + '`attn_implementation="eager"` when loading the model.' + ) else: attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] @@ -230,84 +237,3 @@ def llama_attn_forward( attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights - - -@dataclass -class CausalLMOutputForPPO(CausalLMOutputWithPast): - log_probs: Optional[torch.FloatTensor] = None - entropy: Optional[torch.FloatTensor] = None - - -def forward_for_ppo( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union["Cache", List[torch.FloatTensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, - temperature: float = 1.0, - **loss_kwargs, -) -> Union[Tuple, CausalLMOutputForPPO]: - r""" - Copy paste LLaMa's forward - https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/transformers/model/llama.py - - This function should be generic enough for all pure text models. - ```""" - from verl.utils.experimental.torch_functional import FusedLinearForPPO - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - ) - - hidden_states = outputs[0] - - if not return_dict: - raise NotImplementedError("forward_for_ppo has to return_dict") - - # Loss calculations - if labels is not None: - rolled_labels = torch.roll(labels, shifts=-1, dims=-1) - elif input_ids is not None: - rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1) - else: - raise RuntimeError("To use forward_for_ppo, either labels or input_ids must be provided.") - - fused_linear_for_ppo = FusedLinearForPPO() - log_probs, entropy = fused_linear_for_ppo.forward( - hidden_states=hidden_states, - vocab_weights=self.lm_head.weight, - input_ids=rolled_labels, - temperature=temperature, - ) - - return CausalLMOutputForPPO( - log_probs=log_probs, - entropy=entropy, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) diff --git a/verl/models/transformers/monkey_patch.py b/verl/models/transformers/monkey_patch.py index c06b237d9..d6be65a77 100644 --- a/verl/models/transformers/monkey_patch.py +++ b/verl/models/transformers/monkey_patch.py @@ -25,6 +25,7 @@ from transformers.modeling_flash_attention_utils import _flash_attention_forward from transformers.modeling_utils import PreTrainedModel +from verl.utils.import_utils import is_trl_available from verl.utils.ulysses import ( gather_heads_scatter_seq, gather_seq_scatter_heads, @@ -97,7 +98,9 @@ def _ulysses_flash_attention_forward( position_ids = torch.concat(position_ids_list, dim=-1) # (bsz, seq_len, n_head/n, head_dim) - attn_output = _flash_attention_forward(query_states, key_states, value_states, *args, position_ids=position_ids, **kwargs) + attn_output = _flash_attention_forward( + query_states, key_states, value_states, *args, position_ids=position_ids, **kwargs + ) ########## AlltoAll for Ulysses ########## if ulysses_sp_size > 1: @@ -120,7 +123,11 @@ def ulysses_wrapped_decoder_forward(self, *args, **kwargs): current_ulysses_sp_size = get_ulysses_sequence_parallel_world_size() - slice_now = inputs_embeds is not None and current_ulysses_sp_size > 1 and getattr(self, "_needs_initial_slice", True) + slice_now = ( + inputs_embeds is not None + and current_ulysses_sp_size > 1 + and getattr(self, "_needs_initial_slice", True) + ) if slice_now: call_kwargs["inputs_embeds"] = slice_input_tensor(inputs_embeds, dim=1, padding=False) self._needs_initial_slice = False @@ -138,88 +145,160 @@ def ulysses_wrapped_decoder_forward(self, *args, **kwargs): print(f"Monkey patch {model_class.__name__}.forward for Ulysses SP input slicing.") +def patch_forward_with_backends( + model: PreTrainedModel, + use_fused_kernels: bool = False, + fused_kernels_backend: str = None, +): + """ + Choose the forward function based on the model and backend. + Args: + model (PreTrainedModel): The model to apply the monkey patch. + use_fused_kernels (bool): Whether to use fused kernels. + fused_kernels_backend (str): The backend to use for fused kernels. + """ + if not use_fused_kernels or fused_kernels_backend not in ["triton", "torch"]: + print( + f"Skipping monkey patch for {model.__class__.__name__} as use_fused_kernels is " + f"{use_fused_kernels} or fused_kernels_backend is {fused_kernels_backend}" + ) + return + + forward_with_torch_backend_function = model.__class__.forward + forward_with_triton_backend_function = model.__class__.forward + if model.config.model_type == "qwen2_5_vl": + from verl.models.transformers.qwen2_5_vl import forward_with_torch_backend, forward_with_triton_backend + + forward_with_torch_backend_function = forward_with_torch_backend + forward_with_triton_backend_function = forward_with_triton_backend + elif model.config.model_type == "qwen2_vl": + from verl.models.transformers.qwen2_vl import forward_with_torch_backend, forward_with_triton_backend + + forward_with_torch_backend_function = forward_with_torch_backend + forward_with_triton_backend_function = forward_with_triton_backend + else: + from verl.models.transformers.dense_common import forward_with_torch_backend, forward_with_triton_backend + + forward_with_torch_backend_function = forward_with_torch_backend + forward_with_triton_backend_function = forward_with_triton_backend + + if fused_kernels_backend == "triton": + model.__class__.forward = forward_with_triton_backend_function + print(f"Using Triton backend for fused kernels in {model.__class__.__name__}") + elif fused_kernels_backend == "torch": + model.__class__.forward = forward_with_torch_backend_function + print(f"Using Torch backend for fused kernels in {model.__class__.__name__}") + else: + raise ValueError(f"Unsupported fused_kernels_backend: {fused_kernels_backend}. Choose 'triton' or 'torch'.") + + def apply_monkey_patch( model: PreTrainedModel, ulysses_sp_size: int = 1, use_remove_padding: bool = True, use_fused_kernels: bool = False, + fused_kernels_backend: str = None, ): + """ + Apply monkey patch to the models for ulysses sequence parallel and fused kernel. + + In the end of this function forward function of the model is patched for fused kernel. + If the model is not supported with fused kernel, please return after patch. + """ + """Replace _flash_attention_forward to _ulysses_flash_attention_forward""" module = sys.modules[model.__module__] try: num_attention_heads, num_key_value_heads = model.config.num_attention_heads, model.config.num_key_value_heads except AttributeError: - num_attention_heads, num_key_value_heads = model.config.text_config.num_attention_heads, model.config.text_config.num_key_value_heads - - assert num_attention_heads % ulysses_sp_size == 0, f"num_attention_heads {num_attention_heads} must be divisible by ulysses_sp_size {ulysses_sp_size}" + num_attention_heads, num_key_value_heads = ( + model.config.text_config.num_attention_heads, + model.config.text_config.num_key_value_heads, + ) + + assert num_attention_heads % ulysses_sp_size == 0, ( + f"num_attention_heads {num_attention_heads} must be divisible by ulysses_sp_size {ulysses_sp_size}" + ) assert num_key_value_heads % ulysses_sp_size == 0 or ulysses_sp_size % num_key_value_heads == 0, ( - f"num_key_value_heads {num_key_value_heads} must be divisible by ulysses_sp_size {ulysses_sp_size}or vise versa. Upon ulysses_sp_size % num_key_value_heads == 0,kv heads are repeated to ensure correctness." + f"num_key_value_heads {num_key_value_heads} must be divisible by ulysses_sp_size " + f"{ulysses_sp_size}or vise versa. Upon ulysses_sp_size % num_key_value_heads == 0," + f"kv heads are repeated to ensure correctness." ) + + if is_trl_available(): + from trl import AutoModelForCausalLMWithValueHead # type: ignore + + def state_dict(self, *args, **kwargs): + return torch.nn.Module.state_dict(self, *args, **kwargs) + + AutoModelForCausalLMWithValueHead.state_dict = state_dict + print("Monkey patch state_dict in AutoModelForCausalLMWithValueHead. ") + # TODO: VLM models only, unify monkey patch to LLM models. if model.config.model_type == "qwen2_5_vl": - from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( - Qwen2_5_VLFlashAttention2, - Qwen2_5_VLForConditionalGeneration, - ) + if is_transformers_version_in_range(min_version="4.53.0"): + from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLAttention + + # TODO: Support transformers 4.53 + raise ValueError("Transformers 4.53 is not supported") + else: + from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( + Qwen2_5_VLFlashAttention2 as Qwen2_5_VLAttention, + ) if use_remove_padding or ulysses_sp_size > 1: from verl.models.transformers.qwen2_vl import ulysses_flash_attn_forward - Qwen2_5_VLFlashAttention2.forward = ulysses_flash_attn_forward + Qwen2_5_VLAttention.forward = ulysses_flash_attn_forward print("Monkey patch FlashAttention2.forward in Qwen2.5VL") if ulysses_sp_size > 1: if is_transformers_version_in_range(min_version="4.52.0"): from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLTextModel + patch_vlm_for_ulysses_input_slicing(Qwen2_5_VLTextModel) else: from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLModel - patch_vlm_for_ulysses_input_slicing(Qwen2_5_VLModel) - if use_fused_kernels: - from verl.models.transformers.qwen2_5_vl import forward_for_ppo - - Qwen2_5_VLForConditionalGeneration.forward = forward_for_ppo - - return + patch_vlm_for_ulysses_input_slicing(Qwen2_5_VLModel) elif model.config.model_type == "qwen2_vl": - from transformers.models.qwen2_vl.modeling_qwen2_vl import ( - Qwen2VLFlashAttention2, - Qwen2VLForConditionalGeneration, - ) + if is_transformers_version_in_range(min_version="4.53.0"): + from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLAttention + + # TODO: Support transformers 4.53 + raise ValueError("Transformers 4.53 is not supported") + else: + from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLFlashAttention2 as Qwen2VLAttention if use_remove_padding or ulysses_sp_size > 1: from verl.models.transformers.qwen2_vl import ulysses_flash_attn_forward - Qwen2VLFlashAttention2.forward = ulysses_flash_attn_forward + Qwen2VLAttention.forward = ulysses_flash_attn_forward print("Monkey patch FlashAttention2.forward in Qwen2VL") if ulysses_sp_size > 1: if is_transformers_version_in_range(min_version="4.52.0"): from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLTextModel + patch_vlm_for_ulysses_input_slicing(Qwen2VLTextModel) else: from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLModel - patch_vlm_for_ulysses_input_slicing(Qwen2VLModel) - - if use_fused_kernels: - from verl.models.transformers.qwen2_vl import forward_for_ppo - - Qwen2VLForConditionalGeneration.forward = forward_for_ppo - return + patch_vlm_for_ulysses_input_slicing(Qwen2VLModel) elif model.config.model_type == "kimi_vl": if use_remove_padding or ulysses_sp_size > 1: # TODO: Changes need to be made when transformers are adapted. - from verl.models.transformers.kimi_vl import _merge_with_image_features, _ulysses_flash_attn_forward + from verl.models.transformers.kimi_vl import _ulysses_flash_attn_forward - module.KimiVLForConditionalGeneration._merge_with_image_features = _merge_with_image_features module.DeepseekV3FlashAttention2.forward = _ulysses_flash_attn_forward print("Monkey patch FlashAttention2.forward in KimiVL") - + + if ulysses_sp_size > 1: + patch_vlm_for_ulysses_input_slicing(module.DeepseekV3ForCausalLM) + if use_fused_kernels: print("Not support fused kernels for KimiVL") @@ -237,10 +316,7 @@ def apply_monkey_patch( flash_attention._flash_attention_forward = _ulysses_flash_attention_forward print(f"Monkey patch _flash_attention_forward in {flash_attention.__name__}") - if use_fused_kernels: - from verl.models.transformers.llama import forward_for_ppo - - model.__class__.forward = forward_for_ppo + patch_forward_with_backends(model, use_fused_kernels=use_fused_kernels, fused_kernels_backend=fused_kernels_backend) @lru_cache diff --git a/verl/models/transformers/npu_patch.py b/verl/models/transformers/npu_patch.py new file mode 100644 index 000000000..e6bb37368 --- /dev/null +++ b/verl/models/transformers/npu_patch.py @@ -0,0 +1,50 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Copyright 2025 The Qwen Team and The HuggingFace Inc. team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch_npu +from torch_npu import npu_rotary_mul as apply_rotary_emb +from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl +from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2RMSNorm + + +# This patch takes effect when using apply_rotary_pos_emb_flashatt on qwen2_5_vl and will be removed in +# subsequent versions +# https://github.com/huggingface/transformers/pull/38491 +def apply_rotary_pos_emb_flashatt_npu( + q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + cos = cos.chunk(2, dim=-1)[0].contiguous() + sin = sin.chunk(2, dim=-1)[0].contiguous() + cos = cos.repeat(1, 2) + sin = sin.repeat(1, 2) + q_embed = apply_rotary_emb( + q.float(), cos.unsqueeze(0).unsqueeze(2).float(), sin.unsqueeze(0).unsqueeze(2).float() + ).type_as(q) + k_embed = apply_rotary_emb( + k.float(), cos.unsqueeze(0).unsqueeze(2).float(), sin.unsqueeze(0).unsqueeze(2).float() + ).type_as(k) + return q_embed, k_embed + + +# This api can improve performance on ASCEND NPU +def rms_norm_forward(self, x): + return torch_npu.npu_rms_norm(x, self.weight, epsilon=self.variance_epsilon)[0] + + +Qwen2RMSNorm.forward = rms_norm_forward +modeling_qwen2_5_vl.apply_rotary_pos_emb_flashatt = apply_rotary_pos_emb_flashatt_npu diff --git a/verl/models/transformers/qwen2.py b/verl/models/transformers/qwen2.py index e6b81db33..e55fb26d5 100644 --- a/verl/models/transformers/qwen2.py +++ b/verl/models/transformers/qwen2.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Optional, Tuple +from typing import Callable, Optional import torch from transformers.cache_utils import Cache @@ -39,7 +39,7 @@ def qwen2_flash_attn_forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 ): """ Adapted from transformers 4.47.1 to support Ulysses sequence parallelism. @@ -103,7 +103,11 @@ def qwen2_flash_attn_forward( else: target_dtype = self.q_proj.weight.dtype - logger.warning_once(f"The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in {target_dtype}.") + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to " + f"the fact you have upcasted embedding or layer norm layers in float32. We will cast back the " + f"input in {target_dtype}." + ) query_states = query_states.to(target_dtype) key_states = key_states.to(target_dtype) @@ -114,7 +118,11 @@ def qwen2_flash_attn_forward( key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) - if self.config.use_sliding_window and getattr(self.config, "sliding_window", None) is not None and self.layer_idx >= self.config.max_window_layers: + if ( + self.config.use_sliding_window + and getattr(self.config, "sliding_window", None) is not None + and self.layer_idx >= self.config.max_window_layers + ): sliding_window = self.config.sliding_window else: sliding_window = None @@ -149,12 +157,12 @@ def qwen2_flash_attn_forward( def qwen2_attn_forward( self, hidden_states: torch.Tensor, - position_embeddings: Tuple[torch.Tensor, torch.Tensor], + position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: +) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: """ Adapted from transformers 4.49.0 to support Ulysses sequence parallelism for transformers >= 4.48.0. @@ -191,7 +199,11 @@ def qwen2_attn_forward( key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) sliding_window = None - if self.config.use_sliding_window and getattr(self.config, "sliding_window", None) is not None and self.layer_idx >= self.config.max_window_layers: + if ( + self.config.use_sliding_window + and getattr(self.config, "sliding_window", None) is not None + and self.layer_idx >= self.config.max_window_layers + ): sliding_window = self.config.sliding_window from transformers.models.qwen2.modeling_qwen2 import eager_attention_forward @@ -199,7 +211,11 @@ def qwen2_attn_forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once('`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.') + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. " + "Falling back to eager attention. This warning can be removed using the argument " + '`attn_implementation="eager"` when loading the model.' + ) else: attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] diff --git a/verl/models/transformers/qwen2_5_vl.py b/verl/models/transformers/qwen2_5_vl.py index ac4621ec5..51d9753fb 100644 --- a/verl/models/transformers/qwen2_5_vl.py +++ b/verl/models/transformers/qwen2_5_vl.py @@ -13,7 +13,7 @@ # limitations under the License. from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from typing import Optional import torch from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( @@ -28,14 +28,13 @@ class Qwen2_5_VLCausalLMOutputForPPO(Qwen2_5_VLCausalLMOutputWithPast): entropy: Optional[torch.FloatTensor] = None -def forward_for_ppo( +def forward_base_model( self: Qwen2_5_VLForConditionalGeneration, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, @@ -47,15 +46,11 @@ def forward_for_ppo( rope_deltas: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, second_per_grid_ts: Optional[torch.Tensor] = None, - temperature: float = 1.0, - **loss_kwargs, -) -> Union[Tuple, Qwen2_5_VLCausalLMOutputForPPO]: +) -> tuple | Qwen2_5_VLCausalLMOutputWithPast: r""" Copy paste Qwen2_5_VL's forward https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/transformers/model/qwen2_5_vl.py ```""" - from verl.utils.experimental.torch_functional import FusedLinearForPPO - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -71,7 +66,8 @@ def forward_for_ppo( n_image_features = image_embeds.shape[0] if n_image_tokens != n_image_features: raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + f"Image features and image tokens do not match: tokens: {n_image_tokens}, " + f"features {n_image_features}" ) mask = input_ids == self.config.image_token_id @@ -89,7 +85,8 @@ def forward_for_ppo( n_video_features = video_embeds.shape[0] if n_video_tokens != n_video_features: raise ValueError( - f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" + f"Video features and video tokens do not match: tokens: {n_video_tokens}, " + f"features {n_video_features}" ) mask = input_ids == self.config.video_token_id @@ -138,11 +135,57 @@ def forward_for_ppo( return_dict=return_dict, cache_position=cache_position, ) + return outputs + + +def forward_with_torch_backend( + self: Qwen2_5_VLForConditionalGeneration, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + temperature: float = 1.0, + **loss_kwargs, +) -> tuple | Qwen2_5_VLCausalLMOutputForPPO: + from verl.utils.experimental.torch_functional import FusedLinearForPPO + + outputs = forward_base_model( + self, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + rope_deltas=rope_deltas, + cache_position=cache_position, + second_per_grid_ts=second_per_grid_ts, + ) hidden_states = outputs[0] if not return_dict: - raise NotImplementedError("forward_for_ppo has to return_dict") + raise NotImplementedError("forward_with_torch_backend has to return_dict") # Loss calculations if labels is not None: @@ -150,7 +193,7 @@ def forward_for_ppo( elif input_ids is not None: rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1) else: - raise RuntimeError("To use forward_for_ppo, either labels or input_ids must be provided.") + raise RuntimeError("To use forward_with_torch_backend, either labels or input_ids must be provided.") fused_linear_for_ppo = FusedLinearForPPO() log_probs, entropy = fused_linear_for_ppo.forward( @@ -168,3 +211,78 @@ def forward_for_ppo( attentions=outputs.attentions, rope_deltas=rope_deltas, ) + + +def forward_with_triton_backend( + self: Qwen2_5_VLForConditionalGeneration, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + temperature: float = 1.0, + **loss_kwargs, +) -> tuple | Qwen2_5_VLCausalLMOutputForPPO: + from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy + + outputs = forward_base_model( + self, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + rope_deltas=rope_deltas, + cache_position=cache_position, + second_per_grid_ts=second_per_grid_ts, + ) + + hidden_states = outputs[0] + + if not return_dict: + raise NotImplementedError("forward_with_triton_backend has to return_dict") + + # Loss calculations + if labels is not None: + rolled_labels = torch.roll(labels, shifts=-1, dims=-1) + elif input_ids is not None: + rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1) + else: + raise RuntimeError("To use forward_with_triton_backend, either labels or input_ids must be provided.") + + log_probs, entropy = linear_cross_entropy( + hidden_states, + self.lm_head.weight, + rolled_labels, + temperature, + "none", + ) + + return Qwen2_5_VLCausalLMOutputForPPO( + log_probs=log_probs, + entropy=entropy, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=rope_deltas, + ) diff --git a/verl/models/transformers/qwen2_vl.py b/verl/models/transformers/qwen2_vl.py index a7ae346ec..358b00b6b 100644 --- a/verl/models/transformers/qwen2_vl.py +++ b/verl/models/transformers/qwen2_vl.py @@ -15,7 +15,7 @@ import inspect import os from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from typing import Optional import torch from transformers.modeling_flash_attention_utils import _flash_attention_forward @@ -140,7 +140,9 @@ def get_rope_index( return position_ids -def prepare_fa2_from_position_ids(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, position_ids: torch.Tensor): +def prepare_fa2_from_position_ids( + query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, position_ids: torch.Tensor +): query = query.view(-1, query.size(-2), query.size(-1)) key = key.view(-1, key.size(-2), key.size(-1)) value = value.view(-1, value.size(-2), value.size(-1)) @@ -175,7 +177,9 @@ def flash_attention_forward( causal = is_causal if not use_top_left_mask else is_causal and query_length != 1 # Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length). - use_sliding_windows = _flash_supports_window_size and sliding_window is not None and key_states.shape[1] > sliding_window + use_sliding_windows = ( + _flash_supports_window_size and sliding_window is not None and key_states.shape[1] > sliding_window + ) flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {} if is_flash_attn_greater_or_equal("2.4.1"): @@ -185,7 +189,9 @@ def flash_attention_forward( if position_ids is not None and query_length != 1 and not (torch.diff(position_ids[0], dim=-1) >= 0).all(): batch_size = query_states.size(0) - query_states, key_states, value_states, _, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids(query_states, key_states, value_states, position_ids[0]) # remove channel dimension + query_states, key_states, value_states, _, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids( + query_states, key_states, value_states, position_ids[0] + ) # remove channel dimension cu_seqlens_q, cu_seqlens_k = cu_seq_lens max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens attn_output = flash_attn_varlen_func( @@ -224,9 +230,9 @@ def ulysses_flash_attn_forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 **kwargs, -) -> Tuple[torch.Tensor, None, None]: +) -> tuple[torch.Tensor, None, None]: from transformers.models.qwen2_vl.modeling_qwen2_vl import apply_multimodal_rotary_pos_emb, repeat_kv bsz, q_len, _ = hidden_states.size() # q_len = seq_length / sp_size @@ -259,7 +265,9 @@ def ulysses_flash_attn_forward( else: cos, sin = position_embeddings - query_states, key_states = apply_multimodal_rotary_pos_emb(query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]) + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] + ) dropout_rate = 0.0 if not self.training else self.attention_dropout # Reashape to the expected shape for Flash Attention @@ -267,7 +275,11 @@ def ulysses_flash_attn_forward( key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) - if self.config.use_sliding_window and getattr(self.config, "sliding_window", None) is not None and self.layer_idx >= self.config.max_window_layers: + if ( + self.config.use_sliding_window + and getattr(self.config, "sliding_window", None) is not None + and self.layer_idx >= self.config.max_window_layers + ): sliding_window = self.config.sliding_window else: sliding_window = None @@ -298,14 +310,13 @@ class Qwen2VLCausalLMOutputForPPO(Qwen2VLCausalLMOutputWithPast): entropy: Optional[torch.FloatTensor] = None -def forward_for_ppo( +def forward_base_model( self: Qwen2VLForConditionalGeneration, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, @@ -316,15 +327,11 @@ def forward_for_ppo( video_grid_thw: Optional[torch.LongTensor] = None, rope_deltas: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, - temperature: float = 1.0, - **loss_kwargs, -) -> Union[Tuple, Qwen2VLCausalLMOutputForPPO]: +) -> tuple | Qwen2VLCausalLMOutputWithPast: r""" Copy paste Qwen2VL's forward https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/transformers/model/qwen2_vl.py ```""" - from verl.utils.experimental.torch_functional import FusedLinearForPPO - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -340,7 +347,8 @@ def forward_for_ppo( n_image_features = image_embeds.shape[0] if n_image_tokens != n_image_features: raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + f"Image features and image tokens do not match: tokens: {n_image_tokens}, " + f"features {n_image_features}" ) image_mask = ( (input_ids == self.config.image_token_id) @@ -358,7 +366,8 @@ def forward_for_ppo( n_video_features = video_embeds.shape[0] if n_video_tokens != n_video_features: raise ValueError( - f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" + f"Video features and video tokens do not match: tokens: {n_video_tokens}, " + f"features {n_video_features}" ) video_mask = ( (input_ids == self.config.video_token_id) @@ -401,10 +410,55 @@ def forward_for_ppo( cache_position=cache_position, ) + return outputs + + +def forward_with_torch_backend( + self: Qwen2VLForConditionalGeneration, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + temperature: float = 1.0, + **loss_kwargs, +) -> tuple | Qwen2VLCausalLMOutputForPPO: + from verl.utils.experimental.torch_functional import FusedLinearForPPO + + outputs = forward_base_model( + self, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + rope_deltas=rope_deltas, + cache_position=cache_position, + ) + hidden_states = outputs[0] if not return_dict: - raise NotImplementedError("forward_for_ppo has to return_dict") + raise NotImplementedError("forward_with_torch_backend has to return_dict") # Loss calculations if labels is not None: @@ -412,7 +466,7 @@ def forward_for_ppo( elif input_ids is not None: rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1) else: - raise RuntimeError("To use forward_for_ppo, either labels or input_ids must be provided.") + raise RuntimeError("To use forward_with_torch_backend, either labels or input_ids must be provided.") fused_linear_for_ppo = FusedLinearForPPO() log_probs, entropy = fused_linear_for_ppo.forward( @@ -430,3 +484,76 @@ def forward_for_ppo( attentions=outputs.attentions, rope_deltas=rope_deltas, ) + + +def forward_with_triton_backend( + self: Qwen2VLForConditionalGeneration, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + temperature: float = 1.0, + **loss_kwargs, +) -> tuple | Qwen2VLCausalLMOutputForPPO: + from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy + + outputs = forward_base_model( + self, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + rope_deltas=rope_deltas, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + + if not return_dict: + raise NotImplementedError("forward_with_triton_backend has to return_dict") + + # Loss calculations + if labels is not None: + rolled_labels = torch.roll(labels, shifts=-1, dims=-1) + elif input_ids is not None: + rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1) + else: + raise RuntimeError("To use forward_with_triton_backend, either labels or input_ids must be provided.") + + log_probs, entropy = linear_cross_entropy( + hidden_states, + self.lm_head.weight, + rolled_labels, + temperature, + "none", + ) + + return Qwen2VLCausalLMOutputForPPO( + log_probs=log_probs, + entropy=entropy, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=rope_deltas, + ) diff --git a/verl/models/weight_loader_registry.py b/verl/models/weight_loader_registry.py index 8f1e20853..8aa3bc71f 100644 --- a/verl/models/weight_loader_registry.py +++ b/verl/models/weight_loader_registry.py @@ -23,21 +23,34 @@ def get_weight_loader(arch: str): if arch in _MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY: return _MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY[arch] - raise ValueError(f"Model architectures {arch} loader are not supported for now. Supported architectures: {_MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY.keys()}") + raise ValueError( + f"Model architectures {arch} loader are not supported for now. Supported architectures: " + f"{_MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY.keys()}" + ) def get_weight_saver(arch: str): - from verl.models.mcore.saver import merge_megatron_ckpt_gptmodel, merge_megatron_ckpt_gptmodel_dpskv3, merge_megatron_ckpt_gptmodel_mixtral, merge_megatron_ckpt_gptmodel_qwen_moe + from verl.models.mcore.saver import ( + merge_megatron_ckpt_gptmodel, + merge_megatron_ckpt_gptmodel_dpskv3, + merge_megatron_ckpt_gptmodel_mixtral, + merge_megatron_ckpt_gptmodel_qwen2_5_vl, + merge_megatron_ckpt_gptmodel_qwen_moe, + ) _MODEL_WEIGHT_MEGATRON_SAVER_REGISTRY = { "LlamaForCausalLM": merge_megatron_ckpt_gptmodel, "Qwen2ForCausalLM": merge_megatron_ckpt_gptmodel, "MixtralForCausalLM": merge_megatron_ckpt_gptmodel_mixtral, "Qwen2MoeForCausalLM": merge_megatron_ckpt_gptmodel_qwen_moe, + "Qwen2_5_VLForConditionalGeneration": merge_megatron_ckpt_gptmodel_qwen2_5_vl, "DeepseekV3ForCausalLM": merge_megatron_ckpt_gptmodel_dpskv3, "Qwen3ForCausalLM": merge_megatron_ckpt_gptmodel, "Qwen3MoeForCausalLM": merge_megatron_ckpt_gptmodel_qwen_moe, } if arch in _MODEL_WEIGHT_MEGATRON_SAVER_REGISTRY: return _MODEL_WEIGHT_MEGATRON_SAVER_REGISTRY[arch] - raise ValueError(f"Model architectures {arch} saver are not supported for now. Supported architectures: {_MODEL_WEIGHT_MEGATRON_SAVER_REGISTRY.keys()}") + raise ValueError( + f"Model architectures {arch} saver are not supported for now. Supported architectures: " + f"{_MODEL_WEIGHT_MEGATRON_SAVER_REGISTRY.keys()}" + ) diff --git a/verl/protocol.py b/verl/protocol.py index 69e86c731..39979f848 100644 --- a/verl/protocol.py +++ b/verl/protocol.py @@ -22,7 +22,7 @@ import os import pickle from dataclasses import dataclass, field -from typing import Callable, Dict, List, Optional, Union +from typing import Callable, Optional import numpy as np import pandas as pd @@ -34,7 +34,7 @@ from tensordict import TensorDict from torch.utils.data import DataLoader -from verl.utils.device import get_torch_device +from verl.utils.device import get_device_id, get_torch_device from verl.utils.py_functional import union_two_dict from verl.utils.torch_functional import allgather_dict_tensors @@ -96,6 +96,7 @@ def pad_dataproto_to_divisor(data: "DataProto", size_divisor: int): def unpad_dataproto(data: "DataProto", pad_size): + """Unpad the data proto with pad_size. i.e. `data[:-pad_size]`""" if pad_size != 0: data = data[:-pad_size] return data @@ -103,12 +104,16 @@ def unpad_dataproto(data: "DataProto", pad_size): def union_tensor_dict(tensor_dict1: TensorDict, tensor_dict2: TensorDict) -> TensorDict: """Union two tensordicts.""" - assert tensor_dict1.batch_size == tensor_dict2.batch_size, f"Two tensor dict must have identical batch size. Got {tensor_dict1.batch_size} and {tensor_dict2.batch_size}" + assert tensor_dict1.batch_size == tensor_dict2.batch_size, ( + f"Two tensor dict must have identical batch size. Got {tensor_dict1.batch_size} and {tensor_dict2.batch_size}" + ) for key in tensor_dict2.keys(): if key not in tensor_dict1.keys(): tensor_dict1[key] = tensor_dict2[key] else: - assert tensor_dict1[key].equal(tensor_dict2[key]), f"{key} in tensor_dict1 and tensor_dict2 are not the same object" + assert tensor_dict1[key].equal(tensor_dict2[key]), ( + f"{key} in tensor_dict1 and tensor_dict2 are not the same object" + ) return tensor_dict1 @@ -119,7 +124,9 @@ def union_numpy_dict(tensor_dict1: dict[str, np.ndarray], tensor_dict2: dict[str assert isinstance(tensor_dict2[key], np.ndarray) assert isinstance(tensor_dict1[key], np.ndarray) # to properly deal with nan and object type - assert pd.DataFrame(tensor_dict2[key]).equals(pd.DataFrame(tensor_dict1[key])), f"{key} in tensor_dict1 and tensor_dict2 are not the same object" + assert pd.DataFrame(tensor_dict2[key]).equals(pd.DataFrame(tensor_dict1[key])), ( + f"{key} in tensor_dict1 and tensor_dict2 are not the same object" + ) tensor_dict1[key] = val return tensor_dict1 @@ -193,8 +200,8 @@ def collate_fn(x: list["DataProtoItem"]): class DataProtoItem: # TODO(zhangchi.usc1992) add consistency check batch: TensorDict = None - non_tensor_batch: Dict = field(default_factory=dict) - meta_info: Dict = field(default_factory=dict) + non_tensor_batch: dict = field(default_factory=dict) + meta_info: dict = field(default_factory=dict) @dataclass @@ -207,8 +214,8 @@ class DataProto: """ batch: TensorDict = None - non_tensor_batch: Dict = field(default_factory=dict) - meta_info: Dict = field(default_factory=dict) + non_tensor_batch: dict = field(default_factory=dict) + meta_info: dict = field(default_factory=dict) def __post_init__(self): # perform necessary checking @@ -244,11 +251,11 @@ def __getitem__(self, item): return self.slice(item.start, item.stop, item.step) # Case 2: List, numpy array, or torch tensor - use sel_idxs - elif isinstance(item, (list, np.ndarray, torch.Tensor)): + elif isinstance(item, list | np.ndarray | torch.Tensor): return self.select_idxs(item) # Case 3: Single integer - return DataProtoItem for backward compatibility - elif isinstance(item, (int, np.integer)): + elif isinstance(item, int | np.integer): tensor_data = self.batch[item] if self.batch is not None else None non_tensor_data = {key: val[item] for key, val in self.non_tensor_batch.items()} return DataProtoItem(batch=tensor_data, non_tensor_batch=non_tensor_data, meta_info=self.meta_info) @@ -273,7 +280,11 @@ def __setstate__(self, data): batch_deserialized_bytes, non_tensor_batch, meta_info = data batch_deserialized = io.BytesIO(initial_bytes=batch_deserialized_bytes) - batch = torch.load(batch_deserialized, weights_only=False, map_location="cpu" if not get_torch_device().is_available() else None) + batch = torch.load( + batch_deserialized, + weights_only=False, + map_location="cpu" if not get_torch_device().is_available() else None, + ) self.batch = batch self.non_tensor_batch = non_tensor_batch self.meta_info = meta_info @@ -290,11 +301,11 @@ def load_from_disk(filepath) -> "DataProto": def print_size(self, prefix=""): size_of_tensordict = 0 - if self.batch is None: - for key, tensor in self.batch.items(): + if self.batch is not None: + for _, tensor in self.batch.items(): size_of_tensordict += tensor.element_size() * tensor.numel() size_of_numpy_array = 0 - for key, numpy_array in self.non_tensor_batch.items(): + for _, numpy_array in self.non_tensor_batch.items(): size_of_numpy_array += numpy_array.nbytes size_of_numpy_array /= 1024**3 @@ -323,11 +334,16 @@ def check_consistency(self): batch_size = self.batch.batch_size[0] for key, val in self.non_tensor_batch.items(): - assert isinstance(val, np.ndarray), f"data in the non_tensor_batch must be a numpy.array with dtype=object, but for {key=}, got {type(val)=}" - assert val.shape[0] == batch_size, f"key {key} length {len(val)} is not equal to batch size {batch_size}" + assert isinstance(val, np.ndarray), ( + f"data in the non_tensor_batch must be a numpy.array with dtype=object, but for " + f"{key=}, got {type(val)=}" + ) + assert val.shape[0] == batch_size, ( + f"key {key} length {len(val)} is not equal to batch size {batch_size}" + ) @classmethod - def from_single_dict(cls, data: Dict[str, Union[torch.Tensor, np.ndarray]], meta_info=None, auto_padding=False): + def from_single_dict(cls, data: dict[str, torch.Tensor | np.ndarray], meta_info=None, auto_padding=False): """Create a DataProto from a dict of tensors and non_tensors""" tensors = {} non_tensors = {} @@ -343,7 +359,14 @@ def from_single_dict(cls, data: Dict[str, Union[torch.Tensor, np.ndarray]], meta return cls.from_dict(tensors=tensors, non_tensors=non_tensors, meta_info=meta_info, auto_padding=auto_padding) @classmethod - def from_dict(cls, tensors: Optional[Dict[str, torch.Tensor]] = None, non_tensors=None, meta_info=None, num_batch_dims=1, auto_padding=False): + def from_dict( + cls, + tensors: Optional[dict[str, torch.Tensor]] = None, + non_tensors=None, + meta_info=None, + num_batch_dims=1, + auto_padding=False, + ): """Create a DataProto from a dict of tensors. This assumes that 1. All the tensor in tensors have the same dim0 2. Only dim0 is the batch dim @@ -371,7 +394,10 @@ def from_dict(cls, tensors: Optional[Dict[str, torch.Tensor]] = None, non_tensor pivot_key = key else: current_batch = tensor.shape[:num_batch_dims] - assert batch_size == current_batch, f"Not all the tensor in tensors have the same batch size with batch_dims={num_batch_dims}. Got {pivot_key} has {batch_size}, {key} has {current_batch}" + assert batch_size == current_batch, ( + f"Not all the tensor in tensors have the same batch size with batch_dims={num_batch_dims}. " + f"Got {pivot_key} has {batch_size}, {key} has {current_batch}" + ) for key, val in non_tensors.items(): if not isinstance(val, np.ndarray): @@ -457,7 +483,11 @@ def select_idxs(self, idxs): if self.batch is not None: # Use TensorDict's built-in indexing capabilities - selected_batch = TensorDict(source={key: tensor[idxs_torch] for key, tensor in self.batch.items()}, batch_size=(batch_size,), device=self.batch.device) + selected_batch = TensorDict( + source={key: tensor[idxs_torch] for key, tensor in self.batch.items()}, + batch_size=(batch_size,), + device=self.batch.device, + ) else: selected_batch = None @@ -565,7 +595,9 @@ def validate_input(keys): new_keys = validate_input(new_keys) if len(new_keys) != len(old_keys): - raise ValueError(f"new_keys and old_keys must have the same length, but got {len(new_keys)} and {len(old_keys)}") + raise ValueError( + f"new_keys and old_keys must have the same length, but got {len(new_keys)} and {len(old_keys)}" + ) self.batch.rename_key_(tuple(old_keys), tuple(new_keys)) @@ -596,12 +628,15 @@ def make_iterator(self, mini_batch_size, epochs, seed=None, dataloader_kwargs=No Args: - mini_batch_size (int): mini-batch size when iterating the dataset. We require that ``batch.batch_size[0] % mini_batch_size == 0``. + mini_batch_size (int): mini-batch size when iterating the dataset. We require that + ``batch.batch_size[0] % mini_batch_size == 0``. epochs (int): number of epochs when iterating the dataset. - dataloader_kwargs (Any): internally, it returns a DataLoader over the batch. The dataloader_kwargs is the kwargs passed to the DataLoader. + dataloader_kwargs (Any): internally, it returns a DataLoader over the batch. The + dataloader_kwargs is the kwargs passed to the DataLoader. Returns: - Iterator: an iterator that yields a mini-batch data at a time. The total number of iteration steps is ``self.batch.batch_size * epochs // mini_batch_size`` + Iterator: an iterator that yields a mini-batch data at a time. The total number of iteration + steps is ``self.batch.batch_size * epochs // mini_batch_size`` """ assert self.batch.batch_size[0] % mini_batch_size == 0, f"{self.batch.batch_size[0]} % {mini_batch_size} != 0" # we can directly create a dataloader from TensorDict @@ -614,8 +649,10 @@ def make_iterator(self, mini_batch_size, epochs, seed=None, dataloader_kwargs=No else: generator = None - assert isinstance(dataloader_kwargs, Dict) - train_dataloader = DataLoader(dataset=self, batch_size=mini_batch_size, collate_fn=collate_fn, generator=generator, **dataloader_kwargs) + assert isinstance(dataloader_kwargs, dict) + train_dataloader = DataLoader( + dataset=self, batch_size=mini_batch_size, collate_fn=collate_fn, generator=generator, **dataloader_kwargs + ) def get_data(): for _ in range(epochs): @@ -649,7 +686,7 @@ def padding(self, padding_size, padding_candidate=""): self.batch = padded_dp.batch self.non_tensor_batch = padded_dp.non_tensor_batch - def chunk(self, chunks: int) -> List["DataProto"]: + def chunk(self, chunks: int) -> list["DataProto"]: """Split the batch among dim=0 into chunks. The meta_info is passed to each DataProto after split. Args: @@ -659,7 +696,9 @@ def chunk(self, chunks: int) -> List["DataProto"]: List[DataProto]: a list of DataProto after splitting """ if not self.is_padding_enabled(): - assert len(self) % chunks == 0, f"only support equal chunk. Got size of DataProto {len(self)} and chunk {chunks}." + assert len(self) % chunks == 0, ( + f"only support equal chunk. Got size of DataProto {len(self)} and chunk {chunks}." + ) bsz_in_batch = None if self.batch is not None: @@ -682,12 +721,25 @@ def chunk(self, chunks: int) -> List["DataProto"]: output = [] for i in range(chunks): - output.append(type(self)(batch=batch_lst[i], non_tensor_batch=non_tensor_batch_lst[i], meta_info=self.meta_info)) + output.append( + type(self)(batch=batch_lst[i], non_tensor_batch=non_tensor_batch_lst[i], meta_info=self.meta_info) + ) return output + def split(self, split_size: int) -> list["DataProto"]: + """Split the batch among dim=0 into chunks. The meta_info is passed to each DataProto after split. + + Args: + split_size (int): the size of each split + + Returns: + List[DataProto]: a list of DataProto after splitting + """ + return [self[i : i + split_size] for i in range(0, len(self), split_size)] + @staticmethod - def concat(data: List["DataProto"]) -> "DataProto": + def concat(data: list["DataProto"]) -> "DataProto": """Concat a list of DataProto. The batch is concatenated among dim=0. The meta_info is assumed to be identical and will use the first one. @@ -731,10 +783,15 @@ def repeat(self, repeat_times=2, interleave=True): if self.batch is not None: if interleave: # Interleave the data - repeated_tensors = {key: tensor.repeat_interleave(repeat_times, dim=0) for key, tensor in self.batch.items()} + repeated_tensors = { + key: tensor.repeat_interleave(repeat_times, dim=0) for key, tensor in self.batch.items() + } else: # Stack the data - repeated_tensors = {key: tensor.unsqueeze(0).expand(repeat_times, *tensor.shape).reshape(-1, *tensor.shape[1:]) for key, tensor in self.batch.items()} + repeated_tensors = { + key: tensor.unsqueeze(0).expand(repeat_times, *tensor.shape).reshape(-1, *tensor.shape[1:]) + for key, tensor in self.batch.items() + } repeated_batch = TensorDict( source=repeated_tensors, @@ -756,7 +813,7 @@ def repeat(self, repeat_times=2, interleave=True): meta_info=self.meta_info, ) - def unfold_column_chunks(self, n_split: int, split_keys: Optional[List[str]] = None): + def unfold_column_chunks(self, n_split: int, split_keys: Optional[list[str]] = None): """Split along the second dim into `n_split`, unfold it to the first dim (batch dim) Useful in passing grouped tensors that doesn't want to be shuffled in dataset. keys not in split_keys are repeated to match the shape @@ -773,7 +830,9 @@ def unfold_column_chunks(self, n_split: int, split_keys: Optional[List[str]] = N else: unfolded_batch[key] = torch.repeat_interleave(self.batch[key], n_split, dim=0) # locate the `unfolded_batch` as a TensorDict on the same device as the original batch - unfolded_batch = TensorDict(source=unfolded_batch, batch_size=(self.batch.batch_size[0] * n_split,), device=self.batch.device) + unfolded_batch = TensorDict( + source=unfolded_batch, batch_size=(self.batch.batch_size[0] * n_split,), device=self.batch.device + ) else: unfolded_batch = None @@ -812,12 +871,16 @@ def sample_level_repeat(self, repeat_times): assert len(repeat_times.shape) == 1 repeat_times = repeat_times.tolist() else: - assert isinstance(repeat_times, list), f"repeat_times type must be in [list, torch.Tensor, np.ndarray, tuple], got {type(repeat_times)}" + assert isinstance(repeat_times, list), ( + f"repeat_times type must be in [list, torch.Tensor, np.ndarray, tuple], got {type(repeat_times)}" + ) repeat_times = torch.tensor(repeat_times) if self.batch is not None: # Interleave the data - repeated_tensors = {key: tensor.repeat_interleave(repeat_times, dim=0) for key, tensor in self.batch.items()} + repeated_tensors = { + key: tensor.repeat_interleave(repeat_times, dim=0) for key, tensor in self.batch.items() + } repeated_batch = TensorDict( source=repeated_tensors, @@ -845,7 +908,8 @@ class DataProtoFuture: for data so that asynchronous execution becomes possible. DataProtoFuture contains a list of futures from another WorkerGroup of size world_size. - collect_fn is a Callable that reduces the list of futures to a DataProto - - dispatch_fn is a Callable that partitions the DataProto into a list of DataProto of size world_size and then select + - dispatch_fn is a Callable that partitions the DataProto into a list of DataProto of size world_size + and then select Potential issue: we can optimize dispatch_fn(collect_fn) such that only needed data is fetched on destination - DataProtoFuture only supports directly passing from the output of a method to another input. You can't perform any @@ -853,15 +917,15 @@ class DataProtoFuture: """ collect_fn: Callable - futures: List[ray.ObjectRef] + futures: list[ray.ObjectRef] dispatch_fn: Callable = None @staticmethod - def concat(data: List[ray.ObjectRef]) -> "DataProtoFuture": + def concat(data: list[ray.ObjectRef]) -> "DataProtoFuture": output = DataProtoFuture(collect_fn=DataProto.concat, futures=data) return output - def chunk(self, chunks: int) -> List["DataProtoFuture"]: + def chunk(self, chunks: int) -> list["DataProtoFuture"]: from functools import partial arg_future_lst = [] @@ -870,7 +934,9 @@ def chunk(self, chunks: int) -> List["DataProtoFuture"]: def dispatch_fn(x, i, chunks): return x.chunk(chunks=chunks)[i] - arg_future = DataProtoFuture(collect_fn=self.collect_fn, dispatch_fn=partial(dispatch_fn, i=i, chunks=chunks), futures=self.futures) + arg_future = DataProtoFuture( + collect_fn=self.collect_fn, dispatch_fn=partial(dispatch_fn, i=i, chunks=chunks), futures=self.futures + ) arg_future_lst.append(arg_future) return arg_future_lst @@ -889,7 +955,7 @@ def all_gather_data_proto(data: DataProto, process_group): group_size = torch.distributed.get_world_size(group=process_group) assert isinstance(data, DataProto) prev_device = data.batch.device - data.batch = data.batch.to(get_torch_device().current_device()) + data.batch = data.batch.to(get_device_id()) data.batch = allgather_dict_tensors(data.batch.contiguous(), size=group_size, group=process_group, dim=0) data.batch = data.batch.to(prev_device) # all gather non_tensor_batch diff --git a/tests/reward/test_codeio_reward.py b/verl/py.typed similarity index 100% rename from tests/reward/test_codeio_reward.py rename to verl/py.typed diff --git a/verl/recipe/dapo/prepare_dapo_data.sh b/verl/recipe/dapo/prepare_dapo_data.sh deleted file mode 100644 index c6c60bf20..000000000 --- a/verl/recipe/dapo/prepare_dapo_data.sh +++ /dev/null @@ -1,12 +0,0 @@ -#!/usr/bin/env bash -set -uxo pipefail - -export VERL_HOME=${VERL_HOME:-"${HOME}/verl"} -export TRAIN_FILE=${TRAIN_FILE:-"${VERL_HOME}/data/dapo-math-17k.parquet"} -export TEST_FILE=${TEST_FILE:-"${VERL_HOME}/data/aime-2024.parquet"} - -mkdir -p "${VERL_HOME}/data" - -wget -O "${TRAIN_FILE}" "https://huggingface.co/datasets/BytedTsinghua-SIA/DAPO-Math-17k/resolve/main/data/dapo-math-17k.parquet?download=true" - -wget -O "${TEST_FILE}" "https://huggingface.co/datasets/BytedTsinghua-SIA/AIME-2024/resolve/main/data/aime-2024.parquet?download=true" \ No newline at end of file diff --git a/verl/recipe/dapo/src/config/dapo_trainer.yaml b/verl/recipe/dapo/src/config/dapo_trainer.yaml deleted file mode 100644 index 542e125cd..000000000 --- a/verl/recipe/dapo/src/config/dapo_trainer.yaml +++ /dev/null @@ -1,229 +0,0 @@ -data: - tokenizer: null - train_files: ~/data/rlhf/gsm8k/train.parquet - val_files: ~/data/rlhf/gsm8k/test.parquet - prompt_key: prompt - reward_fn_key: data_source - max_prompt_length: 512 - max_response_length: 512 - gen_batch_size: ${data.train_batch_size} - train_batch_size: 1024 - val_batch_size: null # DEPRECATED: Validation datasets are sent to inference engines as a whole batch, which will schedule the memory themselves - return_raw_input_ids: False # This should be set to true when the tokenizer between policy and rm differs - return_raw_chat: False - return_full_prompt: False - shuffle: True - filter_overlong_prompts: False - truncation: error - image_key: images - trust_remote_code: True - -actor_rollout_ref: - hybrid_engine: True - model: - path: ~/models/deepseek-llm-7b-chat - external_lib: null - override_config: { } - enable_gradient_checkpointing: True - use_remove_padding: False - actor: - strategy: fsdp # This is for backward-compatibility - ppo_mini_batch_size: 256 - ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu - ppo_micro_batch_size_per_gpu: null - use_dynamic_bsz: False - ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length} - grad_clip: 1.0 - # pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, 1 + cliprange_high) - clip_ratio: 0.2 # default value if clip_ratio_low and clip_ratio_high are not specified - clip_ratio_low: 0.2 - clip_ratio_high: 0.2 - clip_ratio_c: 3.0 # lower bound of the value for Dual-clip PPO from https://arxiv.org/pdf/1912.09729 - loss_agg_mode: "token-mean" # / "seq-mean-token-sum" / "seq-mean-token-mean" - # NOTE: "token-mean" is the default behavior - entropy_coeff: 0.001 - use_kl_loss: False # True for GRPO - use_torch_compile: True # False to disable torch compile - kl_loss_coef: 0.001 # for grpo - kl_loss_type: low_var_kl # for grpo - ppo_epochs: 1 - shuffle: False - ulysses_sequence_parallel_size: 1 # sp size - checkpoint: - contents: ['model', 'hf_model', 'optimizer', 'extra'] # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space - optim: - lr: 1e-6 - lr_warmup_steps: -1 # Prioritized. Negative values mean delegating to lr_warmup_steps_ratio. - lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime - min_lr_ratio: null # only useful for warmup with cosine - warmup_style: constant # select from constant/cosine - total_training_steps: -1 # must be override by program - weight_decay: 0.01 - fsdp_config: - wrap_policy: - # transformer_layer_cls_to_wrap: None - min_num_params: 0 - param_offload: False - optimizer_offload: False - fsdp_size: -1 - ref: - fsdp_config: - param_offload: False - wrap_policy: - # transformer_layer_cls_to_wrap: None - min_num_params: 0 - log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu - log_prob_micro_batch_size_per_gpu: null - log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} - log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} - ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size - rollout: - name: vllm - temperature: 1.0 - top_k: -1 # 0 for hf rollout, -1 for vllm rollout - top_p: 1 - use_fire_sampling: False # https://arxiv.org/abs/2410.21236 - prompt_length: ${data.max_prompt_length} # not use for opensource - response_length: ${data.max_response_length} - # for vllm rollout - dtype: bfloat16 # should align with FSDP - gpu_memory_utilization: 0.5 - ignore_eos: False - enforce_eager: True - free_cache_engine: True - load_format: dummy_dtensor - tensor_model_parallel_size: 2 - max_num_batched_tokens: 8192 - max_model_len: null - max_num_seqs: 1024 - log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu - log_prob_micro_batch_size_per_gpu: null - log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} - log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} - disable_log_stats: True - enable_chunked_prefill: True # may get higher throughput when set to True. When activated, Please increase max_num_batched_tokens or decrease max_model_len. - # for hf rollout - do_sample: True - # number of responses (i.e. num sample times) - n: 1 # > 1 for grpo - val_kwargs: - # sampling parameters for validation - top_k: -1 # 0 for hf rollout, -1 for vllm rollout - top_p: 1.0 - temperature: 0 - n: 1 - do_sample: False # default eager for validation - -critic: - rollout_n: ${actor_rollout_ref.rollout.n} - strategy: fsdp - optim: - lr: 1e-5 - lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime - min_lr_ratio: null # only useful for warmup with cosine - warmup_style: constant # select from constant/cosine - total_training_steps: -1 # must be override by program - weight_decay: 0.01 - model: - path: ~/models/deepseek-llm-7b-chat - tokenizer_path: ${actor_rollout_ref.model.path} - override_config: { } - external_lib: ${actor_rollout_ref.model.external_lib} - enable_gradient_checkpointing: True - use_remove_padding: False - fsdp_config: - param_offload: False - optimizer_offload: False - wrap_policy: - # transformer_layer_cls_to_wrap: None - min_num_params: 0 - fsdp_size: -1 - ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size} - ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu - ppo_micro_batch_size_per_gpu: null - forward_micro_batch_size: ${critic.ppo_micro_batch_size} - forward_micro_batch_size_per_gpu: ${critic.ppo_micro_batch_size_per_gpu} - use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} - ppo_max_token_len_per_gpu: 32768 # (${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}) * 2 - forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu} - ulysses_sequence_parallel_size: 1 # sp size - ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs} - shuffle: ${actor_rollout_ref.actor.shuffle} - grad_clip: 1.0 - cliprange_value: 0.5 - checkpoint: - contents: ['model', 'hf_model', 'optimizer', 'extra'] # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space - -reward_model: - enable: False - strategy: fsdp - model: - input_tokenizer: ${actor_rollout_ref.model.path} # set this to null if the chat template is identical - path: ~/models/FsfairX-LLaMA3-RM-v0.1 - external_lib: ${actor_rollout_ref.model.external_lib} - use_remove_padding: False - fsdp_config: - wrap_policy: - min_num_params: 0 - param_offload: False - fsdp_size: -1 - micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu - micro_batch_size_per_gpu: null # set a number - max_length: null - ulysses_sequence_parallel_size: 1 # sp size - use_dynamic_bsz: ${critic.use_dynamic_bsz} - forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu} - launch_reward_fn_async: False # custom reward function executed async on CPU, during log_prob - reward_manager: naive - overlong_buffer: - enable: False # We try to avoid forgetting to set enable - len: 0 - penalty_factor: 0.0 - log: True - -custom_reward_function: - path: null - name: compute_score - -algorithm: - gamma: 1.0 - lam: 1.0 - adv_estimator: gae - use_kl_in_reward: False - kl_penalty: kl # how to estimate kl divergence - kl_ctrl: - type: fixed - kl_coef: 0.001 - horizon: 10000 - target_kl: 0.1 - filter_groups: - enable: False # We try to avoid forgetting to set enable - metric: null # acc / score / seq_reward / seq_final_reward / ... - max_num_gen_batches: 0 # Non-positive values mean no upper limit - -trainer: - balance_batch: True - total_epochs: 30 - total_training_steps: null - project_name: verl_examples - experiment_name: gsm8k - logger: [ 'console', 'wandb' ] - log_val_generations: 0 - nnodes: 1 - n_gpus_per_node: 8 - save_freq: -1 - # auto: find the last ckpt to resume. If can't find, start from scratch - resume_mode: auto # or disable or resume_path if resume_from_path is set - resume_from_path: null - val_before_train: True - test_freq: -1 - critic_warmup: 0 - default_hdfs_dir: null - remove_previous_ckpt_in_save: False - del_local_ckpt_after_load: False - default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} - max_actor_ckpt_to_keep: null - max_critic_ckpt_to_keep: null - -ray_init: - num_cpus: null # `None` means using all CPUs, which might cause hang if limited in systems like SLURM. Please set to a number allowed then. diff --git a/verl/recipe/dapo/src/dapo_ray_trainer.py b/verl/recipe/dapo/src/dapo_ray_trainer.py deleted file mode 100644 index cbd2f4252..000000000 --- a/verl/recipe/dapo/src/dapo_ray_trainer.py +++ /dev/null @@ -1,372 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -FSDP PPO Trainer with Ray-based single controller. -This trainer supports model-agonistic model initialization with huggingface -""" - -import uuid -from collections import defaultdict -from copy import deepcopy -from pprint import pprint - -import numpy as np -import ray -import torch -from tqdm import tqdm - -from verl import DataProto -from verl.trainer.ppo.core_algos import agg_loss -from verl.trainer.ppo.metric_utils import compute_data_metrics, compute_difficulty_histogram_metrics, compute_throughout_metrics, compute_timing_metrics, reduce_metrics -from verl.trainer.ppo.ray_trainer import AdvantageEstimator, RayPPOTrainer, _timer, apply_kl_penalty, compute_advantage, compute_response_mask -from verl.trainer.ppo.reward import compute_reward, compute_reward_async - - -class RayDAPOTrainer(RayPPOTrainer): - """ - Note that this trainer runs on the driver process on a single CPU/GPU node. - """ - - def fit(self): - """ - The training loop of PPO. - The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow. - The light-weight advantage computation is done on the driver process. - """ - from omegaconf import OmegaConf - - from verl.utils.tracking import Tracking - - logger = Tracking(project_name=self.config.trainer.project_name, experiment_name=self.config.trainer.experiment_name, default_backend=self.config.trainer.logger, config=OmegaConf.to_container(self.config, resolve=True)) - - self.global_steps = 0 - - # load checkpoint before doing anything - self._load_checkpoint() - - # perform validation before training - # currently, we only support validation using the reward_function. - if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True): - val_metrics = self._validate() - pprint(f"Initial validation metrics: {val_metrics}") - logger.log(data=val_metrics, step=self.global_steps) - if self.config.trainer.get("val_only", False): - return - - # add tqdm - progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress") - - # we start from step 1 - self.global_steps += 1 - last_val_metrics = None - - timing_raw = defaultdict(float) - batch = None - num_prompt_in_batch = 0 - num_gen_batches = 0 - for epoch in range(self.config.trainer.total_epochs): - for batch_dict in self.train_dataloader: - # (bsz, seq_len) - metrics = {} - - new_batch: DataProto = DataProto.from_single_dict(batch_dict) - num_gen_batches += 1 - # pop those keys for generation - batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"] - non_tensor_batch_keys_to_pop = ["raw_prompt_ids"] - if "multi_modal_data" in new_batch.non_tensor_batch: - non_tensor_batch_keys_to_pop.append("multi_modal_data") - # if "raw_prompt" in new_batch.non_tensor_batch: # Commented out by Reasoning360 as it causes mismatch in multi-domain data - # non_tensor_batch_keys_to_pop.append("raw_prompt") - if "tools_kwargs" in new_batch.non_tensor_batch: - non_tensor_batch_keys_to_pop.append("tools_kwargs") - gen_batch = new_batch.pop( - batch_keys=batch_keys_to_pop, - non_tensor_batch_keys=non_tensor_batch_keys_to_pop, - ) - - is_last_step = self.global_steps >= self.total_training_steps - - with _timer("step", timing_raw): - # generate a batch - with _timer("gen", timing_raw): - if not self.async_rollout_mode: - gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) - else: - self.async_rollout_manager.wake_up() - gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch) - self.async_rollout_manager.sleep() - - if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX: - with _timer("gen_max", timing_raw): - gen_baseline_batch = deepcopy(gen_batch) - gen_baseline_batch.meta_info["do_sample"] = False - gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch) - - new_batch = new_batch.union(gen_baseline_output) - reward_baseline_tensor = self.reward_fn(new_batch) - reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1) - - new_batch.pop(batch_keys=list(gen_baseline_output.batch.keys())) - - new_batch.batch["reward_baselines"] = reward_baseline_tensor - - del gen_baseline_batch, gen_baseline_output - - new_batch.non_tensor_batch["uid"] = np.array([str(uuid.uuid4()) for _ in range(len(new_batch.batch))], dtype=object) - # repeat to align with repeated responses in rollout - new_batch = new_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) - # (bsz*n, seq_len), interleaved, i.e., [A, B] -> [A, A, A, A, B, B, B, B] for n=4 - new_batch = new_batch.union(gen_batch_output) - - new_batch.batch["response_mask"] = compute_response_mask(new_batch) - - with _timer("reward", timing_raw): - # compute scores. Support both model and function-based. - # We first compute the scores using reward model. Then, we call reward_fn to combine - # the results from reward model and rule-based results. - if self.use_rm: - # we first compute reward model score - reward_tensor = self.rm_wg.compute_rm_score(new_batch) - new_batch = new_batch.union(reward_tensor) - - # we combine with rule-based rm - if self.config.reward_model.launch_reward_fn_async: - future_reward = compute_reward_async.remote(new_batch, self.config, self.tokenizer) - else: - reward_tensor, reward_extra_infos_dict = compute_reward(new_batch, self.reward_fn) - - # TODO(yonghao): logics below should be delayed as late as possible - # to maximize overlapping. - if self.config.reward_model.launch_reward_fn_async: - reward_tensor, reward_extra_infos_dict = ray.get(future_reward) - - new_batch.batch["token_level_scores"] = reward_tensor - - print(f"{list(reward_extra_infos_dict.keys())=}") - if reward_extra_infos_dict: - new_batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()}) - - # compute rewards. apply_kl_penalty if available - if self.config.algorithm.use_kl_in_reward: - new_batch, kl_metrics = apply_kl_penalty(new_batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty) - metrics.update(kl_metrics) # TODO: This will be cleared if we use multiple genenration batches - else: - new_batch.batch["token_level_rewards"] = new_batch.batch["token_level_scores"] - - if not self.config.algorithm.filter_groups.enable: - batch = new_batch - else: # NOTE: When prompts after filtering is less than train batch size, we skip to the next generation batch - metric_name = self.config.algorithm.filter_groups.metric - if metric_name == "seq_final_reward": - # Turn to numpy for easier filtering - new_batch.non_tensor_batch["seq_final_reward"] = new_batch.batch["token_level_scores"].sum(dim=-1).numpy() - elif metric_name == "seq_reward": - new_batch.non_tensor_batch["seq_reward"] = new_batch.batch["token_level_scores"].sum(dim=-1).numpy() - - # Collect the sequence reward for each trajectory - prompt_uid2metric_vals = defaultdict(list) - for uid, metric_val in zip(new_batch.non_tensor_batch["uid"], new_batch.non_tensor_batch[metric_name]): - prompt_uid2metric_vals[uid].append(metric_val) - - prompt_uid2metric_std = {} - for prompt_uid, metric_vals in prompt_uid2metric_vals.items(): - prompt_uid2metric_std[prompt_uid] = np.std(metric_vals) - - kept_prompt_uids = [uid for uid, std in prompt_uid2metric_std.items() if std > 0 or len(prompt_uid2metric_vals[uid]) == 1] - num_prompt_in_batch += len(kept_prompt_uids) - - kept_traj_idxs = [] - for idx, traj_from_prompt_uid in enumerate(new_batch.non_tensor_batch["uid"]): - if traj_from_prompt_uid in kept_prompt_uids: - kept_traj_idxs.append(idx) - - new_batch = new_batch[kept_traj_idxs] - if batch is None: - batch = new_batch - else: - batch = DataProto.concat([batch, new_batch]) - - prompt_bsz = self.config.data.train_batch_size - if num_prompt_in_batch < prompt_bsz: - print(f"{num_prompt_in_batch=} < {prompt_bsz=}") - max_num_gen_batches = self.config.algorithm.filter_groups.max_num_gen_batches - if max_num_gen_batches <= 0 or num_gen_batches < max_num_gen_batches: - print(f"{num_gen_batches=}. Keep generating...") - continue - else: - raise ValueError(f"{num_gen_batches=} >= {max_num_gen_batches=}. Generated too many. Please check your data.") - else: - # Align the batch - traj_bsz = self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n - batch = batch[:traj_bsz] - - # balance the number of valid tokens on each dp rank. - # Note that this breaks the order of data inside the batch. - # Please take care when you implement group based adv computation such as GRPO and rloo - if self.config.trainer.balance_batch: - self._balance_batch(batch, metrics=metrics) - - # compute global_valid tokens - batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() - - # recompute old_log_probs - with _timer("old_log_prob", timing_raw): - log_prob_input_batch = batch.select( - batch_keys=["responses", "input_ids", "attention_mask", "position_ids"], - non_tensor_batch_keys=["uid"], - ) - if self.global_steps <= 1: - print(f"removing keys {set(batch.batch.keys()) - set(log_prob_input_batch.batch.keys())} for old logprob.", flush=True) - old_log_prob = self.actor_rollout_wg.compute_log_prob(log_prob_input_batch) - entropys = old_log_prob.batch["entropys"] - response_masks = batch.batch["response_mask"] - loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode - entropy_loss = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode) - old_log_prob_metrics = {"actor/entropy_loss": entropy_loss.detach().item()} - metrics.update(old_log_prob_metrics) - old_log_prob.batch.pop("entropys") - batch = batch.union(old_log_prob) - - if "rollout_log_probs" in batch.batch.keys(): - # TODO: we may want to add diff of probs too. - rollout_old_log_probs = batch.batch["rollout_log_probs"] - actor_old_log_probs = batch.batch["old_log_probs"] - attention_mask = batch.batch["attention_mask"] - responses = batch.batch["responses"] - response_length = responses.size(1) - response_mask = attention_mask[:, -response_length:] - - rollout_probs = torch.exp(rollout_old_log_probs) - actor_probs = torch.exp(actor_old_log_probs) - rollout_probs_diff = torch.abs(rollout_probs - actor_probs) - rollout_probs_diff = torch.masked_select(rollout_probs_diff, response_mask.bool()) - rollout_probs_diff_max = torch.max(rollout_probs_diff) - rollout_probs_diff_mean = torch.mean(rollout_probs_diff) - rollout_probs_diff_std = torch.std(rollout_probs_diff) - metrics.update( - { - "training/rollout_probs_diff_max": rollout_probs_diff_max.detach().item(), - "training/rollout_probs_diff_mean": rollout_probs_diff_mean.detach().item(), - "training/rollout_probs_diff_std": rollout_probs_diff_std.detach().item(), - } - ) - - if self.use_reference_policy: - # compute reference log_prob - with _timer("ref", timing_raw): - ref_log_prob_input_batch = batch.select( - batch_keys=["responses", "input_ids", "attention_mask", "position_ids"], - non_tensor_batch_keys=["uid"], - ) - if self.global_steps <= 1: - print(f"removing keys {set(batch.batch.keys()) - set(ref_log_prob_input_batch.batch.keys())} for ref logprob.", flush=True) - if not self.ref_in_actor: - ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(ref_log_prob_input_batch) - else: - ref_log_prob = self.actor_rollout_wg.compute_ref_log_prob(ref_log_prob_input_batch) - batch = batch.union(ref_log_prob) - - # compute values - if self.use_critic: - with _timer("values", timing_raw): - values = self.critic_wg.compute_values(batch) - batch = batch.union(values) - - with _timer("adv", timing_raw): - # compute advantages, executed on the driver process - norm_adv_by_std_in_grpo = self.config.algorithm.get("norm_adv_by_std_in_grpo", True) # GRPO adv normalization factor - - batch = compute_advantage( - batch, - adv_estimator=self.config.algorithm.adv_estimator, - gamma=self.config.algorithm.gamma, - lam=self.config.algorithm.lam, - num_repeat=self.config.actor_rollout_ref.rollout.n, - norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, - multi_turn=self.config.actor_rollout_ref.rollout.multi_turn.enable, - ) - - # update critic - if self.use_critic: - with _timer("update_critic", timing_raw): - critic_output = self.critic_wg.update_critic(batch) - critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"]) - metrics.update(critic_output_metrics) - - # implement critic warmup - if self.config.trainer.critic_warmup <= self.global_steps: - # update actor - train_batch_keys = ["responses", "input_ids", "attention_mask", "position_ids", "old_log_probs", "advantages"] - if self.global_steps <= 1: - print(f"removing keys {set(batch.batch.keys()) - set(train_batch_keys)} for training.", flush=True) - train_batch = batch.select(batch_keys=train_batch_keys, non_tensor_batch_keys=set()) - with _timer("update_actor", timing_raw): - train_batch.meta_info["multi_turn"] = self.config.actor_rollout_ref.rollout.multi_turn.enable - actor_output = self.actor_rollout_wg.update_actor(train_batch) - actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) - metrics.update(actor_output_metrics) - - # Log rollout generations if enabled - rollout_data_dir = self.config.trainer.get("rollout_data_dir", None) - if rollout_data_dir: - with _timer("dump_rollout_generations", timing_raw): - print(batch.batch.keys()) - inputs = self.tokenizer.batch_decode(batch.batch["prompts"], skip_special_tokens=True) - outputs = self.tokenizer.batch_decode(batch.batch["responses"], skip_special_tokens=True) - scores = batch.batch["token_level_scores"].sum(-1).cpu().tolist() - self._dump_generations( - inputs=inputs, - outputs=outputs, - scores=scores, - reward_extra_infos_dict=reward_extra_infos_dict, - dump_path=rollout_data_dir, - ) - - # validate - if self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0): - with _timer("testing", timing_raw): - val_metrics: dict = self._validate() - if is_last_step: - last_val_metrics = val_metrics - metrics.update(val_metrics) - - if self.config.trainer.save_freq > 0 and (is_last_step or self.global_steps % self.config.trainer.save_freq == 0): - with _timer("save_checkpoint", timing_raw): - self._save_checkpoint() - - # collect metrics - metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) - metrics.update(compute_difficulty_histogram_metrics(batch=batch, config=self.config)) - metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) - # TODO: implement actual tflpo and theoretical tflpo - n_gpus = self.resource_pool_manager.get_n_gpus() - metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus)) - timing_raw = defaultdict(float) # clear timing - - metrics["train/num_gen_batches"] = num_gen_batches - batch = None - num_prompt_in_batch = 0 - num_gen_batches = 0 - - # TODO: make a canonical logger that supports various backend - logger.log(data=metrics, step=self.global_steps) - - if is_last_step: - pprint(f"Final validation metrics: {last_val_metrics}") - progress_bar.close() - return - - progress_bar.update(1) - self.global_steps += 1 diff --git a/verl/recipe/dapo/src/main_dapo.py b/verl/recipe/dapo/src/main_dapo.py deleted file mode 100644 index b847d4a00..000000000 --- a/verl/recipe/dapo/src/main_dapo.py +++ /dev/null @@ -1,198 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Note that we don't combine the main with ray_trainer as ray_trainer is used by other main. -""" -from .dapo_ray_trainer import RayDAPOTrainer - -import os -import ray -import hydra - - -def get_custom_reward_fn(config): - import importlib.util, os - - reward_fn_config = config.get("custom_reward_function") or {} - file_path = reward_fn_config.get("path") - if not file_path: - return None - - if not os.path.exists(file_path): - raise FileNotFoundError(f"Reward function file '{file_path}' not found.") - - spec = importlib.util.spec_from_file_location("custom_module", file_path) - module = importlib.util.module_from_spec(spec) - try: - spec.loader.exec_module(module) - except Exception as e: - raise RuntimeError(f"Error loading module from '{file_path}': {e}") - - function_name = reward_fn_config.get("name") - - if not hasattr(module, function_name): - raise AttributeError(f"Reward function '{function_name}' not found in '{file_path}'.") - - print(f"using customized reward function '{function_name}' from '{file_path}'") - - return getattr(module, function_name) - - -@hydra.main(config_path='config', config_name='dapo_trainer', version_base=None) -def main(config): - run_ppo(config) - - -def run_ppo(config) -> None: - # TODO(linjunrong.ocss884): this ENV is left for resolving SGLang conflict with ray devices - # isolation, will solve in the future - os.environ["ENSURE_CUDA_VISIBLE_DEVICES"] = os.environ.get('CUDA_VISIBLE_DEVICES', '') - if not ray.is_initialized(): - # this is for local ray cluster - ray.init(runtime_env={ - 'env_vars': { - 'TOKENIZERS_PARALLELISM': 'true', - 'NCCL_DEBUG': 'WARN', - 'VLLM_LOGGING_LEVEL': 'WARN' - } - }) - - runner = TaskRunner.remote() - ray.get(runner.run.remote(config)) - - -@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head -class TaskRunner: - - def run(self, config): - from verl.utils.fs import copy_to_local - # print initial config - from pprint import pprint - from omegaconf import OmegaConf - pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values - OmegaConf.resolve(config) - - # download the checkpoint from hdfs - local_path = copy_to_local(config.actor_rollout_ref.model.path) - - # instantiate tokenizer - from verl.utils import hf_tokenizer, hf_processor - tokenizer = hf_tokenizer(local_path) - processor = hf_processor(local_path, use_fast=True) # used for multimodal LLM, could be none - - # define worker classes - if config.actor_rollout_ref.actor.strategy == 'fsdp': - assert config.actor_rollout_ref.actor.strategy == config.critic.strategy - from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker - from verl.single_controller.ray import RayWorkerGroup - ray_worker_group_cls = RayWorkerGroup - - elif config.actor_rollout_ref.actor.strategy == 'megatron': - assert config.actor_rollout_ref.actor.strategy == config.critic.strategy - from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker - from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup - ray_worker_group_cls = NVMegatronRayWorkerGroup - - else: - raise NotImplementedError - - from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role - - role_worker_mapping = { - Role.ActorRollout: ray.remote(ActorRolloutRefWorker), - Role.Critic: ray.remote(CriticWorker), - Role.RefPolicy: ray.remote(ActorRolloutRefWorker) - } - - global_pool_id = 'global_pool' - resource_pool_spec = { - global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, - } - mapping = { - Role.ActorRollout: global_pool_id, - Role.Critic: global_pool_id, - Role.RefPolicy: global_pool_id, - } - - # we should adopt a multi-source reward function here - # - for rule-based rm, we directly call a reward score - # - for model-based rm, we call a model - # - for code related prompt, we send to a sandbox if there are test cases - # - finally, we combine all the rewards together - # - The reward type depends on the tag of the data - if config.reward_model.enable: - if config.reward_model.strategy == 'fsdp': - from verl.workers.fsdp_workers import RewardModelWorker - elif config.reward_model.strategy == 'megatron': - from verl.workers.megatron_workers import RewardModelWorker - else: - raise NotImplementedError - role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker) - mapping[Role.RewardModel] = global_pool_id - - # reference model - if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss: - role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker) - mapping[Role.RefPolicy] = global_pool_id - - reward_manager_name = config.reward_model.get("reward_manager", "naive") - if reward_manager_name == 'naive': - from verl.workers.reward_manager import NaiveRewardManager - reward_manager_cls = NaiveRewardManager - elif reward_manager_name == 'prime': - from verl.workers.reward_manager import PrimeRewardManager - reward_manager_cls = PrimeRewardManager - elif reward_manager_name == 'dapo': - from verl.workers.reward_manager import DAPORewardManager - reward_manager_cls = DAPORewardManager - elif reward_manager_name == 'async_dapo': - from verl.workers.reward_manager import AsyncDAPORewardManager - reward_manager_cls = AsyncDAPORewardManager - else: - - raise NotImplementedError - - compute_score = get_custom_reward_fn(config) - reward_fn = reward_manager_cls(tokenizer=tokenizer, - num_examine=0, - compute_score=compute_score, - reward_fn_key=config.data.reward_fn_key, - max_resp_len=config.data.max_response_length, - overlong_buffer_cfg=config.reward_model.overlong_buffer, - ) - - # Note that we always use function-based RM for validation - val_reward_fn = reward_manager_cls(tokenizer=tokenizer, - num_examine=2, - compute_score=compute_score, - reward_fn_key=config.data.reward_fn_key, - max_resp_len=config.data.max_response_length, - overlong_buffer_cfg=config.reward_model.overlong_buffer, - ) - resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) - - trainer = RayDAPOTrainer(config=config, - tokenizer=tokenizer, - processor=processor, - role_worker_mapping=role_worker_mapping, - resource_pool_manager=resource_pool_manager, - ray_worker_group_cls=ray_worker_group_cls, - reward_fn=reward_fn, - val_reward_fn=val_reward_fn) - trainer.init_workers() - trainer.fit() - - -if __name__ == '__main__': - main() \ No newline at end of file diff --git a/verl/self_defined_reward/IF_reward b/verl/self_defined_reward/IF_reward new file mode 160000 index 000000000..210afff7b --- /dev/null +++ b/verl/self_defined_reward/IF_reward @@ -0,0 +1 @@ +Subproject commit 210afff7b2f8065170006e45a9182192c887d522 diff --git a/verl/single_controller/base/decorator.py b/verl/single_controller/base/decorator.py index 0dc5a49e5..c22ca34ec 100644 --- a/verl/single_controller/base/decorator.py +++ b/verl/single_controller/base/decorator.py @@ -15,9 +15,8 @@ import inspect from functools import wraps from types import FunctionType -from typing import Dict, List, Tuple -import torch +import torch # Added by Reasoning360 from verl.protocol import DataProtoFuture, _padding_size_key from verl.utils.py_functional import DynamicEnum @@ -82,12 +81,12 @@ def _split_args_kwargs_data_proto(chunks, *args, **kwargs): splitted_args = [] for arg in args: - assert isinstance(arg, (DataProto, DataProtoFuture)) + assert isinstance(arg, DataProto | DataProtoFuture) splitted_args.append(arg.chunk(chunks=chunks)) splitted_kwargs = {} for key, val in kwargs.items(): - assert isinstance(val, (DataProto, DataProtoFuture)) + assert isinstance(val, DataProto | DataProtoFuture) splitted_kwargs[key] = val.chunk(chunks=chunks) return splitted_args, splitted_kwargs @@ -156,11 +155,19 @@ def dispatch_megatron_compute(worker_group, *args, **kwargs): """ from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup - assert isinstance(worker_group, MegatronWorkerGroup), f"worker_group must be MegatronWorkerGroup, Got {type(worker_group)}" + assert isinstance(worker_group, MegatronWorkerGroup), ( + f"worker_group must be MegatronWorkerGroup, Got {type(worker_group)}" + ) + + # ray put all the args in advance to avoid duplicate serialization cost + import ray + + args = [[ray.put(dp_arg) for dp_arg in arg] for arg in args] + kwargs = {k: [ray.put(dp_v) for dp_v in v] for k, v in kwargs.items()} all_args = [] for arg in args: - assert isinstance(arg, (Tuple, List)) and len(arg) == worker_group.dp_size + assert isinstance(arg, tuple | list) and len(arg) == worker_group.dp_size transformed_args = [] for i in range(worker_group.world_size): local_dp_rank = worker_group.get_megatron_rank_info(rank=i).dp_rank @@ -170,7 +177,7 @@ def dispatch_megatron_compute(worker_group, *args, **kwargs): all_kwargs = {} for k, v in kwargs.items(): - assert isinstance(v, (Tuple, List)) and len(v) == worker_group.dp_size + assert isinstance(v, tuple | list) and len(v) == worker_group.dp_size transformed_v = [] for i in range(worker_group.world_size): local_dp_rank = worker_group.get_megatron_rank_info(rank=i).dp_rank @@ -207,7 +214,7 @@ def dispatch_megatron_compute_data_proto(worker_group, *args, **kwargs): return dispatch_megatron_compute(worker_group, *splitted_args, **splitted_kwargs) -def _concat_data_proto_or_future(output: List): +def _concat_data_proto_or_future(output: list): import ray from verl.protocol import DataProto, DataProtoFuture @@ -236,7 +243,7 @@ def collect_megatron_compute_data_proto(worker_group, output): output = collect_megatron_compute(worker_group, output) for o in output: - assert isinstance(o, (DataProto, ray.ObjectRef)), f"expecting {o} to be DataProto, but got {type(o)}" + assert isinstance(o, DataProto | ray.ObjectRef), f"expecting {o} to be DataProto, but got {type(o)}" return _concat_data_proto_or_future(output) @@ -256,14 +263,15 @@ def dispatch_megatron_pp_as_dp(worker_group, *args, **kwargs): all_args = [] for arg in args: - assert isinstance(arg, (List, Tuple)) and len(arg) == pp_dp_cp_size + assert isinstance(arg, list | tuple) and len(arg) == pp_dp_cp_size transformed_args = [] for i in range(worker_group.world_size): local_dp_rank = worker_group.get_megatron_rank_info(rank=i).dp_rank local_pp_rank = worker_group.get_megatron_rank_info(rank=i).pp_rank local_cp_rank = worker_group.get_megatron_rank_info(rank=i).cp_rank # compute the rank in arg. Note that the order is dp then cp then pp - # Also note that the outputs within a pp group will be firstly allgathered, then only the output of pp0 will be collected. + # Also note that the outputs within a pp group will be firstly allgathered, then only the + # output of pp0 will be collected. # For pp=2 dp=4, a batch of data "ABCDEFGH" should be dispatched and collected in below order: # dispatch: pp_allgther: collect: # dp 0 1 2 3 dp 0 1 2 3 @@ -280,7 +288,7 @@ def dispatch_megatron_pp_as_dp(worker_group, *args, **kwargs): all_kwargs = {} for k, v in kwargs.items(): - assert isinstance(v, (List, Tuple)) and len(v) == pp_dp_cp_size, f"expect len(v)=={pp_dp_cp_size}, got {len(v)}" + assert isinstance(v, list | tuple) and len(v) == pp_dp_cp_size, f"expect len(v)=={pp_dp_cp_size}, got {len(v)}" transformed_v = [] for i in range(worker_group.world_size): local_dp_rank = worker_group.get_megatron_rank_info(rank=i).dp_rank @@ -349,9 +357,9 @@ def dispatch_dp_compute(worker_group, *args, **kwargs): assert isinstance(worker_group, WorkerGroup) for arg in args: - assert isinstance(arg, (Tuple, List)) and len(arg) == worker_group.world_size + assert isinstance(arg, tuple | list) and len(arg) == worker_group.world_size for k, v in kwargs.items(): - assert isinstance(v, (Tuple, List)) and len(v) == worker_group.world_size + assert isinstance(v, tuple | list) and len(v) == worker_group.world_size return args, kwargs @@ -393,12 +401,13 @@ def collect_dp_compute_data_proto(worker_group, output): from verl.protocol import DataProto for o in output: - assert isinstance(o, (DataProto, ray.ObjectRef)), f"expecting {o} to be DataProto, but got {type(o)}" + assert isinstance(o, DataProto | ray.ObjectRef), f"expecting {o} to be DataProto, but got {type(o)}" output = collect_dp_compute(worker_group, output) return _concat_data_proto_or_future(output) +#### Added by Reasoning360 MAGIC_PREFIX = "__verl_dummy_tensor_" def _materialize_dummy_data_proto(arg): from verl.protocol import DataProto @@ -557,6 +566,7 @@ def dispatch_megatron_pp_dummy_data_proto(worker_group, *args, **kwargs): "dispatch_fn": dummy_direct_rollout_call, "collect_fn": dummy_direct_rollout_call, }, + # Added by Reasoning360 Dispatch.MEGATRON_PP_DUMMY_PROTO: { "dispatch_fn": dispatch_megatron_pp_dummy_data_proto, "collect_fn": collect_megatron_compute_data_proto, @@ -600,8 +610,10 @@ def get_predefined_execute_fn(execute_mode): def _check_dispatch_mode(dispatch_mode): - assert isinstance(dispatch_mode, (Dispatch, Dict)), f"dispatch_mode must be a Dispatch or a Dict. Got {dispatch_mode}" - if isinstance(dispatch_mode, Dict): + assert isinstance(dispatch_mode, Dispatch | dict), ( + f"dispatch_mode must be a Dispatch or a Dict. Got {dispatch_mode}" + ) + if isinstance(dispatch_mode, dict): necessary_keys = ["dispatch_fn", "collect_fn"] for key in necessary_keys: assert key in dispatch_mode, f"key {key} should be in dispatch_mode if it is a dictionary" @@ -656,10 +668,6 @@ def register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.ALL, blocki Whether the execution should be blocking. Defaults to True. materialize_futures: Whether to materialize the data before dispatching. Defaults to True. - materialize_dummy: - Whether it receives a dummy DataProto. If so, it will materialize a dummy - tensor based on the metadata in the DataProto. This is to receive unused - data for intermediate ranks of pipeline parallel. Returns: A decorator that wraps the original function with distributed execution @@ -668,6 +676,7 @@ def register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.ALL, blocki _check_dispatch_mode(dispatch_mode=dispatch_mode) _check_execute_mode(execute_mode=execute_mode) + # Added by Reasoning360 materialize_dummy = dispatch_mode == Dispatch.MEGATRON_PP_DUMMY_PROTO def decorator(func): @@ -683,6 +692,7 @@ def inner(*args, **kwargs): async def async_inner(*args, **kwargs): if materialize_futures: args, kwargs = _materialize_futures(*args, **kwargs) + # Added by Reasoning360 if materialize_dummy: args, kwargs = _materialize_dummy(*args, **kwargs) return await func(*args, **kwargs) diff --git a/verl/single_controller/base/megatron/worker.py b/verl/single_controller/base/megatron/worker.py index 036163eed..baf6eb839 100644 --- a/verl/single_controller/base/megatron/worker.py +++ b/verl/single_controller/base/megatron/worker.py @@ -47,11 +47,12 @@ def _init_hf_config_and_tf_config( override_model_config, override_transformer_config, trust_remote_code=False, + use_mbridge=False, ): from transformers import AutoConfig from verl.models.mcore import hf_to_mcore_config - from verl.utils import hf_tokenizer + from verl.utils import hf_processor, hf_tokenizer from verl.utils.fs import copy_to_local from verl.utils.model import update_model_config @@ -59,10 +60,19 @@ def _init_hf_config_and_tf_config( self.local_path = copy_to_local(model_path) if tokenizer_or_path is None: self.tokenizer = hf_tokenizer(self.local_path, trust_remote_code=trust_remote_code) + self.processor = hf_processor(self.local_path, trust_remote_code=trust_remote_code) elif isinstance(tokenizer_or_path, str): self.tokenizer = hf_tokenizer(copy_to_local(tokenizer_or_path), trust_remote_code=trust_remote_code) + self.processor = hf_processor(copy_to_local(tokenizer_or_path), trust_remote_code=trust_remote_code) else: self.tokenizer = tokenizer_or_path + self.processor = tokenizer_or_path + + if self.config.model.get("custom_chat_template", None) is not None: + if self.processor is not None: + self.processor.chat_template = self.config.model.custom_chat_template + else: + self.tokenizer.chat_template = self.config.model.custom_chat_template # Step 2: get the hf hf_config = AutoConfig.from_pretrained(self.local_path, trust_remote_code=trust_remote_code) @@ -86,7 +96,9 @@ def add_optimization_config_to_tf_config(tf_config): if self.config.model.get("enable_gradient_checkpointing", False): gradient_checkpointing_cfg = dict(self.config.model.get("gradient_checkpointing_kwargs", dict())) tf_config.recompute_method = gradient_checkpointing_cfg.get("activations_checkpoint_method", "full") - tf_config.recompute_granularity = gradient_checkpointing_cfg.get("activations_checkpoint_granularity", "full") + tf_config.recompute_granularity = gradient_checkpointing_cfg.get( + "activations_checkpoint_granularity", "full" + ) tf_config.recompute_num_layers = gradient_checkpointing_cfg.get("activations_checkpoint_num_layers", -1) if megatron_config := self.config.get("megatron", {}): if extra := megatron_config.get("extra", {}): @@ -94,6 +106,15 @@ def add_optimization_config_to_tf_config(tf_config): setattr(tf_config, k, v) add_optimization_config_to_tf_config(tf_config) + if use_mbridge: + from verl.models.mcore.mbridge import AutoBridge + + bridge = AutoBridge.from_config(hf_config) + bridge.set_extra_args(**override_transformer_config) + tf_config = bridge.config + self.bridge = bridge + else: + self.bridge = None print(f"TF config: {tf_config}") self.hf_config = hf_config diff --git a/verl/single_controller/base/megatron/worker_group.py b/verl/single_controller/base/megatron/worker_group.py index 04d211ffe..b9beb844c 100644 --- a/verl/single_controller/base/megatron/worker_group.py +++ b/verl/single_controller/base/megatron/worker_group.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict from verl.single_controller.base import ResourcePool, WorkerGroup @@ -25,7 +24,7 @@ def __init__(self, resource_pool: ResourcePool, **kwargs): self._megatron_rank_info = None self._megatron_global_info: DistGlobalInfo = None - def init_megatron(self, default_megatron_kwargs: Dict = None): + def init_megatron(self, default_megatron_kwargs: dict = None): raise NotImplementedError("MegatronWorkerGroup.init_megatron should be overwritten") def get_megatron_rank_info(self, rank: int) -> DistRankInfo: diff --git a/verl/single_controller/base/register_center/ray.py b/verl/single_controller/base/register_center/ray.py index 8ff70bd36..ac071cde5 100644 --- a/verl/single_controller/base/register_center/ray.py +++ b/verl/single_controller/base/register_center/ray.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict import ray @@ -22,7 +21,7 @@ class WorkerGroupRegisterCenter: def __init__(self, rank_zero_info): self.rank_zero_info = rank_zero_info # rank -> node_id - self.workers_info: Dict[int, str] = {} + self.workers_info: dict[int, str] = {} def get_rank_zero_info(self): return self.rank_zero_info @@ -30,7 +29,7 @@ def get_rank_zero_info(self): def set_worker_info(self, rank, node_id) -> None: self.workers_info[rank] = node_id - def get_worker_info(self) -> Dict[int, str]: + def get_worker_info(self) -> dict[int, str]: return self.workers_info diff --git a/verl/single_controller/base/worker.py b/verl/single_controller/base/worker.py index 6ec1668db..2606a3ef3 100644 --- a/verl/single_controller/base/worker.py +++ b/verl/single_controller/base/worker.py @@ -18,10 +18,11 @@ import os import socket from dataclasses import dataclass -from typing import Dict import ray +from verl.utils.device import get_torch_device, get_visible_devices_keyword + from .decorator import Dispatch, Execute, register @@ -42,33 +43,21 @@ class DistGlobalInfo: class WorkerHelper: - def _get_node_ip(self): - def get_node_ip_by_sdk(): - if os.getenv("WG_BACKEND", None) == "ray": - import ray - - return ray._private.services.get_node_ip_address() - else: - raise NotImplementedError("WG_BACKEND now just support ray mode.") - - host_ipv4 = os.getenv("MY_HOST_IP", None) - host_ipv6 = os.getenv("MY_HOST_IPV6", None) - host_ip_by_env = host_ipv4 or host_ipv6 - host_ip_by_sdk = get_node_ip_by_sdk() - - host_ip = host_ip_by_env or host_ip_by_sdk - return host_ip + @staticmethod + def _get_node_ip(): + if os.getenv("WG_BACKEND", None) == "ray": + return ray.util.get_node_ip_address() + else: + raise NotImplementedError("WG_BACKEND now just support ray mode.") - def _get_free_port(self): + @staticmethod + def _get_free_port(): with socket.socket() as sock: sock.bind(("", 0)) return sock.getsockname()[1] def get_availale_master_addr_port(self): - return self._get_node_ip(), str(self._get_free_port()) - - def _get_pid(self): - return os.getpid() + return self._get_node_ip().strip("[]"), str(self._get_free_port()) # we assume that in each WorkerGroup, there is a Master Worker @@ -121,7 +110,9 @@ def _configure_before_init(self, register_center_name: str, rank: int): if os.getenv("WG_BACKEND", None) == "ray": from verl.single_controller.base.register_center.ray import create_worker_group_register_center - self.register_center = create_worker_group_register_center(name=register_center_name, info=rank_zero_info) + self.register_center = create_worker_group_register_center( + name=register_center_name, info=rank_zero_info + ) os.environ.update(rank_zero_info) else: @@ -133,7 +124,15 @@ def _configure_before_init(self, register_center_name: str, rank: int): @classmethod def env_keys(cls): """The keys of the environment variables that are used to configure the Worker.""" - return ["WORLD_SIZE", "RANK", "LOCAL_WORLD_SIZE", "LOCAL_RANK", "MASTER_ADDR", "MASTER_PORT", "CUDA_VISIBLE_DEVICES"] + return [ + "WORLD_SIZE", + "RANK", + "LOCAL_WORLD_SIZE", + "LOCAL_RANK", + "MASTER_ADDR", + "MASTER_PORT", + get_visible_devices_keyword().upper(), + ] def __init__(self, cuda_visible_devices=None) -> None: """Initialize the worker with environment settings and device configuration. @@ -142,7 +141,8 @@ def __init__(self, cuda_visible_devices=None) -> None: cuda_visible_devices (str, optional): CUDA visible devices configuration. Defaults to None. """ - # construct a meta from environment variable. Note that the import must be inside the class because it is executed remotely + # construct a meta from environment variable. Note that the import must be inside the class because + # it is executed remotely import os self._setup_env_cuda_visible_devices() @@ -167,7 +167,7 @@ def __init__(self, cuda_visible_devices=None) -> None: "_master_port": master_port, } if cuda_visible_devices is not None: - store["_cuda_visible_devices"] = cuda_visible_devices + store[f"_{get_visible_devices_keyword()}".lower()] = cuda_visible_devices self._configure_with_store(store=store) @@ -183,8 +183,6 @@ def get_fused_worker_by_name(self, worker_name: str): return self.fused_worker_dict.get(worker_name, None) def _setup_env_cuda_visible_devices(self): - import torch - from verl.utils.ray_utils import ray_noset_visible_devices is_ray_noset_visible_devices = ray_noset_visible_devices() @@ -200,10 +198,14 @@ def _setup_env_cuda_visible_devices(self): val = os.environ.pop("HIP_VISIBLE_DEVICES") hip_val = None if cuda_val: - assert val == cuda_val, f"Please use the same HIP_VISIBLE_DEVICES or CUDA_VISIBLE_DEVICES, inconsistant values found: {val} and {cuda_val}." + assert val == cuda_val, ( + f"Please use the same HIP_VISIBLE_DEVICES or CUDA_VISIBLE_DEVICES, inconsistant values " + f"found: {val} and {cuda_val}." + ) else: cuda_val = val os.environ["CUDA_VISIBLE_DEVICES"] = val + # os.environ["HIP_VISIBLE_DEVICES"] = val if rocr_val: # You must take care if both HIP/CUDA and ROCR env vars are set as they have @@ -229,9 +231,9 @@ def _setup_env_cuda_visible_devices(self): # so we need to set local rank when the flag is set. local_rank = os.environ.get("RAY_LOCAL_RANK") os.environ["LOCAL_RANK"] = local_rank - torch.cuda.set_device(int(local_rank)) + get_torch_device().set_device(int(local_rank)) - def _configure_with_store(self, store: Dict): + def _configure_with_store(self, store: dict): """ This function should only be called inside by WorkerGroup """ @@ -243,7 +245,9 @@ def _configure_with_store(self, store: Dict): if val is not None: # print(f"set {key} to {val}") os.environ[key] = str(val) - os.environ["REDIS_STORE_SERVER_HOST"] = str(self._master_addr).replace("[", "").replace("]", "") if self._master_addr else "" + os.environ["REDIS_STORE_SERVER_HOST"] = ( + str(self._master_addr).replace("[", "").replace("]", "") if self._master_addr else "" + ) def get_master_addr_port(self): """Get the master address and port for distributed communication.""" @@ -253,8 +257,8 @@ def get_cuda_visible_devices(self): """Get the CUDA visible devices configuration.""" import os - cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "not set") - return cuda_visible_devices + visible_devices = os.environ.get(get_visible_devices_keyword().upper(), "not set") + return visible_devices @property def world_size(self): diff --git a/verl/single_controller/base/worker_group.py b/verl/single_controller/base/worker_group.py index 04c4f15be..cb86ab4f5 100644 --- a/verl/single_controller/base/worker_group.py +++ b/verl/single_controller/base/worker_group.py @@ -19,7 +19,7 @@ import signal import threading import time -from typing import Any, Callable, Dict, List +from typing import Any, Callable from .decorator import MAGIC_ATTR, Dispatch, get_predefined_dispatch_fn, get_predefined_execute_fn @@ -60,12 +60,14 @@ def __call__(self) -> Any: def store(self): return self._store - def local_world_size_list(self) -> List[int]: + def local_world_size_list(self) -> list[int]: """Returns a flat list where each process has its local world size.""" - nested_local_world_size_list = [[local_world_size for _ in range(local_world_size)] for local_world_size in self._store] + nested_local_world_size_list = [ + [local_world_size for _ in range(local_world_size)] for local_world_size in self._store + ] return [item for row in nested_local_world_size_list for item in row] - def local_rank_list(self) -> List[int]: + def local_rank_list(self) -> list[int]: """Returns a flat list of local ranks for all processes across all nodes.""" nested_local_rank_list = [[i for i in range(local_world_size)] for local_world_size in self._store] return [item for row in nested_local_rank_list for item in row] @@ -97,7 +99,7 @@ def __call__(self) -> Any: return self.cls(*self.args, **self.kwargs) -def check_workers_alive(workers: List, is_alive: Callable, gap_time: float = 1) -> None: +def check_workers_alive(workers: list, is_alive: Callable, gap_time: float = 1) -> None: """Continuously monitors worker processes and raises SIGABRT if any worker dies. Args: @@ -167,7 +169,9 @@ def start_worker_aliveness_check(self, every_n_seconds=1) -> None: # before starting checking worker aliveness, make sure all workers are already alive self._block_until_all_workers_alive() - self._checker_thread = threading.Thread(target=check_workers_alive, args=(self._workers, self._is_worker_alive, every_n_seconds)) + self._checker_thread = threading.Thread( + target=check_workers_alive, args=(self._workers, self._is_worker_alive, every_n_seconds) + ) self._checker_thread.start() @property @@ -197,7 +201,7 @@ def _bind_worker_method(self, user_defined_cls, func_generator): if hasattr(method, MAGIC_ATTR): # this method is decorated by register attribute = getattr(method, MAGIC_ATTR) - assert isinstance(attribute, Dict), f"attribute must be a dictionary. Got {type(attribute)}" + assert isinstance(attribute, dict), f"attribute must be a dictionary. Got {type(attribute)}" assert "dispatch_mode" in attribute, "attribute must contain dispatch_mode in its key" dispatch_mode = attribute["dispatch_mode"] diff --git a/verl/single_controller/ray/__init__.py b/verl/single_controller/ray/__init__.py index 7bcd7bd1e..d2a5d6d3c 100644 --- a/verl/single_controller/ray/__init__.py +++ b/verl/single_controller/ray/__init__.py @@ -12,6 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup, create_colocated_worker_cls, create_colocated_worker_cls_fused +from .base import ( + RayClassWithInitArgs, + RayResourcePool, + RayWorkerGroup, + create_colocated_worker_cls, + create_colocated_worker_cls_fused, +) -__all__ = ["RayClassWithInitArgs", "RayResourcePool", "RayWorkerGroup", "create_colocated_worker_cls", "create_colocated_worker_cls_fused"] +__all__ = [ + "RayClassWithInitArgs", + "RayResourcePool", + "RayWorkerGroup", + "create_colocated_worker_cls", + "create_colocated_worker_cls_fused", +] diff --git a/verl/single_controller/ray/base.py b/verl/single_controller/ray/base.py index 3302ffee8..6c9495d61 100644 --- a/verl/single_controller/ray/base.py +++ b/verl/single_controller/ray/base.py @@ -14,11 +14,9 @@ import inspect import logging -import os import time from copy import deepcopy -from typing import Any, Dict, List, Optional, Tuple -from unittest.mock import patch +from typing import Any, Optional import ray from ray.experimental.state.api import get_actor @@ -29,6 +27,7 @@ from verl.protocol import DataProto, _padding_size_key from verl.single_controller.base import ClassWithInitArgs, ResourcePool, Worker, WorkerGroup from verl.single_controller.base.decorator import MAGIC_ATTR, Dispatch +from verl.utils.py_functional import temp_env_var __all__ = ["Worker"] @@ -62,27 +61,7 @@ def __call__(this, *args, **kwargs): return type(method_name, (Functor,), {})() -def sort_placement_group_by_node_ip(pgs: List[PlacementGroup]) -> List[PlacementGroup]: - """ - Sort the placement groups by node ip, all bundles in a single placement group should be on the same node. - - FSDPCheckpointManager saves sharded model states and optimizer states in local storage, which requires RANK - to be consistent across nodes when resume from checkpoint. - - With this function, if there's only one resource pool and there's no node change, RANK should be consistent - across nodes in multiple ray jobs, even if the whole ray cluster is restarted. - """ - node_ip = {node["NodeID"]: node["NodeManagerAddress"] for node in ray.nodes()} - pg_ip = {} - for pg in pgs: - specs = ray._private.state.state.placement_group_table(pg.id) - # all bunles should be on the same node - node_id = specs["bundles_to_node_id"][0] - pg_ip[pg.id] = node_ip[node_id] - return sorted(pgs, key=lambda pg: pg_ip[pg.id]) - - -def sort_placement_group_by_node_ip(pgs: List[PlacementGroup]) -> List[PlacementGroup]: +def sort_placement_group_by_node_ip(pgs: list[PlacementGroup]) -> list[PlacementGroup]: """ Sort the placement groups by node ip, all bundles in a single placement group should be on the same node. @@ -105,7 +84,7 @@ def sort_placement_group_by_node_ip(pgs: List[PlacementGroup]) -> List[Placement class RayResourcePool(ResourcePool): def __init__( self, - process_on_nodes: Optional[List[int]] = None, + process_on_nodes: Optional[list[int]] = None, use_gpu: bool = True, name_prefix: str = None, max_colocate_count: int = 10, @@ -124,7 +103,9 @@ def get_placement_groups(self, strategy="STRICT_PACK", name=None, device_name="c if self.pgs is not None: return self.pgs - pg_name_prefix = name if name else f"{self.name_prefix}verl_group_{'_'.join([str(count) for count in self._store])}:" + pg_name_prefix = ( + name if name else f"{self.name_prefix}verl_group_{'_'.join([str(count) for count in self._store])}:" + ) # print(f"pg_name_prefix = {pg_name_prefix}") if device_name == "npu": device_name = "NPU" @@ -140,7 +121,10 @@ def get_placement_groups(self, strategy="STRICT_PACK", name=None, device_name="c lifetime = "detached" if self.detached else None - pgs = [placement_group(bundles=bundles, strategy=strategy, name=pg_name_prefix + str(idx), lifetime=lifetime) for idx, bundles in enumerate(pg_scheme)] + pgs = [ + placement_group(bundles=bundles, strategy=strategy, name=pg_name_prefix + str(idx), lifetime=lifetime) + for idx, bundles in enumerate(pg_scheme) + ] ray.get([pg.ready() for pg in pgs]) @@ -148,17 +132,26 @@ def get_placement_groups(self, strategy="STRICT_PACK", name=None, device_name="c return pgs -def extract_pg_from_exist(resource_pools: Dict[str, RayResourcePool], src_role_names: List[str], resource_pool: RayResourcePool) -> List: - src_pgs = [pg for role_name, resource_pool in resource_pools.items() for pg in resource_pool.get_placement_groups() if role_name in src_role_names] +def extract_pg_from_exist( + resource_pools: dict[str, RayResourcePool], src_role_names: list[str], resource_pool: RayResourcePool +) -> list: + src_pgs = [ + pg + for role_name, resource_pool in resource_pools.items() + for pg in resource_pool.get_placement_groups() + if role_name in src_role_names + ] sorted_src_pgs = sorted(src_pgs, key=lambda pg: pg.bundle_count, reverse=True) sorted_process_on_nodes = sorted([(val, idx) for idx, val in enumerate(resource_pool.store)], reverse=True) - unsorted_pgs: List[Tuple[int, PlacementGroup]] = [] + unsorted_pgs: list[tuple[int, PlacementGroup]] = [] searching_idx = 0 for request_process, original_idx in sorted_process_on_nodes: assert searching_idx < len(sorted_src_pgs), f"no enough nodes for request: searching {searching_idx} th node" - assert request_process <= sorted_src_pgs[searching_idx].bundle_count, f"requesting {request_process} processes, bundle count cannot satisfy" + assert request_process <= sorted_src_pgs[searching_idx].bundle_count, ( + f"requesting {request_process} processes, bundle count cannot satisfy" + ) unsorted_pgs.append((original_idx, sorted_src_pgs[searching_idx])) searching_idx += 1 @@ -201,7 +194,7 @@ def set_additional_resource(self, additional_resource): """ self._additional_resource = additional_resource - def update_options(self, options: Dict): + def update_options(self, options: dict): """Update the Ray actor creation options. Args: @@ -209,7 +202,15 @@ def update_options(self, options: Dict): """ self._options.update(options) - def __call__(self, placement_group, placement_group_bundle_idx, use_gpu: bool = True, num_gpus=1, sharing_with=None, device_name="cuda") -> Any: + def __call__( + self, + placement_group, + placement_group_bundle_idx, + use_gpu: bool = True, + num_gpus=1, + sharing_with=None, + device_name="cuda", + ) -> Any: """Create and return a Ray actor with the configured options. Args: @@ -225,11 +226,15 @@ def __call__(self, placement_group, placement_group_bundle_idx, use_gpu: bool = """ if sharing_with is not None: target_node_id = ray.get(sharing_with.get_node_id.remote()) - cuda_visible_devices = ray.get(sharing_with.get_cuda_visible_devices.remote()) + visible_devices = ray.get(sharing_with.get_cuda_visible_devices.remote()) options = {"scheduling_strategy": NodeAffinitySchedulingStrategy(node_id=target_node_id, soft=False)} - return self.cls.options(**options).remote(*self.args, cuda_visible_devices=cuda_visible_devices, **self.kwargs) + return self.cls.options(**options).remote(*self.args, cuda_visible_devices=visible_devices, **self.kwargs) - options = {"scheduling_strategy": PlacementGroupSchedulingStrategy(placement_group=placement_group, placement_group_bundle_index=placement_group_bundle_idx)} + options = { + "scheduling_strategy": PlacementGroupSchedulingStrategy( + placement_group=placement_group, placement_group_bundle_index=placement_group_bundle_idx + ) + } options.update(self._options) if use_gpu and device_name == "cuda": @@ -263,7 +268,7 @@ def __init__( name_prefix: str = None, detached=False, worker_names=None, - worker_handles: List[ray.actor.ActorHandle] = None, + worker_handles: list[ray.actor.ActorHandle] = None, ray_wait_register_center_timeout: int = 300, device_name="cuda", **kwargs, @@ -286,9 +291,14 @@ def __init__( self._ray_wait_register_center_timeout = ray_wait_register_center_timeout # Whether the WorkerGroup is a Colocate WorkerGroup created by FusedWorker. self.fused_worker_used = ray_cls_with_init.fused_worker_used - # if a WorkerGroup is spawned from Colocate WorkerGroup, this indicates which sub-class is binded to this WorkerGroup. + # if a WorkerGroup is spawned from Colocate WorkerGroup, this indicates which sub-class is binded to + # this WorkerGroup. self.sub_cls_name = "" self.device_name = device_name + self.profile_steps = kwargs.get("profile_steps", None) + self.worker_nsight_options = kwargs.get("worker_nsight_options", None) + if self.worker_nsight_options is not None and self.worker_nsight_options["capture-range-end"] is None: + self.worker_nsight_options["capture-range-end"] = f"repeat-shutdown:{6 * len(self.profile_steps)}" if worker_names is not None and (not self.fused_worker_used): assert self._is_init_with_detached_workers @@ -297,7 +307,9 @@ def __init__( if self._is_init_with_detached_workers: self._init_with_detached_workers(worker_names=worker_names, worker_handles=worker_handles) else: - self._init_with_resource_pool(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init, bin_pack=bin_pack, detached=detached) + self._init_with_resource_pool( + resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init, bin_pack=bin_pack, detached=detached + ) if ray_cls_with_init is not None: self._bind_worker_method(self.ray_cls_with_init.cls, func_generator) @@ -373,13 +385,30 @@ def _init_with_resource_pool(self, resource_pool, ray_cls_with_init, bin_pack, d cia_name = match.group(1) if match else cia_name # "ActorClass(Obj)" -> "Obj" name = f"{self.name_prefix}{cia_name}_{pg_idx}:{local_rank}" # e.g. Worker_2:5 - ray_cls_with_init.update_options({"runtime_env": {"env_vars": env_vars}, "name": name}) + if self.profile_steps and self.device_name == "cuda": + ray_cls_with_init.update_options( + { + "runtime_env": { + "env_vars": env_vars, + "nsight": self.worker_nsight_options, + }, + "name": name, + } + ) + else: + ray_cls_with_init.update_options({"runtime_env": {"env_vars": env_vars}, "name": name}) if detached: ray_cls_with_init.update_options({"lifetime": "detached"}) # create a worker - worker = ray_cls_with_init(placement_group=pg, placement_group_bundle_idx=local_rank, use_gpu=use_gpu, num_gpus=num_gpus, device_name=self.device_name) + worker = ray_cls_with_init( + placement_group=pg, + placement_group_bundle_idx=local_rank, + use_gpu=use_gpu, + num_gpus=num_gpus, + device_name=self.device_name, + ) self._workers.append(worker) self._worker_names.append(name) @@ -396,7 +425,8 @@ def _init_with_resource_pool(self, resource_pool, ray_cls_with_init, bin_pack, d elapsed = int(time.time() - start_time) if elapsed % 30 == 0: logging.warning( - "Waiting for register center actor %s to be ready. Elapsed time: %s seconds out of %s seconds.", + "Waiting for register center actor %s to be ready. Elapsed time: %s seconds out of " + "%s seconds.", actor_name, elapsed, self._ray_wait_register_center_timeout, @@ -430,6 +460,7 @@ def from_detached( worker_names=None, worker_handles=None, ray_cls_with_init=None, + **kwargs, ): """Create a worker group from existing detached workers. @@ -441,7 +472,14 @@ def from_detached( Returns: A new RayWorkerGroup instance """ - worker_group = cls(resource_pool=None, ray_cls_with_init=ray_cls_with_init, name_prefix=name_prefix, worker_names=worker_names, worker_handles=worker_handles) + worker_group = cls( + resource_pool=None, + ray_cls_with_init=ray_cls_with_init, + name_prefix=name_prefix, + worker_names=worker_names, + worker_handles=worker_handles, + **kwargs, + ) return worker_group def spawn(self, prefix_set): @@ -460,7 +498,6 @@ def _rebind_actor_methods(worker_group, actor_name): prefix: str = actor_name + "_" for method_name in dir(worker_group): if method_name.startswith(prefix): - # only valid when Python >= 3.9 original_method_name = method_name.removeprefix(prefix) method = getattr(worker_group, method_name) setattr(worker_group, original_method_name, method) @@ -472,6 +509,8 @@ def _rebind_actor_methods(worker_group, actor_name): worker_names=self._worker_names, worker_handles=self._workers, ray_cls_with_init=self.ray_cls_with_init, + profile_steps=self.profile_steps, + worker_nsight_options=self.worker_nsight_options, ) _rebind_actor_methods(new_worker_group, prefix) @@ -614,7 +653,9 @@ def execute_all_async(self, method_name: str, *args, **kwargs): for i in range(length): sliced_args = tuple(arg[i] for arg in args) sliced_kwargs = {k: v[i] for k, v in kwargs.items()} - result.append(self._execute_remote_single_worker(self._workers[i], method_name, *sliced_args, **sliced_kwargs)) + result.append( + self._execute_remote_single_worker(self._workers[i], method_name, *sliced_args, **sliced_kwargs) + ) return result return [self._execute_remote_single_worker(worker, method_name, *args, **kwargs) for worker in self._workers] @@ -637,7 +678,7 @@ def world_size(self): """ -Utilities that enables creating workers inside the same ray.Actor, +Utilities that enables creating workers inside the same ray.Actor, with code written in separate ray.Actors. """ @@ -679,7 +720,9 @@ async def async_func(self, *args, **kwargs): try: # bind direct rollout method to class without prefix if attrs["dispatch_mode"] == Dispatch.DIRECT_ROLLOUT_METHOD and "rollout" in key: - assert not hasattr(cls, method_name), f"conflict direct rollout method {method_name} with role {key}" + assert not hasattr(cls, method_name), ( + f"conflict direct rollout method {method_name} with role {key}" + ) setattr(cls, method_name, func) print(f"bind role {key} method {method_name} to class {cls}") else: @@ -695,7 +738,7 @@ def _unwrap_ray_remote(cls): return cls -def _determine_fsdp_megatron_base_class(mros: List): +def _determine_fsdp_megatron_base_class(mros: list): """ - megatron: base class should be MegatronWorker - fsdp: base class should be Worker @@ -716,7 +759,9 @@ def create_colocated_worker_cls(class_dict: dict[str, RayClassWithInitArgs]): """ cls_dict = {} init_args_dict = {} - worker_cls = _determine_fsdp_megatron_base_class([cls.cls.__ray_actor_class__.__mro__ for cls in class_dict.values()]) + worker_cls = _determine_fsdp_megatron_base_class( + [cls.cls.__ray_actor_class__.__mro__ for cls in class_dict.values()] + ) assert issubclass(worker_cls, Worker), f"worker_cls {worker_cls} should be a subclass of Worker" print(f"colocated worker base class {worker_cls}") @@ -736,8 +781,10 @@ def __init__(self): # directly instantiate the class without remote # in worker class, e.g. # when DISABLE_WORKER_INIT == 1 it will return immediately - with patch.dict(os.environ, {"DISABLE_WORKER_INIT": "1"}): - self.worker_dict[key] = user_defined_cls(*init_args_dict[key].get("args", ()), **init_args_dict[key].get("kwargs", {})) + with temp_env_var("DISABLE_WORKER_INIT", "1"): + self.worker_dict[key] = user_defined_cls( + *init_args_dict[key].get("args", ()), **init_args_dict[key].get("kwargs", {}) + ) # now monkey-patch the methods from inner class to WorkerDict for key, user_defined_cls in cls_dict.items(): @@ -786,8 +833,14 @@ def __init__(self, *args, **kwargs): self.init_args_dict = init_args_dict self.init_kwargs_dict = init_kwargs_dict - for cls_name, udc, ud_args, ud_kwargs in zip(self.cls_names, self.raw_cls_dict.values(), self.init_args_dict.values(), self.init_kwargs_dict.values()): - with patch.dict(os.environ, {"DISABLE_WORKER_INIT": "1"}): + for cls_name, udc, ud_args, ud_kwargs in zip( + self.cls_names, + self.raw_cls_dict.values(), + self.init_args_dict.values(), + self.init_kwargs_dict.values(), + strict=True, + ): + with temp_env_var("DISABLE_WORKER_INIT", "1"): udc._get_ray_actor_cls_name = lambda x, name_renamed=class_name_renamed: name_renamed udc._get_ray_method_prefix = lambda x, name_prefixed=cls_name: f"{name_prefixed}_" # cls_name = "actor", "critic", udc = ActorWorker, CriticWorker @@ -805,7 +858,9 @@ def _fuw_execute(self, method_name: str, *args, **kwargs): cls_name = names[0] method_name = names[1] - assert cls_name in self.fused_worker_dict, f"calling {cls_name}'s {method_name}, but {cls_name} not in fused_worker_dict" + assert cls_name in self.fused_worker_dict, ( + f"calling {cls_name}'s {method_name}, but {cls_name} not in fused_worker_dict" + ) udc_method = getattr(self.fused_worker_dict[cls_name], method_name) return udc_method(*args, **kwargs) diff --git a/verl/single_controller/ray/megatron.py b/verl/single_controller/ray/megatron.py index 4f56ac1bf..b46fe44a1 100644 --- a/verl/single_controller/ray/megatron.py +++ b/verl/single_controller/ray/megatron.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Optional +from typing import Optional import ray @@ -40,7 +40,9 @@ def __init__(self, resource_pool: RayResourcePool, ray_cls_with_init: RayClassWi """ super().__init__(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init, **kwargs) self._megatron_rank_info: DistRankInfo = self.execute_all_sync(method_name="get_megatron_rank_info") - self._megatron_global_info: DistGlobalInfo = ray.get(self.execute_rank_zero_async(method_name="get_megatron_global_info")) + self._megatron_global_info: DistGlobalInfo = ray.get( + self.execute_rank_zero_async(method_name="get_megatron_global_info") + ) class MegatronRayWorkerGroup(RayWorkerGroup, MegatronWorkerGroup): @@ -53,7 +55,7 @@ def __init__( self, resource_pool: RayResourcePool, ray_cls_with_init: RayClassWithInitArgs, - default_megatron_kwargs: Dict = None, + default_megatron_kwargs: dict = None, **kwargs, ): super().__init__( @@ -64,9 +66,11 @@ def __init__( ) self.init_megatron(default_megatron_kwargs=default_megatron_kwargs) self._megatron_rank_info: DistRankInfo = self.execute_all_sync(method_name="get_megatron_rank_info") - self._megatron_global_info: DistGlobalInfo = ray.get(self.execute_rank_zero_async(method_name="get_megatron_global_info")) + self._megatron_global_info: DistGlobalInfo = ray.get( + self.execute_rank_zero_async(method_name="get_megatron_global_info") + ) - def init_megatron(self, default_megatron_kwargs: Optional[Dict] = None): + def init_megatron(self, default_megatron_kwargs: Optional[dict] = None): # after super, we will call init of each worker if not self._is_init_with_detached_workers: # only init_megatron if the WorkerGroup is created from scratch diff --git a/verl/test.py b/verl/test.py new file mode 100644 index 000000000..8f7cd89e1 --- /dev/null +++ b/verl/test.py @@ -0,0 +1,6 @@ +a = f"""123 + 456 + 789 + """ + +print(a) \ No newline at end of file diff --git a/verl/third_party/sglang/parallel_state.py b/verl/third_party/sglang/parallel_state.py index 7eea222d1..cdec743d1 100644 --- a/verl/third_party/sglang/parallel_state.py +++ b/verl/third_party/sglang/parallel_state.py @@ -59,7 +59,7 @@ def initialize_parallel_state( assert world_size != -1, "The world_size is set to -1, not initialized by TORCHRUN" init_distributed_environment(world_size, rank, distributed_init_method, local_rank, backend) if torch.distributed.get_world_size() > 1: - # NOTE: build a sepearate inference group with infer tp & micro dp + # NOTE: build a separate inference group with infer tp & micro dp initialize_model_parallel_for_sglang( tensor_model_parallel_size=tensor_model_parallel_size, num_tensor_model_parallel_groups_per_train_tp=num_tp_per_train_tp, @@ -89,9 +89,15 @@ def ensure_model_parallel_initialized( initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, backend) return - assert get_tensor_model_parallel_world_size() == tensor_model_parallel_size, f"tensor parallel group already initialized, but of unexpected size: {get_tensor_model_parallel_world_size()=} vs. {tensor_model_parallel_size=}" + assert get_tensor_model_parallel_world_size() == tensor_model_parallel_size, ( + f"tensor parallel group already initialized, but of unexpected size: " + f"{get_tensor_model_parallel_world_size()=} vs. {tensor_model_parallel_size=}" + ) pp_world_size = get_pp_group().world_size - assert pp_world_size == pipeline_model_parallel_size, f"pipeline parallel group already initialized, but of unexpected size: {pp_world_size=} vs. {pipeline_model_parallel_size=}" + assert pp_world_size == pipeline_model_parallel_size, ( + f"pipeline parallel group already initialized, but of unexpected size: {pp_world_size=} vs. " + f"{pipeline_model_parallel_size=}" + ) # TODO(sgm): deviate from the v0.5.4, not pp now diff --git a/verl/third_party/vllm/__init__.py b/verl/third_party/vllm/__init__.py index bd2a6d34c..76fe51b3c 100644 --- a/verl/third_party/vllm/__init__.py +++ b/verl/third_party/vllm/__init__.py @@ -30,24 +30,26 @@ def get_version(pkg): package_version = get_version(package_name) vllm_version = None - -if package_version == "0.5.4": - vllm_version = "0.5.4" - from .vllm_v_0_5_4 import parallel_state - from .vllm_v_0_5_4.llm import LLM, LLMEngine -elif package_version == "0.6.3" or package_version.startswith("0.6.3"): - # rocm version: "0.6.3+rocmxxx" - vllm_version = "0.6.3" - from .vllm_v_0_6_3 import parallel_state - from .vllm_v_0_6_3.llm import LLM, LLMEngine +if package_version is None: + if not is_sglang_available(): + raise ValueError( + f"vllm version {package_version} not supported and SGLang also not Found. Currently supported " + f"vllm versions are 0.7.0+" + ) elif vs.parse(package_version) >= vs.parse("0.7.0"): - # From 0.6.6.post2 on, vllm supports SPMD inference - # See https://github.com/vllm-project/vllm/pull/12071 - + vllm_version = package_version from vllm import LLM from vllm.distributed import parallel_state else: + if vs.parse(package_version) in [vs.parse("0.5.4"), vs.parse("0.6.3")]: + raise ValueError( + f"vLLM version {package_version} support has been removed. vLLM 0.5.4 and 0.6.3 are no longer " + f"supported. Please use vLLM 0.7.0 or later." + ) if not is_sglang_available(): - raise ValueError(f"vllm version {package_version} not supported and SGLang also not Found. Currently supported vllm versions are 0.6.3 and 0.7.0+") + raise ValueError( + f"vllm version {package_version} not supported and SGLang also not Found. Currently supported " + f"vllm versions are 0.7.0+" + ) -__all__ = ["LLM", "LLMEngine", "parallel_state"] +__all__ = ["LLM", "parallel_state"] diff --git a/verl/third_party/vllm/vllm_v_0_5_4/arg_utils.py b/verl/third_party/vllm/vllm_v_0_5_4/arg_utils.py deleted file mode 100644 index 28529cc0d..000000000 --- a/verl/third_party/vllm/vllm_v_0_5_4/arg_utils.py +++ /dev/null @@ -1,447 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2023 The vLLM team. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/engine/arg_utils.py - -import argparse -import dataclasses -import os -from dataclasses import dataclass -from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union - -from transformers import PretrainedConfig -from vllm.config import ( - CacheConfig, - DecodingConfig, - DeviceConfig, - EngineConfig, - LoRAConfig, - MultiModalConfig, - ObservabilityConfig, - ParallelConfig, - PromptAdapterConfig, - SchedulerConfig, - SpeculativeConfig, - TokenizerPoolConfig, -) -from vllm.executor.executor_base import ExecutorBase -from vllm.logger import init_logger - -from .config import LoadConfig, ModelConfig - -if TYPE_CHECKING: - from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import BaseTokenizerGroup - -logger = init_logger(__name__) - - -def nullable_str(val: str): - if not val or val == "None": - return None - return val - - -@dataclass -class EngineArgs: - """Arguments for vLLM engine.""" - - model_hf_config: PretrainedConfig = None # for verl - served_model_name = None # TODO(sgm): check this - # tokenizer: Optional[str] = None # TODO(sgm): check this - skip_tokenizer_init: bool = False - tokenizer_mode: str = "auto" - trust_remote_code: bool = False - download_dir: Optional[str] = None - load_format: str = "auto" - dtype: str = "auto" - kv_cache_dtype: str = "auto" - quantization_param_path: Optional[str] = None - seed: int = 0 - max_model_len: Optional[int] = None - worker_use_ray: bool = False - # Note: Specifying a custom executor backend by passing a class - # is intended for expert use only. The API may change without - # notice. - distributed_executor_backend: Optional[Union[str, Type[ExecutorBase]]] = None - pipeline_parallel_size: int = 1 - tensor_parallel_size: int = 1 - max_parallel_loading_workers: Optional[int] = None - block_size: int = 16 - enable_prefix_caching: bool = False - disable_sliding_window: bool = False - use_v2_block_manager: bool = False - swap_space: int = 4 # GiB - cpu_offload_gb: int = 0 # GiB - gpu_memory_utilization: float = 0.90 - max_num_batched_tokens: Optional[int] = None - max_num_seqs: int = 256 - max_logprobs: int = 20 # Default value for OpenAI Chat Completions API - disable_log_stats: bool = False - revision: Optional[str] = None - code_revision: Optional[str] = None - rope_scaling: Optional[dict] = None - rope_theta: Optional[float] = None - tokenizer_revision: Optional[str] = None - quantization: Optional[str] = None - enforce_eager: bool = False - max_context_len_to_capture: Optional[int] = None - max_seq_len_to_capture: int = 8192 - disable_custom_all_reduce: bool = False - tokenizer_pool_size: int = 0 - # Note: Specifying a tokenizer pool by passing a class - # is intended for expert use only. The API may change without - # notice. - tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]] = "ray" - tokenizer_pool_extra_config: Optional[dict] = None - enable_lora: bool = False - max_loras: int = 1 - max_lora_rank: int = 16 - enable_prompt_adapter: bool = False - max_prompt_adapters: int = 1 - max_prompt_adapter_token: int = 0 - fully_sharded_loras: bool = False - lora_extra_vocab_size: int = 256 - long_lora_scaling_factors: Optional[Tuple[float]] = None - lora_dtype: str = "auto" - max_cpu_loras: Optional[int] = None - device: str = "auto" - ray_workers_use_nsight: bool = False - num_gpu_blocks_override: Optional[int] = None - num_lookahead_slots: int = 0 - model_loader_extra_config: Optional[dict] = None - ignore_patterns: Optional[Union[str, List[str]]] = None - preemption_mode: Optional[str] = None - - scheduler_delay_factor: float = 0.0 - enable_chunked_prefill: Optional[bool] = None - - guided_decoding_backend: str = "outlines" - # Speculative decoding configuration. - speculative_model: Optional[str] = None - speculative_draft_tensor_parallel_size: Optional[int] = None - num_speculative_tokens: Optional[int] = None - speculative_max_model_len: Optional[int] = None - speculative_disable_by_batch_size: Optional[int] = None - ngram_prompt_lookup_max: Optional[int] = None - ngram_prompt_lookup_min: Optional[int] = None - spec_decoding_acceptance_method: str = "rejection_sampler" - typical_acceptance_sampler_posterior_threshold: Optional[float] = None - typical_acceptance_sampler_posterior_alpha: Optional[float] = None - qlora_adapter_name_or_path: Optional[str] = None - disable_logprobs_during_spec_decoding: Optional[bool] = None - - otlp_traces_endpoint: Optional[str] = None - - @staticmethod - def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: - """Shared CLI arguments for vLLM engine.""" - # Model arguments - # TODO(shengguangming): delete the unused args - parser.add_argument("--model", type=str, default="facebook/opt-125m", help="name or path of the huggingface model to use") - parser.add_argument( - "--tokenizer", - type=str, - default=EngineArgs.tokenizer, - help="name or path of the huggingface tokenizer to use", - ) - parser.add_argument( - "--revision", - type=str, - default=None, - help="the specific model version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version.", - ) - parser.add_argument( - "--tokenizer-revision", - type=str, - default=None, - help="the specific tokenizer version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version.", - ) - parser.add_argument( - "--tokenizer-mode", - type=str, - default=EngineArgs.tokenizer_mode, - choices=["auto", "slow"], - help='tokenizer mode. "auto" will use the fast tokenizer if available, and "slow" will always use the slow tokenizer.', - ) - parser.add_argument("--trust-remote-code", action="store_true", help="trust remote code from huggingface") - parser.add_argument( - "--download-dir", - type=str, - default=EngineArgs.download_dir, - help="directory to download and load the weights, default to the default cache dir of huggingface", - ) - parser.add_argument( - "--load-format", - type=str, - default=EngineArgs.load_format, - choices=["auto", "pt", "safetensors", "npcache", "dummy"], - help="The format of the model weights to load. " - '"auto" will try to load the weights in the safetensors format ' - "and fall back to the pytorch bin format if safetensors format " - "is not available. " - '"pt" will load the weights in the pytorch bin format. ' - '"safetensors" will load the weights in the safetensors format. ' - '"npcache" will load the weights in pytorch format and store ' - "a numpy cache to speed up the loading. " - '"dummy" will initialize the weights with random values, ' - "which is mainly for profiling.", - ) - parser.add_argument( - "--dtype", - type=str, - default=EngineArgs.dtype, - choices=["auto", "half", "float16", "bfloat16", "float", "float32"], - help='data type for model weights and activations. The "auto" option will use FP16 precision for FP32 and FP16 models, and BF16 precision for BF16 models.', - ) - parser.add_argument( - "--max-model-len", - type=int, - default=None, - help="model context length. If unspecified, will be automatically derived from the model.", - ) - # Parallel arguments - parser.add_argument( - "--worker-use-ray", - action="store_true", - help="use Ray for distributed serving, will be automatically set when using more than 1 GPU", - ) - parser.add_argument( - "--pipeline-parallel-size", - "-pp", - type=int, - default=EngineArgs.pipeline_parallel_size, - help="number of pipeline stages", - ) - parser.add_argument( - "--tensor-parallel-size", - "-tp", - type=int, - default=EngineArgs.tensor_parallel_size, - help="number of tensor parallel replicas", - ) - # KV cache arguments - parser.add_argument("--block-size", type=int, default=EngineArgs.block_size, choices=[8, 16, 32], help="token block size") - # TODO(woosuk): Support fine-grained seeds (e.g., seed per request). - parser.add_argument("--seed", type=int, default=EngineArgs.seed, help="random seed") - parser.add_argument("--swap-space", type=int, default=EngineArgs.swap_space, help="CPU swap space size (GiB) per GPU") - parser.add_argument( - "--gpu-memory-utilization", - type=float, - default=EngineArgs.gpu_memory_utilization, - help="the percentage of GPU memory to be used forthe model executor", - ) - parser.add_argument( - "--max-num-batched-tokens", - type=int, - default=EngineArgs.max_num_batched_tokens, - help="maximum number of batched tokens per iteration", - ) - parser.add_argument( - "--max-num-seqs", - type=int, - default=EngineArgs.max_num_seqs, - help="maximum number of sequences per iteration", - ) - parser.add_argument("--disable-log-stats", action="store_true", help="disable logging statistics") - # Quantization settings. - parser.add_argument( - "--quantization", - "-q", - type=str, - choices=["awq", None], - default=None, - help="Method used to quantize the weights", - ) - return parser - - @classmethod - def from_cli_args(cls, args: argparse.Namespace) -> "EngineArgs": - # Get the list of attributes of this dataclass. - attrs = [attr.name for attr in dataclasses.fields(cls)] - # Set the attributes from the parsed arguments. - engine_args = cls(**{attr: getattr(args, attr) for attr in attrs}) - return engine_args - - def create_engine_config( - self, - ) -> EngineConfig: - # bitsandbytes quantization needs a specific model loader - # so we make sure the quant method and the load format are consistent - if (self.quantization == "bitsandbytes" or self.qlora_adapter_name_or_path is not None) and self.load_format != "bitsandbytes": - raise ValueError(f"BitsAndBytes quantization and QLoRA adapter only support 'bitsandbytes' load format, but got {self.load_format}") - - if (self.load_format == "bitsandbytes" or self.qlora_adapter_name_or_path is not None) and self.quantization != "bitsandbytes": - raise ValueError(f"BitsAndBytes load format and QLoRA adapter only support 'bitsandbytes' quantization, but got {self.quantization}") - - assert self.cpu_offload_gb >= 0, f"CPU offload space must be non-negative, but got {self.cpu_offload_gb}" - - multimodal_config = MultiModalConfig() - device_config = DeviceConfig(self.device) - # NOTE(sgm): we only modify ModelConfig, other configs are import from vllm - model_config = ModelConfig( - hf_config=self.model_hf_config, - tokenizer_mode=self.tokenizer_mode, - trust_remote_code=self.trust_remote_code, - dtype=self.dtype, - seed=self.seed, - revision=self.revision, - code_revision=self.code_revision, - rope_scaling=self.rope_scaling, - rope_theta=self.rope_theta, - tokenizer_revision=self.tokenizer_revision, - max_model_len=self.max_model_len, - quantization=self.quantization, - quantization_param_path=self.quantization_param_path, - enforce_eager=self.enforce_eager, - max_context_len_to_capture=self.max_context_len_to_capture, - max_seq_len_to_capture=self.max_seq_len_to_capture, - max_logprobs=self.max_logprobs, - disable_sliding_window=self.disable_sliding_window, - skip_tokenizer_init=self.skip_tokenizer_init, - served_model_name=self.served_model_name, - multimodal_config=multimodal_config, - ) - cache_config = CacheConfig( - block_size=self.block_size, - gpu_memory_utilization=self.gpu_memory_utilization, - swap_space=self.swap_space, - cache_dtype=self.kv_cache_dtype, - num_gpu_blocks_override=self.num_gpu_blocks_override, - sliding_window=model_config.get_sliding_window(), - enable_prefix_caching=self.enable_prefix_caching, - cpu_offload_gb=self.cpu_offload_gb, - ) - parallel_config = ParallelConfig( - pipeline_parallel_size=self.pipeline_parallel_size, - tensor_parallel_size=self.tensor_parallel_size, - worker_use_ray=self.worker_use_ray, - max_parallel_loading_workers=self.max_parallel_loading_workers, - disable_custom_all_reduce=self.disable_custom_all_reduce, - tokenizer_pool_config=TokenizerPoolConfig.create_config( - self.tokenizer_pool_size, - self.tokenizer_pool_type, - self.tokenizer_pool_extra_config, - ), - ray_workers_use_nsight=self.ray_workers_use_nsight, - distributed_executor_backend=self.distributed_executor_backend, - ) - - # NOTE[VERL]: Use the world_size set by TORCHRUN - world_size = int(os.getenv("WORLD_SIZE", "-1")) - assert world_size != -1, "The world_size is set to -1, not initialized by TORCHRUN" - parallel_config.world_size = world_size - - max_model_len = model_config.max_model_len - use_long_context = max_model_len > 32768 - if self.enable_chunked_prefill is None: - # If not explicitly set, enable chunked prefill by default for - # long context (> 32K) models. This is to avoid OOM errors in the - # initial memory profiling phase. - if use_long_context: - is_gpu = device_config.device_type == "cuda" - use_sliding_window = model_config.get_sliding_window() is not None - use_spec_decode = self.speculative_model is not None - has_seqlen_agnostic_layers = model_config.contains_seqlen_agnostic_layers(parallel_config) - if is_gpu and not use_sliding_window and not use_spec_decode and not self.enable_lora and not self.enable_prompt_adapter and not self.enable_prefix_caching and not has_seqlen_agnostic_layers: - self.enable_chunked_prefill = True - logger.warning("Chunked prefill is enabled by default for models with max_model_len > 32K. Currently, chunked prefill might not work with some features or models. If you encounter any issues, please disable chunked prefill by setting --enable-chunked-prefill=False.") - if self.enable_chunked_prefill is None: - self.enable_chunked_prefill = False - - if not self.enable_chunked_prefill and use_long_context: - logger.warning( - "The model has a long context length (%s). This may cause OOM errors during the initial memory profiling phase, or result in low performance due to small KV cache space. Consider setting --max-model-len to a smaller value.", - max_model_len, - ) - - # TODO: spec config - speculative_config = SpeculativeConfig.maybe_create_spec_config( - target_model_config=model_config, - target_parallel_config=parallel_config, - target_dtype=self.dtype, - speculative_model=self.speculative_model, - speculative_draft_tensor_parallel_size=self.speculative_draft_tensor_parallel_size, - num_speculative_tokens=self.num_speculative_tokens, - speculative_disable_by_batch_size=self.speculative_disable_by_batch_size, - speculative_max_model_len=self.speculative_max_model_len, - enable_chunked_prefill=self.enable_chunked_prefill, - use_v2_block_manager=self.use_v2_block_manager, - disable_log_stats=self.disable_log_stats, - ngram_prompt_lookup_max=self.ngram_prompt_lookup_max, - ngram_prompt_lookup_min=self.ngram_prompt_lookup_min, - draft_token_acceptance_method=self.spec_decoding_acceptance_method, - typical_acceptance_sampler_posterior_threshold=self.typical_acceptance_sampler_posterior_threshold, - typical_acceptance_sampler_posterior_alpha=self.typical_acceptance_sampler_posterior_alpha, - disable_logprobs=self.disable_logprobs_during_spec_decoding, - ) - - scheduler_config = SchedulerConfig( - max_num_batched_tokens=self.max_num_batched_tokens, - max_num_seqs=self.max_num_seqs, - max_model_len=model_config.max_model_len, - use_v2_block_manager=self.use_v2_block_manager, - num_lookahead_slots=(self.num_lookahead_slots if speculative_config is None else speculative_config.num_lookahead_slots), - delay_factor=self.scheduler_delay_factor, - enable_chunked_prefill=self.enable_chunked_prefill, - embedding_mode=model_config.embedding_mode, - preemption_mode=self.preemption_mode, - ) - lora_config = ( - LoRAConfig( - max_lora_rank=self.max_lora_rank, - max_loras=self.max_loras, - fully_sharded_loras=self.fully_sharded_loras, - lora_extra_vocab_size=self.lora_extra_vocab_size, - long_lora_scaling_factors=self.long_lora_scaling_factors, - lora_dtype=self.lora_dtype, - max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras and self.max_cpu_loras > 0 else None, - ) - if self.enable_lora - else None - ) - - if self.qlora_adapter_name_or_path is not None and self.qlora_adapter_name_or_path != "": - if self.model_loader_extra_config is None: - self.model_loader_extra_config = {} - self.model_loader_extra_config["qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path - - load_config = LoadConfig( - load_format=self.load_format, - download_dir=self.download_dir, - model_loader_extra_config=self.model_loader_extra_config, - ignore_patterns=self.ignore_patterns, - ) - - prompt_adapter_config = PromptAdapterConfig(max_prompt_adapters=self.max_prompt_adapters, max_prompt_adapter_token=self.max_prompt_adapter_token) if self.enable_prompt_adapter else None - - decoding_config = DecodingConfig(guided_decoding_backend=self.guided_decoding_backend) - - observability_config = ObservabilityConfig(otlp_traces_endpoint=self.otlp_traces_endpoint) - - if model_config.get_sliding_window() is not None and scheduler_config.chunked_prefill_enabled and not scheduler_config.use_v2_block_manager: - raise ValueError("Chunked prefill is not supported with sliding window. Set --disable-sliding-window to disable sliding window.") - - return EngineConfig( - model_config=model_config, - cache_config=cache_config, - parallel_config=parallel_config, - scheduler_config=scheduler_config, - device_config=device_config, - lora_config=lora_config, - multimodal_config=multimodal_config, - speculative_config=speculative_config, - load_config=load_config, - decoding_config=decoding_config, - observability_config=observability_config, - prompt_adapter_config=prompt_adapter_config, - ) diff --git a/verl/third_party/vllm/vllm_v_0_5_4/config.py b/verl/third_party/vllm/vllm_v_0_5_4/config.py deleted file mode 100644 index 2133c6898..000000000 --- a/verl/third_party/vllm/vllm_v_0_5_4/config.py +++ /dev/null @@ -1,247 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2023 The vLLM team. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py - -import enum -import json -from dataclasses import dataclass, field -from typing import List, Optional, Union - -import torch -from transformers import PretrainedConfig - -# Add for verl -from vllm.config import ( - ModelConfig, - MultiModalConfig, - _get_and_verify_dtype, - _get_and_verify_max_len, - get_served_model_name, -) -from vllm.logger import init_logger -from vllm.model_executor.layers.quantization import get_quantization_config -from vllm.model_executor.model_loader import BaseModelLoader -from vllm.transformers_utils.config import get_hf_text_config -from vllm.utils import is_hip, print_warning_once - -GPTQMarlinConfig = get_quantization_config("gptq_marlin") - -logger = init_logger(__name__) - -_GB = 1 << 30 - - -class ModelConfig(ModelConfig): - """Configuration for the model. - - Args: - model: Name or path of the huggingface model to use. - tokenizer: Name or path of the huggingface tokenizer to use. - tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if - available, and "slow" will always use the slow tokenizer. - trust_remote_code: Trust remote code (e.g., from HuggingFace) when - downloading the model and tokenizer. - download_dir: Directory to download and load the weights, default to the - default cache directory of huggingface. - load_format: The format of the model weights to load: - "auto" will try to load the weights in the safetensors format and - fall back to the pytorch bin format if safetensors format is - not available. - "pt" will load the weights in the pytorch bin format. - "safetensors" will load the weights in the safetensors format. - "npcache" will load the weights in pytorch format and store - a numpy cache to speed up the loading. - "dummy" will initialize the weights with random values, which is - mainly for profiling. - dtype: Data type for model weights and activations. The "auto" option - will use FP16 precision for FP32 and FP16 models, and BF16 precision - for BF16 models. - seed: Random seed for reproducibility. - revision: The specific model version to use. It can be a branch name, - a tag name, or a commit id. If unspecified, will use the default - version. - code_revision: The specific revision to use for the model code on - Hugging Face Hub. It can be a branch name, a tag name, or a - commit id. If unspecified, will use the default version. - tokenizer_revision: The specific tokenizer version to use. It can be a - branch name, a tag name, or a commit id. If unspecified, will use - the default version. - max_model_len: Maximum length of a sequence (including prompt and - output). If None, will be derived from the model. - quantization: Quantization method that was used to quantize the model - weights. If None, we assume the model weights are not quantized. - quantization_param_path: Path to JSON file containing scaling factors. - Used to load KV cache scaling factors into the model when KV cache - type is FP8_E4M3 on ROCm (AMD GPU). In the future these will also - be used to load activation and weight scaling factors when the - model dtype is FP8_E4M3 on ROCm. - enforce_eager: Whether to enforce eager execution. If True, we will - disable CUDA graph and always execute the model in eager mode. - If False, we will use CUDA graph and eager execution in hybrid. - max_context_len_to_capture: Maximum context len covered by CUDA graphs. - When a sequence has context length larger than this, we fall back - to eager mode (DEPRECATED. Use max_seq_len_to_capture instead). - max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs. - When a sequence has context length larger than this, we fall back - to eager mode - skip_tokenizer_init: If true, skip initialization of tokenizer and - detokenizer. - served_model_name: The model name used in metrics tag `model_name`, - matches the model name exposed via the APIs. If multiple model - names provided, the first name will be used. If not specified, - the model name will be the same as `model`. - """ - - def __init__( - self, - hf_config: PretrainedConfig, - tokenizer_mode: str, - trust_remote_code: bool, - dtype: Union[str, torch.dtype], - seed: int, - revision: Optional[str] = None, - code_revision: Optional[str] = None, - rope_scaling: Optional[dict] = None, - rope_theta: Optional[float] = None, - tokenizer_revision: Optional[str] = None, - max_model_len: Optional[int] = None, - quantization: Optional[str] = None, - quantization_param_path: Optional[str] = None, - enforce_eager: bool = False, - max_context_len_to_capture: Optional[int] = None, - max_seq_len_to_capture: Optional[int] = None, - max_logprobs: int = 20, - disable_sliding_window: bool = False, - skip_tokenizer_init: bool = False, - served_model_name: Optional[Union[str, List[str]]] = None, - multimodal_config: Optional[MultiModalConfig] = None, - ) -> None: - self.model = hf_config._name_or_path - self.tokenizer = hf_config._name_or_path - # NOTE(sgm): same as open-sourced - self.tokenizer_mode = tokenizer_mode - self.trust_remote_code = trust_remote_code - self.seed = seed - self.revision = revision - self.code_revision = code_revision - self.rope_scaling = rope_scaling - self.rope_theta = rope_theta - # The tokenizer version is consistent with the model version by default. - if tokenizer_revision is None: - self.tokenizer_revision = revision - else: - self.tokenizer_revision = tokenizer_revision - self.quantization = quantization - self.quantization_param_path = quantization_param_path - self.enforce_eager = enforce_eager - if max_context_len_to_capture is not None: - raise ValueError("`max_context_len_to_capture` is deprecated. Use `max_seq_len_to_capture` instead.") - self.max_seq_len_to_capture = max_seq_len_to_capture - self.max_logprobs = max_logprobs - self.disable_sliding_window = disable_sliding_window - self.skip_tokenizer_init = skip_tokenizer_init - - # self.hf_config = get_config(model, trust_remote_code, revision) - self.hf_config = hf_config - self.hf_text_config = get_hf_text_config(hf_config) - self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) - # self.served_model_name = get_served_model_name(model, - # served_model_name) - # self._verify_load_format() - # self._verify_tokenizer_mode() - if not self.disable_sliding_window and self.hf_text_config.model_type == "gemma2" and self.hf_text_config.sliding_window is not None: - print_warning_once(f"Gemma 2 uses sliding window attention for every odd layer, which is currently not supported by vLLM. Disabling sliding window and capping the max length to the sliding window size ({self.hf_text_config.sliding_window}).") - self.disable_sliding_window = True - - self.max_model_len = _get_and_verify_max_len( - hf_config=self.hf_text_config, - max_model_len=max_model_len, - disable_sliding_window=self.disable_sliding_window, - sliding_window_len=self.get_hf_config_sliding_window(), - ) - self.served_model_name = get_served_model_name( - self.model, # str - served_model_name, - ) - self.multimodal_config = multimodal_config - - if not self.skip_tokenizer_init: - self._verify_tokenizer_mode() - self._verify_embedding_mode() - self._verify_quantization() - self._verify_cuda_graph() - - -class LoadFormat(str, enum.Enum): - AUTO = "auto" - MEGATRON = "megatron" - HF = "hf" - DTENSOR = "dtensor" - DUMMY_HF = "dummy_hf" - DUMMY_MEGATRON = "dummy_megatron" - DUMMY_DTENSOR = "dummy_dtensor" - - -# TODO: check whether this is necessary -@dataclass -class LoadConfig: - """ - download_dir: Directory to download and load the weights, default to the - default cache directory of huggingface. - load_format: The format of the model weights to load: - "auto" will try to load the weights in the safetensors format and - fall back to the pytorch bin format if safetensors format is - not available. - "pt" will load the weights in the pytorch bin format. - "safetensors" will load the weights in the safetensors format. - "npcache" will load the weights in pytorch format and store - a numpy cache to speed up the loading. - "dummy" will initialize the weights with random values, which is - mainly for profiling. - "tensorizer" will use CoreWeave's tensorizer library for - fast weight loading. - "bitsandbytes" will load nf4 type weights. - ignore_patterns: The list of patterns to ignore when loading the model. - Default to "original/**/*" to avoid repeated loading of llama's - checkpoints. - - """ - - load_format: Union[str, LoadFormat, BaseModelLoader] = LoadFormat.AUTO - download_dir: Optional[str] = None - model_loader_extra_config: Optional[Union[str, dict]] = field(default_factory=dict) - ignore_patterns: Optional[Union[List[str], str]] = None - - def __post_init__(self): - model_loader_extra_config = self.model_loader_extra_config or {} - if isinstance(model_loader_extra_config, str): - self.model_loader_extra_config = json.loads(model_loader_extra_config) - self._verify_load_format() - - if self.ignore_patterns is not None and len(self.ignore_patterns) > 0: - logger.info("Ignoring the following patterns when downloading weights: %s", self.ignore_patterns) - else: - self.ignore_patterns = ["original/**/*"] - - def _verify_load_format(self) -> None: - if not isinstance(self.load_format, str): - return - - load_format = self.load_format.lower() - self.load_format = LoadFormat(load_format) - - rocm_not_supported_load_format: List[str] = [] - if is_hip() and load_format in rocm_not_supported_load_format: - rocm_supported_load_format = [f for f in LoadFormat.__members__ if (f not in rocm_not_supported_load_format)] - raise ValueError(f"load format '{load_format}' is not supported in ROCm. Supported load formats are {rocm_supported_load_format}") diff --git a/verl/third_party/vllm/vllm_v_0_5_4/dtensor_weight_loaders.py b/verl/third_party/vllm/vllm_v_0_5_4/dtensor_weight_loaders.py deleted file mode 100644 index 22cca0950..000000000 --- a/verl/third_party/vllm/vllm_v_0_5_4/dtensor_weight_loaders.py +++ /dev/null @@ -1,337 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2023 The vLLM team. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models - -from typing import Dict - -import torch.nn as nn -from torch.distributed._tensor import DTensor -from vllm.model_executor.layers.fused_moe import FusedMoE -from vllm.model_executor.layers.linear import * -from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.utils import is_pp_missing_parameter - - -def gemma_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - params_dict = dict(vllm_model.named_parameters()) - for name, loaded_weight in actor_weights.items(): - for param_name, shard_name, shard_id in stacked_params_mapping: - if shard_name not in name: - continue - stacked_name = name.replace(shard_name, param_name) - # Skip loading extra bias for GPTQ models. - if stacked_name.endswith(".bias") and stacked_name not in params_dict: - continue - local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) - param = params_dict[stacked_name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id) - break - else: - # lm_head is not used in vllm as it is tied with embed_token. - # To prevent errors, skip loading lm_head.weight. - if "lm_head.weight" in name: - continue - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, local_loaded_weight.to(dtype=param.dtype)) - - -def gptbigcode_dtensor_load_weights(actor_weights: Dict, vllm_model: nn.Module): - params_dict = dict(vllm_model.named_parameters(remove_duplicate=False)) - for name, loaded_weight in actor_weights.items(): - if "lm_head.weight" in name: - continue - if ".attn.bias" in name: - # Skip attention mask. - # NOTE: "c_attn.bias" should not be skipped. - continue - local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, local_loaded_weight.to(dtype=param.dtype)) - - -def starcoder2_dtensor_load_weights(actor_weights: Dict, vllm_model: nn.Module): - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ] - - params_dict = dict(vllm_model.named_parameters(remove_duplicate=False)) - for name, loaded_weight in actor_weights.items(): - if "rotary_emb.inv_freq" in name: - continue - - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id) - break - else: - if vllm_model.config.tie_word_embeddings and "lm_head.weight" in name: - continue - param = params_dict[name] - local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, local_loaded_weight.to(dtype=param.dtype)) - - -def llama_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - (".qkv_proj", ".q_proj", "q"), - (".qkv_proj", ".k_proj", "k"), - (".qkv_proj", ".v_proj", "v"), - (".gate_up_proj", ".gate_proj", 0), - (".gate_up_proj", ".up_proj", 1), - ] - params_dict = dict(vllm_model.named_parameters()) - for name, loaded_weight in actor_weights.items(): - if "rotary_emb.inv_freq" in name: - continue - if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: - # Models trained using ColossalAI may include these tensors in - # the checkpoint. Skip them. - continue - # With tie_word_embeddings, we can skip lm_head.weight - # The weight might appear unnecessarily in the files if the model is - # processed with quantization, LoRA, fine-tuning, etc. - if vllm_model.config.tie_word_embeddings and "lm_head.weight" in name: - continue - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, local_loaded_weight) - - -def qwen2_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - params_dict = dict(vllm_model.named_parameters(remove_duplicate=False)) - for name, loaded_weight in actor_weights.items(): - if "rotary_emb.inv_freq" in name: - continue - if vllm_model.config.tie_word_embeddings and "lm_head.weight" in name: - continue - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - param = params_dict[name] - local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, local_loaded_weight.to(dtype=param.dtype)) - - -def deepseekv2_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - - # Params for weights, fp8 weight scales, fp8 activation scales - # (param_name, weight_name, expert_id, shard_id) - expert_params_mapping = FusedMoE.make_expert_params_mapping( - ckpt_gate_proj_name="gate_proj", - ckpt_down_proj_name="down_proj", - ckpt_up_proj_name="up_proj", - num_experts=vllm_model.config.n_routed_experts, - ) - - params_dict = dict(vllm_model.named_parameters(remove_duplicate=False)) - for name, loaded_weight in actor_weights.items(): - if "rotary_emb.inv_freq" in name: - continue - for param_name, weight_name, shard_id in stacked_params_mapping: - # Skip non-stacked layers and experts (experts handled below). - if weight_name not in name: - continue - # We have mlp.experts[0].gate_proj in the checkpoint. - # Since we handle the experts below in expert_params_mapping, - # we need to skip here BEFORE we update the name, otherwise - # name will be updated to mlp.experts[0].gate_up_proj, which - # will then be updated below in expert_params_mapping - # for mlp.experts[0].gate_gate_up_proj, which breaks load. - if ("mlp.experts." in name) and name not in params_dict: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - - if is_pp_missing_parameter(name, vllm_model): - continue - - param = params_dict[name] - local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id) - break - else: - for mapping in expert_params_mapping: - param_name, weight_name, expert_id, shard_id = mapping - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - - if is_pp_missing_parameter(name, vllm_model): - continue - - param = params_dict[name] - local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader( - param, - local_loaded_weight.to(dtype=param.dtype), - weight_name, - shard_id=shard_id, - expert_id=expert_id, - ) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - - if is_pp_missing_parameter(name, vllm_model): - continue - - param = params_dict[name] - local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, local_loaded_weight.to(dtype=param.dtype)) - - -def gpt2_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: - pass - - -def redistribute_dtensor(param_name: str, loaded_weights: DTensor, parallelize_plan: Dict = None): - param_name = _process_parameter_names(name=param_name) - if parallelize_plan is not None: - assert param_name in parallelize_plan, f"param name: {param_name} not in parallelize_plan :{parallelize_plan.keys()}" - placement = parallelize_plan[param_name] - local_loaded_weights = loaded_weights.redistribute(device_mesh=loaded_weights.device_mesh, placements=placement).to_local() - else: - local_loaded_weights = loaded_weights.full_tensor() - return local_loaded_weights - - -def _process_parameter_names(name): - # Remove '.weight' if it exists at the end of the string - if name.endswith(".weight"): - name = name[:-7] - - # Remove 'model.layers.x.' or 'model.' prefix - if "model.layers" in name: - parts = name.split(".") - # Reconstruct the string without 'model.layers.x.' - name = ".".join(parts[3:]) # parts[0] is 'model', parts[1] is 'layers', parts[2] is 'x' - elif name.startswith("model."): - name = name[6:] # Remove 'model.' - - return name - - -__MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__ = { - "GPT2LMHeadModel": gpt2_dtensor_weight_loader, - "LlamaForCausalLM": llama_dtensor_weight_loader, - "LLaMAForCausalLM": llama_dtensor_weight_loader, - "MistralForCausalLM": llama_dtensor_weight_loader, # mistral is the same as llama in vLLM - "InternLMForCausalLM": llama_dtensor_weight_loader, - "AquilaModel": llama_dtensor_weight_loader, - "AquilaForCausalLM": llama_dtensor_weight_loader, - "Phi3ForCausalLM": llama_dtensor_weight_loader, - "GemmaForCausalLM": gemma_dtensor_weight_loader, - "Gemma2ForCausalLM": gemma_dtensor_weight_loader, - "GPTBigCodeForCausalLM": gptbigcode_dtensor_load_weights, - "Starcoder2ForCausalLM": starcoder2_dtensor_load_weights, - "Qwen2ForCausalLM": qwen2_dtensor_weight_loader, - "DeepseekV2ForCausalLM": deepseekv2_dtensor_weight_loader, -} - - -# the actor model is .state_dict() -# Load dtensor weights -def load_dtensor_weights(actor_weights: Dict, vllm_model: nn.Module): - weight_loader = _get_model_weight_loader(vllm_model.__class__.__name__) - weight_loader(actor_weights, vllm_model) - # NOTE(sgm) to reduce peak memory usage, we offload vllm model to cpu - # after init, and we need this after sync model weights for in first iter. - vllm_model = vllm_model.cuda() - - -def _get_model_weight_loader(arch: str): - if arch in __MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__: - return __MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__[arch] - raise ValueError(f"Model architectures {arch} are not supported for now. Supported architectures: {__MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__.keys()}") - - -# NOTE(sgm): we use per-parameter weight loader in each vllm sub -def update_dtensor_weight_loader(): - pass diff --git a/verl/third_party/vllm/vllm_v_0_5_4/hf_weight_loader.py b/verl/third_party/vllm/vllm_v_0_5_4/hf_weight_loader.py deleted file mode 100644 index 0de56a008..000000000 --- a/verl/third_party/vllm/vllm_v_0_5_4/hf_weight_loader.py +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2023 The vLLM team. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models - -from typing import Dict - -import torch.nn as nn -from vllm.model_executor.model_loader.utils import set_default_torch_dtype - - -def update_hf_weight_loader(): - print("no hf weight loader need to be updated") - return - - -def load_hf_weights(actor_weights: Dict, vllm_model: nn.Module): - assert isinstance(actor_weights, Dict) - with set_default_torch_dtype(next(vllm_model.parameters()).dtype): # TODO - if vllm_model.config.tie_word_embeddings and "lm_head.weight" in actor_weights: - del actor_weights["lm_head.weight"] - vllm_model.load_weights(actor_weights.items()) - for _, module in vllm_model.named_modules(): - quant_method = getattr(module, "quant_method", None) - if quant_method is not None: - quant_method.process_weights_after_loading(module) - # FIXME: Remove this after Mixtral is updated - # to use quant_method. - if hasattr(module, "process_weights_after_loading"): - module.process_weights_after_loading() - vllm_model = vllm_model.cuda() diff --git a/verl/third_party/vllm/vllm_v_0_5_4/llm.py b/verl/third_party/vllm/vllm_v_0_5_4/llm.py deleted file mode 100644 index a1affd701..000000000 --- a/verl/third_party/vllm/vllm_v_0_5_4/llm.py +++ /dev/null @@ -1,224 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2023 The vLLM team. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/llm.py - -from typing import Dict, Iterable, List, Optional, Tuple, Union - -import torch -import torch.nn as nn -from torch.nn.utils.rnn import pad_sequence -from tqdm import tqdm -from transformers import PretrainedConfig, PreTrainedTokenizer, PreTrainedTokenizerFast -from vllm import LLM -from vllm.outputs import EmbeddingRequestOutput, RequestOutput -from vllm.utils import Counter - -from verl.workers.rollout.tokenizer import HybridEngineBaseTokenizer - -from .arg_utils import EngineArgs -from .llm_engine_sp import LLMEngine - - -class LLM(LLM): - """An LLM for generating texts from given prompts and sampling parameters. - - This class includes a tokenizer, a language model (possibly distributed - across multiple GPUs), and GPU memory space allocated for intermediate - states (aka KV cache). Given a batch of prompts and sampling parameters, - this class generates texts from the model, using an intelligent batching - mechanism and efficient memory management. - - NOTE: This class is intended to be used for offline inference. For online - serving, use the `AsyncLLMEngine` class instead. - NOTE: For the comprehensive list of arguments, see `EngineArgs`. - - Args: - model: A HuggingFace Transformers model instance. - tokenizer: A HuggingFace Transformers tokenizer instance. - tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer - if available, and "slow" will always use the slow tokenizer. - trust_remote_code: Trust remote code (e.g., from HuggingFace) when - downloading the model and tokenizer. - tensor_parallel_size: The number of GPUs to use for distributed - execution with tensor parallelism. - dtype: The data type for the model weights and activations. Currently, - we support `float32`, `float16`, and `bfloat16`. If `auto`, we use - the `torch_dtype` attribute specified in the model config file. - However, if the `torch_dtype` in the config is `float32`, we will - use `float16` instead. - quantization: The method used to quantize the model weights. Currently, - we support "awq". If None, we assume the model weights are not - quantized and use `dtype` to determine the data type of the weights. - revision: The specific model version to use. It can be a branch name, - a tag name, or a commit id. - tokenizer_revision: The specific tokenizer version to use. It can be a - branch name, a tag name, or a commit id. - seed: The seed to initialize the random number generator for sampling. - gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to - reserve for the model weights, activations, and KV cache. Higher - values will increase the KV cache size and thus improve the model's - throughput. However, if the value is too high, it may cause out-of- - memory (OOM) errors. - swap_space: The size (GiB) of CPU memory per GPU to use as swap space. - This can be used for temporarily storing the states of the requests - when their `best_of` sampling parameters are larger than 1. If all - requests will have `best_of=1`, you can safely set this to 0. - Otherwise, too small values may cause out-of-memory (OOM) errors. - enforce_eager: Whether to enforce eager execution. If True, we will - disable CUDA graph and always execute the model in eager mode. - If False, we will use CUDA graph and eager execution in hybrid. - max_context_len_to_capture: Maximum context len covered by CUDA graphs. - When a sequence has context length larger than this, we fall back - to eager mode. - disable_custom_all_reduce: See ParallelConfig - """ - - def __init__( - self, - model: Union[nn.Module, Dict], # model itself or its parameter dict - tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast, HybridEngineBaseTokenizer], - model_hf_config: PretrainedConfig, - tokenizer_mode: str = "auto", - trust_remote_code: bool = False, - skip_tokenizer_init: bool = False, - tensor_parallel_size: int = 1, - dtype: str = "auto", - quantization: Optional[str] = None, - revision: Optional[str] = None, - tokenizer_revision: Optional[str] = None, - seed: int = 0, - gpu_memory_utilization: float = 0.9, - swap_space: int = 4, - cpu_offload_gb: float = 0, - enforce_eager: bool = False, - max_context_len_to_capture: Optional[int] = None, - max_seq_len_to_capture: int = 8192, - disable_custom_all_reduce: bool = False, - load_format="auto", - **kwargs, - ) -> None: - if "disable_log_stats" not in kwargs: - kwargs["disable_log_stats"] = True - engine_args = EngineArgs( - model_hf_config=model_hf_config, - tensor_parallel_size=tensor_parallel_size, - dtype=dtype, - quantization=quantization, - revision=revision, - tokenizer_revision=tokenizer_revision, - seed=seed, - gpu_memory_utilization=gpu_memory_utilization, - swap_space=swap_space, - cpu_offload_gb=cpu_offload_gb, - enforce_eager=enforce_eager, - max_context_len_to_capture=max_context_len_to_capture, - max_seq_len_to_capture=max_seq_len_to_capture, - disable_custom_all_reduce=disable_custom_all_reduce, - load_format=load_format, - skip_tokenizer_init=skip_tokenizer_init, - **kwargs, - ) - tokenizer_cls = (PreTrainedTokenizer, PreTrainedTokenizerFast, HybridEngineBaseTokenizer) - if not isinstance(tokenizer, tokenizer_cls): - raise ValueError(f"Unexpected tokenizer type: {type(tokenizer)}. Must beone of the following: PreTrainedTokenizer, PreTrainedTokenizerFast, verl.workers.rollout.HybridEngineBaseTokenizer") - self.llm_engine = LLMEngine.from_engine_args(model, tokenizer, engine_args) # TODO: check usagecontext - self.request_counter = Counter() - - def init_cache_engine(self): - self.llm_engine.init_cache_engine() - - def free_cache_engine(self): - self.llm_engine.free_cache_engine() - - def get_tokenizer(self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: - return self.llm_engine.tokenizer - - def set_tokenizer( - self, - tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], - ) -> None: - self.llm_engine.tokenizer = tokenizer - - def _run_engine(self, *, use_tqdm: bool) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: - # Initialize tqdm. - if use_tqdm: - num_requests = self.llm_engine.get_num_unfinished_requests() - pbar = tqdm( - total=num_requests, - desc="Processed prompts", - dynamic_ncols=True, - postfix=(f"est. speed input: {0:.2f} toks/s, output: {0:.2f} toks/s"), - ) - # Run the engine. - outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = [] - total_in_toks = 0 - total_out_toks = 0 - while self.llm_engine.has_unfinished_requests(): - step_outputs = self.llm_engine.step() - for output in step_outputs: - if output.finished: - outputs.append(output) - if use_tqdm: - if isinstance(output, RequestOutput): - # Calculate tokens only for RequestOutput - total_in_toks += len(output.prompt_token_ids) - in_spd = total_in_toks / pbar.format_dict["elapsed"] - total_out_toks += sum(len(stp.token_ids) for stp in output.outputs) - out_spd = total_out_toks / pbar.format_dict["elapsed"] - pbar.postfix = f"est. speed input: {in_spd:.2f} toks/s, output: {out_spd:.2f} toks/s" - pbar.update(1) - if use_tqdm: - pbar.close() - # Sort the outputs by request ID. - # This is necessary because some requests may be finished earlier than - # its previous requests. - outputs = sorted(outputs, key=lambda x: int(x.request_id)) - return self._post_process_outputs(outputs) - - # # NOTE(shengguangming): add for verl - # # TODO(sgm): we can optimize it by making the dataloader yield List[int] without padding. - # def _pre_process_inputs(self, prompt_token_ids: torch.Tensor) -> List[int]: - # # remove the left padding in the prompt token_id - # pad_token_id = self.llm_engine.tokenizer.pad_token_id if self.llm_engine.tokenizer.pad_token_id is not None else self.llm_engine.tokenizer.eos_token_id - # non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0] - # token_ids = prompt_token_ids[non_pad_index:].tolist() - # return token_ids - - # NOTE(shengguangming): add for verl - def _post_process_outputs(self, request_outputs: List[RequestOutput]) -> Tuple[torch.Tensor, torch.Tensor]: - output_token_ids = [] - logprobs = [] - for request_output in request_outputs: # List[RequestOutput] - outputs = request_output.outputs - for output in outputs: # List[CompletionOutput], usually len == 1 - output_token_ids.append(torch.tensor(output.token_ids)) - # TODO(shengguangming): can be optimzied by rewrite the Sampler._get_logprobs() logits - logprobs_dicts = output.logprobs - if logprobs_dicts is not None: - logprob = [] - for logprobs_dict, id in zip(logprobs_dicts, output.token_ids): - logprob.append(logprobs_dict[id].logprob) - logprobs.append(torch.tensor(logprob)) - - pad_token_id = self.llm_engine.tokenizer.pad_token_id if self.llm_engine.tokenizer.pad_token_id is not None else self.llm_engine.tokenizer.eos_token_id - output_token_ids = pad_sequence(output_token_ids, batch_first=True, padding_value=pad_token_id) - if len(logprobs) > 0: - logprobs = pad_sequence(logprobs, batch_first=True, padding_value=pad_token_id) - return output_token_ids, logprobs - - def sync_model_weights(self, actor_weights: Iterable, load_format: str) -> None: - self.llm_engine.sync_model_weights(actor_weights=actor_weights, load_format=load_format) - - def offload_model_weights(self) -> None: - self.llm_engine.offload_model_weights() diff --git a/verl/third_party/vllm/vllm_v_0_5_4/llm_engine_sp.py b/verl/third_party/vllm/vllm_v_0_5_4/llm_engine_sp.py deleted file mode 100644 index eface0777..000000000 --- a/verl/third_party/vllm/vllm_v_0_5_4/llm_engine_sp.py +++ /dev/null @@ -1,331 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2023 The vLLM team. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/engine/llm_engine.py - -from typing import Dict, Iterable, Optional, Type, Union - -from torch import nn -from vllm.config import ( - CacheConfig, - DecodingConfig, - DeviceConfig, - EngineConfig, - LoRAConfig, - MultiModalConfig, - ObservabilityConfig, - ParallelConfig, - PromptAdapterConfig, - SchedulerConfig, - SpeculativeConfig, -) -from vllm.core.scheduler import Scheduler -from vllm.engine.llm_engine import LLMEngine, _load_generation_config_dict -from vllm.engine.metrics import LoggingStatLogger, PrometheusStatLogger, StatLoggerBase -from vllm.engine.output_processor.interfaces import SequenceGroupOutputProcessor -from vllm.engine.output_processor.stop_checker import StopChecker -from vllm.executor.executor_base import ExecutorBase -from vllm.inputs import INPUT_REGISTRY -from vllm.logger import init_logger -from vllm.tracing import init_tracer -from vllm.transformers_utils.detokenizer import Detokenizer -from vllm.usage.usage_lib import UsageContext, is_usage_stats_enabled, usage_message -from vllm.utils import Counter -from vllm.version import __version__ as VLLM_VERSION - -from .arg_utils import EngineArgs -from .config import LoadConfig, ModelConfig -from .tokenizer import TokenizerGroup - -logger = init_logger(__name__) -_LOCAL_LOGGING_INTERVAL_SEC = 5 - - -class LLMEngine(LLMEngine): - """An LLM engine that receives requests and generates texts. - - This is the main class for the vLLM engine. It receives requests - from clients and generates texts from the LLM. It includes a tokenizer, a - language model (possibly distributed across multiple GPUs), and GPU memory - space allocated for intermediate states (aka KV cache). This class utilizes - iteration-level scheduling and efficient memory management to maximize the - serving throughput. - - The `LLM` class wraps this class for offline batched inference and the - `AsyncLLMEngine` class wraps this class for online serving. - - NOTE: The config arguments are derived from the `EngineArgs` class. For the - comprehensive list of arguments, see `EngineArgs`. - - Args: - model: the actor model initialize outside vllm (add for verl) - tokenizer: the initialized tokenizer (add for verl) - model_config: The configuration related to the LLM model. - cache_config: The configuration related to the KV cache memory - management. - parallel_config: The configuration related to distributed execution. - scheduler_config: The configuration related to the request scheduler. - distributed_init_method: The initialization method for distributed - execution. See `torch.distributed.init_process_group` for details. - placement_group: Ray placement group for distributed execution. - Required for distributed execution. - log_stats: Whether to log statistics. - """ - - def __init__( - self, - # NOTE(sgm): first two arguments are added for verl - model: Union[nn.Module, Dict], # model itself or its parameter dict - tokenizer: nn.Module, - # NOTE(sgm): vllm original arguments - model_config: ModelConfig, - cache_config: CacheConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - load_config: LoadConfig, - lora_config: Optional[LoRAConfig], - multimodal_config: Optional[MultiModalConfig], - speculative_config: Optional[SpeculativeConfig], - decoding_config: Optional[DecodingConfig], - observability_config: Optional[ObservabilityConfig], - prompt_adapter_config: Optional[PromptAdapterConfig], - executor_class: Type[ExecutorBase], - log_stats: bool, - usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, - stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, - ) -> None: - logger.info( - "Initializing an LLM engine (v%s) with config: " - "model=%r, speculative_config=%r, tokenizer=%r, " - "skip_tokenizer_init=%s, revision=%s, " - "rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, " - "trust_remote_code=%s, dtype=%s, max_seq_len=%d, " - "download_dir=%r, load_format=%s, tensor_parallel_size=%d, " - "pipeline_parallel_size=%d, " - "disable_custom_all_reduce=%s, quantization=%s, " - "enforce_eager=%s, kv_cache_dtype=%s, " - "quantization_param_path=%s, device_config=%s, " - "decoding_config=%r, observability_config=%r, " - "seed=%d, served_model_name=%s, use_v2_block_manager=%s, " - "enable_prefix_caching=%s)", - VLLM_VERSION, - model_config.model, - speculative_config, - model_config.tokenizer, - model_config.skip_tokenizer_init, - model_config.revision, - model_config.rope_scaling, - model_config.rope_theta, - model_config.tokenizer_revision, - model_config.trust_remote_code, - model_config.dtype, - model_config.max_model_len, - load_config.download_dir, - load_config.load_format, - parallel_config.tensor_parallel_size, - parallel_config.pipeline_parallel_size, - parallel_config.disable_custom_all_reduce, - model_config.quantization, - model_config.enforce_eager, - cache_config.cache_dtype, - model_config.quantization_param_path, - device_config.device, - decoding_config, - observability_config, - model_config.seed, - model_config.served_model_name, - scheduler_config.use_v2_block_manager, - cache_config.enable_prefix_caching, - ) - # TODO(woosuk): Print more configs in debug mode. - - self.model_config = model_config - self.cache_config = cache_config - self.lora_config = lora_config - self.multimodal_config = multimodal_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - self.device_config = device_config - self.speculative_config = speculative_config - self.load_config = load_config - self.decoding_config = decoding_config or DecodingConfig() - self.prompt_adapter_config = prompt_adapter_config - self.observability_config = observability_config or ObservabilityConfig() - self.log_stats = log_stats - - # self.model = model # should not store the model, it should be deleted - # TODO(shengguangming): maybe we can choose init here or from arguments - if not self.model_config.skip_tokenizer_init: - self.tokenizer = self._init_tokenizer(tokenizer) - self.detokenizer = Detokenizer(self.tokenizer) - else: - self.tokenizer = None - self.detokenizer = None - - self.seq_counter = Counter() - self.generation_config_fields = _load_generation_config_dict(model_config) - - self.input_processor = INPUT_REGISTRY.create_input_processor(self.model_config) - - self.model_executor = executor_class( - model=model, # add for spmd_gpu_executor - model_config=model_config, - cache_config=cache_config, - parallel_config=parallel_config, - scheduler_config=scheduler_config, - device_config=device_config, - lora_config=lora_config, - multimodal_config=multimodal_config, - speculative_config=speculative_config, - load_config=load_config, - prompt_adapter_config=prompt_adapter_config, - ) - - # Profile the memory usage and initialize the cache. - if not self.model_config.embedding_mode: - self._initialize_kv_caches() - - # If usage stat is enabled, collect relevant info. - if is_usage_stats_enabled(): - from vllm.model_executor.model_loader import get_architecture_class_name - - usage_message.report_usage( - get_architecture_class_name(model_config), - usage_context, - extra_kvs={ - # Common configuration - "dtype": str(model_config.dtype), - "tensor_parallel_size": parallel_config.tensor_parallel_size, - "block_size": cache_config.block_size, - "gpu_memory_utilization": cache_config.gpu_memory_utilization, - # Quantization - "quantization": model_config.quantization, - "kv_cache_dtype": str(cache_config.cache_dtype), - # Feature flags - "enable_lora": bool(lora_config), - "enable_prompt_adapter": bool(prompt_adapter_config), - "enable_prefix_caching": cache_config.enable_prefix_caching, - "enforce_eager": model_config.enforce_eager, - "disable_custom_all_reduce": parallel_config.disable_custom_all_reduce, - }, - ) - - if self.tokenizer: - # Ping the tokenizer to ensure liveness if it runs in a - # different process. - self.tokenizer.ping() - - # Create the scheduler. - # NOTE: the cache_config here have been updated with the numbers of - # GPU and CPU blocks, which are profiled in the distributed executor. - self.scheduler = [Scheduler(scheduler_config, cache_config, lora_config, parallel_config.pipeline_parallel_size) for _ in range(parallel_config.pipeline_parallel_size)] - - # Metric Logging. - if self.log_stats: - if stat_loggers is not None: - self.stat_loggers = stat_loggers - else: - self.stat_loggers = { - "logging": LoggingStatLogger(local_interval=_LOCAL_LOGGING_INTERVAL_SEC), - "prometheus": PrometheusStatLogger( - local_interval=_LOCAL_LOGGING_INTERVAL_SEC, - labels=dict(model_name=model_config.served_model_name), - max_model_len=self.model_config.max_model_len, - ), - } - self.stat_loggers["prometheus"].info("cache_config", self.cache_config) - - self.tracer = None - if self.observability_config.otlp_traces_endpoint: - self.tracer = init_tracer("vllm.llm_engine", self.observability_config.otlp_traces_endpoint) - - # Create sequence output processor, e.g. for beam search or - # speculative decoding. - self.output_processor = SequenceGroupOutputProcessor.create_output_processor( - self.scheduler_config, - self.detokenizer, - self.scheduler, - self.seq_counter, - self.get_tokenizer_for_seq, - stop_checker=StopChecker( - self.scheduler_config.max_model_len, - self.get_tokenizer_for_seq, - ), - ) - - # TODO(sgm): add for verl but we may not tokenizer in Rollout - def _init_tokenizer(self, tokenizer, **tokenizer_init_kwargs): - init_kwargs = dict(enable_lora=bool(self.lora_config), max_num_seqs=self.scheduler_config.max_num_seqs, max_input_length=None) - init_kwargs.update(tokenizer_init_kwargs) - return TokenizerGroup(tokenizer, **init_kwargs) - - def init_cache_engine(self): - # TODO: check whether we should rebuild the CUDAGraph every iter when offload/load KVCache - # Re-capture CUDAGraph would be time-consuming - self.model_executor.init_cache_engine() - - def free_cache_engine(self): - self.model_executor.free_cache_engine() - - # NOTE(sgm): currently, we only support GPU executor - # The GPUExecutor remove the Ray dependency - @classmethod - def _get_executor_cls(cls, engine_config: EngineConfig) -> Type[ExecutorBase]: - assert engine_config.device_config.device_type == "cuda", "Currently, the vllm in verl only support running on GPU" - - if engine_config.parallel_config.world_size == 1: - engine_config.load_config.load_format = "dummy_hf" - - from .spmd_gpu_executor import SPMDGPUExecutor - - executor_class = SPMDGPUExecutor - return executor_class - - @classmethod - def from_engine_args( - cls, - model, - tokenizer, - engine_args: EngineArgs, - usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, - stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, - ) -> "LLMEngine": - """Creates an LLM engine from the engine arguments.""" - # Create the engine configs. - engine_config = engine_args.create_engine_config() - executor_class = cls._get_executor_cls(engine_config) - # Initialize the cluster and specify the executor class. - assert engine_config.device_config.device_type == "cuda", "Currently, the vllm in verl only support running on GPU" - - from .spmd_gpu_executor import SPMDGPUExecutor - - executor_class = SPMDGPUExecutor - - # Create the LLM engine. - engine = cls( - model, - tokenizer, - **engine_config.to_dict(), - executor_class=executor_class, - log_stats=not engine_args.disable_log_stats, - usage_context=usage_context, - stat_loggers=stat_loggers, - ) - return engine - - def sync_model_weights(self, actor_weights: Iterable, load_format: str) -> None: - self.model_executor.sync_model_weights(actor_weights=actor_weights, load_format=load_format) - - def offload_model_weights(self) -> None: - self.model_executor.offload_model_weights() diff --git a/verl/third_party/vllm/vllm_v_0_5_4/megatron_weight_loaders.py b/verl/third_party/vllm/vllm_v_0_5_4/megatron_weight_loaders.py deleted file mode 100644 index cdcac7840..000000000 --- a/verl/third_party/vllm/vllm_v_0_5_4/megatron_weight_loaders.py +++ /dev/null @@ -1,219 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2023 The vLLM team. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models - -from typing import Dict, Iterable - -import torch -import torch.nn as nn -from vllm.model_executor.layers.linear import * -from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead, VocabParallelEmbedding -from vllm.model_executor.models import ModelRegistry - - -# NOTE(shengguangming): replace the origin weight loader function in the class -def parallel_weight_loader(self, param: torch.Tensor, loaded_weight: torch.Tensor) -> None: - """Parallel Linear weight loader.""" - assert param.size() == loaded_weight.size(), "the parameter size is not align with the loaded weight size, param size: {}, loaded_weight size: {}".format(param.size(), loaded_weight.size()) - assert param.data.dtype == loaded_weight.data.dtype, "if we want to shared weights, the data type should also be the same" - - param.data = loaded_weight.data - - -def default_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: - """Default weight loader.""" - assert param.size() == loaded_weight.size() - assert param.data.dtype == loaded_weight.data.dtype, "if we want to shared weights, the data type should also be the same" - - param.data = loaded_weight.data - - -def gpt2_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: - params_dict = dict(vllm_model.named_parameters(remove_duplicate=False)) - for name, loaded_weight in actor_weights.items(): - if "lm_head.weight" in name: - # GPT-2 ties the weights of the embedding layer and the final - # linear layer. - continue - if ".attn.bias" in name or ".attn.masked_bias" in name: - # Skip attention mask. - # NOTE: "c_attn.bias" should not be skipped. - continue - if not name.startswith("transformer."): - name = "transformer." + name - param = params_dict[name] - # The HF's GPT-2 implementation uses Conv1D instead of Linear. - # Because of this, we need to transpose the weights. - # Note(zhuohan): the logic below might break quantized models. - for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]: - if conv1d_weight_name not in name: - continue - if not name.endswith(".weight"): - continue - # TODO: check megatron - loaded_weight = loaded_weight.t() - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - - -def llama_megatron_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: - # NOTE(shengguangming): the megatron llama may have this prefix - params_dict = dict(vllm_model.named_parameters()) - for name, loaded_weight in actor_weights.items(): - if "rotary_emb.inv_freq" in name: - continue - else: - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - - -def _replace_name(megatron_name, name_mapping): - for m_name, v_name in name_mapping: - if m_name not in megatron_name: - continue - if "layers" in megatron_name: # deal with decoder layers - megatron_name = megatron_name.replace("decoder", "model") - megatron_name_list = megatron_name.split(".") - if "layer_norm_weight" in megatron_name_list or "layer_norm_bias" in megatron_name_list: - param_name_list = megatron_name_list[:3] - param_name_list.append(v_name) - param_name = ".".join(param_name_list) - else: - param_name_list = megatron_name_list[:3] - weight_or_bias = megatron_name_list[-1] - param_name_list.append(v_name) - param_name_list.append(weight_or_bias) - param_name = ".".join(param_name_list) - return param_name - else: - param_name = megatron_name.replace(m_name, v_name) - return param_name - - -def llama_megatron_core_te_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: - params_mapping = [ - # (megatron core gpt model name, vllm model name) - ("embedding.word_embeddings", "model.embed_tokens"), - ("self_attention.linear_qkv.layer_norm_weight", "input_layernorm.weight"), - ("self_attention.linear_qkv.layer_norm_bias", "input_layernorm.bias"), - ("self_attention.linear_qkv", "self_attn.qkv_proj"), - ("self_attention.linear_qkv", "self_attn.qkv_proj"), - ("self_attention.linear_proj", "self_attn.o_proj"), - ("pre_mlp_layernorm", "post_attention_layernorm"), - ("mlp.linear_fc1.layer_norm_weight", "post_attention_layernorm.weight"), - ("mlp.linear_fc1.layer_norm_bias", "post_attention_layernorm.bias"), - ("mlp.linear_fc1", "mlp.gate_up_proj"), - ("mlp.linear_fc2", "mlp.down_proj"), - ("decoder.final_layernorm", "model.norm"), - ("output_layer", "lm_head"), - ] - # NOTE(shengguangming): the megatron llama may have this prefix - params_dict = dict(vllm_model.named_parameters()) - for name, loaded_weight in actor_weights.items(): - name = _replace_name(name, params_mapping) - if name.endswith(".bias") and name not in params_dict: - continue - if "rotary_emb.inv_freq" in name: - continue - else: - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - - -def llama_megatron_core_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: - params_mapping = [ - # (megatron core gpt model name, vllm model name) - ("embedding.word_embeddings", "model.embed_tokens"), - ("self_attention.linear_qkv", "self_attn.qkv_proj"), - ("self_attention.linear_proj", "self_attn.o_proj"), - ( - "input_layernorm", - "input_layernorm", - ), - ("pre_mlp_layernorm", "post_attention_layernorm"), - ("mlp.linear_fc1", "mlp.gate_up_proj"), - ("mlp.linear_fc2", "mlp.down_proj"), - ("decoder.final_layernorm", "model.norm"), - ("output_layer", "lm_head"), - ] - # NOTE(shengguangming): the megatron llama may have this prefix - params_dict = dict(vllm_model.named_parameters()) - for name, loaded_weight in actor_weights.items(): - name = _replace_name(name, params_mapping) - if name.endswith(".bias") and name not in params_dict: - continue - if "rotary_emb.inv_freq" in name: - continue - else: - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - - -def mistral_megatron_weight_loader(actor_weights: Iterable, vllm_model: nn.Module) -> nn.Module: - # TODO: need to implement a general way to deal with prefix - params_dict = dict(vllm_model.named_parameters()) - for name, loaded_weight in actor_weights.items(): - if "rotary_emb.inv_freq" in name: - continue - else: - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - - -__LAYER_WEIGHT_MEGATRON_LOADER_REGISTRY__ = { - ColumnParallelLinear: parallel_weight_loader, - MergedColumnParallelLinear: parallel_weight_loader, - QKVParallelLinear: parallel_weight_loader, - RowParallelLinear: parallel_weight_loader, - VocabParallelEmbedding: parallel_weight_loader, - ParallelLMHead: parallel_weight_loader, - # "ScaledActivation.weight_loader": ScaledActivation, # TODO(shengguangming): latest commit in vllm fix awq for this function and add load_weights - # "default_weight_loader": default_weight_loader -} - -# for layer_class, weight_loader in __LAYER_WEIGHT_MEGATRON_LOADER_REGISTRY__.items(): -# # setattr(layer_class, 'megatron_weight_loader', weight_loader) -# layer_class.weight_loader = weight_loader - -__MODEL_MEGATRON_WEIGHT_LOADER_REGISTRY__ = { - "GPT2LMHeadModel": gpt2_weight_loader, - "LlamaForCausalLM": llama_megatron_weight_loader, # use te backend for open-source megatron - "LLaMAForCausalLM": llama_megatron_weight_loader, - "MistralForCausalLM": mistral_megatron_weight_loader, -} - - -# the actor model is .state_dict() -# Load megatron weights -def load_megatron_weights(actor_weights: Iterable, vllm_model: nn.Module): - weight_loader = _get_model_weight_loader(vllm_model.__class__.__name__) - weight_loader(actor_weights, vllm_model) - # NOTE(sgm) to reduce peak memory usage, we offload vllm model to cpu - # after init, and we need this after sync model weights for in first iter. - vllm_model = vllm_model.cuda() - - -def _get_model_weight_loader(arch: str): - if arch in __MODEL_MEGATRON_WEIGHT_LOADER_REGISTRY__: - return __MODEL_MEGATRON_WEIGHT_LOADER_REGISTRY__[arch] - raise ValueError(f"Model architectures {arch} are not supported for now. Supported architectures: {ModelRegistry.get_supported_archs()}") - - -def update_megatron_weight_loader(): - for layer_class, weight_loader in __LAYER_WEIGHT_MEGATRON_LOADER_REGISTRY__.items(): - layer_class.weight_loader = weight_loader diff --git a/verl/third_party/vllm/vllm_v_0_5_4/model_loader.py b/verl/third_party/vllm/vllm_v_0_5_4/model_loader.py deleted file mode 100644 index 3dc2027bd..000000000 --- a/verl/third_party/vllm/vllm_v_0_5_4/model_loader.py +++ /dev/null @@ -1,329 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2023 The vLLM team. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/model_loader - -from typing import Dict, Optional, Union - -import torch -import torch.nn as nn -from transformers import PreTrainedModel -from vllm.config import ( - CacheConfig, - DeviceConfig, - LoRAConfig, - MultiModalConfig, - ParallelConfig, - SchedulerConfig, -) -from vllm.distributed.communication_op import tensor_model_parallel_all_gather -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.model_loader import BaseModelLoader -from vllm.model_executor.model_loader.loader import _initialize_model -from vllm.model_executor.model_loader.utils import set_default_torch_dtype - -from .config import LoadConfig, LoadFormat, ModelConfig -from .dtensor_weight_loaders import load_dtensor_weights, update_dtensor_weight_loader -from .hf_weight_loader import update_hf_weight_loader -from .megatron_weight_loaders import load_megatron_weights, update_megatron_weight_loader - - -def get_model( - actor_model: Union[PreTrainedModel, Dict], - model_config: ModelConfig, - load_config: LoadConfig, - device_config: DeviceConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - lora_config: Optional[LoRAConfig], - multimodal_config: Optional[MultiModalConfig], - cache_config: CacheConfig = None, -) -> nn.Module: - loader = get_model_loader(load_config) - if load_config.load_format.startswith("dummy"): - return loader.load_model( - model_config=model_config, - device_config=device_config, - lora_config=lora_config, - multimodal_config=multimodal_config, - parallel_config=parallel_config, - scheduler_config=scheduler_config, - cache_config=cache_config, - ) - else: - return loader.load_model( - actor_model=actor_model, - model_config=model_config, - device_config=device_config, - lora_config=lora_config, - multimodal_config=multimodal_config, - parallel_config=parallel_config, - scheduler_config=scheduler_config, - cache_config=cache_config, - ) - - -def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: - """Get a model loader based on the load format.""" - - if isinstance(load_config.load_format, type): - return load_config.load_format(load_config) - - if load_config.load_format == LoadFormat.AUTO: - update_megatron_weight_loader() - return MegatronLoader(load_config) - - # NOTE(sgm): change the weight_loader function in runtime - if load_config.load_format == LoadFormat.MEGATRON: - update_megatron_weight_loader() - return MegatronLoader(load_config) - - if load_config.load_format == LoadFormat.HF: - update_hf_weight_loader() - return HFLoader(load_config) - - if load_config.load_format == LoadFormat.DTENSOR: - update_dtensor_weight_loader() - return DTensorLoader(load_config) - - if load_config.load_format == LoadFormat.DUMMY_HF: - update_hf_weight_loader() - return DummyModelLoader(load_config) - - if load_config.load_format == LoadFormat.DUMMY_MEGATRON: - update_megatron_weight_loader() - return DummyModelLoader(load_config) - - if load_config.load_format == LoadFormat.DUMMY_DTENSOR: - update_dtensor_weight_loader() - return DummyModelLoader(load_config) - - raise ValueError("load format not supported in verl: {}, only support {} and {}".format(load_config.load_format, LoadFormat.MEGATRON, LoadFormat.HF)) - - -class DummyModelLoader(BaseModelLoader): - """Model loader that will set model weights to random values.""" - - def __init__(self, load_config: LoadConfig): - super().__init__(load_config) - if load_config.model_loader_extra_config: - raise ValueError(f"Model loader extra config is not supported for load format {load_config.load_format}") - - def load_model( - self, - *, - model_config: ModelConfig, - device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], - multimodal_config: Optional[MultiModalConfig], - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - cache_config: CacheConfig, - ) -> nn.Module: - with set_default_torch_dtype(model_config.dtype), torch.device(device_config.device): - model = _initialize_model(model_config, self.load_config, lora_config, multimodal_config, cache_config, scheduler_config) - # NOTE(woosuk): For accurate performance evaluation, we assign - # random values to the weights. - # initialize_dummy_weights(model) - return model.eval() - - -class MegatronLoader(BaseModelLoader): - """Model loader that can load the model weights from partitioned megatron model.""" - - def __init__(self, load_config: LoadConfig): - super().__init__(load_config) - if load_config.model_loader_extra_config: - raise ValueError(f"Model loader extra config is not supported for load format {load_config.load_format}") - - def _get_weights_iterator(actor_model: Union[PreTrainedModel, Dict]): - # NOTE(shengguangming) Load the weights from the actor model - pass - # if isinstance(actor_model, nn.Module): - # load_weights(actor_weights=dict(actor_model.named_parameters(remove_duplicate=False)), vllm_model=model) - # else: - # load_weights(actor_weights=actor_model, vllm_model=model) - # return actor_model - - def load_model( - self, - actor_model: Union[PreTrainedModel, Dict], - model_config: ModelConfig, - device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], - multimodal_config: Optional[MultiModalConfig], - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - cache_config: CacheConfig, - ) -> nn.Module: - with set_default_torch_dtype(model_config.dtype): - with torch.device(device_config.device): - model = _initialize_model(model_config, self.load_config, lora_config, multimodal_config, cache_config, scheduler_config) - - # TODO(sgm): This is a hack, we need to register the load_weight() func for each model in vllm - if isinstance(actor_model, nn.Module): - load_megatron_weights(actor_weights=dict(actor_model.named_parameters(remove_duplicate=False)), vllm_model=model) - else: - load_megatron_weights(actor_weights=actor_model, vllm_model=model) - - for _, module in model.named_modules(): - quant_method = getattr(module, "quant_method", None) - if quant_method is not None: - quant_method.process_weights_after_loading(module) - # FIXME: Remove this after Mixtral is updated - # to use quant_method. - if hasattr(module, "process_weights_after_loading"): - module.process_weights_after_loading() - # NOTE(sgm) Some weights are point to gpu, but still need this. - model = model.cuda() # NOTE (zhangchi.usc1992) We need this for vllm to profile memory usage - return model.eval() - - -class HFLoader(BaseModelLoader): - """Model loader that can load the model weights from model's full params.""" - - def __init__(self, load_config: LoadConfig): - super().__init__(load_config) - if load_config.model_loader_extra_config: - raise ValueError(f"Model loader extra config is not supported for load format {load_config.load_format}") - - def _get_weights_iterator(self, actor_model: Union[PreTrainedModel, Dict]): - if isinstance(actor_model, Dict): - return actor_model.items() - elif isinstance(actor_model, nn.Module): - return dict(actor_model.named_parameters()).items() - else: - raise ValueError(f"actor model should be Dict or nn.Module, but get {type(actor_model)}") - - def load_model( - self, - actor_model: Union[PreTrainedModel, Dict], - model_config: ModelConfig, - device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], - multimodal_config: Optional[MultiModalConfig], - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - cache_config: CacheConfig, - ) -> nn.Module: - with set_default_torch_dtype(model_config.dtype): - # with torch.device(device_config.device): - # NOTE(sgm): init the model in cpu - model = _initialize_model(model_config, self.load_config, lora_config, multimodal_config, cache_config, scheduler_config) - model.load_weights(self._get_weights_iterator(actor_model)) - for _, module in model.named_modules(): - quant_method = getattr(module, "quant_method", None) - if quant_method is not None: - quant_method.process_weights_after_loading(module) - # FIXME: Remove this after Mixtral is updated - # to use quant_method. - if hasattr(module, "process_weights_after_loading"): - module.process_weights_after_loading() - # NOTE(sgm) Some weights are point to gpu, but still need this. - model = model.cuda() # NOTE (zhangchi.usc1992) We need this for vllm to profile memory usage - return model.eval() - - -class DTensorLoader(BaseModelLoader): - """Model loader that can load the model weights from partitioned megatron model.""" - - def __init__(self, load_config: LoadConfig): - super().__init__(load_config) - if load_config.model_loader_extra_config: - raise ValueError(f"Model loader extra config is not supported for load format {load_config.load_format}") - - def _get_weights_iterator(actor_model: Union[PreTrainedModel, Dict]): - # NOTE(shengguangming) Load the weights from the actor model - pass - # if isinstance(actor_model, nn.Module): - # load_weights(actor_weights=dict(actor_model.named_parameters(remove_duplicate=False)), vllm_model=model) - # else: - # load_weights(actor_weights=actor_model, vllm_model=model) - # return actor_model - - def load_model( - self, - actor_model: Union[PreTrainedModel, Dict], - model_config: ModelConfig, - device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], - multimodal_config: Optional[MultiModalConfig], - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - cache_config: CacheConfig, - ) -> nn.Module: - with set_default_torch_dtype(model_config.dtype): - with torch.device(device_config.device): - model = _initialize_model(model_config, self.load_config, lora_config, multimodal_config, cache_config, scheduler_config) - - # TODO(sgm): This is a hack, we need to register the load_weight() func for each model in vllm - if isinstance(actor_model, nn.Module): - load_dtensor_weights(actor_weights=dict(actor_model.named_parameters(remove_duplicate=False)), vllm_model=model) - else: - load_dtensor_weights(actor_weights=actor_model, vllm_model=model) - - for _, module in model.named_modules(): - quant_method = getattr(module, "quant_method", None) - if quant_method is not None: - quant_method.process_weights_after_loading(module) - # FIXME: Remove this after Mixtral is updated - # to use quant_method. - if hasattr(module, "process_weights_after_loading"): - module.process_weights_after_loading() - # NOTE(sgm) Some weights are point to gpu, but still need this. - model = model.cuda() # NOTE (zhangchi.usc1992) We need this for vllm to profile memory usage - return model.eval() - - -# FIXME(sgm): hack the _get_logits function in vllm v0.4.2 -# as they use ray, the _get_logits result will only need to return to the driver node, -# therefore gather is enough. However, we use SPMD instead of a central scheduler, -# all_gather is required (aligned with v0.2.6) -def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor, embedding_bias: Optional[torch.Tensor]) -> torch.Tensor: - # Get the logits for the next tokens. - logits = torch.matmul(hidden_states, embedding.t()) - if embedding_bias is not None: - logits += embedding_bias - logits = tensor_model_parallel_all_gather(logits) - # Remove paddings in vocab (if any). - if logits is not None: - logits = logits[:, : self.org_vocab_size] - return logits - - -def logitsprocessor_init( - self, - vocab_size: int, - org_vocab_size: Optional[int] = None, - scale: float = 1.0, - logits_as_input: bool = False, - soft_cap: Optional[float] = None, -) -> None: - """ - Args: - scale: A scaling factor to apply to the logits. - """ - super(LogitsProcessor, self).__init__() - self.scale = scale - self.vocab_size = vocab_size - # Whether the input is logits (default is hidden states). - self.logits_as_input = logits_as_input - # original vocabulary size (without LoRA). - self.org_vocab_size = org_vocab_size or vocab_size - # Soft cap the logits. Used in Gemma 2. - self.soft_cap = soft_cap - # Whether to use gather or all-gather to gather the logits. - self.use_gather = False - - -LogitsProcessor.__init__ = logitsprocessor_init # use all_gather diff --git a/verl/third_party/vllm/vllm_v_0_5_4/model_runner.py b/verl/third_party/vllm/vllm_v_0_5_4/model_runner.py deleted file mode 100644 index 83bfa1809..000000000 --- a/verl/third_party/vllm/vllm_v_0_5_4/model_runner.py +++ /dev/null @@ -1,155 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2023 The vLLM team. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/worker/model_runner.py - -import warnings -from enum import IntEnum -from typing import Dict, Optional, Union - -import torch -import torch.nn as nn -import vllm.envs as envs -from vllm.config import ( - CacheConfig, - DeviceConfig, - LoRAConfig, - MultiModalConfig, - ParallelConfig, - PromptAdapterConfig, - SchedulerConfig, -) -from vllm.logger import init_logger -from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager -from vllm.model_executor.models.interfaces import supports_lora, supports_vision -from vllm.prompt_adapter.worker_manager import LRUCacheWorkerPromptAdapterManager -from vllm.utils import CudaMemoryProfiler, is_hip -from vllm.worker.model_runner import ModelRunner - -from .config import LoadConfig, ModelConfig -from .model_loader import get_model - -logger = init_logger(__name__) - - -# How batches are constructed. -class BatchType(IntEnum): - # Every batch is prefill. - PREFILL = 0 - # Every batch is decode. - DECODE = 1 - # Batch is a mixture of prefill and decode. - MIXED = 2 - - -class ModelRunner(ModelRunner): - def __init__( - self, - model: Union[nn.Module, Dict], # [verl] model itself or its parameter dict - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - cache_config: CacheConfig, - load_config: LoadConfig, - lora_config: Optional[LoRAConfig], - kv_cache_dtype: Optional[str] = "auto", - is_driver_worker: bool = False, - prompt_adapter_config: Optional[PromptAdapterConfig] = None, - multimodal_config: Optional[MultiModalConfig] = None, - return_hidden_states: bool = False, - ): - super().__init__( - model_config, - parallel_config, - scheduler_config, - device_config, - cache_config, - load_config, - lora_config, - kv_cache_dtype, - is_driver_worker=True, # a hack - prompt_adapter_config=prompt_adapter_config, - multimodal_config=multimodal_config, - return_hidden_states=return_hidden_states, - ) - - # NOTE(sgm): add for verl - self.model = model # this will be replaced by get_model() - - # NOTE(sgm): initialize model using the actor model - def load_model(self) -> None: - logger.info("Starting to load model %s...", self.model_config.model) - with CudaMemoryProfiler() as m: - self.model = get_model( - actor_model=self.model, - model_config=self.model_config, - device_config=self.device_config, - lora_config=self.lora_config, - load_config=self.load_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - multimodal_config=self.multimodal_config, - cache_config=self.cache_config, - ) - self.model_memory_usage = m.consumed_memory - logger.info("Loading model weights took %.4f GB", self.model_memory_usage / float(2**30)) - - if self.lora_config: - assert supports_lora(self.model), "Model does not support LoRA" - assert not supports_vision(self.model), "To be tested: vision language model with LoRA settings." - - self.lora_manager = LRUCacheWorkerLoRAManager( - self.scheduler_config.max_num_seqs, - self.scheduler_config.max_num_batched_tokens, - self.vocab_size, - self.lora_config, - self.device, - self.model.embedding_modules, - self.model.embedding_padding_modules, - max_position_embeddings=self.model.config.max_position_embeddings, - ) - self.model = self.lora_manager.create_lora_manager(self.model) - - if self.prompt_adapter_config: - self.prompt_adapter_manager = LRUCacheWorkerPromptAdapterManager( - self.scheduler_config.max_num_seqs, - self.scheduler_config.max_num_batched_tokens, - self.device, - self.prompt_adapter_config, - ) - self.model = self.prompt_adapter_manager.create_prompt_adapter_manager(self.model) - - if self.kv_cache_dtype == "fp8" and is_hip(): - # Currently only ROCm accepts kv-cache scaling factors - # via quantization_param_path and this will be deprecated - # in the future. - if self.model_config.quantization_param_path is not None: - if callable(getattr(self.model, "load_kv_cache_scales", None)): - warnings.warn( - "Loading kv cache scaling factor from JSON is deprecated and will be removed. Please include kv cache scaling factors in the model checkpoint.", - FutureWarning, - stacklevel=2, - ) - self.model.load_kv_cache_scales(self.model_config.quantization_param_path) - logger.info("Loaded KV cache scaling factors from %s", self.model_config.quantization_param_path) - else: - raise RuntimeError( - "Using FP8 KV cache and scaling factors provided but model %s does not support loading scaling factors.", - self.model.__class__, - ) - else: - logger.warning("Using FP8 KV cache but no scaling factors provided. Defaulting to scaling factors of 1.0. This may lead to less accurate results!") - - if envs.VLLM_TEST_DYNAMO_GRAPH_CAPTURE: - self.model = torch.compile(self.model, fullgraph=True, backend="eager") diff --git a/verl/third_party/vllm/vllm_v_0_5_4/parallel_state.py b/verl/third_party/vllm/vllm_v_0_5_4/parallel_state.py deleted file mode 100644 index d907e9a03..000000000 --- a/verl/third_party/vllm/vllm_v_0_5_4/parallel_state.py +++ /dev/null @@ -1,302 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2023 The vLLM team. -# Adapted from -# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. -"""Model and data parallel groups.""" - -import os -from typing import Optional - -import torch -import torch.distributed -import vllm.distributed.parallel_state as ps -from vllm.distributed.parallel_state import ( - get_pp_group, - get_world_group, - init_distributed_environment, - init_model_parallel_group, -) -from vllm.logger import init_logger - -logger = init_logger(__name__) -""" -This version is strongly tied with Megatron to implement HybridEngine and weight sharing between vllm and Megatron. -- We assume the Megatron tp+dp+pp world is already established before calling this function. - -""" - -# Device mesh for using DTensor -_DEVICE_MESH = None - -# Tensor model parallel group that the current rank belongs to. -_TP = None -# Pipeline model parallel group that the current rank belongs to. -_PP = None - - -# This method is for initializing the ParallelGroup when using HybridEngine -def initialize_parallel_state( - distributed_init_method: str = "env://", - backend: str = "nccl", - tensor_model_parallel_size: int = 1, - num_tp_per_train_tp: int = 1, - pipeline_model_parallel_size: int = 1, -): - # torch.distributed.all_reduce does not free the input tensor until - # the synchronization point. This causes the memory usage to grow - # as the number of all_reduce calls increases. This env var disables - # this behavior. - # Related issue: - # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573 - os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" - - # NOTE(sgm): Modify for verl, Env vars will be set by TORCHRUN. - rank = int(os.getenv("RANK", "-1")) - local_rank = int(os.getenv("LOCAL_RANK", "0")) - - # Use the world_size set by TORCHRUN - world_size = int(os.getenv("WORLD_SIZE", "-1")) - assert world_size != -1, "The world_size is set to -1, not initialized by TORCHRUN" - init_distributed_environment(world_size, rank, distributed_init_method, local_rank, backend) - if torch.distributed.get_world_size() > 1: - # NOTE: build a sepearate inference group with infer tp & micro dp - initialize_model_parallel_for_vllm( - tensor_model_parallel_size=tensor_model_parallel_size, - num_tensor_model_parallel_groups_per_train_tp=num_tp_per_train_tp, - ) - else: - initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, backend) - - -def ensure_model_parallel_initialized( - tensor_model_parallel_size: int, - pipeline_model_parallel_size: int = 1, - backend: Optional[str] = None, -) -> None: - """Helper to initialize model parallel groups if they are not initialized, - or ensure tensor-parallel and pipeline-parallel sizes are equal to expected - values if the model parallel groups are initialized. - """ - # get the backend of _DEVICE_WORLD_GROUP - backend = backend or torch.distributed.get_backend(get_world_group().device_group) - if not model_parallel_is_initialized(): - initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, backend) - return - - assert get_tensor_model_parallel_world_size() == tensor_model_parallel_size, f"tensor parallel group already initialized, but of unexpected size: {get_tensor_model_parallel_world_size()=} vs. {tensor_model_parallel_size=}" - pp_world_size = get_pp_group().world_size - assert pp_world_size == pipeline_model_parallel_size, f"pipeline parallel group already initialized, but of unexpected size: {pp_world_size=} vs. {pipeline_model_parallel_size=}" - - -# TODO(sgm): deviate from the v0.5.4, not pp now -def model_parallel_is_initialized(): - """Check if tensor and pipeline parallel groups are initialized.""" - return ps._TP is not None - # and _PIPELINE_MODEL_PARALLEL_GROUP is not None) - - -def initialize_model_parallel_for_vllm( - tensor_model_parallel_size: int, - num_tensor_model_parallel_groups_per_train_tp: int = 1, - pipeline_model_parallel_size: int = 1, -) -> None: - # Get world size and rank. Ensure some consistencies. - assert torch.distributed.is_initialized() - - assert isinstance(tensor_model_parallel_size, int) - - # assert num_tensor_model_parallel_groups_per_train_tp == 1 and not different_tp_group - # assert num_tensor_model_parallel_groups_per_train_tp > 1 and different_tp_group - - # Build the tensor model-parallel groups. - assert ps._TP is None, "tensor model parallel group is already initialized" - - global _TP - - world_size: int = torch.distributed.get_world_size() - - backend = torch.distributed.get_backend() - - num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size - - if num_tensor_model_parallel_groups_per_train_tp == 1: - # if tensor_model_parallel_size == train_tensor_parallel_size: - # using the same tp group as Megatron/vllm - assert _TP is None, "tensor model parallel group is already initialized" - group_ranks = [] - for i in range(num_tensor_model_parallel_groups): - ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size) - group_ranks.append(ranks) - _TP = init_model_parallel_group( - group_ranks=group_ranks, - local_rank=get_world_group().local_rank, - backend=backend, - use_custom_allreduce=False, # TODO: check why True is not work in Ray trainer - use_message_queue_broadcaster=True, - ) - ps._TP = _TP - # _MICRO_DATA_PARALLEL_GROUP is move to hybrid engine - else: - # initialize a micro_dp group and a tp group - # assume training tp=4, infer tp=2, then, weight is partitioned as - # [1], [2], [3], [4] for training and [1,2], [1,2], [3,4], [3,4] for inference - - # Build the inference tp groups - # train_tp = train_tensor_parallel_size - train_tp = num_tensor_model_parallel_groups_per_train_tp * tensor_model_parallel_size - # num_tensor_model_parallel_groups_per_train_tp = train_tp // tensor_model_parallel_size - assert _TP is None, "tensor model parallel group is already initialized" - group_ranks = [] - for i in range(num_tensor_model_parallel_groups // num_tensor_model_parallel_groups_per_train_tp): - start = train_tp * i - end = train_tp * (i + 1) - for j in range(num_tensor_model_parallel_groups_per_train_tp): - ranks = list(range(start, end, num_tensor_model_parallel_groups_per_train_tp)) - for i in range(len(ranks)): - ranks[i] += j - group_ranks.append(ranks) - _TP = init_model_parallel_group( - group_ranks=group_ranks, - local_rank=get_world_group().local_rank, - backend=backend, - use_custom_allreduce=False, # TODO: check why True is not work in Ray trainer - use_message_queue_broadcaster=True, - ) - ps._TP = _TP - - # Build the pipeline model-parallel groups. - # global _PIPELINE_MODEL_PARALLEL_GROUP - # global _PIPELINE_GLOBAL_RANKS - # assert ps._PIPELINE_MODEL_PARALLEL_GROUP is None, ("pipeline model parallel group is already initialized") - - # ps._PIPELINE_MODEL_PARALLEL_GROUP = mpu.get_pipeline_model_parallel_group() - # ps._PIPELINE_GLOBAL_RANKS = mpu.get_pipeline_model_parallel_ranks() - - # TODO: init using device mesh (not support hybrid engine now) - # Build the pipeline model-parallel groups. - num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size - global _PP - assert _PP is None, "pipeline model parallel group is already initialized" - group_ranks = [] - for i in range(num_pipeline_model_parallel_groups): - ranks = list(range(i, world_size, num_pipeline_model_parallel_groups)) - group_ranks.append(ranks) - # pipeline parallel does not need custom allreduce - _PP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, use_custom_allreduce=False) - ps._PP = _PP # for verl - - -def initialize_model_parallel( - tensor_model_parallel_size: int = 1, - pipeline_model_parallel_size: int = 1, - backend: Optional[str] = None, -) -> None: - """ - NOTE: This method is a hack from the open-sourced version without - asertion of world_size = tp * pp - - Initialize model parallel groups. - - Arguments: - tensor_model_parallel_size: number of GPUs used for tensor model - parallelism. - pipeline_model_parallel_size: number of GPUs used for pipeline model - parallelism. - - Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we - use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize - the model pipeline. The present function will - create 4 tensor model-parallel groups and 2 pipeline model-parallel groups: - 4 tensor model-parallel groups: - [g0, g1], [g2, g3], [g4, g5], [g6, g7] - 2 pipeline model-parallel groups: - [g0, g2, g4, g6], [g1, g3, g5, g7] - Note that for efficiency, the caller should make sure adjacent ranks - are on the same DGX box. For example if we are using 2 DGX-1 boxes - with a total of 16 GPUs, rank 0 to 7 belong to the first box and - ranks 8 to 15 belong to the second box. - """ - # Get world size and rank. Ensure some consistencies. - assert torch.distributed.is_initialized() - world_size: int = torch.distributed.get_world_size() - backend = backend or torch.distributed.get_backend(ps.get_world_group().device_group) - - # NOTE(sgm) we don't assert world_size == tp * pp - # DP is not managed by vllm but by the verl WorkerGroup - # if (world_size != - # tensor_model_parallel_size * pipeline_model_parallel_size): - # raise RuntimeError( - # f"world_size ({world_size}) is not equal to " - # f"tensor_model_parallel_size ({tensor_model_parallel_size}) x " - # f"pipeline_model_parallel_size ({pipeline_model_parallel_size})") - - num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size - global _TP - assert _TP is None, "tensor model parallel group is already initialized" - group_ranks = [] - for i in range(num_tensor_model_parallel_groups): - ranks = list(range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)) - group_ranks.append(ranks) - - # message queue broadcaster is only used in tensor model parallel group - _TP = init_model_parallel_group( - group_ranks, - get_world_group().local_rank, - backend, - use_custom_allreduce=False, # TODO: check why True is not work in Ray trainer - use_message_queue_broadcaster=True, - ) - ps._TP = _TP - - # TODO: init using device mesh (not support hybrid engine now) - # Build the pipeline model-parallel groups. - num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size - global _PP - assert _PP is None, "pipeline model parallel group is already initialized" - group_ranks = [] - for i in range(num_pipeline_model_parallel_groups): - ranks = list(range(i, world_size, num_pipeline_model_parallel_groups)) - group_ranks.append(ranks) - # pipeline parallel does not need custom allreduce - _PP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, use_custom_allreduce=False) - ps._PP = _PP # for verl - - -""" -Device mesh utilities -""" - - -def get_device_mesh(): - assert _DEVICE_MESH is not None, "device mesh is not initialized" - return _DEVICE_MESH - - -""" -Tensor model parallel utilities -""" - - -def get_tensor_model_parallel_group(): - """Get the tensor model parallel group the caller rank belongs to.""" - assert _TP is not None, "tensor model parallel group is not initialized" - return _TP.device_group - - -def get_tensor_model_parallel_world_size(): - """Return world size for the tensor model parallel group.""" - return torch.distributed.get_world_size(group=get_tensor_model_parallel_group()) - - -def get_tensor_model_parallel_rank(): - """Return my rank for the tensor model parallel group.""" - return torch.distributed.get_rank(group=get_tensor_model_parallel_group()) - - -def get_tensor_model_parallel_src_rank(): - """Calculate the global rank corresponding to the first local rank - in the tensor model parallel group.""" - global_rank = torch.distributed.get_rank() - local_world_size = get_tensor_model_parallel_world_size() - return (global_rank // local_world_size) * local_world_size diff --git a/verl/third_party/vllm/vllm_v_0_5_4/spmd_gpu_executor.py b/verl/third_party/vllm/vllm_v_0_5_4/spmd_gpu_executor.py deleted file mode 100644 index 3a8ba25c0..000000000 --- a/verl/third_party/vllm/vllm_v_0_5_4/spmd_gpu_executor.py +++ /dev/null @@ -1,250 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2023 The vLLM team. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/executor/gpu_executor.py - -import os -import socket -from typing import Iterable, List, Optional, Set, Tuple - -import torch -from vllm.config import ( - CacheConfig, - DeviceConfig, - LoRAConfig, - MultiModalConfig, - ParallelConfig, - PromptAdapterConfig, - SchedulerConfig, - SpeculativeConfig, -) -from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase -from vllm.logger import init_logger -from vllm.lora.request import LoRARequest -from vllm.sequence import ExecuteModelRequest, SamplerOutput - -from .config import LoadConfig, ModelConfig - -logger = init_logger(__name__) - - -class SPMDGPUExecutor(ExecutorBase): - """SPMD-based multi-GPU executor implementations.""" - - def __init__( - self, - model, # pytorch model itself or its parameter dict - model_config: ModelConfig, - cache_config: CacheConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - load_config: LoadConfig, - lora_config: Optional[LoRAConfig], - multimodal_config: Optional[MultiModalConfig], - speculative_config: Optional[SpeculativeConfig], - prompt_adapter_config: Optional[PromptAdapterConfig], - ) -> None: - self.model_config = model_config - self.cache_config = cache_config - self.lora_config = lora_config - self.load_config = load_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - self.device_config = device_config - self.multimodal_config = multimodal_config - self.speculative_config = speculative_config - self.prompt_adapter_config = prompt_adapter_config - - distributed_init_method = initialize_cluster(parallel_config) - self._init_executor(model, distributed_init_method) - - # TODO(sgm): verl not support speculative decode now - def _init_executor(self, model, distributed_init_method) -> None: - assert not self.speculative_config, "Speculative decoding not yet supported for multi-GPU backend." - - # Create the parallel worker for each GPU. - self._init_workers_sp(model, distributed_init_method) - - def _init_workers_sp(self, model, distributed_init_method: str): - # Lazy import the Worker to avoid importing torch.cuda/xformers - # before CUDA_VISIBLE_DEVICES is set in the Worker - from .worker import Worker - - rank = int(os.getenv("RANK")) - local_rank = int(os.getenv("LOCAL_RANK")) - print(f"local rank {local_rank}") - - # see https://github.com/NVIDIA/nccl/issues/1234 - os.environ["NCCL_CUMEM_ENABLE"] = "0" - - self.worker = Worker( - model, - self.model_config, - self.parallel_config, - self.scheduler_config, - self.device_config, - self.cache_config, - self.load_config, - local_rank, - rank, - distributed_init_method, - lora_config=self.lora_config, - multimodal_config=self.multimodal_config, - speculative_config=None, - prompt_adapter_config=self.speculative_config, - is_driver_worker=True, - model_runner_cls=None, # use the default one - ) - - # NOTE(shengguangming): torch.distributed.init_process_group will be called inside the init_model() - self.worker.init_device() - self.worker.load_model() - - def determine_num_available_blocks(self) -> Tuple[int, int]: - """Determine the number of available KV blocks. - - This invokes `determine_num_available_blocks` on each worker and takes - the min of the results, guaranteeing that the selected cache sizes are - compatible with all workers. - - Returns: - - tuple[num_gpu_blocks, num_cpu_blocks] - """ - # Get the maximum number of blocks that can be allocated on GPU and CPU. - num_blocks = self.worker.determine_num_available_blocks() - - # NOTE(shengguangming): Now we don't use a shared centralized controler but each process will - # have its own scheduler - num_gpu_blocks = num_blocks[0] - num_cpu_blocks = num_blocks[1] - - return num_gpu_blocks, num_cpu_blocks - - def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: - """Initialize the KV cache in all workers.""" - - # NOTE: We log here to avoid multiple logs when number of workers is - # greater than one. We could log in the engine, but not all executors - # have GPUs. - logger.info("# GPU blocks: %d, # CPU blocks: %d", num_gpu_blocks, num_cpu_blocks) - - self.cache_config.num_gpu_blocks = num_gpu_blocks - self.cache_config.num_cpu_blocks = num_cpu_blocks - - if torch.distributed.get_rank() == 0: - print(f"before init cache memory allocated: {torch.cuda.memory_allocated() / 1e9}GB, reserved: {torch.cuda.memory_reserved() / 1e9}GB") - self.worker.initialize_cache(num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks) - if torch.distributed.get_rank() == 0: - print(f"after init cache memory allocated: {torch.cuda.memory_allocated() / 1e9}GB, reserved: {torch.cuda.memory_reserved() / 1e9}GB") - - # NOTE(sgm): This will not profile & capture the model(CUDAGraph) when rebuilding KVCache - def init_cache_engine(self) -> None: - self.worker._init_cache_engine() - - def free_cache_engine(self) -> None: - self.worker.free_cache_engine() - - def execute_model(self, execute_model_req) -> List[SamplerOutput]: - all_outputs = self.worker.execute_model(execute_model_req=execute_model_req) - - # NOTE(sgm): - # Each GPU in vllm under verl has its own spmd_gpu_executor, therefore all GPUs should return the outputs - # In vllm with ray, only the driver worker returns the sampling results. - return all_outputs - - def add_lora(self, lora_request: LoRARequest) -> bool: - assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." - return self.worker.add_lora(lora_request=lora_request) - - def remove_lora(self, lora_id: int) -> bool: - assert lora_id > 0, "lora_id must be greater than 0." - return self.worker.remove_lora(lora_id=lora_id) - - def list_loras(self) -> Set[int]: - return self.worker.list_loras() - - def check_health(self) -> None: - # SPMDExecutor will always be healthy as long as - # it's running. - return - - # NOTE(sgm) add for verl to pass the abstract class test, not used - from vllm.prompt_adapter.request import PromptAdapterRequest - - def add_prompt_adapter(self, prompt_adapter_request: PromptAdapterRequest) -> bool: - assert prompt_adapter_request.prompt_adapter_id > 0, "prompt_adapter_id must be greater than 0." - return self.worker.add_prompt_adapter(prompt_adapter_request) - - def list_prompt_adapters(self) -> Set[int]: - return self.worker.list_prompt_adapters() - - def pin_lora(self, lora_id: int) -> bool: - assert lora_id > 0, "lora_id must be greater than 0." - return self.worker.pin_lora(lora_id) - - def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: - assert prompt_adapter_id > 0, "prompt_adapter_id must be greater than 0." - return self.worker.pin_prompt_adapter(prompt_adapter_id) - - def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: - assert prompt_adapter_id > 0, "prompt_adapter_id must be greater than 0." - return self.worker.remove_prompt_adapter(prompt_adapter_id) - - # NOTE(sgm): add for verl - def offload_model_weights(self) -> None: - self.worker.offload_model_weights() - - def sync_model_weights(self, actor_weights: Iterable, load_format: str) -> None: - self.worker.sync_model_weights(actor_weights=actor_weights, load_format=load_format) - - -def initialize_cluster( - parallel_config: ParallelConfig, - engine_use_ray: bool = False, - ray_address: Optional[str] = None, -) -> Tuple[str, Optional[None]]: - """Initialize the distributed cluster probably with Ray. - - Args: - parallel_config: The configurations for parallel execution. - - Returns: - The `distributed_init_method` is the address for initializing the - distributed backend. - """ - - # Initialize cluster locally. - # We need to setup the distributed init method to make sure - # the distributed megatron code (e.g., get world size) works correctly. - # distributed_init_method = f"tcp://localhost:{port}" - distributed_init_method = "env://" - return distributed_init_method - - -def get_open_port(): - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(("", 0)) - return s.getsockname()[1] - - -# TODO(sgm): not implemented async executor yet -class SPMDGPUExecutorAsync(SPMDGPUExecutor, ExecutorAsyncBase): - async def execute_model_async(self, execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: - """Executes one model step on the given sequences.""" - raise NotImplementedError - - async def check_health_async(self) -> None: - """Checks if the executor is healthy. If not, it should raise an - exception.""" - self.check_health() diff --git a/verl/third_party/vllm/vllm_v_0_5_4/tokenizer.py b/verl/third_party/vllm/vllm_v_0_5_4/tokenizer.py deleted file mode 100644 index 2541bd157..000000000 --- a/verl/third_party/vllm/vllm_v_0_5_4/tokenizer.py +++ /dev/null @@ -1,69 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2023 The vLLM team. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/tokenizer_group/tokenizer_group.py - -from typing import List, Optional - -from transformers import PreTrainedTokenizer -from vllm.lora.request import LoRARequest -from vllm.transformers_utils.tokenizers import * -from vllm.utils import LRUCache - - -class TokenizerGroup: - """A group of tokenizers that can be used for LoRA adapters.""" - - def __init__(self, tokenizer: PreTrainedTokenizer, enable_lora: bool, max_num_seqs: int, max_input_length: Optional[int]): - self.enable_lora = enable_lora - self.max_input_length = max_input_length - self.tokenizer = tokenizer - self.lora_tokenizers = LRUCache[PreTrainedTokenizer](capacity=max_num_seqs) if enable_lora else None - - def ping(self) -> bool: - """Check if the tokenizer group is alive.""" - return True - - def get_max_input_len(self, lora_request: Optional[LoRARequest] = None) -> Optional[int]: - """Get the maximum input length for the LoRA request.""" - return self.max_input_length - - def encode(self, prompt: str, request_id: Optional[str] = None, lora_request: Optional[LoRARequest] = None) -> List[int]: - tokenizer = self.get_lora_tokenizer(lora_request) - return tokenizer.encode(prompt) - - async def encode_async(self, prompt: str, request_id: Optional[str] = None, lora_request: Optional[LoRARequest] = None) -> List[int]: - tokenizer = await self.get_lora_tokenizer_async(lora_request) - return tokenizer.encode(prompt) - - def get_lora_tokenizer(self, lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer": - if not lora_request or not self.enable_lora: - return self.tokenizer - if lora_request.lora_int_id not in self.lora_tokenizers: - # TODO(sgm): the lora tokenizer is also passed, but may be different - tokenizer = self.tokenizer - # tokenizer = (get_lora_tokenizer( - # lora_request, **self.tokenizer_config) or self.tokenizer) - self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer) - return tokenizer - else: - return self.lora_tokenizers.get(lora_request.lora_int_id) - - # FIXME(sgm): for simplicity, we assign the special token here - @property - def pad_token_id(self): - return self.tokenizer.pad_token_id - - @property - def eos_token_id(self): - return self.tokenizer.eos_token_id diff --git a/verl/third_party/vllm/vllm_v_0_5_4/worker.py b/verl/third_party/vllm/vllm_v_0_5_4/worker.py deleted file mode 100644 index 302145aa1..000000000 --- a/verl/third_party/vllm/vllm_v_0_5_4/worker.py +++ /dev/null @@ -1,323 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2023 The vLLM team. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/worker/worker.py -"""A GPU worker class.""" - -import gc -import os -from typing import Dict, List, Optional, Tuple, Type, Union - -import torch -import torch.distributed -import torch.nn as nn -from vllm.config import ( - CacheConfig, - DeviceConfig, - LoRAConfig, - MultiModalConfig, - ParallelConfig, - PromptAdapterConfig, - SchedulerConfig, - SpeculativeConfig, -) - -# TODO(sgm): check why vllm has similar file in vllm.model_executor.parallel_utils.parallel_state -from vllm.distributed import get_tensor_model_parallel_group, init_distributed_environment, set_custom_all_reduce -from vllm.model_executor import set_random_seed -from vllm.sequence import ExecuteModelRequest, IntermediateTensors, SamplerOutput -from vllm.worker.cache_engine import CacheEngine -from vllm.worker.embedding_model_runner import EmbeddingModelRunner -from vllm.worker.model_runner import GPUModelRunnerBase -from vllm.worker.model_runner_base import ModelRunnerInputBase -from vllm.worker.worker import Worker, _check_if_gpu_supports_dtype -from vllm.worker.worker_base import WorkerInput - -from .config import LoadConfig, LoadFormat, ModelConfig -from .dtensor_weight_loaders import load_dtensor_weights -from .hf_weight_loader import load_hf_weights -from .megatron_weight_loaders import load_megatron_weights -from .model_runner import ModelRunner -from .parallel_state import ensure_model_parallel_initialized - - -class Worker(Worker): - """A worker class that executes (a partition of) the model on a GPU. - - Each worker is associated with a single GPU. The worker is responsible for - maintaining the KV cache and executing the model on the GPU. In case of - distributed inference, each worker is assigned a partition of the model. - """ - - def __init__( - self, - model: Union[nn.Module, Dict], # model itself or its parameter dict - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - cache_config: CacheConfig, - load_config: LoadConfig, - local_rank: int, - rank: int, - distributed_init_method: str, - lora_config: Optional[LoRAConfig] = None, - multimodal_config: Optional[MultiModalConfig] = None, - speculative_config: Optional[SpeculativeConfig] = None, - prompt_adapter_config: Optional[PromptAdapterConfig] = None, - is_driver_worker: bool = False, - model_runner_cls: Optional[Type[GPUModelRunnerBase]] = None, - ) -> None: - # self.model = model # will be replaced in the init_model - self.model_config = model_config - self.parallel_config = parallel_config - self.parallel_config.rank = rank - self.scheduler_config = scheduler_config - self.device_config = device_config - self.cache_config = cache_config - self.local_rank = local_rank - self.rank = rank - self.distributed_init_method = distributed_init_method - self.lora_config = lora_config - self.load_config = load_config - self.prompt_adapter_config = prompt_adapter_config - self.is_driver_worker = is_driver_worker # TODO: we don't need driver - # if parallel_config and is_driver_worker: - # assert rank % parallel_config.tensor_parallel_size == 0, \ - # "Driver worker should be rank 0 of tensor parallel group." - if self.model_config.trust_remote_code: - # note: lazy import to avoid importing torch before initializing - from vllm.utils import init_cached_hf_modules - - init_cached_hf_modules() - self.multimodal_config = multimodal_config - - # Return hidden states from target model if the draft model is an - # mlp_speculator - speculative_args = {} if speculative_config is None or (speculative_config.draft_model_config.model == model_config.model) or (speculative_config.draft_model_config.hf_config.model_type not in ["medusa", "mlp_speculator"]) else {"return_hidden_states": True} - - # TODO(sgm): set correct model runner class - ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner - if model_runner_cls is not None: - ModelRunnerClass = model_runner_cls - elif self.model_config.embedding_mode: - ModelRunnerClass = EmbeddingModelRunner - self.model_runner: GPUModelRunnerBase = ModelRunnerClass( - model, # [VERL]: add for verl - model_config, - parallel_config, - scheduler_config, - device_config, - cache_config, - load_config=load_config, - lora_config=self.lora_config, - kv_cache_dtype=self.cache_config.cache_dtype, - is_driver_worker=is_driver_worker, - prompt_adapter_config=prompt_adapter_config, - multimodal_config=multimodal_config, - **speculative_args, - ) - - # Uninitialized cache engine. Will be initialized by - # initialize_cache. - self.cache_engine: List[CacheEngine] = None - # Initialize gpu_cache as embedding models don't initialize kv_caches - self.gpu_cache: Optional[List[List[torch.Tensor]]] = None - - # NOTE(sgm): [VERL] For offloading inference engine params - self.cpu_model = None - - def init_device(self) -> None: - if self.device_config.device.type == "cuda": - # torch.distributed.all_reduce does not free the input tensor until - # the synchronization point. This causes the memory usage to grow - # as the number of all_reduce calls increases. This env var disables - # this behavior. - # Related issue: - # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573 - os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" - - # NOTE(sgm): Modify for verl, Env vars will be set by TORCHRUN. - self.rank = self.rank if self.rank is not None else int(os.getenv("RANK", "-1")) - local_rank = int(os.getenv("LOCAL_RANK", "0")) - self.device = torch.device(f"cuda:{local_rank}") - if self.rank < 0: - raise ValueError("Invalid or unspecified rank.") - torch.cuda.set_device(self.device) - - # Use the world_size set by TORCHRUN - world_size = int(os.getenv("WORLD_SIZE", "-1")) - assert world_size != -1, "The world_size is set to -1, not initialized by TORCHRUN" - self.parallel_config.world_size = world_size - - _check_if_gpu_supports_dtype(self.model_config.dtype) - torch.cuda.empty_cache() - self.init_gpu_memory = torch.cuda.mem_get_info()[0] - else: - raise RuntimeError(f"Not support device type: {self.device_config.device}") - - # Initialize the distributed environment. - init_worker_distributed_environment(self.parallel_config, self.rank, self.distributed_init_method, self.local_rank) - # Set random seed. - set_random_seed(self.model_config.seed) - # self.model = get_model(actor_model=self.model, model_config=self.model_config) - - @torch.inference_mode() - def determine_num_available_blocks(self) -> Tuple[int, int]: - """Profiles the peak memory usage of the model to determine how many - KV blocks may be allocated without OOMs. - - The engine will first conduct a profiling of the existing memory usage. - Then, it calculate the maximum possible number of GPU and CPU blocks - that can be allocated with the remaining free memory. - - .. tip:: - You may limit the usage of GPU memory - by adjusting the `gpu_memory_utilization` parameter. - """ - # Profile the memory usage of the model and get the maximum number of - # cache blocks that can be allocated with the remaining free memory. - torch.cuda.empty_cache() - # torch.cuda.reset_peak_memory_stats() - - # Execute a forward pass with dummy inputs to profile the memory usage - # of the model. - self.model_runner.profile_run() - - # Calculate the number of blocks that can be allocated with the - # profiled peak memory. - torch.cuda.synchronize() - free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() - peak_memory = total_gpu_memory - free_gpu_memory - - assert peak_memory > 0, "Error in memory profiling. This happens when the GPU memory was not properly cleaned up before initializing the vLLM instance." - - cache_block_size = self.get_cache_block_size_bytes() - - # NOTE(sgm) [VERL] use the remaining memory - num_gpu_blocks = int((free_gpu_memory * self.cache_config.gpu_memory_utilization) // cache_block_size) - # num_gpu_blocks = int((total_gpu_memory * self.cache_config.gpu_memory_utilization - peak_memory) // cache_block_size) - - num_cpu_blocks = int(self.cache_config.swap_space_bytes // cache_block_size) - num_gpu_blocks = max(num_gpu_blocks, 0) - num_cpu_blocks = max(num_cpu_blocks, 0) - if self.model_runner.lora_manager: - self.model_runner.remove_all_loras() - - # NOTE(sgm): Add for [VERL], synchronize number of blocks with all the rank - num_gpu_blocks = torch.tensor([num_gpu_blocks], device="cuda") - num_cpu_blocks = torch.tensor([num_cpu_blocks], device="cuda") - - torch.distributed.all_reduce(num_gpu_blocks, op=torch.distributed.ReduceOp.MIN, group=get_tensor_model_parallel_group().device_group) - torch.distributed.all_reduce(num_cpu_blocks, op=torch.distributed.ReduceOp.MIN, group=get_tensor_model_parallel_group().device_group) - num_gpu_blocks = num_gpu_blocks.item() - num_cpu_blocks = num_cpu_blocks.item() - gc.collect() - torch.cuda.empty_cache() - return num_gpu_blocks, num_cpu_blocks - - def _init_cache_engine(self): - if self.cache_engine is None and self.gpu_cache is None: - super()._init_cache_engine() - - def free_cache_engine(self): - # ensure `enforce_eager=True` - self.cache_engine = None - self.gpu_cache = None - - # NOTE(sgm): [VERL]: adapt from _execute_model_spmd() - def execute_model(self, execute_model_req: ExecuteModelRequest, intermediate_tensors: Optional[IntermediateTensors] = None) -> Optional[List[SamplerOutput]]: - """ - Execute model in Single Program Multiple Data (SPMD) fashion. - All workers take the same request, prepare the input and - execute the model. - """ - assert execute_model_req is not None, "_execute_model_spmd() requires each worker to take in an ExecuteModelRequest" - worker_input: WorkerInput = self.prepare_worker_input(execute_model_req=execute_model_req) - model_input: ModelRunnerInputBase = self.model_runner.prepare_model_input(execute_model_req.seq_group_metadata_list) - - # verl.worker.workerbase.WorkerBase - # swap cache - super().execute_worker(worker_input) - - # If there is no input, we don't need to execute the model. - if worker_input.num_seq_groups == 0: - return [] - - return self.model_runner.execute_model( - model_input, - self.kv_cache[worker_input.virtual_engine] if self.kv_cache is not None else None, - intermediate_tensors, - ) - - # assume the input is .state_dict() - def sync_model_weights(self, actor_weights: Dict, load_format: str): - if load_format in [LoadFormat.MEGATRON, LoadFormat.AUTO]: - load_megatron_weights(actor_weights, self.model_runner.model) - elif load_format == LoadFormat.HF: - # full model state dict without no sharding - load_hf_weights(actor_weights, self.model_runner.model) - elif load_format == LoadFormat.DTENSOR: - load_dtensor_weights(actor_weights, self.model_runner.model) - - def offload_model_weights(self) -> None: - if self.cpu_model is None: - self.cpu_model = {} - for name, params in self.model_runner.model.named_parameters(): - self.cpu_model[name] = torch.empty_like(params, device="cpu") - params.data = self.cpu_model[name] - else: - for name, params in self.model_runner.model.named_parameters(): - params.data = self.cpu_model[name] - - -def init_worker_distributed_environment( - parallel_config: ParallelConfig, - rank: int, - distributed_init_method: Optional[str] = "env://", - local_rank: int = -1, -) -> None: - """Initialize the distributed environment.""" - set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) - - # NOTE(sgm) use tcp://localhost:xxxx will hang in HF setting without megatron - init_distributed_environment(parallel_config.world_size, rank, distributed_init_method, local_rank) - - ensure_model_parallel_initialized( - tensor_model_parallel_size=parallel_config.tensor_parallel_size, - pipeline_model_parallel_size=parallel_config.pipeline_parallel_size, - ) - - # TODO(sgm): check whether need this - # if pynccl_utils.is_initialized(): - # pynccl_world_size = pynccl_utils.get_world_size() - # if pynccl_world_size != parallel_config.world_size: - # raise RuntimeError( - # "pynccl is already initialized but the pynccl world " - # "size does not match parallel_config.world_size " - # f"({pynccl_world_size} vs. {parallel_config.world_size}).") - # elif parallel_config.world_size > 1: - # # NOTE(woosuk): We don't initialize pynccl process group when world size - # # is 1. - # # NOTE(kaichao): By default, pynccl is initialized for tp group. - # pynccl_utils.init_process_group( - # group=get_tensor_model_parallel_cpu_group()) - - # # Initialize a custom fast all-reduce implementation. - # if not parallel_config.disable_custom_all_reduce: - # init_custom_ar() - - # A small all_reduce for warmup. - torch.distributed.all_reduce(torch.zeros(1).cuda()) - # if pynccl_utils.is_initialized(): - # pynccl_utils.all_reduce(torch.zeros(1).cuda()) diff --git a/verl/third_party/vllm/vllm_v_0_6_3/arg_utils.py b/verl/third_party/vllm/vllm_v_0_6_3/arg_utils.py deleted file mode 100644 index bc4685c5f..000000000 --- a/verl/third_party/vllm/vllm_v_0_6_3/arg_utils.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2023 The vLLM team. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/engine/arg_utils.py - -import os -from dataclasses import dataclass - -from transformers import PretrainedConfig -from vllm.config import EngineConfig -from vllm.engine.arg_utils import EngineArgs - -from .config import LoadConfig, ModelConfig - - -@dataclass -class EngineArgs(EngineArgs): - model_hf_config: PretrainedConfig = None # for verl - - def __post_init__(self): - pass - - def create_model_config(self) -> ModelConfig: - return ModelConfig( - hf_config=self.model_hf_config, - tokenizer_mode=self.tokenizer_mode, - trust_remote_code=self.trust_remote_code, - dtype=self.dtype, - seed=self.seed, - revision=self.revision, - code_revision=self.code_revision, - rope_scaling=self.rope_scaling, - rope_theta=self.rope_theta, - tokenizer_revision=self.tokenizer_revision, - max_model_len=self.max_model_len, - quantization=self.quantization, - quantization_param_path=self.quantization_param_path, - enforce_eager=self.enforce_eager, - max_context_len_to_capture=self.max_context_len_to_capture, - max_seq_len_to_capture=self.max_seq_len_to_capture, - max_logprobs=self.max_logprobs, - disable_sliding_window=self.disable_sliding_window, - skip_tokenizer_init=self.skip_tokenizer_init, - served_model_name=self.served_model_name, - limit_mm_per_prompt=self.limit_mm_per_prompt, - use_async_output_proc=not self.disable_async_output_proc, - override_neuron_config=self.override_neuron_config, - config_format=self.config_format, - mm_processor_kwargs=self.mm_processor_kwargs, - ) - - def create_load_config(self) -> LoadConfig: - return LoadConfig( - load_format=self.load_format, - download_dir=self.download_dir, - model_loader_extra_config=self.model_loader_extra_config, - ignore_patterns=self.ignore_patterns, - ) - - def create_engine_config(self) -> EngineConfig: - engine_config = super().create_engine_config() - - # NOTE[VERL]: Use the world_size set by torchrun - world_size = int(os.getenv("WORLD_SIZE", "-1")) - assert world_size != -1, "The world_size is set to -1, not initialized by TORCHRUN" - engine_config.parallel_config.world_size = world_size - - return engine_config diff --git a/verl/third_party/vllm/vllm_v_0_6_3/config.py b/verl/third_party/vllm/vllm_v_0_6_3/config.py deleted file mode 100644 index fcac585f9..000000000 --- a/verl/third_party/vllm/vllm_v_0_6_3/config.py +++ /dev/null @@ -1,100 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2023 The vLLM team. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py - -import enum -import json -from dataclasses import dataclass, field -from typing import TYPE_CHECKING, List, Optional, Union - -from transformers import PretrainedConfig - -# Add for verl -from vllm.config import ModelConfig -from vllm.logger import init_logger -from vllm.utils import is_hip - -if TYPE_CHECKING: - from vllm.model_executor.model_loader.loader import BaseModelLoader - -logger = init_logger(__name__) - - -class LoadFormat(str, enum.Enum): - AUTO = "auto" - MEGATRON = "megatron" - HF = "hf" - DTENSOR = "dtensor" - DUMMY_HF = "dummy_hf" - DUMMY_MEGATRON = "dummy_megatron" - DUMMY_DTENSOR = "dummy_dtensor" - - -class ModelConfig(ModelConfig): - def __init__(self, hf_config: PretrainedConfig, *args, **kwargs) -> None: - super().__init__(model=hf_config._name_or_path, tokenizer=hf_config._name_or_path, *args, **kwargs) # noqa: B026 - self.hf_config = hf_config - - -@dataclass -class LoadConfig: - """ - download_dir: Directory to download and load the weights, default to the - default cache directory of huggingface. - load_format: The format of the model weights to load: - "auto" will try to load the weights in the safetensors format and - fall back to the pytorch bin format if safetensors format is - not available. - "pt" will load the weights in the pytorch bin format. - "safetensors" will load the weights in the safetensors format. - "npcache" will load the weights in pytorch format and store - a numpy cache to speed up the loading. - "dummy" will initialize the weights with random values, which is - mainly for profiling. - "tensorizer" will use CoreWeave's tensorizer library for - fast weight loading. - "bitsandbytes" will load nf4 type weights. - ignore_patterns: The list of patterns to ignore when loading the model. - Default to "original/**/*" to avoid repeated loading of llama's - checkpoints. - - """ - - load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO - download_dir: Optional[str] = None - model_loader_extra_config: Optional[Union[str, dict]] = field(default_factory=dict) - ignore_patterns: Optional[Union[List[str], str]] = None - - def __post_init__(self): - model_loader_extra_config = self.model_loader_extra_config or {} - if isinstance(model_loader_extra_config, str): - self.model_loader_extra_config = json.loads(model_loader_extra_config) - self._verify_load_format() - - if self.ignore_patterns is not None and len(self.ignore_patterns) > 0: - logger.info("Ignoring the following patterns when downloading weights: %s", self.ignore_patterns) - else: - self.ignore_patterns = ["original/**/*"] - - def _verify_load_format(self) -> None: - if not isinstance(self.load_format, str): - return - - load_format = self.load_format.lower() - self.load_format = LoadFormat(load_format) - - rocm_not_supported_load_format: List[str] = [] - if is_hip() and load_format in rocm_not_supported_load_format: - rocm_supported_load_format = [f for f in LoadFormat.__members__ if (f not in rocm_not_supported_load_format)] - raise ValueError(f"load format '{load_format}' is not supported in ROCm. Supported load formats are {rocm_supported_load_format}") diff --git a/verl/third_party/vllm/vllm_v_0_6_3/dtensor_weight_loaders.py b/verl/third_party/vllm/vllm_v_0_6_3/dtensor_weight_loaders.py deleted file mode 100644 index 98c2db1f6..000000000 --- a/verl/third_party/vllm/vllm_v_0_6_3/dtensor_weight_loaders.py +++ /dev/null @@ -1,374 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2023 The vLLM team. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/model_loader - -from typing import Dict - -import torch.nn as nn -from torch.distributed._tensor import DTensor -from vllm.model_executor.layers.fused_moe import FusedMoE -from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.utils import is_pp_missing_parameter - - -def gemma_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - params_dict = dict(vllm_model.named_parameters()) - for name, loaded_weight in actor_weights.items(): - for param_name, shard_name, shard_id in stacked_params_mapping: - if shard_name not in name: - continue - stacked_name = name.replace(shard_name, param_name) - # Skip loading extra bias for GPTQ models. - if stacked_name.endswith(".bias") and stacked_name not in params_dict: - continue - local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) - param = params_dict[stacked_name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id) - break - else: - # lm_head is not used in vllm as it is tied with embed_token. - # To prevent errors, skip loading lm_head.weight. - if "lm_head.weight" in name: - continue - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, local_loaded_weight.to(dtype=param.dtype)) - - -def gptbigcode_dtensor_load_weights(actor_weights: Dict, vllm_model: nn.Module): - params_dict = dict(vllm_model.named_parameters(remove_duplicate=False)) - for name, loaded_weight in actor_weights.items(): - if "lm_head.weight" in name: - continue - if ".attn.bias" in name: - # Skip attention mask. - # NOTE: "c_attn.bias" should not be skipped. - continue - local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, local_loaded_weight.to(dtype=param.dtype)) - - -def starcoder2_dtensor_load_weights(actor_weights: Dict, vllm_model: nn.Module): - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ] - - params_dict = dict(vllm_model.named_parameters(remove_duplicate=False)) - for name, loaded_weight in actor_weights.items(): - if "rotary_emb.inv_freq" in name: - continue - - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id) - break - else: - if vllm_model.config.tie_word_embeddings and "lm_head.weight" in name: - continue - param = params_dict[name] - local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, local_loaded_weight.to(dtype=param.dtype)) - - -def llama_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - (".qkv_proj", ".q_proj", "q"), - (".qkv_proj", ".k_proj", "k"), - (".qkv_proj", ".v_proj", "v"), - (".gate_up_proj", ".gate_proj", 0), - (".gate_up_proj", ".up_proj", 1), - ] - params_dict = dict(vllm_model.named_parameters()) - for name, loaded_weight in actor_weights.items(): - if "rotary_emb.inv_freq" in name: - continue - if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: - # Models trained using ColossalAI may include these tensors in - # the checkpoint. Skip them. - continue - # With tie_word_embeddings, we can skip lm_head.weight - # The weight might appear unnecessarily in the files if the model is - # processed with quantization, LoRA, fine-tuning, etc. - if vllm_model.config.tie_word_embeddings and "lm_head.weight" in name: - continue - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, local_loaded_weight) - - -def qwen2_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - params_dict = dict(vllm_model.named_parameters(remove_duplicate=False)) - for name, loaded_weight in actor_weights.items(): - if "rotary_emb.inv_freq" in name: - continue - if vllm_model.config.tie_word_embeddings and "lm_head.weight" in name: - continue - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - param = params_dict[name] - local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, local_loaded_weight.to(dtype=param.dtype)) - - -def qwen2vl_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - params_dict = dict(vllm_model.named_parameters(remove_duplicate=False)) - for name, loaded_weight in actor_weights.items(): - if "rotary_emb.inv_freq" in name: - continue - if vllm_model.config.tie_word_embeddings and "lm_head.weight" in name: - continue - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - param = params_dict[name] - local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, local_loaded_weight.to(dtype=param.dtype)) - - -def deepseekv2_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - - # Params for weights, fp8 weight scales, fp8 activation scales - # (param_name, weight_name, expert_id, shard_id) - expert_params_mapping = FusedMoE.make_expert_params_mapping( - ckpt_gate_proj_name="gate_proj", - ckpt_down_proj_name="down_proj", - ckpt_up_proj_name="up_proj", - num_experts=vllm_model.config.n_routed_experts, - ) - - params_dict = dict(vllm_model.named_parameters(remove_duplicate=False)) - for name, loaded_weight in actor_weights.items(): - if "rotary_emb.inv_freq" in name: - continue - for param_name, weight_name, shard_id in stacked_params_mapping: - # Skip non-stacked layers and experts (experts handled below). - if weight_name not in name: - continue - # We have mlp.experts[0].gate_proj in the checkpoint. - # Since we handle the experts below in expert_params_mapping, - # we need to skip here BEFORE we update the name, otherwise - # name will be updated to mlp.experts[0].gate_up_proj, which - # will then be updated below in expert_params_mapping - # for mlp.experts[0].gate_gate_up_proj, which breaks load. - if ("mlp.experts." in name) and name not in params_dict: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - - if is_pp_missing_parameter(name, vllm_model): - continue - - param = params_dict[name] - local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id) - break - else: - for mapping in expert_params_mapping: - param_name, weight_name, expert_id, shard_id = mapping - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - - if is_pp_missing_parameter(name, vllm_model): - continue - - param = params_dict[name] - local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader( - param, - local_loaded_weight.to(dtype=param.dtype), - weight_name, - shard_id=shard_id, - expert_id=expert_id, - ) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - - if is_pp_missing_parameter(name, vllm_model): - continue - - param = params_dict[name] - local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, local_loaded_weight.to(dtype=param.dtype)) - - -def gpt2_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: - pass - - -def redistribute_dtensor(param_name: str, loaded_weights: DTensor, parallelize_plan: Dict = None): - param_name = _process_parameter_names(name=param_name) - if parallelize_plan is not None: - assert param_name in parallelize_plan, f"param name: {param_name} not in parallelize_plan :{parallelize_plan.keys()}" - placement = parallelize_plan[param_name] - local_loaded_weights = loaded_weights.redistribute(device_mesh=loaded_weights.device_mesh, placements=placement).to_local() - else: - local_loaded_weights = loaded_weights.full_tensor() - return local_loaded_weights - - -def _process_parameter_names(name): - # Remove '.weight' if it exists at the end of the string - if name.endswith(".weight"): - name = name[:-7] - - # Remove 'model.layers.x.' or 'model.' prefix - if "model.layers" in name: - parts = name.split(".") - # Reconstruct the string without 'model.layers.x.' - name = ".".join(parts[3:]) # parts[0] is 'model', parts[1] is 'layers', parts[2] is 'x' - elif name.startswith("model."): - name = name[6:] # Remove 'model.' - - return name - - -__MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__ = { - "GPT2LMHeadModel": gpt2_dtensor_weight_loader, - "LlamaForCausalLM": llama_dtensor_weight_loader, - "LLaMAForCausalLM": llama_dtensor_weight_loader, - "MistralForCausalLM": llama_dtensor_weight_loader, # mistral is the same as llama in vLLM - "InternLMForCausalLM": llama_dtensor_weight_loader, - "AquilaModel": llama_dtensor_weight_loader, - "AquilaForCausalLM": llama_dtensor_weight_loader, - "Phi3ForCausalLM": llama_dtensor_weight_loader, - "GemmaForCausalLM": gemma_dtensor_weight_loader, - "Gemma2ForCausalLM": gemma_dtensor_weight_loader, - "GPTBigCodeForCausalLM": gptbigcode_dtensor_load_weights, - "Starcoder2ForCausalLM": starcoder2_dtensor_load_weights, - "Qwen2ForCausalLM": qwen2_dtensor_weight_loader, - "DeepseekV2ForCausalLM": deepseekv2_dtensor_weight_loader, - "Qwen2VLForConditionalGeneration": qwen2vl_dtensor_weight_loader, -} - - -# the actor model is .state_dict() -# Load dtensor weights -def load_dtensor_weights(actor_weights: Dict, vllm_model: nn.Module): - weight_loader = _get_model_weight_loader(vllm_model.__class__.__name__) - weight_loader(actor_weights, vllm_model) - # NOTE(sgm) to reduce peak memory usage, we offload vllm model to cpu - # after init, and we need this after sync model weights for in first iter. - vllm_model = vllm_model.cuda() - - -def _get_model_weight_loader(arch: str): - if arch in __MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__: - return __MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__[arch] - raise ValueError(f"Model architectures {arch} are not supported for now. Supported architectures: {__MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__.keys()}") - - -# NOTE(sgm): we use per-parameter weight loader in each vllm sub -def update_dtensor_weight_loader(): - pass diff --git a/verl/third_party/vllm/vllm_v_0_6_3/hf_weight_loader.py b/verl/third_party/vllm/vllm_v_0_6_3/hf_weight_loader.py deleted file mode 100644 index 23304298b..000000000 --- a/verl/third_party/vllm/vllm_v_0_6_3/hf_weight_loader.py +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2023 The vLLM team. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/model_loader - -from typing import Dict - -import torch.nn as nn -from vllm.model_executor.model_loader.utils import set_default_torch_dtype - - -def update_hf_weight_loader(): - print("no hf weight loader need to be updated") - return - - -def load_hf_weights(actor_weights: Dict, vllm_model: nn.Module): - assert isinstance(actor_weights, Dict) - with set_default_torch_dtype(next(vllm_model.parameters()).dtype): # TODO - if vllm_model.config.tie_word_embeddings and "lm_head.weight" in actor_weights: - del actor_weights["lm_head.weight"] - vllm_model.load_weights(actor_weights.items()) - for _, module in vllm_model.named_modules(): - quant_method = getattr(module, "quant_method", None) - if quant_method is not None: - quant_method.process_weights_after_loading(module) - # FIXME: Remove this after Mixtral is updated - # to use quant_method. - if hasattr(module, "process_weights_after_loading"): - module.process_weights_after_loading() - vllm_model = vllm_model.cuda() diff --git a/verl/third_party/vllm/vllm_v_0_6_3/llm.py b/verl/third_party/vllm/vllm_v_0_6_3/llm.py deleted file mode 100644 index 964107b12..000000000 --- a/verl/third_party/vllm/vllm_v_0_6_3/llm.py +++ /dev/null @@ -1,197 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2023 The vLLM team. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/llm.py - -from typing import Dict, Iterable, List, Optional, Tuple, Union - -import torch -import torch.nn as nn -from torch.nn.utils.rnn import pad_sequence -from transformers import PretrainedConfig, PreTrainedTokenizer, PreTrainedTokenizerFast -from vllm import LLM -from vllm.outputs import EmbeddingRequestOutput, RequestOutput -from vllm.utils import Counter - -from verl.workers.rollout.tokenizer import HybridEngineBaseTokenizer - -from .arg_utils import EngineArgs -from .llm_engine_sp import LLMEngine - - -class LLM(LLM): - """An LLM for generating texts from given prompts and sampling parameters. - - This class includes a tokenizer, a language model (possibly distributed - across multiple GPUs), and GPU memory space allocated for intermediate - states (aka KV cache). Given a batch of prompts and sampling parameters, - this class generates texts from the model, using an intelligent batching - mechanism and efficient memory management. - - NOTE: This class is intended to be used for offline inference. For online - serving, use the `AsyncLLMEngine` class instead. - NOTE: For the comprehensive list of arguments, see `EngineArgs`. - - Args: - model: A HuggingFace Transformers model instance. - tokenizer: A HuggingFace Transformers tokenizer instance. - tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer - if available, and "slow" will always use the slow tokenizer. - trust_remote_code: Trust remote code (e.g., from HuggingFace) when - downloading the model and tokenizer. - tensor_parallel_size: The number of GPUs to use for distributed - execution with tensor parallelism. - dtype: The data type for the model weights and activations. Currently, - we support `float32`, `float16`, and `bfloat16`. If `auto`, we use - the `torch_dtype` attribute specified in the model config file. - However, if the `torch_dtype` in the config is `float32`, we will - use `float16` instead. - quantization: The method used to quantize the model weights. Currently, - we support "awq". If None, we assume the model weights are not - quantized and use `dtype` to determine the data type of the weights. - revision: The specific model version to use. It can be a branch name, - a tag name, or a commit id. - tokenizer_revision: The specific tokenizer version to use. It can be a - branch name, a tag name, or a commit id. - seed: The seed to initialize the random number generator for sampling. - gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to - reserve for the model weights, activations, and KV cache. Higher - values will increase the KV cache size and thus improve the model's - throughput. However, if the value is too high, it may cause out-of- - memory (OOM) errors. - swap_space: The size (GiB) of CPU memory per GPU to use as swap space. - This can be used for temporarily storing the states of the requests - when their `best_of` sampling parameters are larger than 1. If all - requests will have `best_of=1`, you can safely set this to 0. - Otherwise, too small values may cause out-of-memory (OOM) errors. - enforce_eager: Whether to enforce eager execution. If True, we will - disable CUDA graph and always execute the model in eager mode. - If False, we will use CUDA graph and eager execution in hybrid. - max_context_len_to_capture: Maximum context len covered by CUDA graphs. - When a sequence has context length larger than this, we fall back - to eager mode. - disable_custom_all_reduce: See ParallelConfig - """ - - def __init__( - self, - model: Union[nn.Module, Dict], # model itself or its parameter dict - tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast, HybridEngineBaseTokenizer], - model_hf_config: PretrainedConfig, - tokenizer_mode: str = "auto", - trust_remote_code: bool = False, - skip_tokenizer_init: bool = False, - tensor_parallel_size: int = 1, - dtype: str = "auto", - quantization: Optional[str] = None, - revision: Optional[str] = None, - tokenizer_revision: Optional[str] = None, - seed: int = 0, - gpu_memory_utilization: float = 0.9, - swap_space: int = 32, - cpu_offload_gb: float = 0, - enforce_eager: bool = False, - max_context_len_to_capture: Optional[int] = None, - max_seq_len_to_capture: int = 8192, - disable_custom_all_reduce: bool = False, - load_format="auto", - **kwargs, - ) -> None: - if "disable_log_stats" not in kwargs: - kwargs["disable_log_stats"] = True - removed_vision_keys = ("image_token_id", "image_feature_size", "image_input_shape", "image_input_type") - if any(k in kwargs for k in removed_vision_keys): - raise TypeError("There is no need to pass vision-related arguments anymore.") - engine_args = EngineArgs( - model_hf_config=model_hf_config, - # tokenizer=tokenizer, - tokenizer_mode=tokenizer_mode, - skip_tokenizer_init=skip_tokenizer_init, - trust_remote_code=trust_remote_code, - tensor_parallel_size=tensor_parallel_size, - dtype=dtype, - quantization=quantization, - revision=revision, - tokenizer_revision=tokenizer_revision, - seed=seed, - gpu_memory_utilization=gpu_memory_utilization, - swap_space=swap_space, - cpu_offload_gb=cpu_offload_gb, - enforce_eager=enforce_eager, - max_context_len_to_capture=max_context_len_to_capture, - max_seq_len_to_capture=max_seq_len_to_capture, - disable_custom_all_reduce=disable_custom_all_reduce, - load_format=load_format, - **kwargs, - ) - tokenizer_cls = (PreTrainedTokenizer, PreTrainedTokenizerFast, HybridEngineBaseTokenizer) - if not isinstance(tokenizer, tokenizer_cls): - raise ValueError(f"Unexpected tokenizer type: {type(tokenizer)}. Must beone of the following: PreTrainedTokenizer, PreTrainedTokenizerFast, verl.workers.rollout.HybridEngineBaseTokenizer") - self.llm_engine = LLMEngine.from_engine_args(model, tokenizer, engine_args) # TODO: check usagecontext - self.request_counter = Counter() - - def init_cache_engine(self): - self.llm_engine.init_cache_engine() - - def free_cache_engine(self): - self.llm_engine.free_cache_engine() - - def get_tokenizer(self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: - return self.llm_engine.tokenizer - - def set_tokenizer( - self, - tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], - ) -> None: - self.llm_engine.tokenizer = tokenizer - - def _run_engine(self, *, use_tqdm: bool) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: - outputs = super()._run_engine(use_tqdm=use_tqdm) - return self._post_process_outputs(outputs) - - # # NOTE(shengguangming): add for verl - # # TODO(sgm): we can optimize it by making the dataloader yield List[int] without padding. - # def _pre_process_inputs(self, prompt_token_ids: torch.Tensor) -> List[int]: - # # remove the left padding in the prompt token_id - # pad_token_id = self.llm_engine.tokenizer.pad_token_id if self.llm_engine.tokenizer.pad_token_id is not None else self.llm_engine.tokenizer.eos_token_id - # non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0] - # token_ids = prompt_token_ids[non_pad_index:].tolist() - # return token_ids - - # NOTE(shengguangming): add for verl - def _post_process_outputs(self, request_outputs: List[RequestOutput]) -> Tuple[torch.Tensor, torch.Tensor]: - output_token_ids = [] - logprobs = [] - for request_output in request_outputs: # List[RequestOutput] - outputs = request_output.outputs - for output in outputs: # List[CompletionOutput], usually len == 1 - output_token_ids.append(torch.tensor(output.token_ids)) - # TODO(shengguangming): can be optimzied by rewrite the Sampler._get_logprobs() logits - logprobs_dicts = output.logprobs - if logprobs_dicts is not None: - logprob = [] - for logprobs_dict, id in zip(logprobs_dicts, output.token_ids): - logprob.append(logprobs_dict[id].logprob) - logprobs.append(torch.tensor(logprob)) - - pad_token_id = self.llm_engine.tokenizer.pad_token_id if self.llm_engine.tokenizer.pad_token_id is not None else self.llm_engine.tokenizer.eos_token_id - output_token_ids = pad_sequence(output_token_ids, batch_first=True, padding_value=pad_token_id) - if len(logprobs) > 0: - logprobs = pad_sequence(logprobs, batch_first=True, padding_value=pad_token_id) - return output_token_ids, logprobs - - def sync_model_weights(self, actor_weights: Iterable, load_format: str) -> None: - self.llm_engine.sync_model_weights(actor_weights=actor_weights, load_format=load_format) - - def offload_model_weights(self) -> None: - self.llm_engine.offload_model_weights() diff --git a/verl/third_party/vllm/vllm_v_0_6_3/llm_engine_sp.py b/verl/third_party/vllm/vllm_v_0_6_3/llm_engine_sp.py deleted file mode 100644 index 57ed22968..000000000 --- a/verl/third_party/vllm/vllm_v_0_6_3/llm_engine_sp.py +++ /dev/null @@ -1,429 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2023 The vLLM team. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/engine/llm_engine.py - -from functools import partial -from typing import Callable, Dict, Iterable, Optional, Type, Union - -import torch.nn as nn -from vllm.config import ( - CacheConfig, - DecodingConfig, - DeviceConfig, - EngineConfig, - LoRAConfig, - ObservabilityConfig, - ParallelConfig, - PromptAdapterConfig, - SchedulerConfig, - SpeculativeConfig, -) -from vllm.core.scheduler import Scheduler -from vllm.engine.llm_engine import LLMEngine, SchedulerContext, SchedulerOutputState, _load_generation_config_dict -from vllm.engine.metrics_types import StatLoggerBase -from vllm.engine.output_processor.interfaces import SequenceGroupOutputProcessor -from vllm.engine.output_processor.stop_checker import StopChecker -from vllm.executor.executor_base import ExecutorBase -from vllm.inputs import INPUT_REGISTRY, InputRegistry -from vllm.inputs.preprocess import InputPreprocessor -from vllm.logger import init_logger -from vllm.sequence import Sequence -from vllm.tracing import init_tracer -from vllm.transformers_utils.detokenizer import Detokenizer -from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.usage.usage_lib import UsageContext, is_usage_stats_enabled, usage_message -from vllm.utils import Counter, weak_bind -from vllm.version import __version__ as VLLM_VERSION - -from .arg_utils import EngineArgs -from .config import LoadConfig, ModelConfig -from .tokenizer import TokenizerGroup - -logger = init_logger(__name__) -_LOCAL_LOGGING_INTERVAL_SEC = 5 - - -class LLMEngine(LLMEngine): - """An LLM engine that receives requests and generates texts. - - This is the main class for the vLLM engine. It receives requests - from clients and generates texts from the LLM. It includes a tokenizer, a - language model (possibly distributed across multiple GPUs), and GPU memory - space allocated for intermediate states (aka KV cache). This class utilizes - iteration-level scheduling and efficient memory management to maximize the - serving throughput. - - The :class:`~vllm.LLM` class wraps this class for offline batched inference - and the :class:`AsyncLLMEngine` class wraps this class for online serving. - - The config arguments are derived from :class:`~vllm.EngineArgs`. (See - :ref:`engine_args`) - - Args: - model_config: The configuration related to the LLM model. - cache_config: The configuration related to the KV cache memory - management. - parallel_config: The configuration related to distributed execution. - scheduler_config: The configuration related to the request scheduler. - device_config: The configuration related to the device. - lora_config (Optional): The configuration related to serving multi-LoRA. - speculative_config (Optional): The configuration related to speculative - decoding. - executor_class: The model executor class for managing distributed - execution. - prompt_adapter_config (Optional): The configuration related to serving - prompt adapters. - log_stats: Whether to log statistics. - usage_context: Specified entry point, used for usage info collection. - """ - - def __init__( - self, - # NOTE(sgm): first two arguments are added for verl - model: Union[nn.Module, Dict], # model itself or its parameter dict - tokenizer: nn.Module, - # NOTE(sgm): vllm original arguments - model_config: ModelConfig, - cache_config: CacheConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - load_config: LoadConfig, - lora_config: Optional[LoRAConfig], - speculative_config: Optional[SpeculativeConfig], - decoding_config: Optional[DecodingConfig], - observability_config: Optional[ObservabilityConfig], - prompt_adapter_config: Optional[PromptAdapterConfig], - executor_class: Type[ExecutorBase], - log_stats: bool, - usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, - stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, - input_registry: InputRegistry = INPUT_REGISTRY, - use_cached_outputs: bool = False, - ) -> None: - logger.info( - "Initializing an LLM engine (v%s) with config: " - "model=%r, speculative_config=%r, tokenizer=%r, " - "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, " - "override_neuron_config=%s, " - "rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, " - "trust_remote_code=%s, dtype=%s, max_seq_len=%d, " - "download_dir=%r, load_format=%s, tensor_parallel_size=%d, " - "pipeline_parallel_size=%d, " - "disable_custom_all_reduce=%s, quantization=%s, " - "enforce_eager=%s, kv_cache_dtype=%s, " - "quantization_param_path=%s, device_config=%s, " - "decoding_config=%r, observability_config=%r, " - "seed=%d, served_model_name=%s, use_v2_block_manager=%s, " - "num_scheduler_steps=%d, chunked_prefill_enabled=%s " - "multi_step_stream_outputs=%s, enable_prefix_caching=%s, " - "use_async_output_proc=%s, use_cached_outputs=%s, " - "mm_processor_kwargs=%s)", - VLLM_VERSION, - model_config.model, - speculative_config, - model_config.tokenizer, - model_config.skip_tokenizer_init, - model_config.tokenizer_mode, - model_config.revision, - model_config.override_neuron_config, - model_config.rope_scaling, - model_config.rope_theta, - model_config.tokenizer_revision, - model_config.trust_remote_code, - model_config.dtype, - model_config.max_model_len, - load_config.download_dir, - load_config.load_format, - parallel_config.tensor_parallel_size, - parallel_config.pipeline_parallel_size, - parallel_config.disable_custom_all_reduce, - model_config.quantization, - model_config.enforce_eager, - cache_config.cache_dtype, - model_config.quantization_param_path, - device_config.device, - decoding_config, - observability_config, - model_config.seed, - model_config.served_model_name, - scheduler_config.use_v2_block_manager, - scheduler_config.num_scheduler_steps, - scheduler_config.chunked_prefill_enabled, - scheduler_config.multi_step_stream_outputs, - cache_config.enable_prefix_caching, - model_config.use_async_output_proc, - use_cached_outputs, - model_config.mm_processor_kwargs, - ) - # TODO(woosuk): Print more configs in debug mode. - self.model_config = model_config - self.cache_config = cache_config - self.lora_config = lora_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - self.device_config = device_config - self.speculative_config = speculative_config - self.load_config = load_config - self.decoding_config = decoding_config or DecodingConfig() - self.prompt_adapter_config = prompt_adapter_config - self.observability_config = observability_config or ObservabilityConfig() - self.log_stats = log_stats - self.use_cached_outputs = use_cached_outputs - - if not self.model_config.skip_tokenizer_init: - self.tokenizer = self._init_tokenizer(tokenizer) - self.detokenizer = Detokenizer(self.tokenizer) - tokenizer_group = self.get_tokenizer_group() - else: - self.tokenizer = None - self.detokenizer = None - tokenizer_group = None - - # Ensure that the function doesn't contain a reference to self, - # to avoid engine GC issues - def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: - assert tokenizer_group, "tokenizer_group cannot be None, make sure skip_tokenizer_init is False" - return tokenizer_group.get_lora_tokenizer(sequence.lora_request) - - self.seq_counter = Counter() - self.generation_config_fields = _load_generation_config_dict(model_config) - - self.input_preprocessor = InputPreprocessor(model_config, self.tokenizer) - - self.input_registry = input_registry - self.input_processor = input_registry.create_input_processor(model_config) - - self.model_executor = executor_class( - model=model, # add for spmd_gpu_executor - model_config=model_config, - cache_config=cache_config, - parallel_config=parallel_config, - scheduler_config=scheduler_config, - device_config=device_config, - lora_config=lora_config, - speculative_config=speculative_config, - load_config=load_config, - prompt_adapter_config=prompt_adapter_config, - observability_config=self.observability_config, - ) - - if not self.model_config.embedding_mode: - self._initialize_kv_caches() - - # If usage stat is enabled, collect relevant info. - if is_usage_stats_enabled(): - from vllm.model_executor.model_loader import get_architecture_class_name - - usage_message.report_usage( - get_architecture_class_name(model_config), - usage_context, - extra_kvs={ - # Common configuration - "dtype": str(model_config.dtype), - "tensor_parallel_size": parallel_config.tensor_parallel_size, - "block_size": cache_config.block_size, - "gpu_memory_utilization": cache_config.gpu_memory_utilization, - # Quantization - "quantization": model_config.quantization, - "kv_cache_dtype": str(cache_config.cache_dtype), - # Feature flags - "enable_lora": bool(lora_config), - "enable_prompt_adapter": bool(prompt_adapter_config), - "enable_prefix_caching": cache_config.enable_prefix_caching, - "enforce_eager": model_config.enforce_eager, - "disable_custom_all_reduce": parallel_config.disable_custom_all_reduce, - }, - ) - - if self.tokenizer: - # Ping the tokenizer to ensure liveness if it runs in a - # different process. - self.tokenizer.ping() - - self.cached_scheduler_outputs = [SchedulerOutputState() for _ in range(self.parallel_config.pipeline_parallel_size)] - - self.scheduler_contexts = [SchedulerContext(multi_step_stream_outputs=self.scheduler_config.multi_step_stream_outputs) for _ in range(self.parallel_config.pipeline_parallel_size)] - - if model_config.use_async_output_proc: - process_model_outputs = weak_bind(self._process_model_outputs) - - self.async_callbacks = [partial(process_model_outputs, ctx=self.scheduler_contexts[v_id]) for v_id in range(self.parallel_config.pipeline_parallel_size)] - else: - self.async_callbacks = [] - - # Currently used by AsyncLLMEngine to ensure quick append - # of request outputs to asyncio queues - self.process_request_outputs_callback: Optional[Callable] = None - - # Create the scheduler. - # NOTE: the cache_config here have been updated with the numbers of - # GPU and CPU blocks, which are profiled in the distributed executor. - self.scheduler = [ - Scheduler( - scheduler_config, - cache_config, - lora_config, - parallel_config.pipeline_parallel_size, - self.async_callbacks[v_id] if model_config.use_async_output_proc else None, - ) - for v_id in range(parallel_config.pipeline_parallel_size) - ] - - # Metric Logging. - if self.log_stats: - if stat_loggers is not None: - self.stat_loggers = stat_loggers - else: - # Lazy import for prometheus multiprocessing. - # We need to set PROMETHEUS_MULTIPROC_DIR environment variable - # before prometheus_client is imported. - # See https://prometheus.github.io/client_python/multiprocess/ - from vllm.engine.metrics import LoggingStatLogger, PrometheusStatLogger - - self.stat_loggers = { - "logging": LoggingStatLogger(local_interval=_LOCAL_LOGGING_INTERVAL_SEC), - "prometheus": PrometheusStatLogger( - local_interval=_LOCAL_LOGGING_INTERVAL_SEC, - labels=dict(model_name=model_config.served_model_name), - max_model_len=self.model_config.max_model_len, - ), - } - self.stat_loggers["prometheus"].info("cache_config", self.cache_config) - - self.tracer = None - if self.observability_config.otlp_traces_endpoint: - self.tracer = init_tracer("vllm.llm_engine", self.observability_config.otlp_traces_endpoint) - - # Create sequence output processor, e.g. for beam search or - # speculative decoding. - self.output_processor = SequenceGroupOutputProcessor.create_output_processor( - self.scheduler_config, - self.detokenizer, - self.scheduler, - self.seq_counter, - get_tokenizer_for_seq, - stop_checker=StopChecker( - self.scheduler_config.max_model_len, - get_tokenizer_for_seq, - ), - ) - - # NOTE: added by Reasoning360 - self.max_page_usage = 0 - self.page_usage_average = 0 - self.page_usage_sample_times = 0 - - # TODO(sgm): add for verl but we may not tokenizer in Rollout - def _init_tokenizer(self, tokenizer, **tokenizer_init_kwargs): - init_kwargs = dict(enable_lora=bool(self.lora_config), max_num_seqs=self.scheduler_config.max_num_seqs, max_input_length=None) - init_kwargs.update(tokenizer_init_kwargs) - return TokenizerGroup(tokenizer, **init_kwargs) - - def init_cache_engine(self): - # TODO: check whether we should rebuild the CUDAGraph every iter when offload/load KVCache - # Re-capture CUDAGraph would be time-consuming - self.model_executor.init_cache_engine() - - def free_cache_engine(self): - self.model_executor.free_cache_engine() - - # NOTE(sgm): currently, we only support GPU executor - # The GPUExecutor remove the Ray dependency - @classmethod - def _get_executor_cls(cls, engine_config: EngineConfig) -> Type[ExecutorBase]: - # Initialize the cluster and specify the executor class.] - assert engine_config.device_config.device_type == "cuda", "Currently, the vllm in verl only support running on GPU" - - # print('Waiting for debugger'); import os,debugpy; debugpy.listen(('localhost', 5678 + int(os.getenv('RANK', '0')))); debugpy.wait_for_client() - if engine_config.parallel_config.world_size == 1: - engine_config.load_config.load_format = "dummy_hf" - - from .spmd_gpu_executor import SPMDGPUExecutor - - executor_class = SPMDGPUExecutor - - return executor_class - - @classmethod - def from_engine_args( - cls, - model, - tokenizer, - engine_args: EngineArgs, - usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, - stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, - ) -> "LLMEngine": - """Creates an LLM engine from the engine arguments.""" - # Create the engine configs. - engine_config = engine_args.create_engine_config() - executor_class = cls._get_executor_cls(engine_config) - # Initialize the cluster and specify the executor class. - assert engine_config.device_config.device_type == "cuda", "Currently, the vllm in verl only support running on GPU" - - from .spmd_gpu_executor import SPMDGPUExecutor - - executor_class = SPMDGPUExecutor - - # Create the LLM engine. - engine = cls( - model, - tokenizer, - **engine_config.to_dict(), - executor_class=executor_class, - log_stats=not engine_args.disable_log_stats, - usage_context=usage_context, - stat_loggers=stat_loggers, - ) - return engine - - def sync_model_weights(self, actor_weights: Iterable, load_format: str) -> None: - self.model_executor.sync_model_weights(actor_weights=actor_weights, load_format=load_format) - - def offload_model_weights(self) -> None: - self.model_executor.offload_model_weights() - - def step(self): - ret = super().step() - # # NOTE: added by Reasoning360. Log stats for page usage. - num_total_gpu = self.cache_config.num_gpu_blocks - gpu_cache_usage_sys = 0. - if num_total_gpu is not None: - num_free_gpu = sum( - scheduler.block_manager.get_num_free_gpu_blocks() - for scheduler in self.scheduler) - gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu) - self.max_page_usage = max(self.max_page_usage, gpu_cache_usage_sys) - self.page_usage_average = ( - (self.page_usage_average * self.page_usage_sample_times + gpu_cache_usage_sys) / - (self.page_usage_sample_times + 1)) - self.page_usage_sample_times += 1 - return ret - - def report_page_usage_history(self, reset=False): - # NOTE: added by Reasoning360 - if reset: - max_page_usage = self.max_page_usage - page_usage_average = self.page_usage_average - self.max_page_usage = 0 - self.page_usage_average = 0 - self.page_usage_sample_times = 0 - return { - "gpu_max_page_usage": max_page_usage, - "gpu_average_page_usage": page_usage_average, - } - return { - "gpu_max_page_usage": self.max_page_usage, - "gpu_average_page_usage": self.page_usage_average, - } diff --git a/verl/third_party/vllm/vllm_v_0_6_3/megatron_weight_loaders.py b/verl/third_party/vllm/vllm_v_0_6_3/megatron_weight_loaders.py deleted file mode 100644 index 127a0ff49..000000000 --- a/verl/third_party/vllm/vllm_v_0_6_3/megatron_weight_loaders.py +++ /dev/null @@ -1,241 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2023 The vLLM team. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/model_loader - -from typing import Dict, Iterable - -import torch -import torch.nn as nn -from vllm.model_executor.layers.linear import * -from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead, VocabParallelEmbedding -from vllm.model_executor.models import ModelRegistry - - -# NOTE(shengguangming): replace the origin weight loader function in the class -def parallel_weight_loader(self, param: torch.Tensor, loaded_weight: torch.Tensor) -> None: - """Parallel Linear weight loader.""" - assert param.size() == loaded_weight.size(), "the parameter size is not align with the loaded weight size, param size: {}, loaded_weight size: {}".format(param.size(), loaded_weight.size()) - assert param.data.dtype == loaded_weight.data.dtype, "if we want to shared weights, the data type should also be the same" - - param.data = loaded_weight.data - - -def default_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: - """Default weight loader.""" - assert param.size() == loaded_weight.size() - assert param.data.dtype == loaded_weight.data.dtype, "if we want to shared weights, the data type should also be the same" - - param.data = loaded_weight.data - - -def gpt2_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: - params_dict = dict(vllm_model.named_parameters(remove_duplicate=False)) - for name, loaded_weight in actor_weights.items(): - if "lm_head.weight" in name: - # GPT-2 ties the weights of the embedding layer and the final - # linear layer. - continue - if ".attn.bias" in name or ".attn.masked_bias" in name: - # Skip attention mask. - # NOTE: "c_attn.bias" should not be skipped. - continue - if not name.startswith("transformer."): - name = "transformer." + name - param = params_dict[name] - # The HF's GPT-2 implementation uses Conv1D instead of Linear. - # Because of this, we need to transpose the weights. - # Note(zhuohan): the logic below might break quantized models. - for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]: - if conv1d_weight_name not in name: - continue - if not name.endswith(".weight"): - continue - # TODO: check megatron - loaded_weight = loaded_weight.t() - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - - -def llama_megatron_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: - # NOTE(shengguangming): the megatron llama may have this prefix - params_dict = dict(vllm_model.named_parameters()) - for name, loaded_weight in actor_weights.items(): - if "rotary_emb.inv_freq" in name: - continue - else: - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - - -def qwen2_megatron_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: - params_dict = dict(vllm_model.named_parameters()) - for name, loaded_weight in actor_weights.items(): - if "rotary_emb.inv_freq" in name: - continue - if vllm_model.config.tie_word_embeddings and "lm_head.weight" in name: - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - - -def llama_megatron_core_te_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: - params_mapping = [ - # (megatron core gpt model name, vllm model name) - ("embedding.word_embeddings", "model.embed_tokens"), - ("self_attention.linear_qkv.layer_norm_weight", "input_layernorm.weight"), - ("self_attention.linear_qkv.layer_norm_bias", "input_layernorm.bias"), - ("self_attention.linear_qkv", "self_attn.qkv_proj"), - ("self_attention.linear_qkv", "self_attn.qkv_proj"), - ("self_attention.linear_proj", "self_attn.o_proj"), - ("pre_mlp_layernorm", "post_attention_layernorm"), - ("mlp.linear_fc1.layer_norm_weight", "post_attention_layernorm.weight"), - ("mlp.linear_fc1.layer_norm_bias", "post_attention_layernorm.bias"), - ("mlp.linear_fc1", "mlp.gate_up_proj"), - ("mlp.linear_fc2", "mlp.down_proj"), - ("decoder.final_layernorm", "model.norm"), - ("output_layer", "lm_head"), - ] - # NOTE(shengguangming): the megatron llama may have this prefix - params_dict = dict(vllm_model.named_parameters()) - for name, loaded_weight in actor_weights.items(): - name = _replace_name(name, params_mapping) - if name.endswith(".bias") and name not in params_dict: - continue - if "rotary_emb.inv_freq" in name: - continue - else: - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - - -def llama_megatron_core_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: - params_mapping = [ - # (megatron core gpt model name, vllm model name) - ("embedding.word_embeddings", "model.embed_tokens"), - ("self_attention.linear_qkv", "self_attn.qkv_proj"), - ("self_attention.linear_proj", "self_attn.o_proj"), - ( - "input_layernorm", - "input_layernorm", - ), - ("pre_mlp_layernorm", "post_attention_layernorm"), - ("mlp.linear_fc1", "mlp.gate_up_proj"), - ("mlp.linear_fc2", "mlp.down_proj"), - ("decoder.final_layernorm", "model.norm"), - ("output_layer", "lm_head"), - ] - # NOTE(shengguangming): the megatron llama may have this prefix - params_dict = dict(vllm_model.named_parameters()) - for name, loaded_weight in actor_weights.items(): - name = _replace_name(name, params_mapping) - if name.endswith(".bias") and name not in params_dict: - continue - if "rotary_emb.inv_freq" in name: - continue - else: - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - - -def _replace_name(megatron_name, name_mapping): - for m_name, v_name in name_mapping: - if m_name not in megatron_name: - continue - if "layers" in megatron_name: # deal with decoder layers - megatron_name = megatron_name.replace("decoder", "model") - megatron_name_list = megatron_name.split(".") - if "layer_norm_weight" in megatron_name_list or "layer_norm_bias" in megatron_name_list: - param_name_list = megatron_name_list[:3] - param_name_list.append(v_name) - param_name = ".".join(param_name_list) - else: - param_name_list = megatron_name_list[:3] - weight_or_bias = megatron_name_list[-1] - param_name_list.append(v_name) - param_name_list.append(weight_or_bias) - param_name = ".".join(param_name_list) - return param_name - else: - param_name = megatron_name.replace(m_name, v_name) - return param_name - - -def mistral_megatron_weight_loader(actor_weights: Iterable, vllm_model: nn.Module) -> nn.Module: - # TODO: need to implement a general way to deal with prefix - params_dict = dict(vllm_model.named_parameters()) - for name, weight in actor_weights: - if "rotary_emb.inv_freq" in name: - continue - else: - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, weight) - - -def megatron_core_te_weight_loader(actor_weights: Iterable, vllm_model: nn.Module) -> nn.Module: - # NOTE(shengguangming): the megatron llama may have this prefix - params_dict = dict(vllm_model.named_parameters()) - for name, weight in actor_weights: - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, weight) - - -__LAYER_WEIGHT_MEGATRON_LOADER_REGISTRY__ = { - ColumnParallelLinear: parallel_weight_loader, - MergedColumnParallelLinear: parallel_weight_loader, - QKVParallelLinear: parallel_weight_loader, - RowParallelLinear: parallel_weight_loader, - VocabParallelEmbedding: parallel_weight_loader, - ParallelLMHead: parallel_weight_loader, - # "ScaledActivation.weight_loader": ScaledActivation, # TODO(shengguangming): latest commit in vllm fix awq for this function and add load_weights - # "default_weight_loader": default_weight_loader -} - -# for layer_class, weight_loader in __LAYER_WEIGHT_MEGATRON_LOADER_REGISTRY__.items(): -# # setattr(layer_class, 'megatron_weight_loader', weight_loader) -# layer_class.weight_loader = weight_loader - -__MODEL_MEGATRON_WEIGHT_LOADER_REGISTRY__ = { - "GPT2LMHeadModel": gpt2_weight_loader, - "LlamaForCausalLM": megatron_core_te_weight_loader, # use te backend for open-source megatron - "LLaMAForCausalLM": megatron_core_te_weight_loader, - "MistralForCausalLM": mistral_megatron_weight_loader, - "Qwen2ForCausalLM": megatron_core_te_weight_loader, -} - - -# the actor model is .state_dict() -# Load megatron weights -def load_megatron_weights(actor_weights: Iterable, vllm_model: nn.Module): - weight_loader = _get_model_weight_loader(vllm_model.__class__.__name__) - weight_loader(actor_weights, vllm_model) - # NOTE(sgm) to reduce peak memory usage, we offload vllm model to cpu - # after init, and we need this after sync model weights for in first iter. - vllm_model = vllm_model.cuda() - - -def _get_model_weight_loader(arch: str): - if arch in __MODEL_MEGATRON_WEIGHT_LOADER_REGISTRY__: - return __MODEL_MEGATRON_WEIGHT_LOADER_REGISTRY__[arch] - raise ValueError(f"Model architectures {arch} are not supported for now. Supported architectures: {ModelRegistry.get_supported_archs()}") - - -def update_megatron_weight_loader(): - for layer_class, weight_loader in __LAYER_WEIGHT_MEGATRON_LOADER_REGISTRY__.items(): - layer_class.weight_loader = weight_loader diff --git a/verl/third_party/vllm/vllm_v_0_6_3/model_loader.py b/verl/third_party/vllm/vllm_v_0_6_3/model_loader.py deleted file mode 100644 index b9771aacb..000000000 --- a/verl/third_party/vllm/vllm_v_0_6_3/model_loader.py +++ /dev/null @@ -1,328 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2023 The vLLM team. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models -"""Utilities for selecting and loading models.""" - -from typing import Dict, Optional, Union - -import torch -import torch.nn as nn -from transformers import PreTrainedModel -from vllm.config import CacheConfig, DeviceConfig, LoRAConfig, ParallelConfig, SchedulerConfig -from vllm.distributed.communication_op import tensor_model_parallel_all_gather -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.model_loader import BaseModelLoader -from vllm.model_executor.model_loader.loader import _initialize_model -from vllm.model_executor.model_loader.utils import set_default_torch_dtype - -from .config import LoadConfig, LoadFormat, ModelConfig -from .dtensor_weight_loaders import load_dtensor_weights, update_dtensor_weight_loader -from .hf_weight_loader import update_hf_weight_loader -from .megatron_weight_loaders import load_megatron_weights, update_megatron_weight_loader - - -def get_model( - actor_model: Union[PreTrainedModel, Dict], - model_config: ModelConfig, - load_config: LoadConfig, - device_config: DeviceConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - lora_config: Optional[LoRAConfig], - cache_config: CacheConfig = None, -) -> nn.Module: - loader = get_model_loader(load_config) - if load_config.load_format.startswith("dummy"): - return loader.load_model( - model_config=model_config, - device_config=device_config, - lora_config=lora_config, - parallel_config=parallel_config, - scheduler_config=scheduler_config, - cache_config=cache_config, - ) - else: - return loader.load_model( - actor_model=actor_model, - model_config=model_config, - device_config=device_config, - lora_config=lora_config, - parallel_config=parallel_config, - scheduler_config=scheduler_config, - cache_config=cache_config, - ) - - -def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: - """Get a model loader based on the load format.""" - - if isinstance(load_config.load_format, type): - return load_config.load_format(load_config) - - if load_config.load_format == LoadFormat.AUTO: - update_megatron_weight_loader() - return MegatronLoader(load_config) - - # NOTE(sgm): change the weight_loader function in runtime - if load_config.load_format == LoadFormat.MEGATRON: - update_megatron_weight_loader() - return MegatronLoader(load_config) - - if load_config.load_format == LoadFormat.HF: - update_hf_weight_loader() - return HFLoader(load_config) - - if load_config.load_format == LoadFormat.DTENSOR: - update_dtensor_weight_loader() - return DTensorLoader(load_config) - - if load_config.load_format == LoadFormat.DUMMY_HF: - update_hf_weight_loader() - return DummyModelLoader(load_config) - - if load_config.load_format == LoadFormat.DUMMY_MEGATRON: - update_megatron_weight_loader() - return DummyModelLoader(load_config) - - if load_config.load_format == LoadFormat.DUMMY_DTENSOR: - update_dtensor_weight_loader() - return DummyModelLoader(load_config) - - raise ValueError("load format not supported in verl: {}, only support {} and {}".format(load_config.load_format, LoadFormat.MEGATRON, LoadFormat.HF)) - - -class DummyModelLoader(BaseModelLoader): - """Model loader that will set model weights to random values.""" - - def __init__(self, load_config: LoadConfig): - super().__init__(load_config) - if load_config.model_loader_extra_config: - raise ValueError(f"Model loader extra config is not supported for load format {load_config.load_format}") - - def download_model(self, model_config: ModelConfig) -> None: - pass - - def load_model( - self, - *, - model_config: ModelConfig, - device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - cache_config: CacheConfig, - ) -> nn.Module: - with set_default_torch_dtype(model_config.dtype), torch.device(device_config.device): - model = _initialize_model(model_config, self.load_config, lora_config, cache_config, scheduler_config) - # NOTE(woosuk): For accurate performance evaluation, we assign - # random values to the weights. - # initialize_dummy_weights(model) - return model.eval() - - -class MegatronLoader(BaseModelLoader): - """Model loader that can load the model weights from partitioned megatron model.""" - - def __init__(self, load_config: LoadConfig): - super().__init__(load_config) - if load_config.model_loader_extra_config: - raise ValueError(f"Model loader extra config is not supported for load format {load_config.load_format}") - - def download_model(self, model_config: ModelConfig) -> None: - pass # Nothing to download - - def _get_weights_iterator(actor_model: Union[PreTrainedModel, Dict]): - # NOTE(shengguangming) Load the weights from the actor model - pass - # if isinstance(actor_model, nn.Module): - # load_weights(actor_weights=dict(actor_model.named_parameters(remove_duplicate=False)), vllm_model=model) - # else: - # load_weights(actor_weights=actor_model, vllm_model=model) - # return actor_model - - def load_model( - self, - actor_model: Union[PreTrainedModel, Dict], - model_config: ModelConfig, - device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - cache_config: CacheConfig, - ) -> nn.Module: - with set_default_torch_dtype(model_config.dtype): - with torch.device(device_config.device): - model = _initialize_model(model_config, self.load_config, lora_config, cache_config, scheduler_config) - - # TODO(sgm): This is a hack, we need to register the load_weight() func for each model in vllm - if isinstance(actor_model, nn.Module): - load_megatron_weights(actor_weights=dict(actor_model.named_parameters(remove_duplicate=False)), vllm_model=model) - else: - load_megatron_weights(actor_weights=actor_model, vllm_model=model) - - for _, module in model.named_modules(): - quant_method = getattr(module, "quant_method", None) - if quant_method is not None: - quant_method.process_weights_after_loading(module) - # FIXME: Remove this after Mixtral is updated - # to use quant_method. - if hasattr(module, "process_weights_after_loading"): - module.process_weights_after_loading() - # NOTE(sgm) Some weights are point to gpu, but still need this. - model = model.cuda() # NOTE (zhangchi.usc1992) We need this for vllm to profile memory usage - return model.eval() - - -class HFLoader(BaseModelLoader): - """Model loader that can load the model weights from model's full params.""" - - def __init__(self, load_config: LoadConfig): - super().__init__(load_config) - if load_config.model_loader_extra_config: - raise ValueError(f"Model loader extra config is not supported for load format {load_config.load_format}") - - def download_model(self, model_config: ModelConfig) -> None: - pass # Nothing to download - - def _get_weights_iterator(self, actor_model: Union[PreTrainedModel, Dict]): - if isinstance(actor_model, Dict): - return actor_model.items() - elif isinstance(actor_model, nn.Module): - return dict(actor_model.named_parameters()).items() - else: - raise ValueError(f"actor model should be Dict or nn.Module, but get {type(actor_model)}") - - def load_model( - self, - actor_model: Union[PreTrainedModel, Dict], - model_config: ModelConfig, - device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - cache_config: CacheConfig, - ) -> nn.Module: - with set_default_torch_dtype(model_config.dtype): - # with torch.device(device_config.device): - # NOTE(sgm): init the model in cpu - model = _initialize_model(model_config, self.load_config, lora_config, cache_config, scheduler_config) - model.load_weights(self._get_weights_iterator(actor_model)) - for _, module in model.named_modules(): - quant_method = getattr(module, "quant_method", None) - if quant_method is not None: - quant_method.process_weights_after_loading(module) - # FIXME: Remove this after Mixtral is updated - # to use quant_method. - if hasattr(module, "process_weights_after_loading"): - module.process_weights_after_loading() - # NOTE(sgm) Some weights are point to gpu, but still need this. - model = model.cuda() # NOTE (zhangchi.usc1992) We need this for vllm to profile memory usage - return model.eval() - - -class DTensorLoader(BaseModelLoader): - """Model loader that can load the model weights from partitioned megatron model.""" - - def __init__(self, load_config: LoadConfig): - super().__init__(load_config) - if load_config.model_loader_extra_config: - raise ValueError(f"Model loader extra config is not supported for load format {load_config.load_format}") - - def download_model(self, model_config: ModelConfig) -> None: - pass # Nothing to download - - def _get_weights_iterator(actor_model: Union[PreTrainedModel, Dict]): - # NOTE(shengguangming) Load the weights from the actor model - pass - # if isinstance(actor_model, nn.Module): - # load_weights(actor_weights=dict(actor_model.named_parameters(remove_duplicate=False)), vllm_model=model) - # else: - # load_weights(actor_weights=actor_model, vllm_model=model) - # return actor_model - - def load_model( - self, - actor_model: Union[PreTrainedModel, Dict], - model_config: ModelConfig, - device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - cache_config: CacheConfig, - ) -> nn.Module: - with set_default_torch_dtype(model_config.dtype): - with torch.device(device_config.device): - model = _initialize_model(model_config, self.load_config, lora_config, cache_config, scheduler_config) - - # TODO(sgm): This is a hack, we need to register the load_weight() func for each model in vllm - if isinstance(actor_model, nn.Module): - load_dtensor_weights(actor_weights=dict(actor_model.named_parameters(remove_duplicate=False)), vllm_model=model) - else: - load_dtensor_weights(actor_weights=actor_model, vllm_model=model) - - for _, module in model.named_modules(): - quant_method = getattr(module, "quant_method", None) - if quant_method is not None: - quant_method.process_weights_after_loading(module) - # FIXME: Remove this after Mixtral is updated - # to use quant_method. - if hasattr(module, "process_weights_after_loading"): - module.process_weights_after_loading() - # NOTE(sgm) Some weights are point to gpu, but still need this. - model = model.cuda() # NOTE (zhangchi.usc1992) We need this for vllm to profile memory usage - return model.eval() - - -# FIXME(sgm): hack the _get_logits function in vllm v0.4.2 -# as they use ray, the _get_logits result will only need to return to the driver node, -# therefore gather is enough. However, we use SPMD instead of a central scheduler, -# all_gather is required (aligned with v0.2.6) -def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor, embedding_bias: Optional[torch.Tensor]) -> torch.Tensor: - # Get the logits for the next tokens. - logits = torch.matmul(hidden_states, embedding.t()) - if embedding_bias is not None: - logits += embedding_bias - logits = tensor_model_parallel_all_gather(logits) - # Remove paddings in vocab (if any). - if logits is not None: - logits = logits[:, : self.org_vocab_size] - return logits - - -def logitsprocessor_init( - self, - vocab_size: int, - org_vocab_size: Optional[int] = None, - scale: float = 1.0, - logits_as_input: bool = False, - soft_cap: Optional[float] = None, -) -> None: - """ - Args: - scale: A scaling factor to apply to the logits. - """ - super(LogitsProcessor, self).__init__() - self.scale = scale - self.vocab_size = vocab_size - # Whether the input is logits (default is hidden states). - self.logits_as_input = logits_as_input - # original vocabulary size (without LoRA). - self.org_vocab_size = org_vocab_size or vocab_size - # Soft cap the logits. Used in Gemma 2. - self.soft_cap = soft_cap - # Whether to use gather or all-gather to gather the logits. - self.use_gather = False - - -LogitsProcessor.__init__ = logitsprocessor_init # use all_gather diff --git a/verl/third_party/vllm/vllm_v_0_6_3/model_runner.py b/verl/third_party/vllm/vllm_v_0_6_3/model_runner.py deleted file mode 100644 index 6507d3c6d..000000000 --- a/verl/third_party/vllm/vllm_v_0_6_3/model_runner.py +++ /dev/null @@ -1,174 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2023 The vLLM team. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/worker/model_runner.py - -import warnings -from enum import IntEnum -from typing import Dict, Optional, Union - -import torch -import torch.nn as nn -import vllm.envs as envs -from vllm.compilation.levels import CompilationLevel -from vllm.config import ( - CacheConfig, - DeviceConfig, - LoRAConfig, - ObservabilityConfig, - ParallelConfig, - PromptAdapterConfig, - SchedulerConfig, -) -from vllm.inputs import INPUT_REGISTRY, InputRegistry -from vllm.logger import init_logger -from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager -from vllm.model_executor.models.interfaces import supports_lora -from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry -from vllm.prompt_adapter.worker_manager import LRUCacheWorkerPromptAdapterManager -from vllm.utils import DeviceMemoryProfiler, is_hip, supports_dynamo -from vllm.worker.model_runner import ModelRunner - -from .config import LoadConfig, ModelConfig -from .model_loader import get_model - -logger = init_logger(__name__) - - -# How batches are constructed. -class BatchType(IntEnum): - # Every batch is prefill. - PREFILL = 0 - # Every batch is decode. - DECODE = 1 - # Batch is a mixture of prefill and decode. - MIXED = 2 - - -class ModelRunner(ModelRunner): - def __init__( - self, - model: Union[nn.Module, Dict], # [verl] model itself or its parameter dict - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - cache_config: CacheConfig, - load_config: LoadConfig, - lora_config: Optional[LoRAConfig], - kv_cache_dtype: Optional[str] = "auto", - is_driver_worker: bool = False, - prompt_adapter_config: Optional[PromptAdapterConfig] = None, - return_hidden_states: bool = False, - observability_config: Optional[ObservabilityConfig] = None, - input_registry: InputRegistry = INPUT_REGISTRY, - mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, - ): - super().__init__( - model_config, - parallel_config, - scheduler_config, - device_config, - cache_config, - load_config, - lora_config, - kv_cache_dtype, - is_driver_worker=True, # a hack - prompt_adapter_config=prompt_adapter_config, - return_hidden_states=return_hidden_states, - observability_config=observability_config, - input_registry=input_registry, - mm_registry=mm_registry, - ) - - # NOTE(sgm): add for verl - self.model = model # this will be replaced by get_model() - - def load_model(self) -> None: - logger.info("Starting to load model %s...", self.model_config.model) - with DeviceMemoryProfiler() as m: - self.model = get_model( - self.model, - model_config=self.model_config, - device_config=self.device_config, - load_config=self.load_config, - lora_config=self.lora_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - cache_config=self.cache_config, - ) - - self.model_memory_usage = m.consumed_memory - logger.info("Loading model weights took %.4f GB", self.model_memory_usage / float(2**30)) - - if self.lora_config: - assert supports_lora(self.model), f"{self.model.__class__.__name__} does not support LoRA yet." - - # if supports_multimodal(self.model): - # logger.warning( - # "Regarding multimodal models, vLLM currently only supports adding LoRA to language model." - # ) - # It's necessary to distinguish between the max_position_embeddings - # of VLMs and LLMs. - if hasattr(self.model.config, "max_position_embeddings"): - max_pos_embeddings = self.model.config.max_position_embeddings - else: - max_pos_embeddings = self.model.config.text_config.max_position_embeddings - - self.lora_manager = LRUCacheWorkerLoRAManager( - self.scheduler_config.max_num_seqs, - self.scheduler_config.max_num_batched_tokens, - self.vocab_size, - self.lora_config, - self.device, - self.model.embedding_modules, - self.model.embedding_padding_modules, - max_position_embeddings=max_pos_embeddings, - ) - self.model = self.lora_manager.create_lora_manager(self.model) - - if self.prompt_adapter_config: - self.prompt_adapter_manager = LRUCacheWorkerPromptAdapterManager( - self.scheduler_config.max_num_seqs, - self.scheduler_config.max_num_batched_tokens, - self.device, - self.prompt_adapter_config, - ) - self.model = self.prompt_adapter_manager.create_prompt_adapter_manager(self.model) - - if self.kv_cache_dtype == "fp8" and is_hip(): - # Currently only ROCm accepts kv-cache scaling factors - # via quantization_param_path and this will be deprecated - # in the future. - if self.model_config.quantization_param_path is not None: - if callable(getattr(self.model, "load_kv_cache_scales", None)): - warnings.warn( - "Loading kv cache scaling factor from JSON is deprecated and will be removed. Please include kv cache scaling factors in the model checkpoint.", - FutureWarning, - stacklevel=2, - ) - self.model.load_kv_cache_scales(self.model_config.quantization_param_path) - logger.info("Loaded KV cache scaling factors from %s", self.model_config.quantization_param_path) - else: - raise RuntimeError( - "Using FP8 KV cache and scaling factors provided but model %s does not support loading scaling factors.", - self.model.__class__, - ) - else: - logger.warning("Using FP8 KV cache but no scaling factors provided. Defaulting to scaling factors of 1.0. This may lead to less accurate results!") - - if envs.VLLM_TORCH_COMPILE_LEVEL == CompilationLevel.DYNAMO_AS_IS and supports_dynamo(): - from vllm.plugins import get_torch_compile_backend - - backend = get_torch_compile_backend() or "eager" - self.model = torch.compile(self.model, fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, backend=backend) diff --git a/verl/third_party/vllm/vllm_v_0_6_3/parallel_state.py b/verl/third_party/vllm/vllm_v_0_6_3/parallel_state.py deleted file mode 100644 index e37121cf8..000000000 --- a/verl/third_party/vllm/vllm_v_0_6_3/parallel_state.py +++ /dev/null @@ -1,304 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2023 The vLLM team. -# Adapted from -# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. -"""Model and data parallel groups.""" - -import os -from typing import Optional - -import torch -import torch.distributed -import vllm.distributed.parallel_state as ps -from vllm.distributed.parallel_state import ( - get_pp_group, - get_world_group, - init_distributed_environment, - init_model_parallel_group, -) -from vllm.logger import init_logger - -logger = init_logger(__name__) -""" -This version is strongly tied with Megatron to implement HybridEngine and weight sharing between vllm and Megatron. -- We assume the Megatron tp+dp+pp world is already established before calling this function. - -""" - -# Device mesh for using DTensor -_DEVICE_MESH = None - -# Tensor model parallel group that the current rank belongs to. -_TP = None -# Pipeline model parallel group that the current rank belongs to. -_PP = None - - -# This method is for initializing the ParallelGroup when using HybridEngine -def initialize_parallel_state( - distributed_init_method: str = "env://", - backend: str = "nccl", - tensor_model_parallel_size: int = 1, - num_tp_per_train_tp: int = 1, - pipeline_model_parallel_size: int = 1, -): - # torch.distributed.all_reduce does not free the input tensor until - # the synchronization point. This causes the memory usage to grow - # as the number of all_reduce calls increases. This env var disables - # this behavior. - # Related issue: - # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573 - os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" - - # NOTE(sgm): Modify for verl, Env vars will be set by TORCHRUN. - rank = int(os.getenv("RANK", "-1")) - local_rank = int(os.getenv("LOCAL_RANK", "0")) - - # Use the world_size set by TORCHRUN - world_size = int(os.getenv("WORLD_SIZE", "-1")) - assert world_size != -1, "The world_size is set to -1, not initialized by TORCHRUN" - init_distributed_environment(world_size, rank, distributed_init_method, local_rank, backend) - if torch.distributed.get_world_size() > 1: - # NOTE: build a sepearate inference group with infer tp & micro dp - initialize_model_parallel_for_vllm( - tensor_model_parallel_size=tensor_model_parallel_size, - num_tensor_model_parallel_groups_per_train_tp=num_tp_per_train_tp, - ) - else: - initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, backend) - - -def ensure_model_parallel_initialized( - tensor_model_parallel_size: int, - pipeline_model_parallel_size: int = 1, - backend: Optional[str] = None, -) -> None: - """Helper to initialize model parallel groups if they are not initialized, - or ensure tensor-parallel and pipeline-parallel sizes are equal to expected - values if the model parallel groups are initialized. - """ - # get the backend of _DEVICE_WORLD_GROUP - backend = backend or torch.distributed.get_backend(get_world_group().device_group) - if not model_parallel_is_initialized(): - initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, backend) - return - - assert get_tensor_model_parallel_world_size() == tensor_model_parallel_size, f"tensor parallel group already initialized, but of unexpected size: {get_tensor_model_parallel_world_size()=} vs. {tensor_model_parallel_size=}" - pp_world_size = get_pp_group().world_size - assert pp_world_size == pipeline_model_parallel_size, f"pipeline parallel group already initialized, but of unexpected size: {pp_world_size=} vs. {pipeline_model_parallel_size=}" - - -# TODO(sgm): deviate from the v0.5.4, not pp now -def model_parallel_is_initialized(): - """Check if tensor and pipeline parallel groups are initialized.""" - return ps._TP is not None - # and _PIPELINE_MODEL_PARALLEL_GROUP is not None) - - -def initialize_model_parallel_for_vllm( - tensor_model_parallel_size: int, - num_tensor_model_parallel_groups_per_train_tp: int = 1, - pipeline_model_parallel_size: int = 1, -) -> None: - pass - - # Get world size and rank. Ensure some consistencies. - assert torch.distributed.is_initialized() - - assert isinstance(tensor_model_parallel_size, int) - - # assert num_tensor_model_parallel_groups_per_train_tp == 1 and not different_tp_group - # assert num_tensor_model_parallel_groups_per_train_tp > 1 and different_tp_group - - # Build the tensor model-parallel groups. - assert ps._TP is None, "tensor model parallel group is already initialized" - - global _TP - - world_size: int = torch.distributed.get_world_size() - - backend = torch.distributed.get_backend() - - num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size - - if num_tensor_model_parallel_groups_per_train_tp == 1: - # if tensor_model_parallel_size == train_tensor_parallel_size: - # using the same tp group as Megatron/vllm - assert _TP is None, "tensor model parallel group is already initialized" - group_ranks = [] - for i in range(num_tensor_model_parallel_groups): - ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size) - group_ranks.append(ranks) - _TP = init_model_parallel_group( - group_ranks=group_ranks, - local_rank=get_world_group().local_rank, - backend=backend, - use_custom_allreduce=False, # TODO: check why True is not work in Ray trainer - use_message_queue_broadcaster=True, - ) - ps._TP = _TP - # _MICRO_DATA_PARALLEL_GROUP is move to hybrid engine - else: - # initialize a micro_dp group and a tp group - # assume training tp=4, infer tp=2, then, weight is partitioned as - # [1], [2], [3], [4] for training and [1,2], [1,2], [3,4], [3,4] for inference - - # Build the inference tp groups - # train_tp = train_tensor_parallel_size - train_tp = num_tensor_model_parallel_groups_per_train_tp * tensor_model_parallel_size - # num_tensor_model_parallel_groups_per_train_tp = train_tp // tensor_model_parallel_size - assert _TP is None, "tensor model parallel group is already initialized" - group_ranks = [] - for i in range(num_tensor_model_parallel_groups // num_tensor_model_parallel_groups_per_train_tp): - start = train_tp * i - end = train_tp * (i + 1) - for j in range(num_tensor_model_parallel_groups_per_train_tp): - ranks = list(range(start, end, num_tensor_model_parallel_groups_per_train_tp)) - for i in range(len(ranks)): - ranks[i] += j - group_ranks.append(ranks) - _TP = init_model_parallel_group( - group_ranks=group_ranks, - local_rank=get_world_group().local_rank, - backend=backend, - use_custom_allreduce=False, # TODO: check why True is not work in Ray trainer - use_message_queue_broadcaster=True, - ) - ps._TP = _TP - - # Build the pipeline model-parallel groups. - # global _PIPELINE_MODEL_PARALLEL_GROUP - # global _PIPELINE_GLOBAL_RANKS - # assert ps._PIPELINE_MODEL_PARALLEL_GROUP is None, ("pipeline model parallel group is already initialized") - - # ps._PIPELINE_MODEL_PARALLEL_GROUP = mpu.get_pipeline_model_parallel_group() - # ps._PIPELINE_GLOBAL_RANKS = mpu.get_pipeline_model_parallel_ranks() - - # TODO: init using device mesh (not support hybrid engine now) - # Build the pipeline model-parallel groups. - num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size - global _PP - assert _PP is None, "pipeline model parallel group is already initialized" - group_ranks = [] - for i in range(num_pipeline_model_parallel_groups): - ranks = list(range(i, world_size, num_pipeline_model_parallel_groups)) - group_ranks.append(ranks) - # pipeline parallel does not need custom allreduce - _PP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, use_custom_allreduce=False) - ps._PP = _PP # for verl - - -def initialize_model_parallel( - tensor_model_parallel_size: int = 1, - pipeline_model_parallel_size: int = 1, - backend: Optional[str] = None, -) -> None: - """ - NOTE: This method is a hack from the open-sourced version without - asertion of world_size = tp * pp - - Initialize model parallel groups. - - Arguments: - tensor_model_parallel_size: number of GPUs used for tensor model - parallelism. - pipeline_model_parallel_size: number of GPUs used for pipeline model - parallelism. - - Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we - use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize - the model pipeline. The present function will - create 4 tensor model-parallel groups and 2 pipeline model-parallel groups: - 4 tensor model-parallel groups: - [g0, g1], [g2, g3], [g4, g5], [g6, g7] - 2 pipeline model-parallel groups: - [g0, g2, g4, g6], [g1, g3, g5, g7] - Note that for efficiency, the caller should make sure adjacent ranks - are on the same DGX box. For example if we are using 2 DGX-1 boxes - with a total of 16 GPUs, rank 0 to 7 belong to the first box and - ranks 8 to 15 belong to the second box. - """ - # Get world size and rank. Ensure some consistencies. - assert torch.distributed.is_initialized() - world_size: int = torch.distributed.get_world_size() - backend = backend or torch.distributed.get_backend(ps.get_world_group().device_group) - - # NOTE(sgm) we don't assert world_size == tp * pp - # DP is not managed by vllm but by the VeRL WorkerGroup - # if (world_size != - # tensor_model_parallel_size * pipeline_model_parallel_size): - # raise RuntimeError( - # f"world_size ({world_size}) is not equal to " - # f"tensor_model_parallel_size ({tensor_model_parallel_size}) x " - # f"pipeline_model_parallel_size ({pipeline_model_parallel_size})") - - num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size - global _TP - assert _TP is None, "tensor model parallel group is already initialized" - group_ranks = [] - for i in range(num_tensor_model_parallel_groups): - ranks = list(range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)) - group_ranks.append(ranks) - - # message queue broadcaster is only used in tensor model parallel group - _TP = init_model_parallel_group( - group_ranks, - get_world_group().local_rank, - backend, - use_custom_allreduce=False, # TODO: check why True is not work in Ray trainer - use_message_queue_broadcaster=True, - ) - ps._TP = _TP - - # TODO: init using device mesh (not support hybrid engine now) - # Build the pipeline model-parallel groups. - num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size - global _PP - assert _PP is None, "pipeline model parallel group is already initialized" - group_ranks = [] - for i in range(num_pipeline_model_parallel_groups): - ranks = list(range(i, world_size, num_pipeline_model_parallel_groups)) - group_ranks.append(ranks) - # pipeline parallel does not need custom allreduce - _PP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, use_custom_allreduce=False) - ps._PP = _PP # for verl - - -""" -Device mesh utilities -""" - - -def get_device_mesh(): - assert _DEVICE_MESH is not None, "device mesh is not initialized" - return _DEVICE_MESH - - -""" -Tensor model parallel utilities -""" - - -def get_tensor_model_parallel_group(): - """Get the tensor model parallel group the caller rank belongs to.""" - assert _TP is not None, "tensor model parallel group is not initialized" - return _TP.device_group - - -def get_tensor_model_parallel_world_size(): - """Return world size for the tensor model parallel group.""" - return torch.distributed.get_world_size(group=get_tensor_model_parallel_group()) - - -def get_tensor_model_parallel_rank(): - """Return my rank for the tensor model parallel group.""" - return torch.distributed.get_rank(group=get_tensor_model_parallel_group()) - - -def get_tensor_model_parallel_src_rank(): - """Calculate the global rank corresponding to the first local rank - in the tensor model parallel group.""" - global_rank = torch.distributed.get_rank() - local_world_size = get_tensor_model_parallel_world_size() - return (global_rank // local_world_size) * local_world_size diff --git a/verl/third_party/vllm/vllm_v_0_6_3/spmd_gpu_executor.py b/verl/third_party/vllm/vllm_v_0_6_3/spmd_gpu_executor.py deleted file mode 100644 index 4e1edf0ba..000000000 --- a/verl/third_party/vllm/vllm_v_0_6_3/spmd_gpu_executor.py +++ /dev/null @@ -1,250 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2023 The vLLM team. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/executor/gpu_executor.py - -import os -import socket -from typing import Iterable, List, Optional, Set, Tuple - -import torch -from vllm.config import ( - CacheConfig, - DeviceConfig, - LoRAConfig, - ObservabilityConfig, - ParallelConfig, - PromptAdapterConfig, - SchedulerConfig, - SpeculativeConfig, -) -from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase -from vllm.logger import init_logger -from vllm.lora.request import LoRARequest -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.sequence import ExecuteModelRequest - -from .config import LoadConfig, ModelConfig - -logger = init_logger(__name__) - - -class SPMDGPUExecutor(ExecutorBase): - """SPMD-based multi-GPU executor implementations.""" - - def __init__( - self, - model, # pytorch model itself or its parameter dict - model_config: ModelConfig, - cache_config: CacheConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - load_config: LoadConfig, - lora_config: Optional[LoRAConfig], - speculative_config: Optional[SpeculativeConfig], - prompt_adapter_config: Optional[PromptAdapterConfig], - observability_config: Optional[ObservabilityConfig], - ) -> None: - self.model_config = model_config - self.cache_config = cache_config - self.lora_config = lora_config - self.load_config = load_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - self.device_config = device_config - self.speculative_config = speculative_config - self.prompt_adapter_config = prompt_adapter_config - self.observability_config = observability_config - - distributed_init_method = initialize_cluster(parallel_config) - self._init_executor(model, distributed_init_method) - - # TODO(sgm): verl not support speculative decode now - def _init_executor(self, model, distributed_init_method) -> None: - assert not self.speculative_config, "Speculative decoding not yet supported for multi-GPU backend." - - # Create the parallel worker for each GPU. - self._init_workers_sp(model, distributed_init_method) - - def _init_workers_sp(self, model, distributed_init_method: str): - # Lazy import the Worker to avoid importing torch.cuda/xformers - # before CUDA_VISIBLE_DEVICES is set in the Worker - from .worker import Worker - - rank = int(os.getenv("RANK")) - local_rank = int(os.getenv("LOCAL_RANK")) - print(f"local rank {local_rank}") - - # see https://github.com/NVIDIA/nccl/issues/1234 - os.environ["NCCL_CUMEM_ENABLE"] = "0" - - self.worker = Worker( - model, - self.model_config, - self.parallel_config, - self.scheduler_config, - self.device_config, - self.cache_config, - self.load_config, - local_rank, - rank, - distributed_init_method, - lora_config=self.lora_config, - speculative_config=None, - prompt_adapter_config=self.speculative_config, - is_driver_worker=True, - model_runner_cls=None, # use the default one - ) - - # NOTE(shengguangming): torch.distributed.init_process_group will be called inside the init_model() - self.worker.init_device() - self.worker.load_model() - - def determine_num_available_blocks(self) -> Tuple[int, int]: - """Determine the number of available KV blocks. - - This invokes `determine_num_available_blocks` on each worker and takes - the min of the results, guaranteeing that the selected cache sizes are - compatible with all workers. - - Returns: - - tuple[num_gpu_blocks, num_cpu_blocks] - """ - # Get the maximum number of blocks that can be allocated on GPU and CPU. - num_blocks = self.worker.determine_num_available_blocks() - - # NOTE(shengguangming): Now we don't use a shared centralized controler but each process will - # have its own scheduler - num_gpu_blocks = num_blocks[0] - num_cpu_blocks = num_blocks[1] - - return num_gpu_blocks, num_cpu_blocks - - def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: - """Initialize the KV cache in all workers.""" - - # NOTE: We log here to avoid multiple logs when number of workers is - # greater than one. We could log in the engine, but not all executors - # have GPUs. - logger.info("# GPU blocks: %d, # CPU blocks: %d", num_gpu_blocks, num_cpu_blocks) - - self.cache_config.num_gpu_blocks = num_gpu_blocks - self.cache_config.num_cpu_blocks = num_cpu_blocks - - if torch.distributed.get_rank() == 0: - print(f"before init cache memory allocated: {torch.cuda.memory_allocated() / 1e9}GB, reserved: {torch.cuda.memory_reserved() / 1e9}GB") - self.worker.initialize_cache(num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks) - if torch.distributed.get_rank() == 0: - print(f"after init cache memory allocated: {torch.cuda.memory_allocated() / 1e9}GB, reserved: {torch.cuda.memory_reserved() / 1e9}GB") - - # NOTE(sgm): This will not profile & capture the model(CUDAGraph) when rebuilding KVCache - def init_cache_engine(self) -> None: - self.worker._init_cache_engine() - - def free_cache_engine(self) -> None: - self.worker.free_cache_engine() - - def execute_model(self, execute_model_req) -> List[SamplerOutput]: - all_outputs = self.worker.execute_model(execute_model_req=execute_model_req) - - # NOTE(sgm): - # Each GPU in vllm under verl has its own spmd_gpu_executor, therefore all GPUs should return the outputs - # In vllm with ray, only the driver worker returns the sampling results. - return all_outputs - - def add_lora(self, lora_request: LoRARequest) -> bool: - assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." - return self.worker.add_lora(lora_request=lora_request) - - def remove_lora(self, lora_id: int) -> bool: - assert lora_id > 0, "lora_id must be greater than 0." - return self.worker.remove_lora(lora_id=lora_id) - - def list_loras(self) -> Set[int]: - return self.worker.list_loras() - - def check_health(self) -> None: - # SPMDExecutor will always be healthy as long as - # it's running. - return - - # NOTE(sgm) add for verl to pass the abstract class test, not used - from vllm.prompt_adapter.request import PromptAdapterRequest - - def add_prompt_adapter(self, prompt_adapter_request: PromptAdapterRequest) -> bool: - assert prompt_adapter_request.prompt_adapter_id > 0, "prompt_adapter_id must be greater than 0." - return self.worker.add_prompt_adapter(prompt_adapter_request) - - def list_prompt_adapters(self) -> Set[int]: - return self.worker.list_prompt_adapters() - - def pin_lora(self, lora_id: int) -> bool: - assert lora_id > 0, "lora_id must be greater than 0." - return self.worker.pin_lora(lora_id) - - def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: - assert prompt_adapter_id > 0, "prompt_adapter_id must be greater than 0." - return self.worker.pin_prompt_adapter(prompt_adapter_id) - - def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: - assert prompt_adapter_id > 0, "prompt_adapter_id must be greater than 0." - return self.worker.remove_prompt_adapter(prompt_adapter_id) - - # NOTE(sgm): add for verl - def offload_model_weights(self) -> None: - self.worker.offload_model_weights() - - def sync_model_weights(self, actor_weights: Iterable, load_format: str) -> None: - self.worker.sync_model_weights(actor_weights=actor_weights, load_format=load_format) - - -def initialize_cluster( - parallel_config: ParallelConfig, - engine_use_ray: bool = False, - ray_address: Optional[str] = None, -) -> Tuple[str, Optional[None]]: - """Initialize the distributed cluster probably with Ray. - - Args: - parallel_config: The configurations for parallel execution. - - Returns: - The `distributed_init_method` is the address for initializing the - distributed backend. - """ - - # Initialize cluster locally. - # We need to setup the distributed init method to make sure - # the distributed megatron code (e.g., get world size) works correctly. - # distributed_init_method = f"tcp://localhost:{port}" - distributed_init_method = "env://" - return distributed_init_method - - -def get_open_port(): - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(("", 0)) - return s.getsockname()[1] - - -# TODO(sgm): not implemented async executor yet -class SPMDGPUExecutorAsync(SPMDGPUExecutor, ExecutorAsyncBase): - async def execute_model_async(self, execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: - """Executes one model step on the given sequences.""" - raise NotImplementedError - - async def check_health_async(self) -> None: - """Checks if the executor is healthy. If not, it should raise an - exception.""" - self.check_health() diff --git a/verl/third_party/vllm/vllm_v_0_6_3/tokenizer.py b/verl/third_party/vllm/vllm_v_0_6_3/tokenizer.py deleted file mode 100644 index ac94f5447..000000000 --- a/verl/third_party/vllm/vllm_v_0_6_3/tokenizer.py +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2023 The vLLM team. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/tokenizer_group/tokenizer_group.py - -from typing import Optional - -from transformers import PreTrainedTokenizer -from vllm.transformers_utils.tokenizer_group import TokenizerGroup -from vllm.utils import LRUCache - - -class TokenizerGroup(TokenizerGroup): - """A group of tokenizers that can be used for LoRA adapters.""" - - def __init__(self, tokenizer: PreTrainedTokenizer, enable_lora: bool, max_num_seqs: int, max_input_length: Optional[int]): - self.enable_lora = enable_lora - self.max_input_length = max_input_length - self.tokenizer = tokenizer - self.lora_tokenizers = LRUCache[PreTrainedTokenizer](capacity=max_num_seqs) if enable_lora else None - - # FIXME(sgm): for simplicity, we assign the special token here - @property - def pad_token_id(self): - return self.tokenizer.pad_token_id - - @property - def eos_token_id(self): - return self.tokenizer.eos_token_id diff --git a/verl/third_party/vllm/vllm_v_0_6_3/worker.py b/verl/third_party/vllm/vllm_v_0_6_3/worker.py deleted file mode 100644 index 988be1300..000000000 --- a/verl/third_party/vllm/vllm_v_0_6_3/worker.py +++ /dev/null @@ -1,320 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2023 The vLLM team. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/worker/worker.py -"""A GPU worker class.""" - -import gc -import os -from typing import Dict, Iterable, List, Optional, Tuple, Type, Union - -import torch -import torch.distributed -import torch.nn as nn -from vllm.config import ( - CacheConfig, - DeviceConfig, - LoRAConfig, - ParallelConfig, - PromptAdapterConfig, - SchedulerConfig, - SpeculativeConfig, -) - -# TODO(sgm): check why vllm has similar file in vllm.model_executor.parallel_utils.parallel_state -from vllm.distributed import get_tensor_model_parallel_group, init_distributed_environment, set_custom_all_reduce -from vllm.model_executor import set_random_seed -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.sequence import ExecuteModelRequest, IntermediateTensors -from vllm.worker.cache_engine import CacheEngine -from vllm.worker.embedding_model_runner import EmbeddingModelRunner -from vllm.worker.model_runner import GPUModelRunnerBase -from vllm.worker.model_runner_base import ModelRunnerInputBase -from vllm.worker.worker import Worker, _check_if_gpu_supports_dtype -from vllm.worker.worker_base import WorkerInput - -from .config import LoadConfig, LoadFormat, ModelConfig -from .dtensor_weight_loaders import load_dtensor_weights -from .hf_weight_loader import load_hf_weights -from .megatron_weight_loaders import load_megatron_weights -from .model_runner import ModelRunner -from .parallel_state import ensure_model_parallel_initialized - - -class Worker(Worker): - """A worker class that executes (a partition of) the model on a GPU. - - Each worker is associated with a single GPU. The worker is responsible for - maintaining the KV cache and executing the model on the GPU. In case of - distributed inference, each worker is assigned a partition of the model. - """ - - def __init__( - self, - model: Union[nn.Module, Dict], # model itself or its parameter dict - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - cache_config: CacheConfig, - load_config: LoadConfig, - local_rank: int, - rank: int, - distributed_init_method: str, - lora_config: Optional[LoRAConfig] = None, - speculative_config: Optional[SpeculativeConfig] = None, - prompt_adapter_config: Optional[PromptAdapterConfig] = None, - is_driver_worker: bool = False, - model_runner_cls: Optional[Type[GPUModelRunnerBase]] = None, - ) -> None: - # self.model = model # will be replaced in the init_model - self.model_config = model_config - self.parallel_config = parallel_config - self.parallel_config.rank = rank - self.scheduler_config = scheduler_config - self.device_config = device_config - self.cache_config = cache_config - self.local_rank = local_rank - self.rank = rank - self.distributed_init_method = distributed_init_method - self.lora_config = lora_config - self.load_config = load_config - self.prompt_adapter_config = prompt_adapter_config - self.is_driver_worker = is_driver_worker # TODO: we don't need driver - # if parallel_config and is_driver_worker: - # assert rank % parallel_config.tensor_parallel_size == 0, \ - # "Driver worker should be rank 0 of tensor parallel group." - if self.model_config.trust_remote_code: - # note: lazy import to avoid importing torch before initializing - from vllm.utils import init_cached_hf_modules - - init_cached_hf_modules() - - # Return hidden states from target model if the draft model is an - # mlp_speculator - speculative_args = {} if speculative_config is None or (speculative_config.draft_model_config.model == model_config.model) or (speculative_config.draft_model_config.hf_config.model_type not in ["medusa", "mlp_speculator"]) else {"return_hidden_states": True} - - # TODO(sgm): set correct model runner class - ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner - if model_runner_cls is not None: - ModelRunnerClass = model_runner_cls - elif self.model_config.embedding_mode: - ModelRunnerClass = EmbeddingModelRunner - self.model_runner: GPUModelRunnerBase = ModelRunnerClass( - model, # [VERL]: add for verl - model_config, - parallel_config, - scheduler_config, - device_config, - cache_config, - load_config=load_config, - lora_config=self.lora_config, - kv_cache_dtype=self.cache_config.cache_dtype, - is_driver_worker=is_driver_worker, - prompt_adapter_config=prompt_adapter_config, - **speculative_args, - ) - - # Uninitialized cache engine. Will be initialized by - # initialize_cache. - self.cache_engine: List[CacheEngine] = None - # Initialize gpu_cache as embedding models don't initialize kv_caches - self.gpu_cache: Optional[List[List[torch.Tensor]]] = None - - # NOTE(sgm): [VERL] For offloading inference engine params - self.cpu_model = None - - def init_device(self) -> None: - if self.device_config.device.type == "cuda": - # torch.distributed.all_reduce does not free the input tensor until - # the synchronization point. This causes the memory usage to grow - # as the number of all_reduce calls increases. This env var disables - # this behavior. - # Related issue: - # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573 - os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" - - # NOTE(sgm): Modify for verl, Env vars will be set by TORCHRUN. - self.rank = self.rank if self.rank is not None else int(os.getenv("RANK", "-1")) - local_rank = int(os.getenv("LOCAL_RANK", "0")) - self.device = torch.device(f"cuda:{local_rank}") - if self.rank < 0: - raise ValueError("Invalid or unspecified rank.") - torch.cuda.set_device(self.device) - - # Use the world_size set by TORCHRUN - world_size = int(os.getenv("WORLD_SIZE", "-1")) - assert world_size != -1, "The world_size is set to -1, not initialized by TORCHRUN" - self.parallel_config.world_size = world_size - - _check_if_gpu_supports_dtype(self.model_config.dtype) - torch.cuda.empty_cache() - self.init_gpu_memory = torch.cuda.mem_get_info()[0] - else: - raise RuntimeError(f"Not support device type: {self.device_config.device}") - - # Initialize the distributed environment. - init_worker_distributed_environment(self.parallel_config, self.rank, self.distributed_init_method, self.local_rank) - # Set random seed. - set_random_seed(self.model_config.seed) - # self.model = get_model(actor_model=self.model, model_config=self.model_config) - - @torch.inference_mode() - def determine_num_available_blocks(self) -> Tuple[int, int]: - """Profiles the peak memory usage of the model to determine how many - KV blocks may be allocated without OOMs. - - The engine will first conduct a profiling of the existing memory usage. - Then, it calculate the maximum possible number of GPU and CPU blocks - that can be allocated with the remaining free memory. - - .. tip:: - You may limit the usage of GPU memory - by adjusting the `gpu_memory_utilization` parameter. - """ - # Profile the memory usage of the model and get the maximum number of - # cache blocks that can be allocated with the remaining free memory. - torch.cuda.empty_cache() - # torch.cuda.reset_peak_memory_stats() - - # Execute a forward pass with dummy inputs to profile the memory usage - # of the model. - self.model_runner.profile_run() - - # Calculate the number of blocks that can be allocated with the - # profiled peak memory. - torch.cuda.synchronize() - free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() - peak_memory = total_gpu_memory - free_gpu_memory - - assert peak_memory > 0, "Error in memory profiling. This happens when the GPU memory was not properly cleaned up before initializing the vLLM instance." - - cache_block_size = self.get_cache_block_size_bytes() - - # NOTE(sgm) [VERL] use the remaining memory - num_gpu_blocks = int((free_gpu_memory * self.cache_config.gpu_memory_utilization) // cache_block_size) - # num_gpu_blocks = int((total_gpu_memory * self.cache_config.gpu_memory_utilization - peak_memory) // cache_block_size) - - num_cpu_blocks = int(self.cache_config.swap_space_bytes // cache_block_size) - num_gpu_blocks = max(num_gpu_blocks, 0) - num_cpu_blocks = max(num_cpu_blocks, 0) - if self.model_runner.lora_manager: - self.model_runner.remove_all_loras() - - # NOTE(sgm): Add for [VERL], synchronize number of blocks with all the rank - num_gpu_blocks = torch.tensor([num_gpu_blocks], device="cuda") - num_cpu_blocks = torch.tensor([num_cpu_blocks], device="cuda") - - torch.distributed.all_reduce(num_gpu_blocks, op=torch.distributed.ReduceOp.MIN, group=get_tensor_model_parallel_group().device_group) - torch.distributed.all_reduce(num_cpu_blocks, op=torch.distributed.ReduceOp.MIN, group=get_tensor_model_parallel_group().device_group) - num_gpu_blocks = num_gpu_blocks.item() - num_cpu_blocks = num_cpu_blocks.item() - gc.collect() - torch.cuda.empty_cache() - return num_gpu_blocks, num_cpu_blocks - - def _init_cache_engine(self): - if self.cache_engine is None and self.gpu_cache is None: - super()._init_cache_engine() - - def free_cache_engine(self): - # ensure `enforce_eager=True` - self.cache_engine = None - self.gpu_cache = None - - # NOTE(sgm): [VERL]: adapt from _execute_model_spmd() - def execute_model(self, execute_model_req: ExecuteModelRequest, intermediate_tensors: Optional[IntermediateTensors] = None) -> Optional[List[SamplerOutput]]: - """ - Execute model in Single Program Multiple Data (SPMD) fashion. - All workers take the same request, prepare the input and - execute the model. - """ - assert execute_model_req is not None, "_execute_model_spmd() requires each worker to take in an ExecuteModelRequest" - worker_input: WorkerInput = self.prepare_worker_input(execute_model_req=execute_model_req) - model_input: ModelRunnerInputBase = self.model_runner.prepare_model_input(execute_model_req.seq_group_metadata_list) - - # verl.worker.workerbase.WorkerBase - # swap cache - super().execute_worker(worker_input) - - # If there is no input, we don't need to execute the model. - if worker_input.num_seq_groups == 0: - return [] - - return self.model_runner.execute_model( - model_input, - self.kv_cache[worker_input.virtual_engine] if self.kv_cache is not None else None, - intermediate_tensors, - ) - - # assume the input is .state_dict() - def sync_model_weights(self, actor_weights: Iterable, load_format: str): - if load_format in [LoadFormat.MEGATRON, LoadFormat.AUTO]: - load_megatron_weights(actor_weights, self.model_runner.model) - elif load_format == LoadFormat.HF: - # full model state iterable without no sharding - load_hf_weights(actor_weights, self.model_runner.model) - elif load_format == LoadFormat.DTENSOR: - load_dtensor_weights(actor_weights, self.model_runner.model) - - def offload_model_weights(self) -> None: - if self.cpu_model is None: - self.cpu_model = {} - for name, params in self.model_runner.model.named_parameters(): - self.cpu_model[name] = torch.empty_like(params, device="cpu") - params.data = self.cpu_model[name] - else: - for name, params in self.model_runner.model.named_parameters(): - params.data = self.cpu_model[name] - - -def init_worker_distributed_environment( - parallel_config: ParallelConfig, - rank: int, - distributed_init_method: Optional[str] = "env://", - local_rank: int = -1, -) -> None: - """Initialize the distributed environment.""" - set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) - - # NOTE(sgm) use tcp://localhost:xxxx will hang in HF setting without megatron - init_distributed_environment(parallel_config.world_size, rank, distributed_init_method, local_rank) - - ensure_model_parallel_initialized( - tensor_model_parallel_size=parallel_config.tensor_parallel_size, - pipeline_model_parallel_size=parallel_config.pipeline_parallel_size, - ) - - # TODO(sgm): check whether need this - # if pynccl_utils.is_initialized(): - # pynccl_world_size = pynccl_utils.get_world_size() - # if pynccl_world_size != parallel_config.world_size: - # raise RuntimeError( - # "pynccl is already initialized but the pynccl world " - # "size does not match parallel_config.world_size " - # f"({pynccl_world_size} vs. {parallel_config.world_size}).") - # elif parallel_config.world_size > 1: - # # NOTE(woosuk): We don't initialize pynccl process group when world size - # # is 1. - # # NOTE(kaichao): By default, pynccl is initialized for tp group. - # pynccl_utils.init_process_group( - # group=get_tensor_model_parallel_cpu_group()) - - # # Initialize a custom fast all-reduce implementation. - # if not parallel_config.disable_custom_all_reduce: - # init_custom_ar() - - # A small all_reduce for warmup. - torch.distributed.all_reduce(torch.zeros(1).cuda()) - # if pynccl_utils.is_initialized(): - # pynccl_utils.all_reduce(torch.zeros(1).cuda()) diff --git a/verl/tools/base_tool.py b/verl/tools/base_tool.py index e627b99f2..9a1189d20 100644 --- a/verl/tools/base_tool.py +++ b/verl/tools/base_tool.py @@ -12,9 +12,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Tuple +import json +from typing import Any, Optional from uuid import uuid4 +from verl.utils.rollout_trace import rollout_trace_op + from .schemas import OpenAIFunctionToolSchema @@ -32,8 +35,10 @@ class BaseTool: def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema): self.config = config - self.name = tool_schema.function.name - self.tool_schema = tool_schema + self.tool_schema = tool_schema or self.get_openai_tool_schema() + assert self.tool_schema is not None, "Tool schema is not set!" + self.name = self.tool_schema.function.name + print(json.dumps(self.tool_schema.model_dump(exclude_unset=True, exclude_none=True), indent=2)) def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema: return self.tool_schema @@ -52,7 +57,8 @@ async def create(self, instance_id: Optional[str] = None, **kwargs) -> str: else: return instance_id - async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> Tuple[str, float, dict]: + @rollout_trace_op + async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]: """Execute the tool. Args: diff --git a/verl/tools/geo3k_tool.py b/verl/tools/geo3k_tool.py new file mode 100644 index 000000000..6ffd6fb2c --- /dev/null +++ b/verl/tools/geo3k_tool.py @@ -0,0 +1,99 @@ +# Copyright 2023-2025 SGLang Team +# Copyright Amazon.com, Inc. or its affiliates. +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +from typing import Any, Optional +from uuid import uuid4 + +from verl.utils.reward_score import geo3k +from verl.utils.rollout_trace import rollout_trace_op + +from .base_tool import BaseTool +from .schemas import OpenAIFunctionToolSchema + +logger = logging.getLogger(__name__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +class Geo3kTool(BaseTool): + """A demo tool for calculating the reward of geo3k. + - `to_openai_function_tool_schema`: return the tool schema in OpenAI format. + - `create`: create a tool instance for a trajectory. + - `execute`: execute the tool. + - `calc_reward`: calculate the reward respect to tool state. + - `release`: release the tool instance. + """ + + def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema): + """ + _tool_schema = OpenAIFunctionToolSchema.model_validate({ + "type": "function", + "function": { + "name": "calc_geo3k_reward", + "description": "A tool for calculating the reward of geo3k", + "parameters": { + "type": "object", + "properties": { + "answer": { + "type": "string", + "description": "The answer to the question, enclosed in \\boxed{}", + }, + }, + "required": ["answer"], + }, + } + }) + """ + super().__init__(config, tool_schema) + self._instance_dict = {} + + def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema: + return self.tool_schema + + async def create(self, instance_id: Optional[str] = None, ground_truth: Optional[str] = None, **kwargs) -> str: + if instance_id is None: + instance_id = str(uuid4()) + self._instance_dict[instance_id] = { + "response": "", + "ground_truth": ground_truth, + "reward": 0.0, + } + return instance_id, None + + @rollout_trace_op + async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]: + answer = parameters.get("answer", "") + if not isinstance(answer, str): + answer = str(answer) + self._instance_dict[instance_id]["response"] = answer + reward = await self.calc_reward(instance_id) + # penalty for non improved answer submission + tool_reward = 0.0 if reward > self._instance_dict[instance_id]["reward"] else -0.05 + # update the reward + self._instance_dict[instance_id]["reward"] = reward + return f"Current parsed {answer=} {reward=}", tool_reward, {} + + async def calc_reward(self, instance_id: str, **kwargs) -> float: + return geo3k.compute_score( + self._instance_dict[instance_id]["response"], + self._instance_dict[instance_id]["ground_truth"], + use_boxed=False, + format_score=0.0, + ) + + async def release(self, instance_id: str, **kwargs) -> None: + del self._instance_dict[instance_id] diff --git a/verl/tools/gsm8k_tool.py b/verl/tools/gsm8k_tool.py index 00118e96e..f6d89134d 100644 --- a/verl/tools/gsm8k_tool.py +++ b/verl/tools/gsm8k_tool.py @@ -15,10 +15,11 @@ import logging import os -from typing import Any, Optional, Tuple +from typing import Any, Optional from uuid import uuid4 from verl.utils.reward_score import gsm8k +from verl.utils.rollout_trace import rollout_trace_op from .base_tool import BaseTool from .schemas import OpenAIFunctionToolSchema @@ -73,7 +74,8 @@ async def create(self, instance_id: Optional[str] = None, ground_truth: Optional } return instance_id - async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> Tuple[str, float, dict]: + @rollout_trace_op + async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]: answer = parameters.get("answer", "") if not isinstance(answer, str): answer = str(answer) diff --git a/verl/tools/mcp_base_tool.py b/verl/tools/mcp_base_tool.py new file mode 100644 index 000000000..dacd18ebe --- /dev/null +++ b/verl/tools/mcp_base_tool.py @@ -0,0 +1,116 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import os +from typing import Any, Optional +from uuid import uuid4 + +from fastmcp.exceptions import ClientError + +from verl.tools.utils.mcp_clients.McpClientManager import ClientManager +from verl.utils.rollout_trace import rollout_trace_op + +from .base_tool import BaseTool +from .schemas import OpenAIFunctionToolSchema + +logger = logging.getLogger(__name__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +class MCPBaseTool(BaseTool): + def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema): + super().__init__(config, tool_schema) + self._instance_dict = {} + self.timeout = config.get("timeout", 30) + + # TODO(hechanghao): create a global client manager to manage the rate limit, client and pool + logger.info(f"Initialized MCPBaseTool with config: {config}") + + def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema: + """Return the OpenAI tool schema.""" + return self.tool_schema + + async def create(self, instance_id: Optional[str] = None, **kwargs) -> str: + """Create a tool instance. + + Args: + instance_id: The instance id of the tool. + + Returns: + The instance id of the tool. + """ + if instance_id is None: + instance_id = str(uuid4()) + self._instance_dict[instance_id] = { + "response": "", + "reward": [], + } + return instance_id + + async def _call_tool(self, instance_id, parameters) -> tuple[str, dict]: + err_msg = "" + try: + call_tool_result = await ClientManager.call_tool(self.name, parameters, self.timeout) + except ClientError as e: + err_msg = f"\n Tool call failed: {e}" + except ConnectionError as e: + err_msg = f"\n Connection failed: {e}" + except Exception as e: + err_msg = f"\n An unexpected error occurred: {e}" + + logger.debug(f"Tool result for instance {instance_id} with tool {self.name}: {call_tool_result.content}") + result, metadata = self._parse_tool_result(call_tool_result.content) + metadata["api_request_error"] += err_msg + return result, metadata + + @rollout_trace_op + async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]: + if self.name == "" or self.name is None or parameters is None: + error_msg = "Error: 'parameters' is missing or empty." + logger.error(f"[MCPTool] {error_msg} Received tool name: {self.name}, parameters: {parameters}") + return json.dumps({"result": error_msg}), 0.0, {} + + try: + result_text, metadata = await self._call_tool(instance_id, parameters) + + # Store results in instance dictionary + self._instance_dict[instance_id]["reward"].append(result_text.strip()) + + # Convert metadata to metrics + metrics = { + "query_count": metadata.get("query_count", 0), + "status": metadata.get("status", "unknown"), + "total_results": metadata.get("total_results", 0), + "api_request_error": metadata.get("api_request_error"), + } + + return result_text, 0.0, metrics + + except Exception as e: + error_result = json.dumps({"result": f"Tool execution failed: {e}"}) + logger.error(f"[MCPBaseTool] Execution failed: {e}") + return error_result, 0.0, {"error": str(e)} + + async def calc_reward(self, instance_id: str, **kwargs) -> str: + return self._instance_dict[instance_id]["reward"] + + async def release(self, instance_id: str, **kwargs) -> None: + if instance_id in self._instance_dict: + del self._instance_dict[instance_id] + + def _parse_tool_result(self, content: list) -> tuple[str, dict]: + tools_content = [part.text for part in filter(lambda x: x.type == "text", content)] + return " ".join(tools_content), {} diff --git a/verl/tools/mcp_search_tool.py b/verl/tools/mcp_search_tool.py new file mode 100644 index 000000000..ac823719b --- /dev/null +++ b/verl/tools/mcp_search_tool.py @@ -0,0 +1,69 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import os +import re + +from verl.tools.mcp_base_tool import MCPBaseTool + +from .schemas import OpenAIFunctionToolSchema + +logger = logging.getLogger(__name__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +class MCPSearchTool(MCPBaseTool): + def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema): + super().__init__(config, tool_schema) + + def _parse_tool_result(self, content: list) -> tuple[str, dict]: + res = "" + res_cnt = 0 + query_list = [] + metadata = { + "api_request_error": "", + "status": "unknown", + "total_results": 0, + } + try: + for part in content: + if part.type != "text": + continue + text = part.text.replace("'", '"') + query_match = re.search(r'query"\s*:\s*"([^"]+)"', text) + query = query_match.group(1) if query_match else "" + query_list.append(query) + + title_matches = re.findall(r'"title"\s*:', text) + title_count = len(title_matches) + + results_match = re.search(r'"results"\s*:\s*(\[.*?\])', text, re.DOTALL) + results_content = results_match.group(1) if results_match else "" + + res += results_content + res_cnt += title_count + except json.JSONDecodeError: + err_msg = "json parse error." + logger.error(err_msg) + metadata["api_request_error"] = err_msg + metadata["status"] = "error" + + # update metadata + metadata["status"] = "success" + metadata["queries"] = query_list + metadata["query_count"] = len(query_list) + metadata["total_results"] = res_cnt + return res, metadata diff --git a/verl/tools/sandbox_fusion_tools.py b/verl/tools/sandbox_fusion_tools.py index ee274bcd2..c3a2748d9 100644 --- a/verl/tools/sandbox_fusion_tools.py +++ b/verl/tools/sandbox_fusion_tools.py @@ -17,15 +17,14 @@ import threading from contextlib import ExitStack from enum import Enum -from typing import Any, Callable, Optional, Tuple, TypeVar +from typing import Any, Callable, Optional, TypeVar from uuid import uuid4 import ray -import ray.actor -import ray.util.multiprocessing from verl.tools.base_tool import BaseTool from verl.utils.reward_score.sandbox_fusion.utils import _process_single_case +from verl.utils.rollout_trace import rollout_trace_op from .schemas import OpenAIFunctionToolSchema @@ -85,9 +84,15 @@ def execute(self, fn: Callable[..., T], *fn_args, **fn_kwargs) -> T: logger.warning(f"Error when executing code: {e}") -def init_execution_pool(num_workers: int, enable_global_rate_limit=True, rate_limit=10, mode: PoolMode = PoolMode.ThreadMode): +def init_execution_pool( + num_workers: int, enable_global_rate_limit=True, rate_limit=10, mode: PoolMode = PoolMode.ThreadMode +): if mode == PoolMode.ThreadMode: - return ray.remote(ExecutionWorker).options(max_concurrency=num_workers).remote(enable_global_rate_limit=enable_global_rate_limit, rate_limit=rate_limit) + return ( + ray.remote(ExecutionWorker) + .options(max_concurrency=num_workers) + .remote(enable_global_rate_limit=enable_global_rate_limit, rate_limit=rate_limit) + ) else: raise NotImplementedError("Process mode is not implemented yet") # return ray.util.multiprocessing.Pool(processes=num_workers) @@ -131,8 +136,14 @@ def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema): self.default_timeout = config.get("default_timeout", 30) self.default_language = config.get("default_language", "python") self.enable_global_rate_limit = config.get("enable_global_rate_limit", True) - self.execution_pool = init_execution_pool(num_workers=self.num_workers, enable_global_rate_limit=self.enable_global_rate_limit, rate_limit=self.rate_limit, mode=PoolMode.ThreadMode) + self.execution_pool = init_execution_pool( + num_workers=self.num_workers, + enable_global_rate_limit=self.enable_global_rate_limit, + rate_limit=self.rate_limit, + mode=PoolMode.ThreadMode, + ) self.sandbox_fusion_url = config.get("sandbox_fusion_url", "") + self.memory_limit_mb = config.get("memory_limit_mb", 1024) if self.sandbox_fusion_url == "": raise ValueError("sandbox_fusion_url is not set") log_msg = f"Init SandboxFusionTool with config: {config}" @@ -151,7 +162,8 @@ async def create(self, instance_id: Optional[str] = None, ground_truth: Optional } return instance_id - async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> Tuple[str, float, dict]: + @rollout_trace_op + async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]: code = parameters.get("code", "") timeout = parameters.get("timeout", self.default_timeout) language = parameters.get("language", self.default_language) @@ -159,14 +171,16 @@ async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) code = str(code) result = await self.execution_pool.execute.remote(self.execute_code, instance_id, code, timeout, language) - - return result, result, result.strip() + # sandbox has no score or metrics, use Nones + return result, None, None def execute_code(self, instance_id, code, timeout=30, language="python"): - result_status, metadata = _process_single_case(0, None, None, self.sandbox_fusion_url, code, timeout, language) + result_status, metadata = _process_single_case( + 0, None, None, self.sandbox_fusion_url, code, timeout, self.memory_limit_mb, language + ) # we should always expect this since we don't have correct answer if metadata["run_status"] == "Finished": - actual_output = metadata["stdout"] if metadata["stdout"] is not None else "" + actual_output = metadata["stdout"] + metadata["stderr"] logger.debug(f"actual_output from sandbox fusion: {actual_output},{instance_id}") return actual_output else: diff --git a/verl/tools/schemas.py b/verl/tools/schemas.py index 187ab4aa1..c0c65a30e 100644 --- a/verl/tools/schemas.py +++ b/verl/tools/schemas.py @@ -64,7 +64,9 @@ class OpenAIFunctionCallSchema(BaseModel): arguments: dict[str, Any] @staticmethod - def from_openai_function_parsed_schema(parsed_schema: OpenAIFunctionParsedSchema) -> tuple["OpenAIFunctionCallSchema", bool]: + def from_openai_function_parsed_schema( + parsed_schema: OpenAIFunctionParsedSchema, + ) -> tuple["OpenAIFunctionCallSchema", bool]: has_decode_error = False try: arguments = json.loads(parsed_schema.arguments) diff --git a/verl/tools/search_tool.py b/verl/tools/search_tool.py index b66200a43..3cc6cda53 100644 --- a/verl/tools/search_tool.py +++ b/verl/tools/search_tool.py @@ -19,13 +19,14 @@ import threading from contextlib import ExitStack from enum import Enum -from typing import Any, Callable, Optional, Tuple, TypeVar +from typing import Any, Callable, Optional, TypeVar from uuid import uuid4 import ray import ray.actor from verl.tools.utils.search_r1_like_utils import perform_single_search_batch +from verl.utils.rollout_trace import rollout_trace_op from .base_tool import BaseTool from .schemas import OpenAIFunctionToolSchema @@ -99,10 +100,16 @@ def execute(self, fn: Callable[..., T], *fn_args, **fn_kwargs) -> T: return fn(*fn_args, **fn_kwargs) -def init_search_execution_pool(num_workers: int, enable_global_rate_limit=True, rate_limit=10, mode: PoolMode = PoolMode.ThreadMode): +def init_search_execution_pool( + num_workers: int, enable_global_rate_limit=True, rate_limit=10, mode: PoolMode = PoolMode.ThreadMode +): """Initialize search execution pool.""" if mode == PoolMode.ThreadMode: - return ray.remote(SearchExecutionWorker).options(max_concurrency=num_workers).remote(enable_global_rate_limit=enable_global_rate_limit, rate_limit=rate_limit) + return ( + ray.remote(SearchExecutionWorker) + .options(max_concurrency=num_workers) + .remote(enable_global_rate_limit=enable_global_rate_limit, rate_limit=rate_limit) + ) else: raise NotImplementedError("Process mode is not implemented yet") @@ -158,7 +165,12 @@ def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema): self.timeout = config.get("timeout", 30) self.enable_global_rate_limit = config.get("enable_global_rate_limit", True) - self.execution_pool = init_search_execution_pool(num_workers=self.num_workers, enable_global_rate_limit=self.enable_global_rate_limit, rate_limit=self.rate_limit, mode=PoolMode.ThreadMode) + self.execution_pool = init_search_execution_pool( + num_workers=self.num_workers, + enable_global_rate_limit=self.enable_global_rate_limit, + rate_limit=self.rate_limit, + mode=PoolMode.ThreadMode, + ) # Retrieval service configuration self.retrieval_service_url = config.get("retrieval_service_url") @@ -213,7 +225,8 @@ def execute_search(self, instance_id: str, query_list: list, retrieval_service_u logger.debug(f"Search result for instance {instance_id}: {result_text}") return result_text, metadata - async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> Tuple[str, float, dict]: + @rollout_trace_op + async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]: """Execute the search tool. Args: @@ -235,13 +248,20 @@ async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) # Execute search using Ray execution pool try: - result_text, metadata = await self.execution_pool.execute.remote(self.execute_search, instance_id, query_list_from_params, self.retrieval_service_url, self.topk, timeout) + result_text, metadata = await self.execution_pool.execute.remote( + self.execute_search, instance_id, query_list_from_params, self.retrieval_service_url, self.topk, timeout + ) # Store results in instance dictionary self._instance_dict[instance_id]["reward"].append(result_text.strip()) # Convert metadata to metrics - metrics = {"query_count": metadata.get("query_count", 0), "status": metadata.get("status", "unknown"), "total_results": metadata.get("total_results", 0), "api_request_error": metadata.get("api_request_error")} + metrics = { + "query_count": metadata.get("query_count", 0), + "status": metadata.get("status", "unknown"), + "total_results": metadata.get("total_results", 0), + "api_request_error": metadata.get("api_request_error"), + } return result_text, 0.0, metrics diff --git a/verl/tools/utils/mcp_clients/McpClientManager.py b/verl/tools/utils/mcp_clients/McpClientManager.py new file mode 100644 index 000000000..ee5fe3119 --- /dev/null +++ b/verl/tools/utils/mcp_clients/McpClientManager.py @@ -0,0 +1,97 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import json +import logging +from typing import Any + +from fastmcp import Client +from fastmcp.client.transports import SSETransport + +from verl.tools.utils.mcp_clients.utils import TokenBucket, mcp2openai + +logger = logging.getLogger(__name__) + + +class MCPClientManager: + rootServerName = "mcpServers" + initialized = False + clients = [] + tool_client_mapping = {} + rate_limiter = None + + async def initialize(self, config_path, rate_limit: float = 10.0): + if self.initialized: + return + """Initialize the MCP Client Manager and start all clients""" + result = self._load_config(config_path) + servers = result[self.rootServerName] + exclude_sse_servers = {self.rootServerName: {}} + for server_name in servers.keys(): + server = servers[server_name] + if "auth_token" in server: + transport = SSETransport(url=server["url"], headers={"Authorization": f"Bearer {server['auth_token']}"}) + client = Client(transport) + self.clients.append(client) + else: + exclude_sse_servers[self.rootServerName][server_name] = server + + if exclude_sse_servers[self.rootServerName]: + self.clients.append(Client(exclude_sse_servers)) + + # Initialize rate limiter + self.rate_limiter = TokenBucket(rate_limit) + self.initialized = True + + async def call_tool(self, tool_name, parameters, timeout): + # Apply rate limiting + while not self.rate_limiter.acquire(): + await asyncio.sleep(0.1) + + client = self.get_client_with_tool_name(tool_name) + async with client: + return await client.call_tool_mcp(tool_name, parameters) + + async def fetch_tool_schemas(self, tool_selected_list: list[str]) -> list[dict]: + tool_schemas = [] + for client in self.clients: + async with client: + tools = await client.list_tools_mcp() + for tool in tools.tools: + if not tool_selected_list: + self.tool_client_mapping[tool.name] = client + tool_schemas.append(mcp2openai(tool)) + elif tool.name in tool_selected_list: + self.tool_client_mapping[tool.name] = client + tool_schemas.append(mcp2openai(tool)) + + return tool_schemas + + def get_client_with_tool_name(self, tool_name: str): + return self.tool_client_mapping[tool_name] + + def _load_config(self, file: str) -> dict[str, Any]: + try: + with open(file) as f: + return json.load(f) + except FileNotFoundError: + logger.warning(f'the "{file}" file was not found') + except Exception: + logger.error(f'there was an error reading the "{file}" file') + + return {} + + +ClientManager = MCPClientManager() diff --git a/verl/tools/utils/mcp_clients/utils.py b/verl/tools/utils/mcp_clients/utils.py new file mode 100644 index 000000000..22a5f6353 --- /dev/null +++ b/verl/tools/utils/mcp_clients/utils.py @@ -0,0 +1,58 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import threading +import time + +from mcp import Tool + +logger = logging.getLogger(__file__) + + +class TokenBucket: + def __init__(self, rate_limit: float): + self.rate_limit = rate_limit # tokens per second + self.tokens = rate_limit + self.last_update = time.time() + self.lock = threading.Lock() + + def acquire(self) -> bool: + with self.lock: + now = time.time() + # Add new tokens based on time elapsed + new_tokens = (now - self.last_update) * self.rate_limit + self.tokens = min(self.rate_limit, self.tokens + new_tokens) + self.last_update = now + + if self.tokens >= 1: + self.tokens -= 1 + return True + return False + + +def mcp2openai(mcp_tool: Tool) -> dict: + """Convert a MCP Tool to an OpenAI ChatCompletionTool.""" + openai_format = { + "type": "function", + "function": { + "name": mcp_tool.name, + "description": mcp_tool.description, + "parameters": mcp_tool.inputSchema, + "strict": False, + }, + } + if not openai_format["function"]["parameters"].get("required", None): + openai_format["function"]["parameters"]["required"] = [] + return openai_format diff --git a/verl/tools/utils/search_r1_like_utils.py b/verl/tools/utils/search_r1_like_utils.py index 8a3bb1bba..23669e44c 100644 --- a/verl/tools/utils/search_r1_like_utils.py +++ b/verl/tools/utils/search_r1_like_utils.py @@ -19,7 +19,7 @@ import time import traceback import uuid -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Optional import requests @@ -31,7 +31,13 @@ logger = logging.getLogger(__name__) -def call_search_api(retrieval_service_url: str, query_list: List[str], topk: int = 3, return_scores: bool = True, timeout: int = DEFAULT_TIMEOUT) -> Tuple[Optional[Dict[str, Any]], Optional[str]]: +def call_search_api( + retrieval_service_url: str, + query_list: list[str], + topk: int = 3, + return_scores: bool = True, + timeout: int = DEFAULT_TIMEOUT, +) -> tuple[Optional[dict[str, Any]], Optional[str]]: """ Calls the remote search API to perform retrieval with retry logic for various errors, using increasing delay between retries. Logs internal calls with a unique ID. @@ -59,7 +65,9 @@ def call_search_api(retrieval_service_url: str, query_list: List[str], topk: int for attempt in range(MAX_RETRIES): try: - logger.info(f"{log_prefix}Attempt {attempt + 1}/{MAX_RETRIES}: Calling search API at {retrieval_service_url}") + logger.info( + f"{log_prefix}Attempt {attempt + 1}/{MAX_RETRIES}: Calling search API at {retrieval_service_url}" + ) response = requests.post( retrieval_service_url, headers=headers, @@ -69,7 +77,10 @@ def call_search_api(retrieval_service_url: str, query_list: List[str], topk: int # Check for Gateway Timeout (504) and other server errors for retrying if response.status_code in [500, 502, 503, 504]: - last_error = f"{log_prefix}API Request Error: Server Error ({response.status_code}) on attempt {attempt + 1}/{MAX_RETRIES}" + last_error = ( + f"{log_prefix}API Request Error: Server Error ({response.status_code}) on attempt " + f"{attempt + 1}/{MAX_RETRIES}" + ) logger.warning(last_error) if attempt < MAX_RETRIES - 1: delay = INITIAL_RETRY_DELAY * (attempt + 1) @@ -127,7 +138,13 @@ def _passages2string(retrieval_result): return format_reference.strip() -def perform_single_search_batch(retrieval_service_url: str, query_list: List[str], topk: int = 3, concurrent_semaphore: Optional[threading.Semaphore] = None, timeout: int = DEFAULT_TIMEOUT) -> Tuple[str, Dict[str, Any]]: +def perform_single_search_batch( + retrieval_service_url: str, + query_list: list[str], + topk: int = 3, + concurrent_semaphore: Optional[threading.Semaphore] = None, + timeout: int = DEFAULT_TIMEOUT, +) -> tuple[str, dict[str, Any]]: """ Performs a single batch search for multiple queries (original search tool behavior). @@ -151,9 +168,21 @@ def perform_single_search_batch(retrieval_service_url: str, query_list: List[str try: if concurrent_semaphore: with concurrent_semaphore: - api_response, error_msg = call_search_api(retrieval_service_url=retrieval_service_url, query_list=query_list, topk=topk, return_scores=True, timeout=timeout) + api_response, error_msg = call_search_api( + retrieval_service_url=retrieval_service_url, + query_list=query_list, + topk=topk, + return_scores=True, + timeout=timeout, + ) else: - api_response, error_msg = call_search_api(retrieval_service_url=retrieval_service_url, query_list=query_list, topk=topk, return_scores=True, timeout=timeout) + api_response, error_msg = call_search_api( + retrieval_service_url=retrieval_service_url, + query_list=query_list, + topk=topk, + return_scores=True, + timeout=timeout, + ) except Exception as e: error_msg = f"API Request Exception during batch search: {e}" logger.error(f"Batch search: {error_msg}") diff --git a/verl/tools/utils/tool_registry.py b/verl/tools/utils/tool_registry.py new file mode 100644 index 000000000..5c14d1016 --- /dev/null +++ b/verl/tools/utils/tool_registry.py @@ -0,0 +1,107 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import importlib +import logging +import os +import sys +from enum import Enum + +from omegaconf import OmegaConf + +from verl.tools.schemas import OpenAIFunctionToolSchema + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +class ToolType(Enum): + NATIVE = "native" + MCP = "mcp" + + +async def initialize_mcp_tool(tool_cls, tool_config) -> list: + from verl.tools.utils.mcp_clients.McpClientManager import ClientManager + + tool_list = [] + mcp_servers_config_path = tool_config.mcp.mcp_servers_config_path + tool_selected_list = tool_config.mcp.tool_selected_list if "tool_selected_list" in tool_config.mcp else None + await ClientManager.initialize(mcp_servers_config_path, tool_config.config.rate_limit) + # Wait for MCP client to be ready + max_retries = 10 + retry_interval = 2 # seconds + for i in range(max_retries): + tool_schemas = await ClientManager.fetch_tool_schemas(tool_selected_list) + if tool_schemas: + break + if i < max_retries - 1: + logger.debug(f"Waiting for MCP client to be ready, attempt {i + 1}/{max_retries}") + await asyncio.sleep(retry_interval) + else: + raise RuntimeError("Failed to initialize MCP tools after maximum retries") + # mcp registry + assert len(tool_schemas), "mcp tool is empty" + for tool_schema_dict in tool_schemas: + logger.debug(f"tool_schema_dict: {tool_schema_dict}") + tool_schema = OpenAIFunctionToolSchema.model_validate(tool_schema_dict) + tool = tool_cls( + config=OmegaConf.to_container(tool_config.config, resolve=True), + tool_schema=tool_schema, + ) + tool_list.append(tool) + return tool_list + + +def get_tool_class(cls_name): + module_name, class_name = cls_name.rsplit(".", 1) + if module_name not in sys.modules: + spec = importlib.util.find_spec(module_name) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + else: + module = sys.modules[module_name] + + tool_cls = getattr(module, class_name) + return tool_cls + + +def initialize_tools_from_config(tools_config_file): + tools_config = OmegaConf.load(tools_config_file) + tool_list = [] + for tool_config in tools_config.tools: + cls_name = tool_config.class_name + tool_type = ToolType(tool_config.config.type) + tool_cls = get_tool_class(cls_name) + + match tool_type: + case ToolType.NATIVE: + if tool_config.get("tool_schema", None) is None: + tool_schema = None + else: + tool_schema_dict = OmegaConf.to_container(tool_config.tool_schema, resolve=True) + tool_schema = OpenAIFunctionToolSchema.model_validate(tool_schema_dict) + tool = tool_cls( + config=OmegaConf.to_container(tool_config.config, resolve=True), + tool_schema=tool_schema, + ) + tool_list.append(tool) + case ToolType.MCP: + loop = asyncio.get_event_loop() + mcp_tools = loop.run_until_complete(initialize_mcp_tool(tool_cls, tool_config)) + tool_list.extend(mcp_tools) + case _: + raise NotImplementedError + return tool_list diff --git a/verl/trainer/config/__init__.py b/verl/trainer/config/__init__.py new file mode 100644 index 000000000..f4cc9b8e2 --- /dev/null +++ b/verl/trainer/config/__init__.py @@ -0,0 +1,22 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .algorithm import AlgoConfig, FilterGroupsConfig, KLControlConfig, PFPPOConfig + +__all__ = [ + "AlgoConfig", + "FilterGroupsConfig", + "KLControlConfig", + "PFPPOConfig", +] diff --git a/verl/trainer/config/_generated_ppo_trainer.yaml b/verl/trainer/config/_generated_ppo_trainer.yaml new file mode 100644 index 000000000..1d715e919 --- /dev/null +++ b/verl/trainer/config/_generated_ppo_trainer.yaml @@ -0,0 +1,372 @@ +# This reference configration yaml is automatically generated via 'scripts/generate_trainer_config.sh' +# in which it invokes 'python3 scripts/print_cfg.py --cfg job' to flatten the 'verl/trainer/config/ppo_trainer.yaml' config fields into a single file. +# Do not modify this file directly. +# The file is usually only for reference and never used. + +actor_rollout_ref: + actor: + strategy: fsdp + ppo_mini_batch_size: 256 + ppo_micro_batch_size: null + ppo_micro_batch_size_per_gpu: null + use_dynamic_bsz: false + ppo_max_token_len_per_gpu: 16384 + clip_ratio: 0.2 + clip_ratio_low: 0.2 + clip_ratio_high: 0.2 + policy_loss: + loss_mode: vanilla + clip_cov_ratio: 0.0002 + clip_cov_lb: 1.0 + clip_cov_ub: 5.0 + kl_cov_ratio: 0.0002 + ppo_kl_coef: 0.1 + clip_ratio_c: 3.0 + loss_agg_mode: token-mean + entropy_coeff: 0 + use_kl_loss: false + use_torch_compile: true + kl_loss_coef: 0.001 + kl_loss_type: low_var_kl + ppo_epochs: 1 + shuffle: false + checkpoint: + save_contents: + - model + - optimizer + - extra + load_contents: ${.save_contents} + optim: + lr: 1.0e-06 + lr_warmup_steps_ratio: 0.0 + total_training_steps: -1 + weight_decay: 0.01 + lr_warmup_steps: -1 + min_lr_ratio: 0.0 + num_cycles: 0.5 + warmup_style: constant + grad_clip: 1.0 + ulysses_sequence_parallel_size: 1 + entropy_from_logits_with_chunking: false + entropy_checkpointing: false + fsdp_config: + wrap_policy: + min_num_params: 0 + param_offload: false + optimizer_offload: false + offload_policy: false + reshard_after_forward: true + fsdp_size: -1 + forward_prefetch: false + ref: + strategy: ${actor_rollout_ref.actor.strategy} + use_torch_compile: ${oc.select:actor_rollout_ref.actor.use_torch_compile,true} + log_prob_micro_batch_size: null + log_prob_micro_batch_size_per_gpu: null + log_prob_use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false} + log_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384} + fsdp_config: + param_offload: false + reshard_after_forward: true + forward_prefetch: false + wrap_policy: + min_num_params: 0 + ulysses_sequence_parallel_size: ${oc.select:actor_rollout_ref.actor.ulysses_sequence_parallel_size,1} + entropy_from_logits_with_chunking: false + entropy_checkpointing: false + rollout: + name: vllm + mode: sync + temperature: 1.0 + top_k: -1 + top_p: 1 + prompt_length: ${oc.select:data.max_prompt_length,512} + response_length: ${oc.select:data.max_response_length,512} + dtype: bfloat16 + gpu_memory_utilization: 0.5 + ignore_eos: false + enforce_eager: true + free_cache_engine: true + tensor_model_parallel_size: 2 + max_num_batched_tokens: 8192 + max_model_len: null + max_num_seqs: 1024 + log_prob_micro_batch_size: null + log_prob_micro_batch_size_per_gpu: null + log_prob_use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false} + log_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384} + disable_log_stats: true + do_sample: true + 'n': 1 + multi_stage_wake_up: false + engine_kwargs: + vllm: + swap_space: null + disable_mm_preprocessor_cache: false + sglang: + attention_backend: null + val_kwargs: + top_k: -1 + top_p: 1.0 + temperature: 0 + 'n': 1 + do_sample: false + multi_turn: + enable: false + max_assistant_turns: null + tool_config_path: null + max_user_turns: null + max_parallel_calls: 1 + max_tool_response_length: 256 + tool_response_truncate_side: middle + interaction_config_path: null + completion_callback: null + use_inference_chat_template: false + tokenization_sanity_check_mode: strict + format: hermes + calculate_log_probs: false + agent: + num_workers: 8 + agent_loop_config_path: null + custom_async_server: + path: null + name: null + update_weights_bucket_megabytes: 512 + trace: + backend: null + token2text: false + enable_chunked_prefill: true + load_format: dummy_dtensor + layered_summon: false + hybrid_engine: true + model: + path: ~/models/deepseek-llm-7b-chat + custom_chat_template: null + use_shm: false + external_lib: null + override_config: {} + enable_gradient_checkpointing: true + enable_activation_offload: false + use_remove_padding: false + lora_rank: 0 + lora_alpha: 16 + target_modules: all-linear + exclude_modules: null + use_liger: false + use_fused_kernels: false + fused_kernel_options: + impl_backend: torch + trust_remote_code: false + profiler: + _target_: verl.utils.profiler.ProfilerConfig + discrete: false + all_ranks: false + ranks: [] +trainer: + npu_profile: + options: + save_path: ./profiler_data + level: level1 + with_memory: false + record_shapes: false + with_npu: true + with_cpu: true + with_module: false + with_stack: false + analysis: true + balance_batch: true + total_epochs: 30 + total_training_steps: null + profile_steps: null + controller_nsight_options: + trace: cuda,nvtx,cublas,ucx + cuda-memory-usage: 'true' + cuda-graph-trace: graph + worker_nsight_options: + trace: cuda,nvtx,cublas,ucx + cuda-memory-usage: 'true' + cuda-graph-trace: graph + capture-range: cudaProfilerApi + capture-range-end: null + kill: none + project_name: verl_examples + experiment_name: gsm8k + logger: + - console + - wandb + log_val_generations: 0 + rollout_data_dir: null + validation_data_dir: null + nnodes: 1 + n_gpus_per_node: 8 + save_freq: -1 + esi_redundant_time: 0 + resume_mode: auto + resume_from_path: null + val_before_train: true + val_only: false + test_freq: -1 + critic_warmup: 0 + default_hdfs_dir: null + del_local_ckpt_after_load: false + default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} + max_actor_ckpt_to_keep: null + max_critic_ckpt_to_keep: null + ray_wait_register_center_timeout: 300 + device: cuda +data: + tokenizer: null + use_shm: false + train_files: ~/data/rlhf/gsm8k/train.parquet + val_files: ~/data/rlhf/gsm8k/test.parquet + prompt_key: prompt + reward_fn_key: data_source + max_prompt_length: 512 + max_response_length: 512 + train_batch_size: 1024 + val_batch_size: null + return_raw_input_ids: false + return_raw_chat: false + return_full_prompt: false + shuffle: true + dataloader_num_workers: 8 + validation_shuffle: false + filter_overlong_prompts: false + filter_overlong_prompts_workers: 1 + truncation: error + image_key: images + video_key: videos + trust_remote_code: false + custom_cls: + path: null + name: null + return_multi_modal_inputs: true + sampler: + class_path: null + class_name: null + datagen: + path: null + name: null +critic: + rollout_n: ${actor_rollout_ref.rollout.n} + strategy: fsdp + optim: + lr_warmup_steps_ratio: 0.0 + total_training_steps: -1 + weight_decay: 0.01 + lr: 1.0e-05 + min_lr_ratio: null + warmup_style: constant + model: + path: ~/models/deepseek-llm-7b-chat + tokenizer_path: ${actor_rollout_ref.model.path} + override_config: {} + external_lib: ${actor_rollout_ref.model.external_lib} + enable_gradient_checkpointing: true + trust_remote_code: ${actor_rollout_ref.model.trust_remote_code} + use_shm: false + enable_activation_offload: false + use_remove_padding: false + fsdp_config: + param_offload: false + optimizer_offload: false + offload_policy: false + reshard_after_forward: true + wrap_policy: + min_num_params: 0 + fsdp_size: -1 + forward_prefetch: false + lora_rank: 0 + lora_alpha: 16 + target_modules: all-linear + ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size} + ppo_micro_batch_size: null + ppo_micro_batch_size_per_gpu: null + use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + ppo_max_token_len_per_gpu: 32768 + forward_max_token_len_per_gpu: ${.ppo_max_token_len_per_gpu} + ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs} + shuffle: ${actor_rollout_ref.actor.shuffle} + cliprange_value: 0.5 + loss_agg_mode: ${actor_rollout_ref.actor.loss_agg_mode} + checkpoint: + save_contents: + - model + - optimizer + - extra + load_contents: ${.save_contents} + profiler: + _target_: verl.utils.profiler.ProfilerConfig + discrete: false + all_ranks: false + ranks: [] + forward_micro_batch_size: ${critic.ppo_micro_batch_size} + forward_micro_batch_size_per_gpu: ${critic.ppo_micro_batch_size_per_gpu} + ulysses_sequence_parallel_size: 1 + grad_clip: 1.0 +reward_model: + enable: false + strategy: fsdp + model: + input_tokenizer: ${actor_rollout_ref.model.path} + path: ~/models/FsfairX-LLaMA3-RM-v0.1 + external_lib: ${actor_rollout_ref.model.external_lib} + trust_remote_code: false + use_shm: false + use_remove_padding: false + use_fused_kernels: ${actor_rollout_ref.model.use_fused_kernels} + fsdp_config: + wrap_policy: + min_num_params: 0 + param_offload: false + reshard_after_forward: true + fsdp_size: -1 + forward_prefetch: false + micro_batch_size: null + micro_batch_size_per_gpu: null + max_length: null + use_dynamic_bsz: ${critic.use_dynamic_bsz} + forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu} + reward_manager: naive + launch_reward_fn_async: false + sandbox_fusion: + url: null + max_concurrent: 64 + memory_limit_mb: 1024 + profiler: + _target_: verl.utils.profiler.ProfilerConfig + discrete: false + all_ranks: false + ranks: [] + ulysses_sequence_parallel_size: 1 +custom_reward_function: + path: null + name: compute_score +algorithm: + _target_: verl.trainer.config.AlgoConfig + gamma: 1.0 + lam: 1.0 + adv_estimator: gae + norm_adv_by_std_in_grpo: true + use_kl_in_reward: false + kl_penalty: kl + kl_ctrl: + _target_: verl.trainer.config.KLControlConfig + type: fixed + kl_coef: 0.001 + horizon: 10000 + target_kl: 0.1 + filter_groups: + enable: false + metric: null + max_num_gen_batches: 0 + horizon: 10000 + target_kl: 0.1 + use_pf_ppo: false + pf_ppo: + _target_: verl.trainer.config.PFPPOConfig + reweight_method: pow + weight_pow: 2.0 +ray_init: + num_cpus: null + timeline_json_file: null diff --git a/verl/trainer/config/actor/actor.yaml b/verl/trainer/config/actor/actor.yaml new file mode 100644 index 000000000..d5402d870 --- /dev/null +++ b/verl/trainer/config/actor/actor.yaml @@ -0,0 +1,111 @@ +# Format checks enforced on CI: +# 1. Comments must appear above each field. +# 2. There must be a blank line between each field. +# 3. Inline comments (after a field on the same line) are not allowed. +# 4. Indentation level is respected for nested fields. + +# the abstract actor configs +# fsdp, fsdp2 or megatron. must be set. +strategy: ??? + +# Split each sample into sub-batches of this size for PPO +ppo_mini_batch_size: 256 + +# [Deprecated] Global micro batch size +ppo_micro_batch_size: null + +# Local per-GPU micro batch size +ppo_micro_batch_size_per_gpu: null + +# Whether to automatically adjust batch size at runtime +# oc.select: the default val for ref.log_prob_use_dynamic_bsz +use_dynamic_bsz: false + +# Max tokens per GPU in one PPO batch; affects gradient accumulation +# Typically it should be: n * ${data.max_prompt_length} + ${data.max_response_length} +# oc.select: the default val for ref.log_prob_max_token_len_per_gpu +ppo_max_token_len_per_gpu: 16384 + +# PPO clip ratio +clip_ratio: 0.2 + +# Lower bound for asymmetric clipping (used in dual-clip PPO) +clip_ratio_low: 0.2 + +# Upper bound for asymmetric clipping (used in dual-clip PPO) +clip_ratio_high: 0.2 + +# policy loss config +policy_loss: + + # Loss function mode: vanilla / clip-cov / kl-cov /gpg from https://arxiv.org/abs/2505.22617 + loss_mode: "vanilla" + + # Ratio of tokens to be clipped for clip-cov loss + clip_cov_ratio: 0.0002 + + # Lower bound for clip-cov loss + clip_cov_lb: 1.0 + + # Upper bound for clip-cov loss + clip_cov_ub: 5.0 + + # Ratio of tokens to be applied kl penalty for kl-cov loss + kl_cov_ratio: 0.0002 + + # KL divergence penalty coefficient + ppo_kl_coef: 0.1 + +# Constant C in Dual-clip PPO; clips when advantage < 0 and ratio > C +clip_ratio_c: 3.0 + +# Loss aggregation mode: "token-mean", "seq-mean-token-sum", or "seq-mean-token-mean" +loss_agg_mode: token-mean + +# Entropy regularization coefficient in PPO loss +entropy_coeff: 0 + +# Whether to use KL loss instead of KL reward penalty. True for GRPO +use_kl_loss: false + +# Whether to use torch.compile() +# oc.select: the default val for ref.use_torch_compile +use_torch_compile: true + +# KL loss coefficient when use_kl_loss is enabled. For GRPO +kl_loss_coef: 0.001 + +# Type of KL divergence loss. Options: "kl"(k1), "abs", "mse"(k2), "low_var_kl"(k3), "full" +kl_loss_type: low_var_kl + +# Number of PPO epochs per batch +ppo_epochs: 1 + +# Shuffle training data across PPO epochs +shuffle: false + +# checkpoint configs +checkpoint: + + # What to include in saved checkpoints + # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space + save_contents: ['model', 'optimizer', 'extra'] + + # For more flexibility, you can specify the contents to load from the checkpoint. + # .xxx refers to the local variable xxx from the same level of hierarchy similar to python pkg + load_contents: ${.save_contents} + +# optimizer configs +optim: + + # Learning rate + lr: 1e-6 + + # Warmup steps ratio (used if lr_warmup_steps is negative) + lr_warmup_steps_ratio: 0.0 + + # Total training steps (must be overridden at runtime) + total_training_steps: -1 + + # Weight decay + weight_decay: 0.01 diff --git a/verl/trainer/config/actor/dp_actor.yaml b/verl/trainer/config/actor/dp_actor.yaml new file mode 100644 index 000000000..f298c3cfa --- /dev/null +++ b/verl/trainer/config/actor/dp_actor.yaml @@ -0,0 +1,73 @@ +# Format checks enforced on CI: +# 1. Comments must appear above each field. +# 2. There must be a blank line between each field. +# 3. Inline comments (after a field on the same line) are not allowed. +# 4. Indentation level is respected for nested fields. + +# defaults specify the default config from each component +defaults: + + # dp actor config, inheriting from trainer/config/actor/actor.yaml + - actor + + # load the reference default config, then apply the fields in the current yaml + - _self_ + +# TODO(haibin.lin): switch to fsdp2 +strategy: fsdp + +# Gradient clipping for actor updates, specific to the strategy. +grad_clip: 1.0 + +# Sequence parallelism size for Ulysses-style model parallelism +# oc.select: the default val for ref.ulysses_sequence_parallel_size +ulysses_sequence_parallel_size: 1 + +# calculate entropy with chunking to reduce memory peak +entropy_from_logits_with_chunking: False + +# recompute entropy +entropy_checkpointing: False + +# optimizer configs +optim: + + # Warmup steps; negative value delegates to lr_warmup_steps_ratio + lr_warmup_steps: -1 + + # Minimum LR ratio for cosine schedule + min_lr_ratio: 0.0 + + # Number of cosine cycles in LR schedule + num_cycles: 0.5 + + # LR warmup style: "constant" or "cosine" + warmup_style: constant + +# configs for FSDP +fsdp_config: + + # policy for wrapping the model + wrap_policy: + + # Minimum number of parameters to trigger wrapping a layer with FSDP + min_num_params: 0 + + # Whether to offload model parameters to CPU (trades speed for memory) + param_offload: false + + # Whether to offload optimizer state to CPU + optimizer_offload: false + + # Only for FSDP2: offload param/grad/optimizer during train + offload_policy: false + + # Only for FSDP2: Reshard after forward pass to reduce memory footprint + reshard_after_forward: true + + # Number of GPUs in each FSDP shard group; -1 means auto + fsdp_size: -1 + + # Only for FSDP1: FSDP1 configuration, prefetch the next forward-pass all-gather + # before the current forward computation. + forward_prefetch: False diff --git a/verl/trainer/config/actor/megatron_actor.yaml b/verl/trainer/config/actor/megatron_actor.yaml new file mode 100644 index 000000000..6492eab2f --- /dev/null +++ b/verl/trainer/config/actor/megatron_actor.yaml @@ -0,0 +1,104 @@ +# megatron actor config, inheriting from trainer/config/actor/actor.yaml +defaults: + - actor + # load the reference default config, then apply the fields in the current yaml + - _self_ + +strategy: megatron + +data_loader_seed: null + +load_weight: True + +checkpoint: + + async_save: False + +optim: + + optimizer: adam + + clip_grad: 1.0 + + # initial learning rate for warmup, default to 0.0 + lr_warmup_init: 0.0 + + # Prioritized. None, 0 or Negative values mean delegating to lr_warmup_steps_ratio. + lr_warmup_steps: null + + lr_decay_steps: null + + # select from constant/linear/cosine/inverse_square_root + lr_decay_style: constant + + # minimum learning rate, default to 0.0 + min_lr: 0.0 + + # select from constant/linear/cosine + weight_decay_incr_style: constant + + # select from constant/exponential/cosine + lr_wsd_decay_style: exponential + + lr_wsd_decay_steps: null + + # use checkpoint optimizer parameter scheduler + use_checkpoint_opt_param_scheduler: False + +megatron: + + param_offload: False + + grad_offload: False + + optimizer_offload: False + + tensor_model_parallel_size: 1 + + expert_model_parallel_size: 1 + + expert_tensor_parallel_size: null + + pipeline_model_parallel_size: 1 + + virtual_pipeline_model_parallel_size: null + + context_parallel_size: 1 + + sequence_parallel: True + + use_distributed_optimizer: True + + use_dist_checkpointing: False + + dist_checkpointing_path: null + + # oc.select: default val for ref.megatron.seed + seed: 42 + + # Allow to override Distributed Data Parallel (DDP) config + override_ddp_config: {} + + # additional transformer config like: num_layers_in_first(/last)_pipeline_stage + # oc.select: default val for ref.megatron.override_transformer_config + override_transformer_config: {} + + # oc.select: default val for ref.megatron.use_mbridge + use_mbridge: False + +# profile the actor model in `update_policy` +profile: + # turn it on when you want to profile the actor model + use_profile: False + + # list, you can specify the ranks to profile + profile_ranks: null + + # start step in update_policy + step_start: -1 + + # end step + step_end: -1 + + # the path to save the profile result + save_path: null diff --git a/verl/trainer/config/algorithm.py b/verl/trainer/config/algorithm.py new file mode 100644 index 000000000..5bc6cf943 --- /dev/null +++ b/verl/trainer/config/algorithm.py @@ -0,0 +1,114 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Optional + +from verl.base_config import BaseConfig + + +@dataclass +class KLControlConfig(BaseConfig): + """Configuration for KL control. + + The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config. + + Args: + type (str): Type of KL control. Can be "fixed" or "adaptive". + kl_coef (float): Initial coefficient for KL penalty. + horizon (int): Horizon value for adaptive controller. + target_kl (float): Target KL divergence for adaptive controller. + """ + + _frozen_fields = ["type", "kl_coef", "horizon", "target_kl"] + type: str = "fixed" + kl_coef: float = 0.001 + horizon: int = 10000 + target_kl: float = 0.1 + + +@dataclass +class PFPPOConfig(BaseConfig): + """Configuration for preference feedback PPO. + + The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config. + + Args: + reweight_method (str): Method for reweighting samples. Can be "pow", "max_min", or "max_random". + weight_pow (float): Power used for weight scaling in "pow" method. + """ + + _frozen_fields = ["reweight_method", "weight_pow"] + reweight_method: str = "pow" + weight_pow: float = 2.0 + + +@dataclass +class FilterGroupsConfig(BaseConfig): + """Configuration for filter groups (used in DAPO and Entropy). + + The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config. + + Args: + enable (bool): Whether to enable filter groups. + metric (Optional[str]): Metric to use for filtering: "acc", "score", "seq_reward", "seq_final_reward", etc. + max_num_gen_batches (int): Non-positive values mean no upper limit. + """ + + _frozen_fields = ["enable", "metric", "max_num_gen_batches"] + + enable: bool = False + metric: Optional[str] = None + max_num_gen_batches: int = 0 + + +@dataclass +class AlgoConfig(BaseConfig): + """Configuration for the algorithm. + + The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config. + + Args: + gamma (float): Discount factor for future rewards. + lam (float): Trade-off between bias and variance in the GAE estimator. + adv_estimator (str): Advantage estimator type: "gae", "grpo", "reinforce_plus_plus", etc. + norm_adv_by_std_in_grpo (bool): Whether to normalize advantages by std (specific to GRPO). + use_kl_in_reward (bool): Whether to enable in-reward KL penalty. + kl_penalty (str): How to estimate KL divergence: "kl", "abs", "mse", "low_var_kl", or "full". + kl_ctrl (KLControlConfig): KL control configuration. + use_pf_ppo (bool): Whether to enable preference feedback PPO. + pf_ppo (Optional[PFPPOConfig]): Preference feedback PPO settings. + filter_groups (Optional[FilterGroupsConfig]): Filter groups configuration, used in DAPO and Entropy + """ + + _frozen_fields = [ + "gamma", + "lam", + "adv_estimator", + "norm_adv_by_std_in_grpo", + "use_kl_in_reward", + "kl_penalty", + "use_pf_ppo", + ] + + gamma: float = 1.0 + lam: float = 1.0 + adv_estimator: str = "gae" + norm_adv_by_std_in_grpo: bool = True + use_kl_in_reward: bool = False + kl_penalty: str = "kl" + kl_ctrl: KLControlConfig = field(default_factory=KLControlConfig) + use_pf_ppo: bool = False + pf_ppo: Optional[PFPPOConfig] = None + filter_groups: Optional[FilterGroupsConfig] = None diff --git a/verl/trainer/config/critic/critic.yaml b/verl/trainer/config/critic/critic.yaml new file mode 100644 index 000000000..a02fca231 --- /dev/null +++ b/verl/trainer/config/critic/critic.yaml @@ -0,0 +1,94 @@ +# Number of rollouts per update (mirrors actor rollout_n) +rollout_n: ${actor_rollout_ref.rollout.n} + +# fsdp or fsdp2 strategy used for critic model training +strategy: ??? + +# optimizer configs +optim: + + # Warmup steps ratio; total steps will be injected at runtime + lr_warmup_steps_ratio: 0.0 + + # Total training steps (must be overridden at runtime) + total_training_steps: -1 + + # Weight decay + weight_decay: 0.01 + +# model config for the critic +model: + + # Path to pretrained model weights + path: ~/models/deepseek-llm-7b-chat + + # Tokenizer path (defaults to actor's model path) + tokenizer_path: ${actor_rollout_ref.model.path} + + # Hugging Face config override + override_config: {} + + # External model implementation (optional) + external_lib: ${actor_rollout_ref.model.external_lib} + + # Enable gradient checkpointing to save memory + enable_gradient_checkpointing: True + + # Whether to trust remote code from Hugging Face models + trust_remote_code: ${actor_rollout_ref.model.trust_remote_code} + +# PPO mini-batch size per update +ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size} + +# [Deprecated] Global micro batch size +ppo_micro_batch_size: null + +# Local per-GPU micro batch size +ppo_micro_batch_size_per_gpu: null + +# Whether to automatically adjust batch size at runtime +use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + +# Max tokens per GPU in one PPO batch (doubled for critic) +ppo_max_token_len_per_gpu: 32768 + +# Max token length per GPU in forward pass +forward_max_token_len_per_gpu: ${.ppo_max_token_len_per_gpu} + +# Number of PPO epochs per batch +ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs} + +# Shuffle training data across PPO epochs +shuffle: ${actor_rollout_ref.actor.shuffle} + +# PPO value function clipping range +cliprange_value: 0.5 + +# Loss aggregation mode: "token-mean", "seq-mean-token-sum", or "seq-mean-token-mean" +loss_agg_mode: ${actor_rollout_ref.actor.loss_agg_mode} + +# checkpoint configs +checkpoint: + + # What to include in saved checkpoints + # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space + save_contents: ['model', 'optimizer', 'extra'] + + # What to include when loading checkpoints + load_contents: ${.save_contents} + +# profiler configs +# the corresponding dataclass is verl.utils.profiler.ProfilerConfig. +profiler: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint + _target_: verl.utils.profiler.ProfilerConfig + + # True for each task has its own database, False for all tasks in one training step share one database. + discrete: False + + # Whether to profile all ranks. + all_ranks: False + + # The ranks that will be profiled. [] or [0,1,...] + ranks: [] \ No newline at end of file diff --git a/verl/trainer/config/critic/dp_critic.yaml b/verl/trainer/config/critic/dp_critic.yaml new file mode 100644 index 000000000..88efe143a --- /dev/null +++ b/verl/trainer/config/critic/dp_critic.yaml @@ -0,0 +1,89 @@ +# Format checks enforced on CI: +# 1. Comments must appear above each field. +# 2. There must be a blank line between each field. +# 3. Inline comments (after a field on the same line) are not allowed. +# 4. Indentation level is respected for nested fields. + +# defaults specify the default config from each component +defaults: + + # dp actor config, inheriting from trainer/config/critic/critic.yaml + - critic + + # load the reference default config, then apply the fields in the current yaml + - _self_ + +strategy: fsdp + +# optimizer configs +optim: + + # Learning rate + lr: 1e-5 + + # Minimum LR ratio for cosine schedule + min_lr_ratio: null + + # LR warmup style: "constant" or "cosine" + warmup_style: constant + +# model config for the critic +model: + + # Whether to use shared memory for loading the model + use_shm: False + + # Offload activations to CPU to reduce GPU memory usage + enable_activation_offload: False + + # Use remove padding optimization (saves compute) + use_remove_padding: False + + # FSDP-specific config + fsdp_config: + + # Whether to offload model parameters to CPU + param_offload: False + + # Whether to offload optimizer state to CPU + optimizer_offload: False + + # Only for FSDP2: offload param/grad/optimizer during train + offload_policy: False + + # Only for FSDP2: Reshard after forward pass to reduce memory footprint + reshard_after_forward: True + + # Policy for wrapping layers with FSDP + wrap_policy: + + # Minimum number of parameters to trigger wrapping + min_num_params: 0 + + # Number of GPUs in each FSDP shard group; -1 means auto + fsdp_size: -1 + + # Only for FSDP1: FSDP1 configuration, prefetch the next forward-pass all-gather + # before the current forward computation. + forward_prefetch: False + + # Set to positive value to enable LoRA (e.g., 32) + lora_rank: 0 + + # LoRA scaling factor + lora_alpha: 16 + + # LoRA target modules: "all-linear" or list of linear projection layers + target_modules: all-linear + +# Forward-only batch size during inference (global) +forward_micro_batch_size: ${critic.ppo_micro_batch_size} + +# Forward-only batch size during inference (per GPU) +forward_micro_batch_size_per_gpu: ${critic.ppo_micro_batch_size_per_gpu} + +# Sequence parallelism size for Ulysses-style model parallelism +ulysses_sequence_parallel_size: 1 + +# Gradient clipping for critic updates +grad_clip: 1.0 \ No newline at end of file diff --git a/verl/trainer/config/critic/megatron_critic.yaml b/verl/trainer/config/critic/megatron_critic.yaml new file mode 100644 index 000000000..7db811edc --- /dev/null +++ b/verl/trainer/config/critic/megatron_critic.yaml @@ -0,0 +1,138 @@ +# defaults specify the default config from each component +defaults: + + # dp actor config, inheriting from trainer/config/critic/critic.yaml + - critic + + # load the reference default config, then apply the fields in the current yaml + - _self_ + +strategy: megatron + +# seconds, default is 10 minutes for torch, you can set it to a larger value if you have long-running operations like 32B or 72B model using megatron +nccl_timeout: 600 + +# optimizer configs +optim: + + # select optimizer, default is Adam + optimizer: adam + + # Learning rate + lr: 1e-6 + + # Clip gradients norm + clip_grad: 1.0 + + # initial learning rate for warmup, default to 0.0 + lr_warmup_init: 0.0 + + # Prioritized. None, 0 or Negative values mean delegating to lr_warmup_steps_ratio. + lr_warmup_steps: null + + lr_decay_steps: null + + # select from constant/linear/cosine/inverse_square_root + lr_decay_style: linear + + # minimum learning rate, default to 0.0 + min_lr: 0.0 + + # select from constant/linear/cosine + weight_decay_incr_style: constant + + # select from constant/exponential/cosine + lr_wsd_decay_style: exponential + + # number of steps for weight std decay + lr_wsd_decay_steps: null + + # use checkpoint optimizer parameter scheduler + use_checkpoint_opt_param_scheduler: False + +# model config for the critic +model: + + # override default empty mapping + override_config: + model_config: {} + moe_config: + freeze_moe_router: False + + # Enable gradient checkpointing to save memory + enable_gradient_checkpointing: False + + # Activation Checkpointing settings + gradient_checkpointing_kwargs: + activations_checkpoint_method: null + activations_checkpoint_granularity: null + activations_checkpoint_num_layers: null + +# megatron-specific parallelism settings +megatron: + + # Whether to offload model parameters to CPU + param_offload: False + + # Whether to offload gradients to CPU + grad_offload: False + + # Whether to offload optimizer state to CPU + optimizer_offload: False + + # size of tensor model parallel group + tensor_model_parallel_size: 1 + + # size of expert model parallel group + expert_model_parallel_size: 1 + + # size of expert tensor parallel group + expert_tensor_parallel_size: null + + # size of pipeline model parallel group + pipeline_model_parallel_size: 1 + + # size of virtual pipeline model parallel group + virtual_pipeline_model_parallel_size: null + + # size of context parallel group + context_parallel_size: 1 + + # Whether to use sequence parallelism + sequence_parallel: True + + # Whether to use distributed optimizer + use_distributed_optimizer: True + + # Whether to use distributed checkpointing + use_dist_checkpointing: False + + # Path for distributed checkpointing + dist_checkpointing_path: null + + # Random seed for Megatron + seed: ${actor_rollout_ref.actor.megatron.seed} + + # Allow to override Distributed Data Parallel (DDP) config + override_ddp_config: ${actor_rollout_ref.actor.megatron.override_ddp_config} + + # Transformer config overrides for Megatron + override_transformer_config: ${actor_rollout_ref.actor.megatron.override_transformer_config} + + # Whether to use mBridge communications + use_mbridge: ${actor_rollout_ref.actor.megatron.use_mbridge} + +# Whether to load initial weights +load_weight: True + +# seed for data loader +data_loader_seed: ${actor_rollout_ref.actor.data_loader_seed} + +# KL control settings +kl_ctrl: + type: fixed + kl_coef: 0.001 + +# Asynchronous checkpoint saving +checkpoint: + async_save: False \ No newline at end of file diff --git a/verl/trainer/config/data/legacy_data.yaml b/verl/trainer/config/data/legacy_data.yaml new file mode 100644 index 000000000..9a5ce8f0d --- /dev/null +++ b/verl/trainer/config/data/legacy_data.yaml @@ -0,0 +1,109 @@ +# Tokenizer class or path. If null, it will be inferred from the model. +tokenizer: null + +# Whether to use shared memory for data loading. +use_shm: False + +# Training set parquet. Can be a list or a single file. +# The program will read all files into memory, so it can't be too large (< 100GB). +# The path can be either a local path or an HDFS path. +# For HDFS path, we provide utils to download it to DRAM and convert it to a local path. +train_files: ~/data/rlhf/gsm8k/train.parquet + +# Validation parquet. Can be a list or a single file. +val_files: ~/data/rlhf/gsm8k/test.parquet + +# The field in the dataset where the prompt is located. Default is 'prompt'. +prompt_key: prompt + +# The field used to select the reward function (if using different ones per example). +reward_fn_key: data_source + +# Maximum prompt length. All prompts will be left-padded to this length. +# An error will be reported if the length is too long. +# oc.select: default val for rollout.prompt_length +max_prompt_length: 512 + +# Maximum response length. Rollout in RL algorithms (e.g. PPO) generates up to this length. +# oc.select: default val for rollout.response_length +max_response_length: 512 + +# Batch size sampled for one training iteration of different RL algorithms. +train_batch_size: 1024 + +# Batch size used during validation. Can be null. +val_batch_size: null + +# Whether to return the original input_ids without adding chat template. +# This is used when the reward model's chat template differs from the policy. +# If using a model-based RM with different templates, this should be True. +return_raw_input_ids: False + +# Whether to return the original chat (prompt) without applying chat template. +return_raw_chat: False + +# Whether to return the full prompt with chat template. +return_full_prompt: False + +# Whether to shuffle the data in the dataloader. +shuffle: True + +# num dataloader workers +dataloader_num_workers: 8 + +# Whether to shuffle the validation set. +validation_shuffle: False + +# Whether to filter overlong prompts. +filter_overlong_prompts: False + +# Number of workers for filtering overlong prompts. +# For large-scale datasets, filtering can be time-consuming. +# Use multiprocessing to speed up. Default is 1. +filter_overlong_prompts_workers: 1 + +# Truncate the input_ids or prompt if they exceed max_prompt_length. +# Options: 'error', 'left', 'right', 'middle'. Default is 'error'. +truncation: error + +# The field in the multi-modal dataset where the image is located. Default is 'images'. +image_key: images + +# The field in the multi-modal dataset where the video is located. +video_key: videos + +# If the remote tokenizer has a Python file, this flag determines whether to allow using it. +trust_remote_code: False + +# Optional: specify a custom dataset class path and name if overriding default loading behavior. +custom_cls: + + # The path to the file containing your customized dataset class. If not specified, pre-implemented dataset will be used. + path: null + + # The name of the dataset class within the specified file. + name: null + +# Whether to return multi-modal inputs in the dataset. Set to False if rollout generates new multi-modal inputs. +return_multi_modal_inputs: True + +# settings related to data sampler +sampler: + + # the path to the module containing a curriculum class which implements the + # AbstractSampler interface + class_path: null + + # the name of the curriculum class like `MySampler` + class_name: null + +# Data generation configuration for augmenting the dataset. +datagen: + + # The path to the file containing your customized data generation class. + # E.g. 'pkg://verl.experimental.dynamic_dataset.dynamicgen_dataset' + path: null + + # The class name of the data generation class within the specified file. + # E.g. 'MockDataGenerator' + name: null \ No newline at end of file diff --git a/verl/trainer/config/evaluation.yaml b/verl/trainer/config/evaluation.yaml index 1bd9f4e93..efca03da4 100644 --- a/verl/trainer/config/evaluation.yaml +++ b/verl/trainer/config/evaluation.yaml @@ -10,4 +10,5 @@ custom_reward_function: name: compute_score ray_init: - num_cpus: null # `None` means using all CPUs, which might cause hang if limited in systems like SLURM. Please set to a number allowed then. \ No newline at end of file + num_cpus: null # `None` means using all CPUs, which might cause hang if limited in systems like SLURM. Please set to a number allowed then. + timeline_json_file: null diff --git a/verl/trainer/config/generation.yaml b/verl/trainer/config/generation.yaml index b70126839..c19cfed95 100644 --- a/verl/trainer/config/generation.yaml +++ b/verl/trainer/config/generation.yaml @@ -1,6 +1,7 @@ trainer: nnodes: 1 n_gpus_per_node: 8 + device: cuda data: path: ./data/test/simulation__cruxeval-o_800.parquet @@ -33,18 +34,22 @@ rollout: max_num_seqs: 1024 log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu log_prob_micro_batch_size_per_gpu: 8 - # for fire vllm rollout - use_fire_sampling: False # enable FIRE https://arxiv.org/abs/2410.21236 # for hf rollout do_sample: True disable_log_stats: True enable_chunked_prefill: True n: 1 + # support logging rollout prob for debugging purpose + calculate_log_probs: False actor: strategy: fsdp # This is for backward-compatibility ulysses_sequence_parallel_size: 1 # sp size + entropy_from_logits_with_chunking: False # calculate entropy with chunking to reduce memory peak + entropy_checkpointing: False # recompute entropy fsdp_config: fsdp_size: -1 + forward_prefetch: False # FSDP1 forward_prefetch configuration ray_init: num_cpus: null # `None` means using all CPUs, which might cause hang if limited in systems like SLURM. Please set to a number allowed then. + timeline_json_file: null diff --git a/verl/trainer/config/npu_profile/npu_profile.yaml b/verl/trainer/config/npu_profile/npu_profile.yaml new file mode 100644 index 000000000..b61260375 --- /dev/null +++ b/verl/trainer/config/npu_profile/npu_profile.yaml @@ -0,0 +1,29 @@ +# Options for the npu profiler +options: + + # Storage path of collected data. + save_path: ./profiler_data + + # Collection level, optional values: level_none, level0, level1, level2. + level: level1 + + # Whether to enable memory analysis. + with_memory: False + + # Whether to record tensor shape. + record_shapes: False + + # Whether to record Device-side performance data. + with_npu: True + + # Whether to record Host-side performance data. + with_cpu: True + + # Whether to record Python call stack information. + with_module: False + + # Whether to record operator call stack information. + with_stack: False + + # Whether to automatically parse the data. + analysis: True \ No newline at end of file diff --git a/verl/trainer/config/ppo_megatron_trainer.yaml b/verl/trainer/config/ppo_megatron_trainer.yaml index c916f1059..edafae297 100644 --- a/verl/trainer/config/ppo_megatron_trainer.yaml +++ b/verl/trainer/config/ppo_megatron_trainer.yaml @@ -1,286 +1,89 @@ -data: - tokenizer: null - train_files: ~/data/rlhf/gsm8k/train.parquet - val_files: ~/data/rlhf/gsm8k/test.parquet - prompt_key: prompt - reward_fn_key: data_source - max_prompt_length: 512 - max_response_length: 512 - train_batch_size: 1024 - gen_batch_size: ${data.train_batch_size} - val_batch_size: null # DEPRECATED: Validation datasets are sent to inference engines as a whole batch, which will schedule the memory themselves - return_raw_input_ids: False # This should be set to true when the tokenizer between policy and rm differs - return_raw_chat: False - return_full_prompt: False - shuffle: True - filter_overlong_prompts: False # By Reasoning360. Originally False # for large-scale dataset, filtering overlong prompts could be timeconsuming. You cat set the filter_overlong_prompts_workers to use multiprocessing to speed up. - filter_overlong_prompts_workers: 1 - truncation: error - trust_remote_code: True # main_ppo will check this config to determine whether to use remote code for tokenizer - custom_cls: - path: null - name: null +# specify the default per-component configs +defaults: + + # @.: + # actor_rollout_ref.actor: trainer/config/actor/megatron_actor.yaml + - actor@actor_rollout_ref.actor: megatron_actor + # trainer.npu_profile: trainer/config/npu_profile/npu_profile.yaml + - npu_profile@trainer.npu_profile: npu_profile + # data: trainer/config/data/legacy_data.yaml + - data@data: legacy_data + # load the reference default config, then apply the fields in the current yaml + # Reference model config. + # Reference model will be enabled when actor.use_kl_loss or/and algorithm.use_kl_in_reward is/are True. + - ref@actor_rollout_ref.ref: megatron_ref + # Rollout model config. + - rollout@actor_rollout_ref.rollout: rollout + # Critic model config. + - critic@critic: megatron_critic + # Reward model config. + - reward_model@reward_model: megatron_reward_model + - _self_ actor_rollout_ref: hybrid_engine: True + + nccl_timeout: 600 # seconds, default is 10 minutes for torch, you can set it to a larger value if you have long-running operations like 32B or 72B model using megatron + model: + path: ~/models/deepseek-llm-7b-chat + + custom_chat_template: null + external_lib: null + override_config: + model_config: {} + moe_config: + freeze_moe_router: False + enable_gradient_checkpointing: False + gradient_checkpointing_kwargs: + ## Activation Checkpointing activations_checkpoint_method: null # 'uniform', 'block'; not used with 'selective' + # 'uniform' divides the total number of transformer layers and checkpoints the input activation of each chunk # 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity activations_checkpoint_granularity: null # 'selective' or 'full' + # 'full' will checkpoint the entire transformer layer and 'selective' only checkpoints memory intensive part of attention activations_checkpoint_num_layers: null # not used with 'selective' + + use_fused_kernels: False # Whether to use custom fused kernels (PostProcessing, for memory efficiency) + trust_remote_code: False - actor: - strategy: megatron # This is for backward-compatibility - ppo_mini_batch_size: 256 - ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu - ppo_micro_batch_size_per_gpu: null - use_dynamic_bsz: False - ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length} - use_torch_compile: True # False to disable torch compile - # pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, 1 + cliprange_high) - clip_ratio: 0.2 # default value if clip_ratio_low and clip_ratio_high are not specified - clip_ratio_low: 0.2 - clip_ratio_high: 0.2 - clip_ratio_c: 3.0 # lower bound of the value for Dual-clip PPO from https://arxiv.org/pdf/1912.09729 - loss_agg_mode: "token-mean" # / "seq-mean-token-sum" / "seq-mean-token-mean" - # NOTE: "token-mean" is the default behavior - entropy_coeff: 0 - use_kl_loss: False # True for GRPO - kl_loss_coef: 0.001 # for grpo - kl_loss_type: low_var_kl # for grpo - ppo_epochs: 1 - data_loader_seed: null - shuffle: False - optim: - lr: 1e-6 - clip_grad: 1.0 - lr_warmup_steps: -1 # Prioritized. Negative values mean delegating to lr_warmup_steps_ratio. - lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime - min_lr_ratio: null # only useful for warmup with cosine - warmup_style: constant # select from constant/cosine - total_training_steps: -1 # must be override by program - weight_decay: 0.01 - megatron: - param_offload: False - grad_offload: False - optimizer_offload: False - tensor_model_parallel_size: 1 - expert_model_parallel_size: 1 - expert_tensor_parallel_size: null - pipeline_model_parallel_size: 1 - virtual_pipeline_model_parallel_size: null # change VPP interface for parallelism tests - context_parallel_size: 1 - sequence_parallel: True - use_distributed_optimizer: True - use_dist_checkpointing: False - dist_checkpointing_path: null - seed: 42 - override_transformer_config: {} # additional transformer config like: num_layers_in_first(/last)_pipeline_stage - profile: # profile the actor model in `update_policy` - use_profile: False # open it when you want to profile the actor model - profile_ranks: null # list, you can specify the ranks to profile - step_start: -1 # start step in update_policy - step_end: -1 # end step - save_path: null # the path to save the profile result - load_weight: True - checkpoint: - contents: ['model', 'optimizer', 'extra'] # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space - ref: - strategy: megatron - use_torch_compile: ${actor_rollout_ref.actor.use_torch_compile} - megatron: - param_offload: False - tensor_model_parallel_size: 1 - expert_model_parallel_size: 1 - expert_tensor_parallel_size: None - pipeline_model_parallel_size: 1 - virtual_pipeline_model_parallel_size: null # change VPP interface for parallelism tests - context_parallel_size: 1 - sequence_parallel: True - use_distributed_optimizer: False - use_dist_checkpointing: False - dist_checkpointing_path: null - seed: ${actor_rollout_ref.actor.megatron.seed} - override_transformer_config: ${actor_rollout_ref.actor.megatron.override_transformer_config} - profile: - use_profile: False - profile_ranks: null - step_start: -1 - step_end: -1 - save_path: null - load_weight: True - log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu - log_prob_micro_batch_size_per_gpu: null - log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} - log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} + rollout: - name: vllm - mode: sync # sync: LLM, async: AsyncLLM - temperature: 1.0 - top_k: -1 # 0 for hf rollout, -1 for vllm rollout - top_p: 1 - prompt_length: ${data.max_prompt_length} # for xperf_gpt - response_length: ${data.max_response_length} - # for vllm rollout - dtype: bfloat16 # should align with FSDP - gpu_memory_utilization: 0.5 - ignore_eos: False - enforce_eager: True - free_cache_engine: True + # may get higher throughput when set to True. When activated, Please increase max_num_batched_tokens or decrease max_model_len. + enable_chunked_prefill: False + load_format: dummy_megatron + tensor_model_parallel_size: 1 - max_num_batched_tokens: 8192 - max_model_len: null - max_num_seqs: 1024 - log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu - log_prob_micro_batch_size_per_gpu: null - log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} - log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} - disable_log_stats: True - enable_chunked_prefill: False # could get higher throughput - # for hf rollout - do_sample: True + layer_name_map: qkv_layer_name: qkv gate_proj_layer_name: gate_up - # number of responses (i.e. num sample times) - n: 1 - engine_kwargs: # inference engine parameters - vllm: - swap_space: null # null means "use the engine default value" (usually 4 GB), setting it to, e.g., 32 means 32 GB - sglang: - attention_backend: null # null means use the engine default value, available options: flashinfer, triton, flashmla - val_kwargs: - # sampling parameters for validation - top_k: -1 # 0 for hf rollout, -1 for vllm rollout - top_p: 1.0 - temperature: 0 - n: 1 - do_sample: False # default eager for validation - multi_turn: - enable: False # set to True for multi-turn tool interaction tasks; should set rollout.name to sglang as well - max_turns: null # null for no limit (default max_length // 3) - tool_config_path: null # null for no tool - format: chatml # chatml, more formats will be supported in the future -critic: - rollout_n: ${actor_rollout_ref.rollout.n} - strategy: megatron - optim: - lr: 1e-5 - clip_grad: 1.0 - lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime - min_lr_ratio: null # only useful for warmup with cosine - warmup_style: constant # select from constant/cosine - total_training_steps: -1 # must be override by program - weight_decay: 0.01 - model: - path: ~/models/deepseek-llm-7b-chat - tokenizer_path: ${actor_rollout_ref.model.path} - override_config: - model_config: {} - moe_config: - freeze_moe_router: False - external_lib: ${actor_rollout_ref.model.external_lib} - trust_remote_code: False - enable_gradient_checkpointing: False - gradient_checkpointing_kwargs: - ## Activation Checkpointing - activations_checkpoint_method: null - activations_checkpoint_granularity: null - activations_checkpoint_num_layers: null - megatron: - param_offload: False - grad_offload: False - optimizer_offload: False - tensor_model_parallel_size: 1 - expert_model_parallel_size: 1 - expert_tensor_parallel_size: null - pipeline_model_parallel_size: 1 - virtual_pipeline_model_parallel_size: null # change VPP interface for parallelism tests - context_parallel_size: 1 - sequence_parallel: True - use_distributed_optimizer: True - use_dist_checkpointing: False - dist_checkpointing_path: null - seed: ${actor_rollout_ref.actor.megatron.seed} - override_transformer_config: ${actor_rollout_ref.actor.megatron.override_transformer_config} - load_weight: True - ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size} - ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu - ppo_micro_batch_size_per_gpu: null - use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} - ppo_max_token_len_per_gpu: 32768 # (${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}) * 2 - forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu} - ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs} - data_loader_seed: ${actor_rollout_ref.actor.data_loader_seed} - shuffle: ${actor_rollout_ref.actor.shuffle} - cliprange_value: 0.5 - kl_ctrl: - type: fixed - kl_coef: 0.001 - loss_agg_mode: ${actor_rollout_ref.actor.loss_agg_mode} - checkpoint: - contents: ['model', 'optimizer', 'extra'] # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space - -reward_model: - enable: False - strategy: megatron - megatron: - param_offload: False - tensor_model_parallel_size: 1 - expert_model_parallel_size: 1 - expert_tensor_parallel_size: null - pipeline_model_parallel_size: 1 - virtual_pipeline_model_parallel_size: null # change VPP interface for parallelism tests - context_parallel_size: 1 - sequence_parallel: True - use_distributed_optimizer: False - use_dist_checkpointing: False - dist_checkpointing_path: null - seed: ${actor_rollout_ref.actor.megatron.seed} - override_transformer_config: {} - model: - input_tokenizer: ${actor_rollout_ref.model.path} # set this to null if the chat template is identical - path: ~/models/FsfairX-LLaMA3-RM-v0.1 - trust_remote_code: False - external_lib: ${actor_rollout_ref.model.external_lib} - load_weight: True - micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu - micro_batch_size_per_gpu: null - use_dynamic_bsz: ${critic.use_dynamic_bsz} - forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu} - max_length: null - reward_manager: naive - launch_reward_fn_async: False # custom reward function executed async on CPU, during log_prob - sandbox_fusion: - url: null # faas url to run code in cloud sandbox - max_concurrent: 64 # max concurrent requests to sandbox -custom_reward_function: - path: null - name: compute_score + profiler: + _target_: verl.utils.profiler.ProfilerConfig + discrete: False + all_ranks: False + ranks: [] custom_reward_function: path: null name: compute_score - overlong_buffer: - enable: False # We try to avoid forgetting to set enable - len: 0 - penalty_factor: 0.0 - log: False algorithm: + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint + _target_: verl.trainer.config.AlgoConfig gamma: 1.0 lam: 1.0 adv_estimator: gae @@ -288,23 +91,24 @@ algorithm: use_kl_in_reward: False kl_penalty: kl # how to estimate kl divergence kl_ctrl: + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint + _target_: verl.trainer.config.KLControlConfig type: fixed kl_coef: 0.001 horizon: 10000 target_kl: 0.1 use_pf_ppo: False pf_ppo: + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint + _target_: verl.trainer.config.PFPPOConfig reweight_method: pow # ["pow", "max_min", "max_random"] weight_pow: 2.0 - filter_groups: - enable: False # We try to avoid forgetting to set enable - metric: null # acc / score / seq_reward / seq_final_reward / ... - max_num_gen_batches: 0 # Non-positive values mean no upper limit trainer: balance_batch: True total_epochs: 30 total_training_steps: null + profile_steps: null # [1,2,5] or [] or null project_name: verl_examples experiment_name: gsm8k logger: ['console', 'wandb'] @@ -312,6 +116,8 @@ trainer: nnodes: 1 n_gpus_per_node: 8 save_freq: -1 + esi_redundant_time: 0 + # auto: find the last ckpt to resume. If can't find, start from scratch resume_mode: auto # or disable or resume_path if resume_from_path is set resume_from_path: null @@ -326,6 +132,18 @@ trainer: # The timeout for ray worker group to wait for the register center to be ready ray_wait_register_center_timeout: 300 device: cuda - + # see ppo_trainer.yaml for more details + controller_nsight_options: + trace: "cuda,nvtx,cublas,ucx" + cuda-memory-usage: "true" + cuda-graph-trace: "graph" + worker_nsight_options: + trace: "cuda,nvtx,cublas,ucx" + cuda-memory-usage: "true" + cuda-graph-trace: "graph" + capture-range: "cudaProfilerApi" + capture-range-end: null + kill: none ray_init: num_cpus: null # `None` means using all CPUs, which might cause hang if limited in systems like SLURM. Please set to a number allowed then. + timeline_json_file: null diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index f22622965..925872739 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -1,286 +1,338 @@ -data: - tokenizer: null - use_shm: False - train_files: ~/data/rlhf/gsm8k/train.parquet - val_files: ~/data/rlhf/gsm8k/test.parquet - prompt_key: prompt - reward_fn_key: data_source - max_prompt_length: 512 - max_response_length: 512 - gen_batch_size: ${data.train_batch_size} - train_batch_size: 1024 - val_batch_size: null - return_raw_input_ids: False # This should be set to true when the tokenizer between policy and rm differs - return_raw_chat: False - return_full_prompt: False - shuffle: True - filter_overlong_prompts: False # By Reasoning360. Originally False # for large-scale dataset, filtering overlong prompts could be timeconsuming. You cat set the filter_overlong_prompts_workers to use multiprocessing to speed up. - filter_overlong_prompts_workers: 1 - truncation: error - image_key: images - video_key: videos - trust_remote_code: True # main_ppo will check this config to determine whether to use remote code for tokenizer - custom_cls: - path: null - name: null +# Format checks enforced on CI: +# 1. Comments must appear above each field. +# 2. There must be a blank line between each field. +# 3. Inline comments (after a field on the same line) are not allowed. +# 4. Indentation level is respected for nested fields. +# specify the default per-component configs +defaults: + + # @.: + # actor_rollout_ref.actor: trainer/config/actor/dp_actor.yaml + - actor@actor_rollout_ref.actor: dp_actor + + # trainer.npu_profile: trainer/config/npu_profile/npu_profile.yaml + - npu_profile@trainer.npu_profile: npu_profile + + # data: trainer/config/data/legacy_data.yaml + - data@data: legacy_data + + # Reference model config. + # Reference model will be enabled when actor.use_kl_loss or/and algorithm.use_kl_in_reward is/are True. + - ref@actor_rollout_ref.ref: dp_ref + + # Rollout model config. + - rollout@actor_rollout_ref.rollout: rollout + + # Critic model config. + - critic@critic: dp_critic + + # Reward model config. + - reward_model@reward_model: dp_reward_model + + # load the reference default config, then apply the fields in the current yaml + # self config override anything above + - _self_ + +# config for actor, rollout and reference model actor_rollout_ref: - hybrid_engine: True + + # Whether it's a hybrid engine, currently only supports hybrid engine + hybrid_engine: true + + # common configs for the model model: + + # Huggingface model path. This can be either local path or HDFS path. path: ~/models/deepseek-llm-7b-chat - use_shm: False + + # Custom chat template for the model. + custom_chat_template: null + + # Whether to use shared memory (SHM) for accelerating the loading of model weights + use_shm: false + + # Additional Python packages to register huggingface models/tokenizers. external_lib: null - override_config: { } - enable_gradient_checkpointing: True - enable_activation_offload: False - use_remove_padding: False - lora_rank: 0 # Set to positive value to enable LoRA (e.g., 32) - lora_alpha: 16 # LoRA scaling factor - target_modules: all-linear # all-linear or [q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj] - use_liger: False - use_fused_kernels: False - trust_remote_code: False - actor: - strategy: fsdp # [fsdp, fsdp2], This is for backward-compatibility - ppo_mini_batch_size: 256 - ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu - ppo_micro_batch_size_per_gpu: null - use_dynamic_bsz: False - ppo_max_token_len_per_gpu: 32768 # n * ${data.max_prompt_length} + ${data.max_response_length} - grad_clip: 1.0 - # pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, 1 + cliprange_high) - clip_ratio: 0.2 # default value if clip_ratio_low and clip_ratio_high are not specified - clip_ratio_low: 0.2 - clip_ratio_high: 0.2 - clip_ratio_c: 3.0 # lower bound of the value for Dual-clip PPO from https://arxiv.org/pdf/1912.09729 - loss_agg_mode: "token-mean" # / "seq-mean-token-sum" / "seq-mean-token-mean" - entropy_coeff: 0 - use_kl_loss: False # True for GRPO - use_torch_compile: True # False to disable torch compile - kl_loss_coef: 0.001 # for grpo - kl_loss_type: low_var_kl # for grpo - ppo_epochs: 1 - shuffle: False - ulysses_sequence_parallel_size: 1 # sp size - checkpoint: - contents: ['model', 'optimizer', 'extra'] # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space - optim: - lr: 1e-6 - lr_warmup_steps: -1 # Prioritized. Negative values mean delegating to lr_warmup_steps_ratio. - lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime - min_lr_ratio: 0.0 # only used with cosine lr scheduler, default to 0.0 - num_cycles: 0.5 # only used with cosine lr scheduler, default to 0.5 - warmup_style: constant # select from constant/cosine - total_training_steps: -1 # must be override by program - weight_decay: 0.01 - fsdp_config: - wrap_policy: - # transformer_layer_cls_to_wrap: None - min_num_params: 0 - param_offload: False - optimizer_offload: False - offload_policy: False # only for fsdp2, offload param\grad\optimizer during train - reshard_after_forward: True # only for fsdp2, [True, False, int between 1 and fsdp_size] - fsdp_size: -1 - ref: - strategy: fsdp - fsdp_config: - param_offload: False - reshard_after_forward: True # only for fsdp2, [True, False, int between 1 and fsdp_size] - wrap_policy: - # transformer_layer_cls_to_wrap: None - min_num_params: 0 - use_torch_compile: ${actor_rollout_ref.actor.use_torch_compile} - log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu - log_prob_micro_batch_size_per_gpu: null - log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} - log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} - ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size + + # Used to override model's original configurations, mainly dropout + override_config: {} + + # Enable gradient checkpointing for actor + enable_gradient_checkpointing: true + + # Enable activation offloading for actor + enable_activation_offload: false + + # Whether to remove padding tokens in inputs during training + use_remove_padding: false + + # Set to positive value to enable LoRA (e.g., 32) + lora_rank: 0 + + # LoRA scaling factor + lora_alpha: 16 + + # Target modules to apply LoRA. Options: "all-linear" (not recommended for VLMs) or + # [q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj] + target_modules: all-linear + + # Exclude modules from applying Lora. Similar usage to target_modules and Peft. + # Example: '.*visual.*' for excluding the ViT in Qwen2.5-VL, as currently vllm does not support ViT Lora. + exclude_modules: null + + # Whether to use Liger for linear layer fusion + use_liger: false + + # Whether to use custom fused kernels (e.g., FlashAttention, fused MLP) + use_fused_kernels: false + + # Options for fused kernels. If use_fused_kernels is true, this will be used. + fused_kernel_options: + + # Implementation backend for fused kernels. Options: "triton" or "torch". + impl_backend: torch + + # Whether to enable loading a remote code model + trust_remote_code: false + + # Rollout model config. rollout: - name: vllm - mode: sync # sync: LLM, async: AsyncLLM - chat_scheduler: null # async chat scheduler, e.g examples.ppo_trainer.naive_chat_scheduler.NaiveChatCompletionScheduler - temperature: 1.0 - top_k: -1 # 0 for hf rollout, -1 for vllm rollout - top_p: 1 - use_fire_sampling: False # https://arxiv.org/abs/2410.21236 - prompt_length: ${data.max_prompt_length} # not use for opensource - response_length: ${data.max_response_length} - # for vllm rollout - dtype: bfloat16 # should align with FSDP - gpu_memory_utilization: 0.5 - ignore_eos: False - enforce_eager: True - free_cache_engine: True - load_format: dummy_dtensor # safetensors (for huge model, and set use_shm=True); dummy_dtensor: randomly init model weight - layered_summon: False # for huge model, layered summon can save memory (prevent OOM) but make it slower - tensor_model_parallel_size: 2 - max_num_batched_tokens: 8192 - max_model_len: null - max_num_seqs: 1024 - log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu - log_prob_micro_batch_size_per_gpu: null - log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} - log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} - disable_log_stats: True - enable_chunked_prefill: True # may get higher throughput when set to True. When activated, Please increase max_num_batched_tokens or decrease max_model_len. - # for hf rollout - do_sample: True - # number of responses (i.e. num sample times) - n: 1 # > 1 for grpo - engine_kwargs: # inference engine parameters - vllm: - swap_space: null # null means "use the engine default value" (usually 4 GB), setting it to, e.g., 32 means 32 GB - sglang: - attention_backend: null # null means use the engine default value, available options: flashinfer, triton, flashmla - val_kwargs: - # sampling parameters for validation - top_k: -1 # 0 for hf rollout, -1 for vllm rollout - top_p: 1.0 - temperature: 0 - n: 1 - do_sample: False # default eager for validation - multi_turn: - enable: False # set to True for multi-turn tool interaction tasks; should set rollout.name to sglang as well - max_turns: null # null for no limit (default max_length // 3) - tool_config_path: null # null for no tool - format: chatml # chatml, more formats will be supported in the future - -critic: - rollout_n: ${actor_rollout_ref.rollout.n} - strategy: fsdp # [fsdp, fsdp2] - optim: - lr: 1e-5 - lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime - min_lr_ratio: null # only useful for warmup with cosine - warmup_style: constant # select from constant/cosine - total_training_steps: -1 # must be override by program - weight_decay: 0.01 - model: - path: ~/models/deepseek-llm-7b-chat - use_shm: False - tokenizer_path: ${actor_rollout_ref.model.path} - override_config: { } - external_lib: ${actor_rollout_ref.model.external_lib} - enable_gradient_checkpointing: True - enable_activation_offload: False - use_remove_padding: False - trust_remote_code: ${actor_rollout_ref.model.trust_remote_code} - fsdp_config: - param_offload: False - optimizer_offload: False - offload_policy: False # only for fsdp2, offload param\grad\optimizer during train - reshard_after_forward: True # only for fsdp2, [True, False, int between 1 and fsdp_size] - wrap_policy: - # transformer_layer_cls_to_wrap: None - min_num_params: 0 - fsdp_size: -1 - lora_rank: 0 # Set to positive value to enable LoRA (e.g., 32) - lora_alpha: 16 # LoRA scaling factor - target_modules: all-linear # all-linear or [q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj] - ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size} - ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu - ppo_micro_batch_size_per_gpu: null - forward_micro_batch_size: ${critic.ppo_micro_batch_size} - forward_micro_batch_size_per_gpu: ${critic.ppo_micro_batch_size_per_gpu} - use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} - ppo_max_token_len_per_gpu: 32768 # (${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}) * 2 - forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu} - ulysses_sequence_parallel_size: 1 # sp size - ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs} - shuffle: ${actor_rollout_ref.actor.shuffle} - grad_clip: 1.0 - cliprange_value: 0.5 - loss_agg_mode: ${actor_rollout_ref.actor.loss_agg_mode} - checkpoint: - contents: ['model', 'optimizer', 'extra'] # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space - -reward_model: - enable: False - strategy: fsdp - model: - input_tokenizer: ${actor_rollout_ref.model.path} # set this to null if the chat template is identical - path: ~/models/FsfairX-LLaMA3-RM-v0.1 - use_shm: False - external_lib: ${actor_rollout_ref.model.external_lib} - use_remove_padding: False - use_fused_kernels: ${actor_rollout_ref.model.use_fused_kernels} - trust_remote_code: False - fsdp_config: - wrap_policy: - min_num_params: 0 - param_offload: False - reshard_after_forward: True # only for fsdp2, [True, False, int between 1 and fsdp_size] - fsdp_size: -1 - micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu - micro_batch_size_per_gpu: null # set a number - max_length: null - ulysses_sequence_parallel_size: 1 # sp size - use_dynamic_bsz: ${critic.use_dynamic_bsz} - forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu} - reward_manager: naive - overlong_buffer: # NOTE: added by Reasoning360 - enable: False # We try to avoid forgetting to set enable - len: 0 - penalty_factor: 0.0 - log: False - - launch_reward_fn_async: False # custom reward function executed async on CPU, during log_prob - sandbox_fusion: - url: null # faas url to run code in cloud sandbox - max_concurrent: 64 # max concurrent requests to sandbox + + # may get higher throughput when set to True. When activated, Please increase max_num_batched_tokens or decrease max_model_len. + enable_chunked_prefill: True + + # Which loader to use for rollout model weights: dummy_dtensor, hf, megatron, etc. + # safetensors (for huge model, and set use_shm=True); dummy_dtensor: randomly init model weight + load_format: dummy_dtensor + + # for huge model, layered summon can save memory (prevent OOM) but make it slower + layered_summon: False + + # profiler configs + profiler: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint + _target_: verl.utils.profiler.ProfilerConfig + + # True for each task has its own database, False for all tasks in one training step share one database. + discrete: False + + # Whether to profile all ranks. + all_ranks: False + + # The ranks that will be profiled. [] or [0,1,...] + ranks: [] + +# custom reward function definition custom_reward_function: + + # The path to the file containing your customized reward function. + # If not specified, pre-implemented reward functions will be used. path: null + + # The name of the reward function within the specified file. Default is 'compute_score'. name: compute_score +# config for the algorithm algorithm: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint + _target_: verl.trainer.config.AlgoConfig + + # Discount factor for future rewards gamma: 1.0 + + # Trade-off between bias and variance in the GAE estimator lam: 1.0 + + # Advantage estimator type: "gae", "grpo", "reinforce_plus_plus", etc. adv_estimator: gae + + # Whether to normalize advantages by std (specific to GRPO) norm_adv_by_std_in_grpo: True + + # Whether to enable in-reward KL penalty use_kl_in_reward: False - kl_penalty: kl # how to estimate kl divergence + + # How to estimate KL divergence: "kl", "abs", "mse", "low_var_kl", or "full" + kl_penalty: kl + + # KL control configuration kl_ctrl: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint + _target_: verl.trainer.config.KLControlConfig + + # KL control type: "fixed" or "adaptive" type: fixed + + # Initial coefficient for KL penalty kl_coef: 0.001 horizon: 10000 target_kl: 0.1 - use_pf_ppo: False - pf_ppo: - reweight_method: pow # ["pow", "max_min", "max_random"] - weight_pow: 2.0 filter_groups: # NOTE: added by Reasoning360 enable: False # We try to avoid forgetting to set enable metric: null # acc / score / seq_reward / seq_final_reward / ... max_num_gen_batches: 0 # Non-positive values mean no upper limit + # Horizon value for adaptive controller (if enabled) + horizon: 10000 + + # Target KL divergence (used for adaptive controller) + target_kl: 0.1 + + # Whether to enable preference feedback PPO + use_pf_ppo: False + + # Preference feedback PPO settings + pf_ppo: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint + _target_: verl.trainer.config.PFPPOConfig + + # Method for reweighting samples: "pow", "max_min", or "max_random" + reweight_method: pow + + # Power used for weight scaling in "pow" method + weight_pow: 2.0 + +# config for the trainer trainer: + + # Whether to balance batch sizes across distributed workers balance_batch: True + + # Number of epochs in training total_epochs: 30 + + # Total training steps (can be set explicitly or derived from epochs) total_training_steps: null + + # The steps that will be profiled. null means no profiling. null or [1,2,5,...] + profile_steps: null + + # controller Nvidia Nsight Systems Options. Must set when profile_steps is not None. + ## reference https://docs.nvidia.com/nsight-systems/UserGuide/index.html + ## reference https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html + controller_nsight_options: + + # Select the API(s) to be traced. + trace: "cuda,nvtx,cublas,ucx" + + # Track the GPU memory usage by CUDA kernels. Must be string type "true" or "false". + cuda-memory-usage: "true" + + # CUDA graphs will be traced as a whole + cuda-graph-trace: "graph" + + # worker Nvidia Nsight Systems Options. Must set when profile_steps is not None. + worker_nsight_options: + + # Select the API(s) to be traced. + trace: "cuda,nvtx,cublas,ucx" + + # Track the GPU memory usage by CUDA kernels. Must be string type "true" or "false". + cuda-memory-usage: "true" + + # CUDA graphs will be traced as a whole + cuda-graph-trace: "graph" + + # Profiling only in a range of torch.cuda.profiler.start and stop. Do not change this config. + capture-range: "cudaProfilerApi" + + # Specify the desired behavior when a capture range ends. + # In verl we need the orch.cuda.profiler.start/stop pair to repeats n times. + # valid values are "repeat-shutdown:n" or null. + # For normal whole step profiling, n = len(profile_steps); + # but for discrete profiling, n = len(profile_steps) * Number(subtasks). + # Or you can just leave it null and the program will use n = len(profile_steps) * 6; + capture-range-end: null + + # Send signal to the target application's process group. We let the program to exit by itself. + kill: none + + # Project name for experiment tracking (e.g., wandb) project_name: verl_examples + + # Experiment name for run identification in tracking tools experiment_name: gsm8k + + # Logging backends to use: "console", "wandb", etc. logger: [ 'console', 'wandb' ] + + # Number of generations to log during validation log_val_generations: 0 - rollout_data_dir: null # directory for logging the rollout data, no dump if null - validation_data_dir: null # directory for logging the validation data, no dump if null + + # Directory for logging rollout data; no dump if null + rollout_data_dir: null + + # Directory for logging validation data; no dump if null + validation_data_dir: null + + # Number of nodes used in the training nnodes: 1 + + # Number of GPUs per node n_gpus_per_node: 8 + + # Save frequency (by iteration) for model checkpoints save_freq: -1 - # auto: find the last ckpt to resume. If can't find, start from scratch - resume_mode: auto # or disable or resume_path if resume_from_path is set + + # ESI refers to the elastic server instance used during training, similar to the training plan. For example, + # if you purchase 10 hours of computing power, the ESI will automatically shut down after 10 hours of training. + # To ensure a checkpoint is saved before ESI shuts down, the system will start saving a checkpoint in advance. + # The advance time is calculated as: Advance Time = Longest historical step duration + Checkpoint save duration + esi_redundant_time. + # Here, esi_redundant_time is a user-defined value that further extends the advance time for added safety. + esi_redundant_time: 0 + + # Resume mode: "auto", "disable", or "resume_path" + # "auto": resume from last checkpoint if available + # "disable": start from scratch + # "resume_path": resume from a user-defined path + resume_mode: auto + + # Path to resume training from (only used when resume_mode is "resume_path") resume_from_path: null + + # Whether to run validation before training begins val_before_train: True + + # Whether to run validation only + val_only: False + + # Validation frequency (in training iterations) test_freq: -1 + + # Number of iterations to warm up the critic before updating policy critic_warmup: 0 + + # Default path to distributed filesystem for saving checkpoints default_hdfs_dir: null + + # Whether to delete local checkpoints after loading del_local_ckpt_after_load: False + + # Default local directory for saving checkpoints default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} + + # Maximum number of actor checkpoints to keep max_actor_ckpt_to_keep: null + + # Maximum number of critic checkpoints to keep max_critic_ckpt_to_keep: null - # The timeout for ray worker group to wait for the register center to be ready + + # Timeout (in seconds) for Ray worker to wait for registration ray_wait_register_center_timeout: 300 + + # Device to run training on (e.g., "cuda", "cpu") device: cuda +# configs related to ray initialization ray_init: - num_cpus: null # `None` means using all CPUs, which might cause hang if limited in systems like SLURM. Please set to a number allowed then. + + # Number of CPUs for Ray. Use a fixed number instead of null when using SLURM. + num_cpus: null + + # Path to save Ray timeline JSON for performance profiling + timeline_json_file: null diff --git a/verl/trainer/config/ref/dp_ref.yaml b/verl/trainer/config/ref/dp_ref.yaml new file mode 100644 index 000000000..13b604718 --- /dev/null +++ b/verl/trainer/config/ref/dp_ref.yaml @@ -0,0 +1,38 @@ +# defaults specify the default config from each component +defaults: + + # dp ref config, inheriting from trainer/config/ref/ref.yaml + - ref + + # load the reference default config, then apply the fields in the current yaml + - _self_ + +# config for FSDP strategy +fsdp_config: + + # whether to offload parameters in FSDP + param_offload: False + + # whether to perform reshard after model forward to save memory. + # only for fsdp2, [True, False, int between 1 and fsdp_size] + reshard_after_forward: True + + # Only for FSDP1: FSDP1 configuration, prefetch the next forward-pass all-gather + # before the current forward computation. + forward_prefetch: False + + # the wrap policy for FSDP model + wrap_policy: + + # minimum number of params in a wrapped module + min_num_params: 0 + +# sequence parallel size +# same as actor_rollout_ref.actor.ulysses_sequence_parallel_size if it exists, otherwise 1 +ulysses_sequence_parallel_size: ${oc.select:actor_rollout_ref.actor.ulysses_sequence_parallel_size,1} + +# calculate entropy with chunking to reduce memory peak +entropy_from_logits_with_chunking: False + +# recompute entropy +entropy_checkpointing: False diff --git a/verl/trainer/config/ref/megatron_ref.yaml b/verl/trainer/config/ref/megatron_ref.yaml new file mode 100644 index 000000000..6a75d68e3 --- /dev/null +++ b/verl/trainer/config/ref/megatron_ref.yaml @@ -0,0 +1,51 @@ +# megatron ref config, inheriting from trainer/config/ref/ref.yaml +defaults: + - ref + # load the reference default config, then apply the fields in the current yaml + - _self_ + +strategy: megatron + +megatron: + + param_offload: False + + tensor_model_parallel_size: 1 + + expert_model_parallel_size: 1 + + expert_tensor_parallel_size: None + + pipeline_model_parallel_size: 1 + + virtual_pipeline_model_parallel_size: null # change VPP interface for parallelism tests + + context_parallel_size: 1 + + sequence_parallel: True + + use_distributed_optimizer: False + + use_dist_checkpointing: False + + dist_checkpointing_path: null + + seed: ${oc.select:actor_rollout_ref.actor.megatron.seed,42} + + override_transformer_config: ${oc.select:actor_rollout_ref.actor.megatron.override_transformer_config,{}} + + use_mbridge: ${oc.select:actor_rollout_ref.actor.megatron.use_mbridge,False} + +profile: + + use_profile: False + + profile_ranks: null + + step_start: -1 + + step_end: -1 + + save_path: null + +load_weight: True \ No newline at end of file diff --git a/verl/trainer/config/ref/ref.yaml b/verl/trainer/config/ref/ref.yaml new file mode 100644 index 000000000..7d9157b3e --- /dev/null +++ b/verl/trainer/config/ref/ref.yaml @@ -0,0 +1,21 @@ +# actor_rollout_ref.ref: FSDP config same as actor. For models larger than 7B, it’s recommended to turn on offload for ref by default +strategy: ${actor_rollout_ref.actor.strategy} + +# whether to enable torch.compile +# same as actor_rollout_ref.actor.use_torch_compile if it exists, otherwise 1 +use_torch_compile: ${oc.select:actor_rollout_ref.actor.use_torch_compile,true} + +# [Will be deprecated, use log_prob_micro_batch_size_per_gpu] +# The batch size for one forward pass in the computation of log_prob. Global batch size. +log_prob_micro_batch_size: null + +# The batch size for one forward pass in the computation of log_prob. Local batch size per GPU. +log_prob_micro_batch_size_per_gpu: null + +# enable dynamic batch size (sequence packing) for log_prob computation +# same as actor_rollout_ref.actor.use_dynamic_bsz if it exists, otherwise false +log_prob_use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false} + +# the max token length per GPU +# same as actor_rollout_ref.actor.ppo_max_token_len_per_gpu if it exists, otherwise 16384 +log_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384} diff --git a/verl/trainer/config/reward_model/dp_reward_model.yaml b/verl/trainer/config/reward_model/dp_reward_model.yaml new file mode 100644 index 000000000..d9a837032 --- /dev/null +++ b/verl/trainer/config/reward_model/dp_reward_model.yaml @@ -0,0 +1,51 @@ +# Format checks enforced on CI: +# 1. Comments must appear above each field. +# 2. There must be a blank line between each field. +# 3. Inline comments (after a field on the same line) are not allowed. +# 4. Indentation level is respected for nested fields. + +# defaults specify the default config from each component +defaults: + + # dp actor config, inheriting from trainer/config/reward_model/reward_model.yaml + - reward_model + + # load the reference default config, then apply the fields in the current yaml + - _self_ + +strategy: fsdp + +model: + + # Whether to use shared memory for loading the model + use_shm: False + + # Use remove padding optimization (saves compute) + use_remove_padding: False + + # Whether to use fused reward kernels for speedup + use_fused_kernels: ${actor_rollout_ref.model.use_fused_kernels} + + # FSDP-specific config + fsdp_config: + + # Policy for wrapping layers with FSDP + wrap_policy: + # Minimum number of parameters to trigger wrapping + min_num_params: 0 + + # Whether to offload model parameters to CPU + param_offload: False + + # Only for FSDP2: Reshard after forward pass to reduce memory footprint + reshard_after_forward: True + + # Number of GPUs in each FSDP shard group; -1 means auto + fsdp_size: -1 + + # Only for FSDP1: FSDP1 configuration, prefetch the next forward-pass all-gather + # before the current forward computation. + forward_prefetch: False + +# Sequence parallelism size for Ulysses-style model parallelism +ulysses_sequence_parallel_size: 1 \ No newline at end of file diff --git a/verl/trainer/config/reward_model/megatron_reward_model.yaml b/verl/trainer/config/reward_model/megatron_reward_model.yaml new file mode 100644 index 000000000..2c5d35cd5 --- /dev/null +++ b/verl/trainer/config/reward_model/megatron_reward_model.yaml @@ -0,0 +1,61 @@ +# defaults specify the default config from each component +defaults: + + # dp actor config, inheriting from trainer/config/reward_model/reward_model.yaml + - reward_model + + # load the reference default config, then apply the fields in the current yaml + - _self_ + +strategy: megatron + +# seconds, default is 10 minutes for torch, you can set it to a larger value +# if you have long-running operations like 32B or 72B model using megatron +nccl_timeout: 600 + +# Megatron parallelism & checkpointing config +megatron: + # Whether to offload model parameters to CPU + param_offload: False + + # Number of GPUs in tensor model parallel group + tensor_model_parallel_size: 1 + + # Number of GPUs in expert model parallel group + expert_model_parallel_size: 1 + + # Expert tensor parallel size + expert_tensor_parallel_size: null + + # Number of pipeline model parallel stages + pipeline_model_parallel_size: 1 + + # change VPP interface for parallelism tests + virtual_pipeline_model_parallel_size: null + + # Context parallel size + context_parallel_size: 1 + + # Whether to use sequence parallelism + sequence_parallel: True + + # Whether to use distributed optimizer + use_distributed_optimizer: False + + # Whether to enable distributed checkpointing + use_dist_checkpointing: False + + # Path for distributed checkpoints + dist_checkpointing_path: null + + # RNG seed for megatron + seed: ${actor_rollout_ref.actor.megatron.seed} + + # Any overrides to transformer config + override_transformer_config: {} + + # Whether to use mbridge for faster comms + use_mbridge: ${actor_rollout_ref.actor.megatron.use_mbridge} + +# Whether to load weights (default True) +load_weight: True \ No newline at end of file diff --git a/verl/trainer/config/reward_model/reward_model.yaml b/verl/trainer/config/reward_model/reward_model.yaml new file mode 100644 index 000000000..698343955 --- /dev/null +++ b/verl/trainer/config/reward_model/reward_model.yaml @@ -0,0 +1,81 @@ +# configs for the reward model + +# Whether to enable reward model. If False, we compute the reward only with the user-defined reward functions. +# In GSM8K and Math examples, we disable reward model. +# For RLHF alignment example using full_hh_rlhf, we utilize reward model to assess the responses. +# If False, the following parameters are not effective +enable: False + +# FSDP strategy: "fsdp" or "fsdp2" +strategy: ??? + +# model config for reward scoring +model: + + # Input tokenizer. If the reward model’s chat template is inconsistent with the policy, + # we need to first decode to plaintext, then apply the rm’s chat_template. + # Then score with RM. If chat_templates are consistent, it can be set to null. + # set this to null if the chat template is identical + input_tokenizer: ${actor_rollout_ref.model.path} + + # RM’s HDFS path or local path. Note that RM only supports AutoModelForSequenceClassification. + # Other model types need to define their own RewardModelWorker and pass it from the code. + path: ~/models/FsfairX-LLaMA3-RM-v0.1 + + # External model implementation (optional) + external_lib: ${actor_rollout_ref.model.external_lib} + + # Whether to enable loading a remote code model, default to False + trust_remote_code: False + +# [Deprecated] Global micro batch size +# will be deprecated, use micro_batch_size_per_gpu +micro_batch_size: null + +# Local per-GPU micro batch size +micro_batch_size_per_gpu: null + +# Maximum sequence length to process for scoring +max_length: null + +# Whether to dynamically adjust batch size at runtime +use_dynamic_bsz: ${critic.use_dynamic_bsz} + +# Maximum number of tokens per GPU in one forward pass +forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu} + +# Reward Manager. This defines the mechanism of computing rule-based reward and handling different reward sources. +# Default is naive. If all verification functions are multiprocessing-safe, +# the reward manager can be set to prime for parallel verification. +reward_manager: naive + +# Whether to launch custom reward function asynchronously during log_prob +# custom reward function executed async on CPU, during log_prob +launch_reward_fn_async: False + +# Cloud/local sandbox fusion configuration for custom reward logic +sandbox_fusion: + + # Cloud /local function URL for sandbox execution + url: null + + # Max concurrent requests allowed to sandbox + max_concurrent: 64 + + # Max memory limit for each sandbox process in MB + memory_limit_mb: 1024 + +# profiler configs +profiler: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint + _target_: verl.utils.profiler.ProfilerConfig + + # True for each task has its own database, False for all tasks in one training step share one database. + discrete: False + + # Whether to profile all ranks. + all_ranks: False + + # The ranks that will be profiled. [] or [0,1,...] + ranks: [] \ No newline at end of file diff --git a/verl/trainer/config/rollout/rollout.yaml b/verl/trainer/config/rollout/rollout.yaml new file mode 100644 index 000000000..fc3af80d4 --- /dev/null +++ b/verl/trainer/config/rollout/rollout.yaml @@ -0,0 +1,215 @@ +# actor_rollout_ref.rollout.name: hf/vllm/sglang. The default value will be removed in the future +name: vllm + +# sync: LLM, async: AsyncLLM +mode: sync + +# Sampling temperature for rollout. +temperature: 1.0 + +# Top-k sampling parameter. -1 for vLLM rollout, 0 for HF rollout. +top_k: -1 + +# Top-p sampling parameter. Default 1.0. +top_p: 1 + +# typically the same as data max prompt length +# same as data.max_prompt_length if it exists +prompt_length: ${oc.select:data.max_prompt_length,512} + +# typically the same as data max response length +# same as data.max_response_length if it exists +response_length: ${oc.select:data.max_response_length,512} + +# for vllm rollout +# Rollout model parameters type. Align with actor model's FSDP/Megatron type. +dtype: bfloat16 + +# Fraction of GPU memory used by vLLM/SGLang for KV cache. +gpu_memory_utilization: 0.5 + +# Whether to ignore EOS and continue generating after EOS is hit. +ignore_eos: False + +# Whether to disable CUDA graph. Default True to allow cache freeing. +enforce_eager: True + +# Whether to free engine KVCache after generation. Set enforce_eager=True when enabled. +free_cache_engine: True + +# TP size for rollout. Not effective for hf +tensor_model_parallel_size: 2 + +# max number of tokens in a batch +max_num_batched_tokens: 8192 + +# max length for rollout +max_model_len: null + +# max length of sequences +max_num_seqs: 1024 + +# [Will be deprecated, use log_prob_micro_batch_size_per_gpu] The batch size for one forward pass in the computation of log_prob. Global batch size. +log_prob_micro_batch_size: null + +# The batch size for one forward pass in the computation of log_prob. Local batch size per GPU. +log_prob_micro_batch_size_per_gpu: null + +# enable dynamic batch size (sequence packing) for log_prob computation +# same as actor_rollout_ref.actor.use_dynamic_bsz if it exists, otherwise false +log_prob_use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false} + +# max token length for log_prob computation +# same as actor_rollout_ref.actor.ppo_max_token_len_per_gpu if it exists, otherwise 16384 +log_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384} + +# disable logging statistics +disable_log_stats: True + +# for hf rollout +# Whether to sample during training rollout. False uses greedy sampling. +do_sample: True + +# number of responses (i.e. num sample times). > 1 for grpo +n: 1 + +# Whether to wake up inference engine in multi-stage. (Wake up model weights first, then resume kv cache) +multi_stage_wake_up: false + +# Extra inference engine arguments (vllm, sglang). +engine_kwargs: + + # for vllm + vllm: + + # Swap space (in GB) used by inference engine. null uses default (e.g., 4 GB). + swap_space: null + + # Whether to disable the preprocessor cache for multimodel models. + disable_mm_preprocessor_cache: False + + # for sglang + sglang: + + # The attention backend for sglang engine. Options: flashinfer, triton, flashmla, null for default. + attention_backend: null + +# Sampling parameters used during validation. +val_kwargs: + + # sampling parameters for validation + # Top-k sampling parameter. -1 for vLLM rollout, 0 for HF rollout. + top_k: -1 + + # Top-p sampling parameter. Default 1.0. + top_p: 1.0 + + # Sampling temperature for rollout. + temperature: 0 + + # whether to repeat n times for validation + n: 1 + + # Whether to sample during training rollout. False uses greedy sampling. + do_sample: False + +# Multi-turn interaction config for tools or chat. +multi_turn: + + # set to True for multi-turn tool interaction tasks; should set rollout.name to sglang as well + enable: False + + # null for no limit (default max_length // 3) + max_assistant_turns: null + + # null for no tool + tool_config_path: null + + # null for no limit (default max_length // 3) + max_user_turns: null + + # max parallel call for tools in single turn + max_parallel_calls: 1 + + # max length of tool response + max_tool_response_length: 256 + + # truncate side of tool response: left, middle, right + tool_response_truncate_side: middle + + # null for no interaction + interaction_config_path: null + + # null for default callback + completion_callback: null + + # - When set to True, the model's default chat template is used for multi-turn rollout, which typically matches production behavior. + # - When set to False, the token ids recorded for training are used instead; unlike the default chat template, these always include the model's full output, + # which may contain additional content such as reasoning content. This maintains the consistency between training and rollout, but it will lead to longer prompts. + use_inference_chat_template: False + + # Tokenization is performed turn by turn and the resulting token ids are concatenated to form the full conversation. + # To ensure this matches the result of tokenizing the entire conversation at once, a sanity check is run at the end of each multi-turn rollout to compare the two sets of token ids. + # Some models are known to produce different tokenization results when tokenizing turn by turn vs. all at once. aThis behavior has already been validated for them. + # To reduce excessive warnings, you can turn off the sanity check for these models if you are using their default chat template: + # Qwen/QwQ-32B, Qwen/Qwen3-xxB + # - disable: disable tokenization sanity check + # - strict: enable strict tokenization sanity check (default) + # - ignore_strippable: ignore strippable tokens when checking tokenization sanity + tokenization_sanity_check_mode: strict + + # Format of the multi-turn interaction. Options: hermes, llama3_json, ... + format: hermes + +# support logging rollout prob for debugging purpose +calculate_log_probs: False + +# [Experimental] agent loop based rollout configs +agent: + + # Number of agent loop workers + num_workers: 8 + + # custom agent loop config path, which should contain list of configs to intialize AgentLoop instances. + # https://hydra.cc/docs/advanced/instantiate_objects/overview/ + # + # - name: react_agent + # _target_: recipe.langgraph_agent.react_agent_loop.ReactAgentLoop + # tools: ["get_current_temperature"] + # - name: math_expression + # _target_: recipe.langgraph_agent.example.math_expression.MathExpressionReactAgentLoop + # min_terms: 2 + # max_terms: 6 + agent_loop_config_path: null + + # custom async server configs + custom_async_server: + + # Path to the custom async server implementation + path: null + + # Class name of the custom async server class (e.g. AsyncvLLMServer) + name: null + +# Specifies the tensor bucket size (in megabytes) for batch weight updates during rollout operations. +# This parameter controls the maximum payload size for a single weight update request. +# Reference: https://github.com/volcengine/verl/pull/2418 +# Currently only supported in SGLang rollout implementations +# Larger values may improve throughput but increase memory overhead +# Detailed performance comparison: +# https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/issues/169#issuecomment-3070686720 +# Default value (512MB) is optimized for typical GPU memory configurations +# For the best performance of `rebuild_cuda_tensor`, it is recommended to: +# 1. Enable `RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES` +# 2. Manually set `CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7` +# when using Tensor Parallelism (TP) >= 8. +update_weights_bucket_megabytes: 512 + +# trace rollout data +trace: + + # trace backend, support mlflow, weave + backend: null + + # whether translate token id to text in output + token2text: False diff --git a/verl/trainer/config/sft_trainer.yaml b/verl/trainer/config/sft_trainer.yaml index d498422f5..c3af1a48f 100644 --- a/verl/trainer/config/sft_trainer.yaml +++ b/verl/trainer/config/sft_trainer.yaml @@ -7,18 +7,16 @@ data: # Single-turn settings prompt_key: question response_key: answer - prompt_dict_keys: ['question'] - response_dict_keys: ['answer'] + prompt_dict_keys: null + response_dict_keys: null # Multi-turn settings multiturn: enable: false # Set to true to use multi-turn dataset messages_key: messages # Key for messages list in multi-turn mode - # NOTE: max_length used by different set of Reasoning360 - # max_length: 1024 # qwen2.5-0.5b on gsm8k - # max_length: 10240 # qwen2.5-7b on limo - # max_length: 16384 # qwen2.5-7b on limo - max_length: 16384 # qwen2.5-32b on limo - truncation: right # NOTE: modified by Reasoning360 + tools_key: tools # Key for tools list in multi-turn mode + enable_thinking_key: enable_thinking # Whether to enable thinking in multi-turn mode + max_length: 1024 + truncation: error balance_dp_token: False chat_template: null custom_cls: @@ -35,7 +33,7 @@ model: cpu_offload: False offload_params: False external_lib: null - enable_gradient_checkpointing: False + enable_gradient_checkpointing: True trust_remote_code: False lora_rank: 0 # Set to positive value to enable LoRA (e.g., 32) lora_alpha: 16 # LoRA scaling factor diff --git a/verl/trainer/constants_ppo.py b/verl/trainer/constants_ppo.py new file mode 100644 index 000000000..84350bbd9 --- /dev/null +++ b/verl/trainer/constants_ppo.py @@ -0,0 +1,22 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +PPO_RAY_RUNTIME_ENV = { + "env_vars": { + "TOKENIZERS_PARALLELISM": "true", + "NCCL_DEBUG": "WARN", + "VLLM_LOGGING_LEVEL": "WARN", + "VLLM_ALLOW_RUNTIME_LORA_UPDATING": "true", + }, +} diff --git a/verl/trainer/fsdp_sft_trainer.py b/verl/trainer/fsdp_sft_trainer.py index e3f1e7020..866998003 100644 --- a/verl/trainer/fsdp_sft_trainer.py +++ b/verl/trainer/fsdp_sft_trainer.py @@ -43,26 +43,26 @@ import verl.utils.hdfs_io as hdfs_io from verl.utils.dataset import SFTDataset from verl.utils.dataset.multiturn_sft_dataset import MultiTurnSFTDataset -from verl.utils.debug import log_gpu_memory_usage -from verl.utils.device import get_device_name, get_torch_device, is_cuda_available, is_npu_available +from verl.utils.device import get_device_id, get_device_name, is_cuda_available, is_npu_available from verl.utils.distributed import destroy_global_process_group, initialize_global_process_group from verl.utils.fs import copy_to_local from verl.utils.fsdp_utils import ( CPUOffloadPolicy, MixedPrecisionPolicy, apply_fsdp2, + fsdp2_clip_grad_norm_, fsdp2_load_full_state_dict, get_fsdp_wrap_policy, get_init_weight_context_manager, init_fn, - fsdp2_clip_grad_norm_ ) +from verl.utils.profiler import log_gpu_memory_usage +from verl.utils.py_functional import convert_to_regular_types from verl.utils.torch_dtypes import PrecisionType from verl.utils.torch_functional import get_cosine_schedule_with_warmup, get_wsd_schedule_with_warmup -from verl.utils.py_functional import convert_to_regular_types from verl.utils.tracking import Tracking from verl.utils.ulysses import ( - gather_outpus_and_unpad, + gather_outputs_and_unpad, get_ulysses_sequence_parallel_world_size, ulysses_pad_and_slice_inputs, ) @@ -85,7 +85,15 @@ def extract_step(path): class FSDPSFTTrainer: - def __init__(self, config, device_mesh: DeviceMesh, ulysses_device_mesh: DeviceMesh, tokenizer, train_dataset: Dataset, val_dataset: Dataset): + def __init__( + self, + config, + device_mesh: DeviceMesh, + ulysses_device_mesh: DeviceMesh, + tokenizer, + train_dataset: Dataset, + val_dataset: Dataset, + ): self.config = config self.device_mesh = device_mesh self.ulysses_device_mesh = ulysses_device_mesh @@ -118,7 +126,9 @@ def _normalize_config_bsz(self): if self.device_mesh.get_rank() == 0: print(f"Normalize batch size by dp {dp_size}") - assert self.config.data.train_batch_size % dp_size == 0, f"Global batch size {self.config.data.train_batch_size} is not divisible by dp size {dp_size}" + assert self.config.data.train_batch_size % dp_size == 0, ( + f"Global batch size {self.config.data.train_batch_size} is not divisible by dp size {dp_size}" + ) self.config.data.train_batch_size //= dp_size @@ -145,7 +155,9 @@ def _build_dataloader(self, train_dataset, val_dataset): if self.device_mesh.get_rank() == 0: print(f"Using FSDP rank {rank} and size {world_size} for data distribution") - self.train_sampler = DistributedSampler(self.train_dataset, shuffle=True, num_replicas=world_size, rank=rank, drop_last=True) + self.train_sampler = DistributedSampler( + self.train_dataset, shuffle=True, num_replicas=world_size, rank=rank, drop_last=True + ) self.train_dataloader = DataLoader( dataset=self.train_dataset, batch_size=config.data.train_batch_size, @@ -155,7 +167,9 @@ def _build_dataloader(self, train_dataset, val_dataset): drop_last=True, ) - self.val_sampler = DistributedSampler(self.val_dataset, shuffle=False, num_replicas=world_size, rank=rank, drop_last=True) + self.val_sampler = DistributedSampler( + self.val_dataset, shuffle=False, num_replicas=world_size, rank=rank, drop_last=True + ) self.val_dataloader = DataLoader( dataset=self.val_dataset, batch_size=config.data.micro_batch_size_per_gpu, @@ -185,11 +199,17 @@ def _build_model_optimizer(self): # load config first config = AutoConfig.from_pretrained(local_model_path, trust_remote_code=trust_remote_code) self.model_config = config + if hasattr(self.model_config, "max_position_embeddings"): + self.model_config.max_position_embeddings = max( + self.model_config.max_position_embeddings, self.config.data.max_length + ) if self.config.ulysses_sequence_parallel_size > 1: assert self.use_remove_padding, "Sequence parallel is only supported when remove_padding is enabled" # This may be very large - init_context = get_init_weight_context_manager(use_meta_tensor=not config.tie_word_embeddings, mesh=self.device_mesh) + init_context = get_init_weight_context_manager( + use_meta_tensor=not config.tie_word_embeddings, mesh=self.device_mesh + ) with init_context(): self.model: PreTrainedModel = AutoModelForCausalLM.from_pretrained( @@ -228,7 +248,9 @@ def _build_model_optimizer(self): log_gpu_memory_usage("After model allocation", logger=logger) - mixed_precision = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32) + mixed_precision = MixedPrecision( + param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32 + ) auto_wrap_policy = get_fsdp_wrap_policy( self.model, @@ -251,7 +273,7 @@ def _build_model_optimizer(self): param_init_fn=init_fn, use_orig_params=False, auto_wrap_policy=auto_wrap_policy, - device_id=get_torch_device().current_device(), + device_id=get_device_id(), sharding_strategy=ShardingStrategy.FULL_SHARD, mixed_precision=mixed_precision, sync_module_states=True, @@ -260,8 +282,9 @@ def _build_model_optimizer(self): ) elif fsdp_strategy == "fsdp2": assert CPUOffloadPolicy is not None, "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)" - mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.float32, - cast_forward_inputs=True) + mp_policy = MixedPrecisionPolicy( + param_dtype=torch.bfloat16, reduce_dtype=torch.float32, cast_forward_inputs=True + ) fsdp_kwargs = { "mesh": self.device_mesh, @@ -291,14 +314,21 @@ def _build_model_optimizer(self): self.total_steps = self.steps_per_epoch * self.config.trainer.total_epochs if self.device_mesh.get_rank() == 0: - print(f"Number of steps/epoch {self.steps_per_epoch}, number of epochs {self.config.trainer.total_epochs}, total number of steps {self.total_steps}") + print( + f"Number of steps/epoch {self.steps_per_epoch}, number of epochs " + f"{self.config.trainer.total_epochs}, total number of steps {self.total_steps}" + ) num_warmup_steps = int(self.total_steps * self.config.optim.warmup_steps_ratio) if not hasattr(self.config.optim, "lr_scheduler") or self.config.optim.lr_scheduler == "cosine": - self.lr_scheduler = get_cosine_schedule_with_warmup(optimizer=self.optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=self.total_steps) + self.lr_scheduler = get_cosine_schedule_with_warmup( + optimizer=self.optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=self.total_steps + ) elif self.config.optim.lr_scheduler == "wsd": - self.lr_scheduler = get_wsd_schedule_with_warmup(optimizer=self.optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=self.total_steps) + self.lr_scheduler = get_wsd_schedule_with_warmup( + optimizer=self.optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=self.total_steps + ) else: raise ValueError(f"Unknown lr scheduler: {self.config.optim.lr_scheduler}") @@ -319,7 +349,9 @@ def _compute_loss_and_backward(self, batch, do_backward=True): if not use_sp: # Standard forward pass without sequence parallel labels = input_ids[:, 1:].contiguous() - output = self.fsdp_model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False) + output = self.fsdp_model( + input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False + ) logits = output.logits shift_logits = logits[..., :-1, :].contiguous() @@ -340,17 +372,25 @@ def _compute_loss_and_backward(self, batch, do_backward=True): batch_size, seqlen = input_ids.shape # Remove padding - input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), attention_mask) # input_ids_rmpad (total_nnz, ...) + input_ids_rmpad, indices, *_ = unpad_input( + input_ids.unsqueeze(-1), attention_mask + ) # input_ids_rmpad (total_nnz, ...) input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) # Unpad position_ids to align rotary - position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices).transpose(0, 1) + position_ids_rmpad = index_first_axis( + rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices + ).transpose(0, 1) # Pad and slice inputs for sequence parallelism - input_ids_rmpad_sliced, position_ids_rmpad_padded, pad_size = ulysses_pad_and_slice_inputs(input_ids_rmpad, position_ids_rmpad, sp_size=get_ulysses_sequence_parallel_world_size()) + input_ids_rmpad_sliced, position_ids_rmpad_padded, pad_size = ulysses_pad_and_slice_inputs( + input_ids_rmpad, position_ids_rmpad, sp_size=get_ulysses_sequence_parallel_world_size() + ) # For computing loss input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=1) # (1, total_nnz) - input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs(input_ids_rmpad_rolled, None, get_ulysses_sequence_parallel_world_size()) + input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs( + input_ids_rmpad_rolled, None, get_ulysses_sequence_parallel_world_size() + ) input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(0) # ((total_nnz / sp) + pad) # Forward pass @@ -366,10 +406,12 @@ def _compute_loss_and_backward(self, batch, do_backward=True): input_ids_rmpad_rolled = input_ids_rmpad_rolled.to(logits_rmpad.device) loss = loss_fct(logits_rmpad, input_ids_rmpad_rolled) # Gather and unpad for sequence parallelism - loss = gather_outpus_and_unpad(loss, gather_dim=0, unpad_dim=0, padding_size=pad_size) + loss = gather_outputs_and_unpad(loss, gather_dim=0, unpad_dim=0, padding_size=pad_size) # This is the loss collected from all ulysses ranks - full_loss = pad_input(hidden_states=loss.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen) + full_loss = pad_input( + hidden_states=loss.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen + ) full_loss = full_loss.squeeze(-1)[:, :-1] # Remove last token's loss full_loss = full_loss.reshape(-1) loss_mask = loss_mask.to(full_loss.device) @@ -405,9 +447,9 @@ def training_step(self, batch: TensorDict): loss = self._compute_loss_and_backward(batch=micro_batch) / n_micro_batches step_loss += loss.item() - if self.config.model.strategy == 'fsdp': + if self.config.model.strategy == "fsdp": grad_norm = self.fsdp_model.clip_grad_norm_(max_norm=self.config.optim.clip_grad) - elif self.config.model.strategy == 'fsdp2': + elif self.config.model.strategy == "fsdp2": grad_norm = fsdp2_clip_grad_norm_(self.fsdp_model.parameters(), max_norm=self.config.optim.clip_grad) else: raise NotImplementedError(f"not implement {self.config.model.strategy}") @@ -435,7 +477,7 @@ def training_step(self, batch: TensorDict): torch.distributed.all_reduce(step_loss, op=torch.distributed.ReduceOp.AVG) elif is_npu_available: torch.distributed.all_reduce(step_loss) - step_loss /= self.ulysses_device_mesh.size(0) + step_loss /= self.device_mesh.size(0) return {"train/loss": step_loss.detach().item(), "train/lr(1e-3)": lr * 1e3} def validation_step(self, batch: TensorDict): @@ -446,7 +488,7 @@ def validation_step(self, batch: TensorDict): torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.AVG) elif is_npu_available: torch.distributed.all_reduce(loss) - loss /= self.ulysses_device_mesh.size(0) + loss /= self.device_mesh.size(0) return loss def save_checkpoint(self, step): @@ -523,7 +565,7 @@ def fit(self): self.train_dataloader, total=self.steps_per_epoch, desc=f"Epoch {epoch + 1}/{self.config.trainer.total_epochs}", - disable=rank != 0 + disable=rank != 0, ): global_step += 1 data = TensorDict(data, batch_size=self.config.data.train_batch_size).to(self.device_name) @@ -540,7 +582,9 @@ def fit(self): # Perform validation val_losses = [] for val_data in self.val_dataloader: - val_data = TensorDict(val_data, batch_size=self.config.data.micro_batch_size_per_gpu).to(self.device_name) + val_data = TensorDict(val_data, batch_size=self.config.data.micro_batch_size_per_gpu).to( + self.device_name + ) val_loss = self.validation_step(val_data) val_losses.append(val_loss) if rank == 0: @@ -565,7 +609,11 @@ def run_sft(config): device_mesh = init_device_mesh(device_type=device_name, mesh_shape=(world_size,), mesh_dim_names=("fsdp",)) dp_size = world_size // config.ulysses_sequence_parallel_size - ulysses_device_mesh = init_device_mesh(device_type=device_name, mesh_shape=(dp_size, config.ulysses_sequence_parallel_size), mesh_dim_names=("dp", "sp")) + ulysses_device_mesh = init_device_mesh( + device_type=device_name, + mesh_shape=(dp_size, config.ulysses_sequence_parallel_size), + mesh_dim_names=("dp", "sp"), + ) # build tokenizer and datasets first from verl.utils import hf_tokenizer @@ -574,7 +622,14 @@ def run_sft(config): train_dataset = create_sft_dataset(config.data.train_files, config.data, tokenizer) val_dataset = create_sft_dataset(config.data.val_files, config.data, tokenizer) - trainer = FSDPSFTTrainer(config=config, device_mesh=device_mesh, ulysses_device_mesh=ulysses_device_mesh, tokenizer=tokenizer, train_dataset=train_dataset, val_dataset=val_dataset) + trainer = FSDPSFTTrainer( + config=config, + device_mesh=device_mesh, + ulysses_device_mesh=ulysses_device_mesh, + tokenizer=tokenizer, + train_dataset=train_dataset, + val_dataset=val_dataset, + ) trainer.fit() diff --git a/verl/trainer/main_eval.py b/verl/trainer/main_eval.py index d287e0057..5b8246a8e 100644 --- a/verl/trainer/main_eval.py +++ b/verl/trainer/main_eval.py @@ -11,19 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """ -Offline evaluation script for generated sequences using a reward model and ground truth verifier. - -This script reads a parquet file containing generated sequences and (optionally) ground truth, -computes reward scores for each response, and calculates pass@k metrics using an unbiased estimator. -Results are saved as a JSON file for further analysis. - -Usage: - python main_eval.py +Offline evaluate the performance of a generated file using reward model and ground truth verifier. +The input is a parquet file that contains N generated sequences and (optional) the ground truth. """ -import json from collections import defaultdict import hydra @@ -34,152 +26,51 @@ from verl.trainer.ppo.reward import get_custom_reward_fn from verl.utils.fs import copy_to_local -from verl.utils.reward_score import default_compute_score - - -# --------------------------------------------------------------------------- # -# Unbiased pass@k estimator -# Formula: 1 - C(n-c, k) / C(n, k) -# --------------------------------------------------------------------------- # -def unbiased_pass_at_k(n: int, c: int, k: int) -> float: - """ - Compute the unbiased pass@k estimate as described in Chen et al. (2021). - - Args: - n (int): Total number of generated samples for this problem. - c (int): Number of correct samples (score == 1.0). - k (int): Target k value. - - Returns: - float: Unbiased pass@k estimate. - - Raises: - ValueError: If k > n. - """ - if k > n: - raise ValueError(f"k = {k} cannot be greater than n = {n}") - if n - c < k: # Not enough incorrect samples for k => pass@k = 1 - return 1.0 - prod = 1.0 - # ∏_{j=n-c+1}^{n} (1 - k / j) == C(n-c, k) / C(n, k) - for j in range(n - c + 1, n + 1): - prod *= 1.0 - k / j - return 1.0 - prod @ray.remote -def process_item(reward_fn, data_source, response_lst, reward_data, extra_info): - """ - Ray remote function to process a single data item. - - Args: - reward_fn (callable): Reward function to evaluate responses. - data_source: The data source for this item. - response_lst (list): List of generated responses. - reward_data (dict): Reward model data, including ground truth. - extra_info: Any extra information for scoring. - - Returns: - tuple: (data_source, score_lst) where score_lst is a list of scores for each response. - """ +def process_item(reward_fn, data_source, response_lst, reward_data): ground_truth = reward_data["ground_truth"] - score_lst = [reward_fn(data_source, r, ground_truth, extra_info) for r in response_lst] - score_lst = [s["score"] for s in score_lst] - return ( - data_source, - score_lst, # a list of scores for each response - ) + score_lst = [reward_fn(data_source, r, ground_truth) for r in response_lst] + return data_source, np.mean(score_lst) @hydra.main(config_path="config", config_name="evaluation", version_base=None) def main(config): - """ - Main evaluation entry point. Loads data, computes reward scores, and calculates pass@k metrics. - - Args: - config: Hydra configuration object. - """ - # Copy data to local (optionally using shared memory) local_path = copy_to_local(config.data.path, use_shm=config.data.get("use_shm", False)) - - # Load dataset using polars for livecodebench, otherwise pandas - if "livecodebench" in local_path: - import polars as pl - - dataset = pl.read_parquet(local_path) - else: - dataset = pd.read_parquet(local_path) - - # Extract relevant columns from the dataset + dataset = pd.read_parquet(local_path) responses = dataset[config.data.response_key] data_sources = dataset[config.data.data_source_key] reward_model_data = dataset[config.data.reward_model_key] - try: - extra_info_data = dataset["extra_info"] - except Exception: - extra_info_data = None total = len(dataset) - # Initialize Ray for distributed processing + # Initialize Ray if not ray.is_initialized(): ray.init(num_cpus=config.ray_init.num_cpus) - # Prepare to collect per-data-source rewards + # evaluate test_score based on data source data_source_reward = defaultdict(list) - # Use custom reward function if provided, otherwise default - compute_score = get_custom_reward_fn(config) or default_compute_score - - # Create Ray remote tasks for each data item - remote_tasks = [process_item.remote(compute_score, data_sources[i], responses[i], reward_model_data[i], extra_info_data[i] if extra_info_data is not None else dict()) for i in range(total)] - - # Compute max_k (number of responses per item) and candidate k values (powers of 2) - if isinstance(responses, pd.Series) or isinstance(responses, pl.Series): - max_k = len(responses.to_list()[-1]) - else: - # numpy array - max_k = len(responses.tolist()[-1]) - candidate_ks = [2**i for i in range(int(np.log2(max_k)) + 1) if 2**i <= max_k] - pass_k_stat = {k: 0 for k in candidate_ks if k <= max_k} - avg_pass = 0 # Sum of average scores for all items - - # Process results as they become available + compute_score = get_custom_reward_fn(config) + + # Create remote tasks + remote_tasks = [ + process_item.remote(compute_score, data_sources[i], responses[i], reward_model_data[i]) for i in range(total) + ] + + # Process results as they come in with tqdm(total=total) as pbar: while len(remote_tasks) > 0: - # Wait for Ray tasks to complete + # Use ray.wait to get completed tasks done_ids, remote_tasks = ray.wait(remote_tasks) for result_id in done_ids: - data_source, score_lst = ray.get(result_id) - # Count the number of correct responses (score == 1.0) - pass_count = sum(1 for score in score_lst if score == 1) - avg_score = float(np.mean(score_lst)) - avg_pass += avg_score - data_source_reward[data_source].append(avg_score) + data_source, score = ray.get(result_id) + data_source_reward[data_source].append(score) pbar.update(1) - # For each candidate k, update unbiased pass@k statistics - for k_val, _ in enumerate(score_lst, start=1): - if k_val in candidate_ks: - pass_k_stat[k_val] += unbiased_pass_at_k(max_k, pass_count, k_val) - - # Prepare output metrics - metric_output_path = config.data.path.replace(".parquet", "_metric.json") - metric_data = { - # Unbiased pass@k for each candidate k - **{f"pass@{k_val}": pass_k_stat[k_val] / total * 100.0 for k_val in candidate_ks}, - # Traditional average pass@1 metric - f"pass@1_(avg{max_k})": avg_pass / total * 100.0, - } - # Save metrics to JSON file - with open(metric_output_path, "w") as f: - json.dump(metric_data, f, indent=4) - - print(metric_data) - - # Print per-data-source average scores metric_dict = {} for data_source, rewards in data_source_reward.items(): - metric_dict[f"test_score(avg@k)/{data_source}"] = float(np.mean(rewards)) + metric_dict[f"test_score/{data_source}"] = np.mean(rewards) print(metric_dict) diff --git a/verl/trainer/main_generation.py b/verl/trainer/main_generation.py index ee9011629..394927146 100644 --- a/verl/trainer/main_generation.py +++ b/verl/trainer/main_generation.py @@ -35,7 +35,6 @@ from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup from verl.utils import hf_tokenizer -from verl.utils.device import is_cuda_available from verl.utils.fs import copy_to_local from verl.utils.hdfs_io import makedirs from verl.utils.model import compute_position_id_with_mask @@ -176,7 +175,11 @@ def main_task(config): ray_cls_with_init = RayClassWithInitArgs(cls=ray.remote(ActorRolloutRefWorker), config=config, role="rollout") resource_pool = RayResourcePool(process_on_nodes=[config.trainer.n_gpus_per_node] * config.trainer.nnodes) - wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init, device_name="cuda" if is_cuda_available else "npu") + wg = RayWorkerGroup( + resource_pool=resource_pool, + ray_cls_with_init=ray_cls_with_init, + device_name=config.trainer.device, + ) wg.init_model() # NOTE: updated by Reasoning360. Sample n times together @@ -246,6 +249,8 @@ def main_task(config): # Check if 'aime' is in the output path to determine if we should merge responses should_merge_aime = "aime" in config.data.output_path.lower() + # add to the data frame + dataset["responses"] = output_lst if should_merge_aime: print("Detected 'aime' in output path, merging responses by prompt content...") @@ -263,8 +268,7 @@ def main_task(config): print(f"Saved merged AIME responses to {config.data.output_path}") else: - # Original logic for non-AIME datasets - # add to the data frame + # NOTE: added by Reasoning360. dump results if is_polars_df: import polars as pl diff --git a/verl/trainer/main_ppo.py b/verl/trainer/main_ppo.py index bda714f3c..0545923ca 100644 --- a/verl/trainer/main_ppo.py +++ b/verl/trainer/main_ppo.py @@ -15,76 +15,139 @@ Note that we don't combine the main with ray_trainer as ray_trainer is used by other main. """ +import os +import socket + import hydra import ray +from omegaconf import OmegaConf +import uuid +import hashlib +from verl.experimental.dataset.sampler import AbstractSampler +from verl.trainer.constants_ppo import PPO_RAY_RUNTIME_ENV from verl.trainer.ppo.ray_trainer import RayPPOTrainer from verl.trainer.ppo.reward import load_reward_manager +from verl.utils.device import is_cuda_available +from verl.utils.import_utils import load_extern_type @hydra.main(config_path="config", config_name="ppo_trainer", version_base=None) def main(config): + """Main entry point for PPO training with Hydra configuration management. + + Args: + config_dict: Hydra configuration dictionary containing training parameters. + """ run_ppo(config) +# Define a function to run the PPO-like training process def run_ppo(config) -> None: + """Initialize Ray cluster and run distributed PPO training process. + + Args: + config: Training configuration object containing all necessary parameters + for distributed PPO training including Ray initialization settings, + model paths, and training hyperparameters. + """ + # Check if Ray is not initialized if not ray.is_initialized(): - # this is for local ray cluster + # Initialize Ray with a local cluster configuration + # Set environment variables in the runtime environment to control tokenizer parallelism, + # NCCL debug level, VLLM logging level, and allow runtime LoRA updating + # `num_cpus` specifies the number of CPU cores Ray can use, obtained from the configuration ray.init( - runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN", "VLLM_LOGGING_LEVEL": "WARN", "VLLM_ALLOW_RUNTIME_LORA_UPDATING": "true"}}, + runtime_env=PPO_RAY_RUNTIME_ENV, num_cpus=config.ray_init.num_cpus, ) - runner = TaskRunner.remote() + # Create a remote instance of the TaskRunner class, and + # Execute the `run` method of the TaskRunner instance remotely and wait for it to complete + if ( + is_cuda_available + and config.trainer.get("profile_steps") is not None + and len(config.trainer.get("profile_steps", [])) > 0 + ): + nsight_options = OmegaConf.to_container(config.trainer.controller_nsight_options) + runner = TaskRunner.options(runtime_env={"nsight": nsight_options}).remote() + else: + runner = TaskRunner.remote() ray.get(runner.run.remote(config)) + # [Optional] get the path of the timeline trace file from the configuration, default to None + # This file is used for performance analysis + timeline_json_file = config.ray_init.get("timeline_json_file", None) + if timeline_json_file: + ray.timeline(filename=timeline_json_file) + @ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head class TaskRunner: + """Ray remote class for executing distributed PPO training tasks. + + This class encapsulates the main training logic and runs as a Ray remote actor + to enable distributed execution across multiple nodes and GPUs. + """ + def run(self, config): - # print initial config + """Execute the main PPO training workflow. + + This method sets up the distributed training environment, initializes + workers, datasets, and reward functions, then starts the training process. + + Args: + config: Training configuration object containing all parameters needed + for setting up and running the PPO training process. + """ + # Print the initial configuration. `resolve=True` will evaluate symbolic values. from pprint import pprint from omegaconf import OmegaConf from verl.utils.fs import copy_to_local - pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values + print(f"TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}") + pprint(OmegaConf.to_container(config, resolve=True)) OmegaConf.resolve(config) - # download the checkpoint from hdfs - local_path = copy_to_local(config.actor_rollout_ref.model.path, use_shm=config.actor_rollout_ref.model.get("use_shm", False)) + # Download the checkpoint from HDFS to the local machine. + # `use_shm` determines whether to use shared memory, which could lead to faster model loading if turned on + local_path = copy_to_local( + config.actor_rollout_ref.model.path, use_shm=config.actor_rollout_ref.model.get("use_shm", False) + ) - # instantiate tokenizer + # Instantiate the tokenizer and processor. from verl.utils import hf_processor, hf_tokenizer trust_remote_code = config.data.get("trust_remote_code", False) tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) - processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True) # used for multimodal LLM, could be none + # Used for multimodal LLM, could be None + processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True) - # vllm early verify - if config.actor_rollout_ref.rollout.name in ["vllm"]: - from verl.utils.vllm_utils import is_version_ge - - if config.actor_rollout_ref.model.get("lora_rank", 0) > 0: - if not is_version_ge(pkg="vllm", minver="0.7.3"): - raise NotImplementedError("PPO LoRA is not supported before vllm 0.7.3") - - # define worker classes - if config.actor_rollout_ref.actor.strategy in ["fsdp", "fsdp2"]: - assert config.critic.strategy in ["fsdp", "fsdp2"] + # Define worker classes based on the actor strategy. + if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}: + assert config.critic.strategy in {"fsdp", "fsdp2"} from verl.single_controller.ray import RayWorkerGroup from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker, CriticWorker - actor_rollout_cls = AsyncActorRolloutRefWorker if config.actor_rollout_ref.rollout.mode == "async" else ActorRolloutRefWorker + actor_rollout_cls = ( + AsyncActorRolloutRefWorker + if config.actor_rollout_ref.rollout.mode == "async" + else ActorRolloutRefWorker + ) ray_worker_group_cls = RayWorkerGroup elif config.actor_rollout_ref.actor.strategy == "megatron": assert config.actor_rollout_ref.actor.strategy == config.critic.strategy from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup - from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker + from verl.workers.megatron_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker, CriticWorker - actor_rollout_cls = ActorRolloutRefWorker + actor_rollout_cls = ( + AsyncActorRolloutRefWorker + if config.actor_rollout_ref.rollout.mode == "async" + else ActorRolloutRefWorker + ) ray_worker_group_cls = NVMegatronRayWorkerGroup else: @@ -92,11 +155,14 @@ def run(self, config): from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role + # Map roles to their corresponding remote worker classes. role_worker_mapping = { Role.ActorRollout: ray.remote(actor_rollout_cls), Role.Critic: ray.remote(CriticWorker), } + # Define the resource pool specification. + # Map roles to the resource pool. global_pool_id = "global_pool" resource_pool_spec = { global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, @@ -106,14 +172,14 @@ def run(self, config): Role.Critic: global_pool_id, } - # we should adopt a multi-source reward function here + # We should adopt a multi-source reward function here: # - for rule-based rm, we directly call a reward score # - for model-based rm, we call a model # - for code related prompt, we send to a sandbox if there are test cases - # - finally, we combine all the rewards together - # - The reward type depends on the tag of the data + # finally, we combine all the rewards together + # The reward type depends on the tag of the data if config.reward_model.enable: - if config.reward_model.strategy in ["fsdp", "fsdp2"]: + if config.reward_model.strategy in {"fsdp", "fsdp2"}: from verl.workers.fsdp_workers import RewardModelWorker elif config.reward_model.strategy == "megatron": from verl.workers.megatron_workers import RewardModelWorker @@ -122,20 +188,28 @@ def run(self, config): role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker) mapping[Role.RewardModel] = global_pool_id - # use reference model + # Add a reference policy worker if KL loss or KL reward is used. if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss: role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker) mapping[Role.RefPolicy] = global_pool_id - reward_fn = load_reward_manager(config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {})) - val_reward_fn = load_reward_manager(config, tokenizer, num_examine=1, **config.reward_model.get("reward_kwargs", {})) + # Load the reward manager for training and validation. + reward_fn = load_reward_manager( + config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {}) + ) + val_reward_fn = load_reward_manager( + config, tokenizer, num_examine=1, **config.reward_model.get("reward_kwargs", {}) + ) resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) from verl.utils.dataset.rl_dataset import collate_fn - train_dataset = create_rl_dataset(config.data.train_files, config.data, tokenizer, processor) - val_dataset = create_rl_dataset(config.data.val_files, config.data, tokenizer, processor) + # Create training and validation datasets. + train_dataset = create_rl_dataset(config.data.train_files, config.data, tokenizer, processor, is_train=True) + val_dataset = create_rl_dataset(config.data.val_files, config.data, tokenizer, processor, is_train=False) train_sampler = create_rl_sampler(config.data, train_dataset) + + # Initialize the PPO trainer. trainer = RayPPOTrainer( config=config, tokenizer=tokenizer, @@ -151,14 +225,17 @@ def run(self, config): train_sampler=train_sampler, device_name=config.trainer.device, ) + # Initialize the workers of the trainer. trainer.init_workers() + # Start the training process. trainer.fit() -def create_rl_dataset(data_paths, data_config, tokenizer, processor): +def create_rl_dataset(data_paths, data_config, tokenizer, processor, is_train=True): """Create a dataset. Arguments: + data_paths: List of paths to data files. data_config: The data config. tokenizer (Tokenizer): The tokenizer. processor (Processor): The processor. @@ -170,16 +247,33 @@ def create_rl_dataset(data_paths, data_config, tokenizer, processor): from verl.utils.dataset.rl_dataset import RLHFDataset + # Check if a custom dataset class is specified in the data configuration + # and if the path to the custom class is provided if "custom_cls" in data_config and data_config.custom_cls.get("path", None) is not None: + # Dynamically load the custom dataset class + raise NotImplementedError("Custom dataset class is not supported yet, please use RLHFDataset instead.") from verl.utils.import_utils import load_extern_type dataset_cls = load_extern_type(data_config.custom_cls.path, data_config.custom_cls.name) + # Verify that the custom dataset class inherits from torch.utils.data.Dataset if not issubclass(dataset_cls, Dataset): - raise TypeError(f"The custom dataset class '{data_config.custom_cls.name}' from '{data_config.custom_cls.path}' must inherit from torch.utils.data.Dataset") + raise TypeError( + f"The custom dataset class '{data_config.custom_cls.name}' from " + f"'{data_config.custom_cls.path}' must inherit from torch.utils.data.Dataset" + ) + elif "datagen" in data_config and data_config.datagen.get("path", None) is not None and is_train: + # If a data generation strategy is specified, use the DynamicGenDataset class + from verl.utils.dataset.dynamicgen_dataset import DynamicGenDataset + + dataset_cls = DynamicGenDataset + print("Using DynamicGenDataset for data generation.") + else: + # Use the default RLHFDataset class if no custom class is specified dataset_cls = RLHFDataset print(f"Using dataset class: {dataset_cls.__name__}") + # Instantiate the dataset using the determined dataset class dataset = dataset_cls( data_files=data_paths, tokenizer=tokenizer, @@ -187,6 +281,21 @@ def create_rl_dataset(data_paths, data_config, tokenizer, processor): config=data_config, ) + # create a new feature called "prompt_id" to identify the prompt + def generate_simple_prompt_id(data_source, extra_info): + """Generate a simple unique prompt ID using data_source + split + index""" + split = extra_info.get("split", "unknown") + # if "original_question" in extra_info and extra_info["original_question"] is not None: + # prompt_bytes = extra_info["original_question"].encode('utf-8') + # sha256_hash = hashlib.sha256(prompt_bytes).hexdigest() + # else: + random_id = str(uuid.uuid4()) + return f"{data_source}_{split}_{random_id}" + + dataset.dataframe["prompt_id"] = dataset.dataframe.apply( + lambda row: generate_simple_prompt_id(row["data_source"], row["extra_info"]), axis=1 + ) + dataset.dataframe["on_policy_pass_rate"] = 0.0 return dataset @@ -203,12 +312,30 @@ def create_rl_sampler(data_config, dataset): import torch from torch.utils.data import RandomSampler, SequentialSampler - # use sampler for better ckpt resume - if data_config.shuffle: + if data_config.sampler is not None and data_config.sampler.get("class_path", None) is not None: + curriculum_class = load_extern_type( + data_config.sampler.class_path, + data_config.sampler.class_name, + ) + sampler = curriculum_class( + data_source=dataset, + data_config=data_config, + ) + assert isinstance(sampler, AbstractSampler) + assert data_config.get("dataloader_num_workers", 8) == 0, ( + "If using curriculum, num_workers must be 0 to prevent data caching. " + "If the dataloader caches data before the batch is done the " + "curriculum sampler won't have the opportunity to reorder it. " + ) + + # Use a sampler to facilitate checkpoint resumption. + # If shuffling is enabled in the data configuration, create a random sampler. + elif data_config.shuffle: train_dataloader_generator = torch.Generator() train_dataloader_generator.manual_seed(data_config.get("seed", 1)) sampler = RandomSampler(data_source=dataset, generator=train_dataloader_generator) else: + # If shuffling is disabled, use a sequential sampler to iterate through the dataset in order. sampler = SequentialSampler(data_source=dataset) return sampler diff --git a/verl/trainer/ppo/core_algos.py b/verl/trainer/ppo/core_algos.py index 71028b3a7..143d733c7 100644 --- a/verl/trainer/ppo/core_algos.py +++ b/verl/trainer/ppo/core_algos.py @@ -15,15 +15,117 @@ """ Core functions to implement PPO algorithms. The function implemented in this file should be used by trainer with different distributed strategies to -implement PPO +implement PPO-like algorithms. """ +__all__ = ["register_adv_est", "get_adv_estimator_fn", "AdvantageEstimator"] + from collections import defaultdict +from enum import Enum +from typing import Optional import numpy as np import torch import verl.utils.torch_functional as verl_F +from verl.trainer.config import AlgoConfig + +POLICY_LOSS_REGISTRY = {} + + +def register_policy_loss(name): + """Register a policy loss function with the given name. + + Args: + name (str): The name to register the policy loss function under. + + Returns: + function: Decorator function that registers the policy loss function. + """ + + def decorator(func): + POLICY_LOSS_REGISTRY[name] = func + return func + + return decorator + + +def get_policy_loss_fn(name): + """Get the policy loss with a given name. + + Args: + name: `(str)` + The name of the policy loss. + + Returns: + `(callable)`: The policy loss function. + """ + loss_name = name + if loss_name not in POLICY_LOSS_REGISTRY: + raise ValueError( + f"Unsupported loss mode: {loss_name}. Supported modes are: {list(POLICY_LOSS_REGISTRY.keys())}" + ) + return POLICY_LOSS_REGISTRY[loss_name] + + +ADV_ESTIMATOR_REGISTRY = {} + + +def register_adv_est(name_or_enum): + """Decorator to register a advantage estimator function with a given name. + + Args: + name_or_enum: `(str)` or `(AdvantageEstimator)` + The name or enum of the advantage estimator. + + """ + + def decorator(fn): + name = name_or_enum.value if isinstance(name_or_enum, Enum) else name_or_enum + if name in ADV_ESTIMATOR_REGISTRY and ADV_ESTIMATOR_REGISTRY[name] != fn: + raise ValueError( + f"Adv estimator {name} has already been registered: {ADV_ESTIMATOR_REGISTRY[name]} vs {fn}" + ) + ADV_ESTIMATOR_REGISTRY[name] = fn + return fn + + return decorator + + +def get_adv_estimator_fn(name_or_enum): + """Get the advantage estimator function with a given name. + + Args: + name_or_enum: `(str)` or `(AdvantageEstimator)` + The name or enum of the advantage estimator. + + Returns: + `(callable)`: The advantage estimator function. + """ + name = name_or_enum.value if isinstance(name_or_enum, Enum) else name_or_enum + if name not in ADV_ESTIMATOR_REGISTRY: + raise ValueError(f"Unknown advantage estimator simply: {name}") + return ADV_ESTIMATOR_REGISTRY[name] + + +class AdvantageEstimator(str, Enum): + """Using an enumeration class to avoid spelling errors in adv_estimator. + + Note(haibin.lin): this enum class is immutable after creation. Extending this + enum for new estimators may not be necessary since users can always just call + `verl.trainer.ppo.core_algos.register` with string name for a custom advantage + estimator instead. + """ + + GAE = "gae" + GRPO = "grpo" + REINFORCE_PLUS_PLUS = "reinforce_plus_plus" + REINFORCE_PLUS_PLUS_BASELINE = "reinforce_plus_plus_baseline" + REMAX = "remax" + RLOO = "rloo" + OPO = "opo" + GRPO_PASSK = "grpo_passk" + GPG = "gpg" class AdaptiveKLController: @@ -38,6 +140,12 @@ def __init__(self, init_kl_coef, target_kl, horizon): self.horizon = horizon def update(self, current_kl, n_steps): + """Update the KL coefficient based on current KL divergence. + + Args: + current_kl (float): Current KL divergence value. + n_steps (int): Number of steps taken. + """ target = self.target proportional_error = np.clip(current_kl / target - 1, -0.2, 0.2) mult = 1 + proportional_error * n_steps / self.horizon @@ -51,10 +159,28 @@ def __init__(self, kl_coef): self.value = kl_coef def update(self, current_kl, n_steps): + """Update method for fixed KL controller (no-op). + + Args: + current_kl (float): Current KL divergence value (unused). + n_steps (int): Number of steps taken (unused). + """ pass def get_kl_controller(kl_ctrl): + """Factory function to create appropriate KL controller based on configuration. + + Args: + kl_ctrl: Configuration object containing KL controller settings. + + Returns: + KL controller instance (FixedKLController or AdaptiveKLController). + + Raises: + NotImplementedError: If controller type is not supported. + AssertionError: If adaptive controller horizon is not positive. + """ if kl_ctrl.type == "fixed": return FixedKLController(kl_coef=kl_ctrl.kl_coef) elif kl_ctrl.type == "adaptive": @@ -64,6 +190,7 @@ def get_kl_controller(kl_ctrl): raise NotImplementedError +@register_adv_est(AdvantageEstimator.GAE) # or simply: @register_adv_est("gae") def compute_gae_advantage_return( token_level_rewards: torch.Tensor, values: torch.Tensor, @@ -93,14 +220,19 @@ def compute_gae_advantage_return( """ with torch.no_grad(): + nextvalues = 0 lastgaelam = 0 advantages_reversed = [] gen_len = token_level_rewards.shape[-1] for t in reversed(range(gen_len)): - nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0 delta = token_level_rewards[:, t] + gamma * nextvalues - values[:, t] - lastgaelam = delta + gamma * lam * lastgaelam + lastgaelam_ = delta + gamma * lam * lastgaelam + + # skip values and TD-error on observation tokens + nextvalues = values[:, t] * response_mask[:, t] + (1 - response_mask[:, t]) * nextvalues + lastgaelam = lastgaelam_ * response_mask[:, t] + (1 - response_mask[:, t]) * lastgaelam + advantages_reversed.append(lastgaelam) advantages = torch.stack(advantages_reversed[::-1], dim=1) @@ -110,13 +242,15 @@ def compute_gae_advantage_return( # NOTE(sgm): this implementation only consider outcome supervision, where the reward is a scalar. +@register_adv_est(AdvantageEstimator.GRPO) # or simply: @register_adv_est("grpo") def compute_grpo_outcome_advantage( token_level_rewards: torch.Tensor, response_mask: torch.Tensor, index: np.ndarray, epsilon: float = 1e-6, - norm_adv_by_std_in_grpo: str = True, -): + norm_adv_by_std_in_grpo: bool = True, + config: Optional[AlgoConfig] = None, +) -> tuple[torch.Tensor, torch.Tensor]: """ Compute advantage for GRPO, operating only on Outcome reward (with only one scalar reward for each response). @@ -126,10 +260,18 @@ def compute_grpo_outcome_advantage( shape is (bs, response_length) response_mask: `(torch.Tensor)` shape is (bs, response_length) - norm_adv_by_std_in_grpo: (bool) - whether to scale the GRPO advantage. - If True, the advantage is scaled by the std, as in the original GRPO. - If False, the advantage is not scaled, as in Dr.GRPO (https://arxiv.org/abs/2503.20783). + index: `(np.ndarray)` + index array for grouping + epsilon: `(float)` + small value to avoid division by zero + norm_adv_by_std_in_grpo: `(bool)` + whether to scale the GRPO advantage + config: `(Optional[AlgoConfig])` + algorithm configuration object + + Note: + If norm_adv_by_std_in_grpo is True, the advantage is scaled by the std, as in the original GRPO. + If False, the advantage is not scaled, as in Dr.GRPO (https://arxiv.org/abs/2503.20783). Returns: advantages: `(torch.Tensor)` @@ -146,8 +288,7 @@ def compute_grpo_outcome_advantage( with torch.no_grad(): bsz = scores.shape[0] for i in range(bsz): - id2score[index[i]].append(scores[i]) # index records idx in bsz -> idx in prompt. e.g, [0, 0, 1, 1, 2, 2] -> [0, 1, 2] - # id2score: {0: [1, 1, 0, 1], 1: [0, 0, 0, 1], ..., bsz: [1, 1, 1, 1]} + id2score[index[i]].append(scores[i]) for idx in id2score: if len(id2score[idx]) == 1: id2mean[idx] = torch.tensor(0.0) @@ -167,13 +308,16 @@ def compute_grpo_outcome_advantage( return scores, scores +@register_adv_est(AdvantageEstimator.GRPO_PASSK) # or simply: @register_adv_est("grpo_passk") def compute_grpo_passk_outcome_advantage( token_level_rewards: torch.Tensor, response_mask: torch.Tensor, index: np.ndarray, epsilon: float = 1e-6, norm_adv_by_std_in_grpo: bool = True, -): + config: Optional[AlgoConfig] = None, + **kwargs, +) -> tuple[torch.Tensor, torch.Tensor]: """ Compute advantage for Pass@k using a GRPO-style outcome reward formulation. Only the best response per group gets a non-zero advantage: r_max - r_second_max. @@ -185,12 +329,15 @@ def compute_grpo_passk_outcome_advantage( response_mask: (bs, response_length) index: (bs,) → group ID per sample epsilon: float for numerical stability - norm_adv_by_std_in_grpo: if True, normalize advantage by std within group + config: (AlgoConfig) algorithm settings, which contains "norm_adv_by_std_in_grpo" Returns: advantages: (bs, response_length) returns: (bs, response_length) """ + assert config is not None + # if True, normalize advantage by std within group + norm_adv_by_std_in_grpo = config.get("norm_adv_by_std_in_grpo", True) scores = token_level_rewards.sum(dim=-1) # (bs,) advantages = torch.zeros_like(scores) @@ -207,7 +354,9 @@ def compute_grpo_passk_outcome_advantage( for idx in id2scores: rewards = torch.stack(id2scores[idx]) # (k,) if rewards.numel() < 2: - raise ValueError(f"Pass@k requires at least 2 samples per group. Got {rewards.numel()} for group {idx}.") + raise ValueError( + f"Pass@k requires at least 2 samples per group. Got {rewards.numel()} for group {idx}." + ) topk, topk_idx = torch.topk(rewards, 2) r_max, r_second_max = topk[0], topk[1] i_max = id2indices[idx][topk_idx[0].item()] @@ -221,7 +370,17 @@ def compute_grpo_passk_outcome_advantage( return advantages, advantages -def compute_reinforce_plus_plus_baseline_outcome_advantage(token_level_rewards: torch.Tensor, response_mask: torch.Tensor, index: torch.Tensor, epsilon: float = 1e-6): +@register_adv_est( + AdvantageEstimator.REINFORCE_PLUS_PLUS_BASELINE +) # or simply: @register_adv_est("reinforce_plus_plus_baseline") +def compute_reinforce_plus_plus_baseline_outcome_advantage( + token_level_rewards: torch.Tensor, + response_mask: torch.Tensor, + index: torch.Tensor, + epsilon: float = 1e-6, + config: Optional[AlgoConfig] = None, + **kwargs, +) -> tuple[torch.Tensor, torch.Tensor]: """ Compute advantage for RF++-baseline (https://arxiv.org/abs/2501.03262), operating only on Outcome reward (with only one scalar reward for each response). @@ -231,6 +390,7 @@ def compute_reinforce_plus_plus_baseline_outcome_advantage(token_level_rewards: shape: (bs, response_length) response_mask: `(torch.Tensor)` shape: (bs, response_length) + config: (AlgoConfig) algorithm config Returns: advantages: `(torch.Tensor)` @@ -264,7 +424,15 @@ def compute_reinforce_plus_plus_baseline_outcome_advantage(token_level_rewards: return scores, scores -def compute_rloo_outcome_advantage(token_level_rewards: torch.Tensor, response_mask: torch.Tensor, index: np.ndarray, epsilon: float = 1e-6): +@register_adv_est(AdvantageEstimator.RLOO) # or simply: @register_adv_est("rloo") +def compute_rloo_outcome_advantage( + token_level_rewards: torch.Tensor, + response_mask: torch.Tensor, + index: np.ndarray, + epsilon: float = 1e-6, + config: Optional[AlgoConfig] = None, + **kwargs, +) -> tuple[torch.Tensor, torch.Tensor]: """ Compute advantage for RLOO based on https://arxiv.org/abs/2402.14740 @@ -273,6 +441,7 @@ def compute_rloo_outcome_advantage(token_level_rewards: torch.Tensor, response_m shape: (bs, response_length) response_mask: `(torch.Tensor)` shape: (bs, response_length) + config: (AlgoConfig) algorithm config Returns: advantages: `(torch.Tensor)` @@ -299,13 +468,23 @@ def compute_rloo_outcome_advantage(token_level_rewards: torch.Tensor, response_m for i in range(bsz): response_num = len(id2score[index[i]]) if response_num > 1: - scores[i] = scores[i] * response_num / (response_num - 1) - id2mean[index[i]] * response_num / (response_num - 1) + scores[i] = scores[i] * response_num / (response_num - 1) - id2mean[index[i]] * response_num / ( + response_num - 1 + ) scores = scores.unsqueeze(-1) * response_mask return scores, scores -def compute_opo_outcome_advantage(token_level_rewards: torch.Tensor, response_mask: torch.Tensor, index: np.ndarray, epsilon: float = 1e-6): +@register_adv_est(AdvantageEstimator.OPO) # or simply: @register_adv_est("opo") +def compute_opo_outcome_advantage( + token_level_rewards: torch.Tensor, + response_mask: torch.Tensor, + index: np.ndarray, + epsilon: float = 1e-6, + config: Optional[AlgoConfig] = None, + **kwargs, +) -> tuple[torch.Tensor, torch.Tensor]: """ Compute advantage for OPO based on https://arxiv.org/pdf/2505.23585 @@ -314,6 +493,7 @@ def compute_opo_outcome_advantage(token_level_rewards: torch.Tensor, response_ma shape: (bs, response_length) response_mask: `(torch.Tensor)` shape: (bs, response_length) + config: (AlgoConfig) algorithm config Returns: advantages: `(torch.Tensor)` @@ -350,7 +530,10 @@ def compute_opo_outcome_advantage(token_level_rewards: torch.Tensor, response_ma return scores, scores -def compute_reinforce_plus_plus_outcome_advantage(token_level_rewards: torch.Tensor, response_mask: torch.Tensor, gamma: torch.Tensor): +@register_adv_est(AdvantageEstimator.REINFORCE_PLUS_PLUS) # or simply: @register_adv_est("reinforce_plus_plus") +def compute_reinforce_plus_plus_outcome_advantage( + token_level_rewards: torch.Tensor, response_mask: torch.Tensor, config: Optional[AlgoConfig] = None, **kwargs +) -> tuple[torch.Tensor, torch.Tensor]: """ Compute advantage for REINFORCE++. This implementation is based on the paper: https://arxiv.org/abs/2501.03262 @@ -360,6 +543,7 @@ def compute_reinforce_plus_plus_outcome_advantage(token_level_rewards: torch.Ten shape: (bs, response_length) response_mask: `(torch.Tensor)` shape: (bs, response_length) + config: (AlgoConfig) algorithm config Returns: advantages: `(torch.Tensor)` @@ -367,7 +551,8 @@ def compute_reinforce_plus_plus_outcome_advantage(token_level_rewards: torch.Ten Returns: `(torch.Tensor)` shape: (bs, response_length) """ - + assert config is not None + gamma = config.gamma with torch.no_grad(): returns = torch.zeros_like(token_level_rewards) running_return = 0 @@ -384,7 +569,14 @@ def compute_reinforce_plus_plus_outcome_advantage(token_level_rewards: torch.Ten return advantages, returns -def compute_remax_outcome_advantage(token_level_rewards: torch.Tensor, reward_baselines: torch.Tensor, response_mask: torch.Tensor): +@register_adv_est(AdvantageEstimator.REMAX) # or simply: @register_adv_est("remax") +def compute_remax_outcome_advantage( + token_level_rewards: torch.Tensor, + reward_baselines: torch.Tensor, + response_mask: torch.Tensor, + config: Optional[AlgoConfig] = None, + **kwargs, +) -> tuple[torch.Tensor, torch.Tensor]: """ Compute advantage for ReMax, operating only on Outcome reward This implementation is based on the paper: https://arxiv.org/abs/2310.10505 @@ -397,6 +589,7 @@ def compute_remax_outcome_advantage(token_level_rewards: torch.Tensor, reward_ba shape: (bs,) response_mask: `(torch.Tensor)` shape: (bs, response_length) + config: (AlgoConfig) algorithm config Returns: advantages: `(torch.Tensor)` @@ -412,7 +605,80 @@ def compute_remax_outcome_advantage(token_level_rewards: torch.Tensor, reward_ba return advantages, returns +@register_adv_est(AdvantageEstimator.GPG) # or simply: @register_adv_est("gpg") +def compute_gpg_outcome_advantage( + token_level_rewards: torch.Tensor, + response_mask: torch.Tensor, + index: np.ndarray, + epsilon: float = 1e-6, + f_norm: float = 1.0, + alpha: float = 1.0, + config=None, + **kwargs, +): + """ + Compute advantage for GPG, operating only on Outcome reward + (with only one scalar reward for each response). + Args: + token_level_rewards: `(torch.Tensor)` + shape: (bs, response_length) + response_mask: `(torch.Tensor)` + shape: (bs, response_length) + index: `(np.ndarray)` + shape: (bs,) + epsilon: (float) + f_norm: (float) + alpha: (float) + config: (dict) algorithm config + + Returns: + advantages: `(torch.Tensor)` + shape: (bs, response_length) + Returns: `(torch.Tensor)` + shape: (bs, response_length) + """ + scores = token_level_rewards.sum(dim=-1) + + id2score = defaultdict(list) + id2mean = {} + id2std = {} + + with torch.no_grad(): + bsz = scores.shape[0] + m = torch.count_nonzero(scores) + alpha = bsz / m.clamp(min=1) + + for i in range(bsz): + id2score[index[i]].append(scores[i]) + + for idx in id2score: + if len(id2score[idx]) == 1: + id2mean[idx] = torch.tensor(0.0) + id2std[idx] = torch.tensor(1.0) + elif len(id2score[idx]) > 1: + id2mean[idx] = torch.mean(torch.tensor(id2score[idx])) + id2std[idx] = torch.std(torch.tensor([id2score[idx]])) + else: + raise ValueError(f"no score in prompt index: {idx}") + for i in range(bsz): + scores[i] = alpha * (scores[i] - id2mean[index[i]]) / (f_norm) + scores = scores.unsqueeze(-1) * response_mask + + return scores, scores + + def compute_rewards(token_level_scores, old_log_prob, ref_log_prob, kl_ratio): + """Compute token-level rewards with KL penalty. + + Args: + token_level_scores (torch.Tensor): Token-level reward scores. + old_log_prob (torch.Tensor): Log probabilities from current policy. + ref_log_prob (torch.Tensor): Log probabilities from reference policy. + kl_ratio (float): KL penalty coefficient. + + Returns: + torch.Tensor: Token-level rewards with KL penalty applied. + """ kl = old_log_prob - ref_log_prob return token_level_scores - kl * kl_ratio @@ -420,6 +686,7 @@ def compute_rewards(token_level_scores, old_log_prob, ref_log_prob, kl_ratio): def agg_loss(loss_mat: torch.Tensor, loss_mask: torch.Tensor, loss_agg_mode: str): """ Aggregate the loss matrix into a scalar. + Args: loss_mat: `(torch.Tensor)`: shape: (bs, response_length) @@ -427,11 +694,9 @@ def agg_loss(loss_mat: torch.Tensor, loss_mask: torch.Tensor, loss_agg_mode: str shape: (bs, response_length) loss_agg_mode: (str) choices: method to aggregate the loss matrix into a scalar. - Returns: loss: `a scalar torch.Tensor` aggregated loss - """ if loss_agg_mode == "token-mean": loss = verl_F.masked_mean(loss_mat, loss_mask) @@ -493,24 +758,36 @@ def compute_policy_loss( loss_agg_mode (str, optional): Aggregation mode for `agg_loss`. Defaults to "token-mean". """ - assert clip_ratio_c > 1.0, "The lower bound of the clip_ratio_c for dual-clip PPO should be greater than 1.0," + f" but get the value: {clip_ratio_c}." + assert clip_ratio_c > 1.0, ( + "The lower bound of the clip_ratio_c for dual-clip PPO should be greater than 1.0," + + f" but get the value: {clip_ratio_c}." + ) negative_approx_kl = log_prob - old_log_prob + # Clamp negative_approx_kl for stability + negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0) ratio = torch.exp(negative_approx_kl) ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask) pg_losses1 = -advantages * ratio + if cliprange_low is None: cliprange_low = cliprange if cliprange_high is None: cliprange_high = cliprange - pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, 1 + cliprange_high) # - clip(ratio, 1-cliprange, 1+cliprange) * A - clip_pg_losses1 = torch.maximum(pg_losses1, pg_losses2) # max(-ratio * A, -clip(ratio, 1-cliprange, 1+cliprange) * A) + pg_losses2 = -advantages * torch.clamp( + ratio, 1 - cliprange_low, 1 + cliprange_high + ) # - clip(ratio, 1-cliprange, 1+cliprange) * A + clip_pg_losses1 = torch.maximum( + pg_losses1, pg_losses2 + ) # max(-ratio * A, -clip(ratio, 1-cliprange, 1+cliprange) * A) pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses1).float(), response_mask) pg_losses3 = -advantages * clip_ratio_c clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1) - pg_clipfrac_lower = verl_F.masked_mean(torch.gt(clip_pg_losses1, pg_losses3) * (advantages < 0).float(), response_mask) + pg_clipfrac_lower = verl_F.masked_mean( + torch.gt(clip_pg_losses1, pg_losses3) * (advantages < 0).float(), response_mask + ) pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1) pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) @@ -518,6 +795,183 @@ def compute_policy_loss( return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower +@register_policy_loss("gpg") +def compute_policy_loss_gpg(old_log_prob, log_prob, advantages, response_mask, loss_agg_mode="token-mean", config=None): + """Adapted from + https://github.com/AMAP-ML/GPG/blob/main/VisualThinker-R1-Zero/src/open-r1-multimodal/src/open_r1/trainer/grpo_trainer.py#L495 + Args: + log_prob: `(torch.Tensor)` + shape: (bs, response_length) + advantages: `(torch.Tensor)` + shape: (bs, response_length) + response_mask: `(torch.Tensor)` + shape: (bs, response_length) + return: + pg_loss: `a scalar torch.Tensor` + policy gradient loss computed via GPG + """ + pg_losses = -log_prob * advantages + + pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) + return pg_loss, torch.tensor(0.0), torch.tensor(0.0), torch.tensor(0.0) + + +@register_policy_loss("clip_cov") +def compute_policy_loss_clip_cov( + old_log_prob: torch.Tensor, + log_prob: torch.Tensor, + advantages: torch.Tensor, + response_mask: torch.Tensor, + loss_agg_mode: str = "token-mean", + config: Optional[AlgoConfig] = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Compute the clipped policy objective and related metrics for Clip-Cov. + + Adapted from + https://github.com/PRIME-RL/Entropy-Mechanism-of-RL/blob/main/verl/trainer/ppo/core_algos.py + + Args: + old_log_prob (torch.Tensor): + Log-probabilities of actions under the old policy, shape (batch_size, response_length). + log_prob (torch.Tensor): + Log-probabilities of actions under the current policy, shape (batch_size, response_length). + advantages (torch.Tensor): + Advantage estimates for each action, shape (batch_size, response_length). + response_mask (torch.Tensor): + Mask indicating which tokens to include in the loss, shape (batch_size, response_length). + cliprange (float, optional): + Clipping parameter ε for standard PPO. See https://arxiv.org/abs/1707.06347. + Defaults to None (must be provided). + cliprange_low (float, optional): + Lower clip range for dual-clip PPO. Defaults to same as `cliprange`. + cliprange_high (float, optional): + Upper clip range for dual-clip PPO. Defaults to same as `cliprange`. + loss_agg_mode (str, optional): + Aggregation mode for `agg_loss`. Defaults to "token-mean". + clip_cvo_ratio (float, optional): + Ratio for clipping the covariance. Defaults to 0.0002. + clip_cov_lb (float, optional): + Lower bound for clipping covariance. Defaults to 1.0. + clip_cov_ub (float, optional): + Upper bound for clipping covariance. Defaults to 5.0. + """ + clip_cov_ratio = config.policy_loss.clip_cov_ratio if config.policy_loss.clip_cov_ratio is not None else 0.0002 + cliprange = config.clip_ratio + cliprange_low = config.clip_ratio_low if config.clip_ratio_low is not None else cliprange + cliprange_high = config.clip_ratio_high if config.clip_ratio_high is not None else cliprange + clip_cov_ub = config.policy_loss.clip_cov_ub if config.policy_loss.clip_cov_ub is not None else 5.0 + clip_cov_lb = config.policy_loss.clip_cov_lb if config.policy_loss.clip_cov_lb is not None else 1.0 + + assert clip_cov_ratio > 0, "clip_ratio should be larger than 0." + + negative_approx_kl = log_prob - old_log_prob + ratio = torch.exp(negative_approx_kl) + ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask) + + pg_losses1 = -advantages * ratio + if cliprange_low is None: + cliprange_low = cliprange + if cliprange_high is None: + cliprange_high = cliprange + + corr = torch.ones_like(advantages) + pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, 1 + cliprange_high) + clip_by_origin = (pg_losses2 > pg_losses1) & (response_mask > 0) + + cov_all = (advantages - verl_F.masked_mean(advantages, response_mask)) * ( + log_prob - verl_F.masked_mean(log_prob.detach(), response_mask) + ) + cov_all[response_mask == 0] = -torch.inf + cov_all[clip_by_origin] = -torch.inf + + clip_num = max(int(clip_cov_ratio * response_mask.sum().item()), 1) + top_k_idx = (cov_all < clip_cov_ub) & (cov_all > clip_cov_lb) & (response_mask > 0) + top_k_idx = torch.nonzero(top_k_idx) + + if len(top_k_idx) > 0: + perm = torch.randperm(len(top_k_idx)) + top_k_idx = top_k_idx[perm[: min(clip_num, len(top_k_idx))]] + else: + top_k_idx = torch.empty((0, 2), device=cov_all.device, dtype=torch.long) + + corr[top_k_idx[:, 0], top_k_idx[:, 1]] = 0 + + pg_clipfrac = verl_F.masked_mean((corr == 0).float(), response_mask) + + pg_losses = torch.maximum(pg_losses1, pg_losses2) * corr + pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) + + return pg_loss, pg_clipfrac, ppo_kl, torch.tensor(0.0) + + +@register_policy_loss("kl_cov") +def compute_policy_loss_kl_cov( + old_log_prob: torch.Tensor, + log_prob: torch.Tensor, + advantages: torch.Tensor, + response_mask: torch.Tensor, + loss_agg_mode: str = "token-mean", + config: Optional[AlgoConfig] = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Compute the clipped policy objective and related metrics for Clip-Cov. + + Adapted from + https://github.com/PRIME-RL/Entropy-Mechanism-of-RL/blob/main/verl/trainer/ppo/core_algos.py + + Args: + old_log_prob (torch.Tensor): + Log-probabilities of actions under the old policy, shape (batch_size, response_length). + log_prob (torch.Tensor): + Log-probabilities of actions under the current policy, shape (batch_size, response_length). + advantages (torch.Tensor): + Advantage estimates for each action, shape (batch_size, response_length). + response_mask (torch.Tensor): + Mask indicating which tokens to include in the loss, shape (batch_size, response_length). + loss_agg_mode (str, optional): + Aggregation mode for `agg_loss`. Defaults to "token-mean". + kl_cov_ratio (float, optional): + Ratio for selecting the top-k covariance values. Defaults to 0.0002. + ppo_kl_coef (float, optional): + Coefficient for the KL penalty term in the loss. Defaults to 1. + """ + kl_cov_ratio = config.policy_loss.kl_cov_ratio if config.policy_loss.kl_cov_ratio is not None else 0.0002 + ppo_kl_coef = config.policy_loss.ppo_kl_coef if config.policy_loss.ppo_kl_coef is not None else 1.0 + + assert kl_cov_ratio > 0, "kl_cov_ratio should be larger than 0." + + negative_approx_kl = log_prob - old_log_prob + abs_kl = negative_approx_kl.abs() + ratio = torch.exp(negative_approx_kl) + ppo_kl_abs = verl_F.masked_mean(negative_approx_kl.abs(), response_mask) + pg_losses1 = -advantages * ratio + pg_losses_kl = -advantages * ratio + ppo_kl_coef * abs_kl + pg_losses = pg_losses1 + + all_valid = response_mask > 0 + all_valid_idx = torch.nonzero(all_valid.reshape(-1), as_tuple=True)[0] + all_valid_adv = advantages[all_valid].detach().reshape(-1).cpu() + all_valid_logp = log_prob[all_valid].detach().reshape(-1).cpu() + + k = min(kl_cov_ratio, len(all_valid_adv)) + + if k != 0: + cov_lst_all = (all_valid_adv - all_valid_adv.mean()) * (all_valid_logp - all_valid_logp.mean()) + k_percent_nums = max(1, int(len(cov_lst_all) * kl_cov_ratio)) + large_cov_idxs = torch.topk(cov_lst_all, k_percent_nums, largest=True).indices + + if len(large_cov_idxs) != 0: + large_cov_idxs = all_valid_idx[large_cov_idxs] + pg_losses[large_cov_idxs // advantages.shape[1], large_cov_idxs % advantages.shape[1]] = pg_losses_kl[ + large_cov_idxs // advantages.shape[1], large_cov_idxs % advantages.shape[1] + ] + + pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) + + return pg_loss, torch.tensor(0.0), ppo_kl_abs, torch.tensor(0.0) + + def compute_entropy_loss(logits, response_mask, loss_agg_mode: str = "token-mean"): """Compute categorical entropy loss (For backward compatibility) @@ -535,7 +989,14 @@ def compute_entropy_loss(logits, response_mask, loss_agg_mode: str = "token-mean return entropy_loss -def compute_value_loss(vpreds: torch.Tensor, returns: torch.Tensor, values: torch.Tensor, response_mask: torch.Tensor, cliprange_value: float, loss_agg_mode: str = "token-mean"): +def compute_value_loss( + vpreds: torch.Tensor, + returns: torch.Tensor, + values: torch.Tensor, + response_mask: torch.Tensor, + cliprange_value: float, + loss_agg_mode: str = "token-mean", +): """ Compute the clipped value-function loss for PPO. @@ -565,7 +1026,7 @@ def compute_value_loss(vpreds: torch.Tensor, returns: torch.Tensor, values: torc vf_losses1 = (vpreds - returns) ** 2 vf_losses2 = (vpredclipped - returns) ** 2 clipped_vf_losses = torch.max(vf_losses1, vf_losses2) - vf_loss = agg_loss(loss_mat=clipped_vf_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) + vf_loss = 0.5 * agg_loss(loss_mat=clipped_vf_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) vf_clipfrac = verl_F.masked_mean(torch.gt(vf_losses2, vf_losses1).float(), response_mask) return vf_loss, vf_clipfrac @@ -595,6 +1056,8 @@ def kl_penalty(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_pe # # URL http://joschu.net/blog/kl-approx.html. if kl_penalty in ("low_var_kl", "k3"): kl = ref_logprob - logprob + # For numerical stability + kl = torch.clamp(kl, min=-20, max=20) ratio = torch.exp(kl) kld = (ratio - kl - 1).contiguous() return torch.clamp(kld, min=-10, max=10) @@ -624,6 +1087,19 @@ def compute_pf_ppo_reweight_data( @torch.no_grad() def compute_weights(scores: torch.Tensor, reweight_method: str, weight_pow: float) -> torch.Tensor: + """Compute importance weights for resampling based on scores. + + Args: + scores (torch.Tensor): Tensor of scores to compute weights from. + reweight_method (str): Method for computing weights ('pow', 'max_min', 'max_random'). + weight_pow (float): Power exponent for 'pow' method. + + Returns: + torch.Tensor: Computed importance weights. + + Raises: + ValueError: If reweight_method is not supported. + """ if reweight_method == "pow": weights = torch.pow(torch.abs(scores), weight_pow) elif reweight_method == "max_min": diff --git a/verl/trainer/ppo/metric_utils.py b/verl/trainer/ppo/metric_utils.py index 31a761d26..fcac9f508 100644 --- a/verl/trainer/ppo/metric_utils.py +++ b/verl/trainer/ppo/metric_utils.py @@ -17,7 +17,7 @@ from collections import defaultdict from functools import partial -from typing import Any, Callable, Dict, List +from typing import Any, Callable import numpy as np import torch @@ -29,9 +29,8 @@ # NOTE: added by Reasoning360. _scores_tables = {} # Global dictionary to store wandb tables - @deprecated("verl.utils.metric.reduce_metrics") -def reduce_metrics(metrics: Dict[str, List[Any]]) -> Dict[str, Any]: +def reduce_metrics(metrics: dict[str, list[Any]]) -> dict[str, Any]: """ Reduces a dictionary of metric lists by computing the mean of each list. @@ -51,15 +50,15 @@ def reduce_metrics(metrics: Dict[str, List[Any]]) -> Dict[str, Any]: return reduce_metrics(metrics) -def _compute_response_info(batch: DataProto) -> Dict[str, Any]: +def _compute_response_info(batch: DataProto) -> dict[str, Any]: """ Computes information about prompts and responses from a batch. - + This is an internal helper function that extracts masks and lengths for prompts and responses. - + Args: batch: A DataProto object containing batch data with responses and attention masks. - + Returns: A dictionary containing: - response_mask: Attention mask for the response tokens @@ -81,7 +80,7 @@ def _compute_response_info(batch: DataProto) -> Dict[str, Any]: ) -def compute_data_metrics(batch: DataProto, use_critic: bool = True) -> Dict[str, Any]: +def compute_data_metrics(batch: DataProto, use_critic: bool = True) -> dict[str, Any]: """ Computes various metrics from a batch of data for PPO training. @@ -103,6 +102,7 @@ def compute_data_metrics(batch: DataProto, use_critic: bool = True) -> Dict[str, - critic/vf_explained_var: Explained variance of the value function (if use_critic=True) - response_length/mean, max, min, clip_ratio: Statistics about response lengths - prompt_length/mean, max, min, clip_ratio: Statistics about prompt lengths + - num_turns/mean, max, min: Statistics about the number of multi-turn conversations """ sequence_score = batch.batch["token_level_scores"].sum(-1) sequence_reward = batch.batch["token_level_rewards"].sum(-1) @@ -110,10 +110,13 @@ def compute_data_metrics(batch: DataProto, use_critic: bool = True) -> Dict[str, advantages = batch.batch["advantages"] returns = batch.batch["returns"] - max_response_length = batch.batch["responses"].shape[-1] + max_response_length_for_mask = batch.batch["responses"].shape[-1] # This is not the actual max value of response lengths, only for masking purposes. - prompt_mask = batch.batch["attention_mask"][:, :-max_response_length].bool() - response_mask = batch.batch["attention_mask"][:, -max_response_length:].bool() + # LLM360 changes for length control + # prompt_mask = batch.batch["attention_mask"][:, :-max_response_length].bool() + # response_mask = batch.batch["response_mask"].bool() + prompt_mask = batch.batch["attention_mask"][:, :-max_response_length_for_mask].bool() + response_mask = batch.batch["attention_mask"][:, -max_response_length_for_mask:].bool() max_prompt_length = prompt_mask.size(-1) @@ -121,6 +124,24 @@ def compute_data_metrics(batch: DataProto, use_critic: bool = True) -> Dict[str, prompt_length = response_info["prompt_length"] response_length = response_info["response_length"] + # Handle individual length control + if "per_prompt_length_budget" in batch.non_tensor_batch: + # For individual response lengths (vary_length mode) + per_prompt_length_budget = batch.non_tensor_batch["per_prompt_length_budget"] + if isinstance(per_prompt_length_budget, (list, np.ndarray)): + per_prompt_length_budget = torch.tensor([float(x) for x in per_prompt_length_budget], dtype=torch.float32) + + # Use the maximum across all prompts for reference + max_response_length = torch.max(per_prompt_length_budget).item() + elif batch.meta_info.get("target_max_response_length", None) is not None: + # Traditional approach: single target length for all samples + max_response_length = batch.meta_info["target_max_response_length"] + per_prompt_length_budget = None + else: + # No target length specified, use the maximum actual response length + max_response_length = torch.max(response_length).item() + per_prompt_length_budget = None + valid_adv = torch.masked_select(advantages, response_mask) valid_returns = torch.masked_select(returns, response_mask) @@ -137,6 +158,35 @@ def compute_data_metrics(batch: DataProto, use_critic: bool = True) -> Dict[str, data_source_response_lengths[data_source].append(response_length[i].item()) data_source_scores[data_source].append(sequence_score[i].item()) + # Calculate clip ratio based on individual vs batch-level target lengths + if per_prompt_length_budget is not None: + # Individual length control: compare each response to its own target length + # Ensure per_prompt_length_budget is on the same device as response_length + per_prompt_length_budget = per_prompt_length_budget.to(response_length.device) + # Compare each response_length[i] to per_prompt_length_budget[i] + # NOTE: Fixed clip ratio calculation - should measure if response hit the generation limit + # not if meaningful tokens exactly match target length + + # Debug: print some statistics + print(f"[DEBUG] response_length stats: min={torch.min(response_length).item():.1f}, " + f"max={torch.max(response_length).item():.1f}, mean={torch.mean(response_length).item():.1f}") + print(f"[DEBUG] per_prompt_length_budget stats: min={torch.min(per_prompt_length_budget).item():.1f}, " + f"max={torch.max(per_prompt_length_budget).item():.1f}, mean={torch.mean(per_prompt_length_budget).item():.1f}") + print(f"[DEBUG] Number of responses >= target: {torch.sum(torch.ge(response_length, per_prompt_length_budget)).item()}/{len(response_length)}") + + clip_ratio = torch.mean(torch.ge(response_length, per_prompt_length_budget).float()).item() + else: + # Traditional approach: compare all samples to the same target length + # For traditional approach, if max_response_length is from target, use >= comparison + # If it's from actual max length in batch, use == comparison (no clipping occurred) + if batch.meta_info.get("target_max_response_length", None) is not None: + # Target length specified, measure how often we hit/exceed it + clip_ratio = torch.mean(torch.ge(response_length, max_response_length).float()).item() + else: + # No target specified, max_response_length is actual max in batch + # This measures how often responses reached the maximum actual length + clip_ratio = torch.mean(torch.eq(response_length, max_response_length).float()).item() + metrics = { # score "critic/score/mean": torch.mean(sequence_score).detach().item(), @@ -170,7 +220,7 @@ def compute_data_metrics(batch: DataProto, use_critic: bool = True) -> Dict[str, "response_length/mean": torch.mean(response_length).detach().item(), "response_length/max": torch.max(response_length).detach().item(), "response_length/min": torch.min(response_length).detach().item(), - "response_length/clip_ratio": torch.mean(torch.eq(response_length, max_response_length).float()).detach().item(), + "response_length/clip_ratio": clip_ratio, # prompt length "prompt_length/mean": torch.mean(prompt_length).detach().item(), "prompt_length/max": torch.max(prompt_length).detach().item(), @@ -178,14 +228,45 @@ def compute_data_metrics(batch: DataProto, use_critic: bool = True) -> Dict[str, "prompt_length/clip_ratio": torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(), } + # multi-turn conversation + if "__num_turns__" in batch.non_tensor_batch: + num_turns = batch.non_tensor_batch["__num_turns__"] + metrics["num_turns/min"] = num_turns.min() + metrics["num_turns/max"] = num_turns.max() + metrics["num_turns/mean"] = num_turns.mean() + # Add data source specific response length metrics for data_source, lengths in data_source_response_lengths.items(): - lengths_tensor = torch.tensor(lengths) + lengths_tensor = torch.tensor(lengths, device=response_length.device) + + # Calculate clip ratio for this data source + if per_prompt_length_budget is not None: + # Get target lengths for this data source + data_source_indices = [i for i, ds in enumerate(batch.non_tensor_batch['data_source']) if ds == data_source] + if len(data_source_indices) > 0: + data_source_targets = per_prompt_length_budget[data_source_indices] + # Ensure both tensors are on the same device and have the same shape + data_source_targets = data_source_targets.to(response_length.device) + # NOTE: Fixed clip ratio calculation - compare each response to its own target length + data_source_clip_ratio = torch.mean(torch.ge(lengths_tensor, data_source_targets).float()).item() + else: + data_source_clip_ratio = 0.0 + else: + # For traditional approach, use same logic as main clip ratio calculation + if batch.meta_info.get("target_max_response_length", None) is not None: + # Target length specified, measure how often we hit/exceed it + data_source_clip_ratio = torch.mean(torch.ge(lengths_tensor, max_response_length).float()).item() + else: + # No target specified, use == comparison with actual max + data_source_clip_ratio = torch.mean(torch.eq(lengths_tensor, max_response_length).float()).item() + + metrics[f"response_length/{data_source}/clip_ratio"] = data_source_clip_ratio + + # Add other data source specific metrics metrics.update({ f"response_length/{data_source}/mean": torch.mean(lengths_tensor).item(), f"response_length/{data_source}/max": torch.max(lengths_tensor).item(), f"response_length/{data_source}/min": torch.min(lengths_tensor).item(), - f"response_length/{data_source}/clip_ratio": torch.mean(torch.eq(lengths_tensor, max_response_length).float()).item(), }) # Add data source specific reward metrics @@ -197,15 +278,16 @@ def compute_data_metrics(batch: DataProto, use_critic: bool = True) -> Dict[str, f"critic/scores/{data_source}/min": torch.min(scores_tensor).item(), f"critic/scores/{data_source}/std": torch.std(scores_tensor).item(), }) + return metrics -def compute_timing_metrics(batch: DataProto, timing_raw: Dict[str, float]) -> Dict[str, Any]: +def compute_timing_metrics(batch: DataProto, timing_raw: dict[str, float]) -> dict[str, Any]: """ Computes timing metrics for different processing stages in PPO training. - - This function calculates both raw timing metrics (in seconds) and per-token timing metrics - (in milliseconds) for various processing stages like generation, reference computation, + + This function calculates both raw timing metrics (in seconds) and per-token timing metrics + (in milliseconds) for various processing stages like generation, reference computation, value computation, advantage computation, and model updates. Args: @@ -235,30 +317,33 @@ def compute_timing_metrics(batch: DataProto, timing_raw: Dict[str, float]) -> Di return { **{f"timing_s/{name}": value for name, value in timing_raw.items()}, - **{f"timing_per_token_ms/{name}": timing_raw[name] * 1000 / num_tokens_of_section[name] for name in set(num_tokens_of_section.keys()) & set(timing_raw.keys())}, + **{ + f"timing_per_token_ms/{name}": timing_raw[name] * 1000 / num_tokens_of_section[name] + for name in set(num_tokens_of_section.keys()) & set(timing_raw.keys()) + }, } -def compute_throughout_metrics(batch: DataProto, timing_raw: Dict[str, float], n_gpus: int) -> Dict[str, Any]: +def compute_throughout_metrics(batch: DataProto, timing_raw: dict[str, float], n_gpus: int) -> dict[str, Any]: """ Computes throughput metrics for PPO training. - + This function calculates performance metrics related to token processing speed, including the total number of tokens processed, time per step, and throughput (tokens per second per GPU). - + Args: batch: A DataProto object containing batch data with meta information about token counts. timing_raw: A dictionary mapping stage names to their execution times in seconds. Must contain a "step" key with the total step time. n_gpus: Number of GPUs used for training. - + Returns: A dictionary containing: - perf/total_num_tokens: Total number of tokens processed in the batch - perf/time_per_step: Time taken for the step in seconds - perf/throughput: Tokens processed per second per GPU - + Note: The throughput is calculated as total_tokens / (time * n_gpus) to normalize across different GPU counts. @@ -352,15 +437,16 @@ def calc_maj_val(data: list[dict[str, Any]], vote_key: str, val_key: str) -> flo return maj_val -def process_validation_metrics(data_sources: list[str], sample_inputs: list[str], infos_dict: dict[str, list[Any]], seed: int = 42) -> dict[str, dict[str, dict[str, float]]]: +def process_validation_metrics( + data_sources: list[str], sample_inputs: list[str], infos_dict: dict[str, list[Any]], seed: int = 42 +) -> dict[str, dict[str, dict[str, float]]]: """ Process validation metrics into a structured format with statistical analysis. - + This function organizes validation metrics by data source and prompt, then computes various statistical measures including means, standard deviations, best/worst values, and majority voting results. It also performs bootstrap sampling to estimate statistics for different sample sizes. - Args: data_sources: List of data source identifiers for each sample. sample_inputs: List of input prompts corresponding to each sample. @@ -376,7 +462,7 @@ def process_validation_metrics(data_sources: list[str], sample_inputs: list[str] } } } - + Where metric_name includes: - "mean@N": Mean value across N samples - "std@N": Standard deviation across N samples @@ -386,7 +472,7 @@ def process_validation_metrics(data_sources: list[str], sample_inputs: list[str] - "worst@N/std": Standard deviation of the worst values in bootstrap samples - "maj@N/mean": Mean of majority voting results in bootstrap samples (if "pred" exists) - "maj@N/std": Standard deviation of majority voting results (if "pred" exists) - + Example: >>> data_sources = ["source1", "source1", "source2"] >>> sample_inputs = ["prompt1", "prompt1", "prompt2"] @@ -425,11 +511,15 @@ def process_validation_metrics(data_sources: list[str], sample_inputs: list[str] ns.append(n_resps) for n in ns: - [(bon_mean, bon_std), (won_mean, won_std)] = bootstrap_metric(data=var_vals, subset_size=n, reduce_fns=[np.max, np.min], seed=seed) + [(bon_mean, bon_std), (won_mean, won_std)] = bootstrap_metric( + data=var_vals, subset_size=n, reduce_fns=[np.max, np.min], seed=seed + ) metric[f"best@{n}/mean"], metric[f"best@{n}/std"] = bon_mean, bon_std metric[f"worst@{n}/mean"], metric[f"worst@{n}/std"] = won_mean, won_std if var2vals.get("pred", None) is not None: - vote_data = [{"val": val, "pred": pred} for val, pred in zip(var_vals, var2vals["pred"])] + vote_data = [ + {"val": val, "pred": pred} for val, pred in zip(var_vals, var2vals["pred"], strict=True) + ] [(maj_n_mean, maj_n_std)] = bootstrap_metric( data=vote_data, subset_size=n, @@ -456,7 +546,9 @@ def process_validation_metrics(data_sources: list[str], sample_inputs: list[str] return data_src2var2metric2val -def compute_difficulty_histogram_metrics(batch: DataProto, config) -> Dict[str, Any]: + +# NOTE: added by Reasoning360 +def compute_difficulty_histogram_metrics(batch: DataProto, config) -> dict[str, Any]: metrics = {} with torch.no_grad(): diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index e9b8c6af9..779e86eb9 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -FSDP PPO Trainer with Ray-based single controller. +PPO Trainer with Ray-based single controller. This trainer supports model-agonistic model initialization with huggingface """ @@ -22,29 +22,30 @@ import os import uuid from collections import defaultdict -from contextlib import contextmanager from copy import deepcopy from dataclasses import dataclass, field from enum import Enum from pprint import pprint -from typing import Dict, Optional, Type +from typing import Optional import numpy as np +import pandas as pd import ray import torch -from codetiming import Timer from omegaconf import OmegaConf, open_dict from torch.utils.data import Dataset, Sampler from torchdata.stateful_dataloader import StatefulDataLoader from tqdm import tqdm from verl import DataProto +from verl.experimental.dataset.sampler import AbstractCurriculumSampler from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto from verl.single_controller.base import Worker from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup from verl.single_controller.ray.base import create_colocated_worker_cls +from verl.trainer.config import AlgoConfig from verl.trainer.ppo import core_algos -from verl.trainer.ppo.core_algos import agg_loss +from verl.trainer.ppo.core_algos import AdvantageEstimator, agg_loss from verl.trainer.ppo.metric_utils import ( compute_data_metrics, compute_difficulty_histogram_metrics, # NOTE: added by Reasoning360 @@ -53,16 +54,16 @@ process_validation_metrics, ) from verl.trainer.ppo.reward import compute_reward, compute_reward_async -from verl.utils.checkpoint.checkpoint_manager import BaseCheckpointManager, find_latest_ckpt_path +from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path, should_save_ckpt_esi +from verl.utils.debug import marked_timer from verl.utils.metric import ( reduce_metrics, ) from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance from verl.utils.torch_functional import masked_mean from verl.utils.tracking import ValidationGenerationsLogger -from verl.workers.rollout.async_server import AsyncLLMServerManager -WorkerType = Type[Worker] +WorkerType = type[Worker] class Role(Enum): @@ -79,21 +80,6 @@ class Role(Enum): ActorRolloutRef = 6 -class AdvantageEstimator(str, Enum): - """ - Using an enumeration class to avoid spelling errors in adv_estimator - """ - - GAE = "gae" - GRPO = "grpo" - REINFORCE_PLUS_PLUS = "reinforce_plus_plus" - REINFORCE_PLUS_PLUS_BASELINE = "reinforce_plus_plus_baseline" - REMAX = "remax" - RLOO = "rloo" - OPO = "opo" - GRPO_PASSK = "grpo_passk" - - @dataclass class ResourcePoolManager: """ @@ -105,12 +91,21 @@ class ResourcePoolManager: resource_pool_dict: dict[str, RayResourcePool] = field(default_factory=dict) def create_resource_pool(self): + """Create Ray resource pools for distributed training. + + Initializes resource pools based on the resource pool specification, + with each pool managing GPU resources across multiple nodes. + For FSDP backend, uses max_colocate_count=1 to merge WorkerGroups. + For Megatron backend, uses max_colocate_count>1 for different models. + """ for resource_pool_name, process_on_nodes in self.resource_pool_spec.items(): # max_colocate_count means the number of WorkerGroups (i.e. processes) in each RayResourcePool # For FSDP backend, we recommend using max_colocate_count=1 that merge all WorkerGroups into one. # For Megatron backend, we recommend using max_colocate_count>1 # that can utilize different WorkerGroup for differnt models - resource_pool = RayResourcePool(process_on_nodes=process_on_nodes, use_gpu=True, max_colocate_count=1, name_prefix=resource_pool_name) + resource_pool = RayResourcePool( + process_on_nodes=process_on_nodes, use_gpu=True, max_colocate_count=1, name_prefix=resource_pool_name + ) self.resource_pool_dict[resource_pool_name] = resource_pool self._check_resource_available() @@ -126,13 +121,20 @@ def get_n_gpus(self) -> int: def _check_resource_available(self): """Check if the resource pool can be satisfied in this ray cluster.""" node_available_resources = ray.state.available_resources_per_node() - node_available_gpus = {node: node_info.get("GPU", 0) if "GPU" in node_info else node_info.get("NPU", 0) for node, node_info in node_available_resources.items()} + node_available_gpus = { + node: node_info.get("GPU", 0) if "GPU" in node_info else node_info.get("NPU", 0) + for node, node_info in node_available_resources.items() + } # check total required gpus can be satisfied total_available_gpus = sum(node_available_gpus.values()) - total_required_gpus = sum([n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes]) + total_required_gpus = sum( + [n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes] + ) if total_available_gpus < total_required_gpus: - raise ValueError(f"Total available GPUs {total_available_gpus} is less than total desired GPUs {total_required_gpus}") + raise ValueError( + f"Total available GPUs {total_available_gpus} is less than total desired GPUs {total_required_gpus}" + ) # check each resource pool can be satisfied, O(#resource_pools * #nodes) for resource_pool_name, process_on_nodes in self.resource_pool_spec.items(): @@ -144,10 +146,13 @@ def _check_resource_available(self): if num_nodes == 0: break if num_nodes > 0: - raise ValueError(f"Resource pool {resource_pool_name}: {num_gpus}*{num_nodes}" + "cannot be satisfied in this ray cluster") + raise ValueError( + f"Resource pool {resource_pool_name}: {num_gpus}*{num_nodes}" + + "cannot be satisfied in this ray cluster" + ) -def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, kl_penalty="kl", multi_turn=False): +def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, kl_penalty="kl"): """Apply KL penalty to the token-level rewards. This function computes the KL divergence between the reference policy and current policy, @@ -164,21 +169,15 @@ def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, - The updated data with token-level rewards adjusted by KL penalty - A dictionary of metrics related to the KL penalty """ - responses = data.batch["responses"] - response_length = responses.size(1) + response_mask = data.batch["response_mask"] token_level_scores = data.batch["token_level_scores"] batch_size = data.batch.batch_size[0] - if multi_turn: - loss_mask = data.batch["loss_mask"] - response_mask = loss_mask[:, -response_length:] - else: - attention_mask = data.batch["attention_mask"] - response_mask = attention_mask[:, -response_length:] - # compute kl between ref_policy and current policy # When apply_kl_penalty, algorithm.use_kl_in_reward=True, so the reference model has been enabled. - kld = core_algos.kl_penalty(data.batch["old_log_probs"], data.batch["ref_log_prob"], kl_penalty=kl_penalty) # (batch_size, response_length) + kld = core_algos.kl_penalty( + data.batch["old_log_probs"], data.batch["ref_log_prob"], kl_penalty=kl_penalty + ) # (batch_size, response_length) kld = kld * response_mask beta = kl_ctrl.value @@ -214,7 +213,15 @@ def compute_response_mask(data: DataProto): return attention_mask[:, -response_length:] -def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_repeat=1, multi_turn=False, norm_adv_by_std_in_grpo=True, **kwargs): +def compute_advantage( + data: DataProto, + adv_estimator: AdvantageEstimator, + gamma: float = 1.0, + lam: float = 1.0, + num_repeat: int = 1, + norm_adv_by_std_in_grpo: bool = True, + config: Optional[AlgoConfig] = None, +) -> DataProto: """Compute advantage estimates for policy optimization. This function computes advantage estimates using various estimators like GAE, GRPO, REINFORCE++, etc. @@ -222,22 +229,23 @@ def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_re Args: data (DataProto): The data containing batched model outputs and inputs. - adv_estimator: The advantage estimator to use (e.g., GAE, GRPO, REINFORCE++). + adv_estimator (AdvantageEstimator): The advantage estimator to use (e.g., GAE, GRPO, REINFORCE++). gamma (float, optional): Discount factor for future rewards. Defaults to 1.0. lam (float, optional): Lambda parameter for GAE. Defaults to 1.0. num_repeat (int, optional): Number of times to repeat the computation. Defaults to 1. - multi_turn (bool, optional): Whether the data is from a multi-turn conversation. Defaults to False. - norm_adv_by_std_in_grpo (bool, optional): Whether to normalize advantages by standard deviation in GRPO. Defaults to True. + norm_adv_by_std_in_grpo (bool, optional): Whether to normalize advantages by standard deviation in + GRPO. Defaults to True. + config (dict, optional): Configuration dictionary for algorithm settings. Defaults to None. Returns: DataProto: The updated data with computed advantages and returns. """ # Back-compatible with trainers that do not compute response mask in fit - if "response_mask" not in data.batch: + if "response_mask" not in data.batch.keys(): data.batch["response_mask"] = compute_response_mask(data) # prepare response group - # TODO: add other ways to estimate advantages if adv_estimator == AdvantageEstimator.GAE: + # Compute advantages and returns using Generalized Advantage Estimation (GAE) advantages, returns = core_algos.compute_gae_advantage_return( token_level_rewards=data.batch["token_level_rewards"], values=data.batch["values"], @@ -247,19 +255,15 @@ def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_re ) data.batch["advantages"] = advantages data.batch["returns"] = returns - if kwargs.get("use_pf_ppo", False): + if config.get("use_pf_ppo", False): data = core_algos.compute_pf_ppo_reweight_data( data, - kwargs.get("pf_ppo_reweight_method", "pow"), - kwargs.get("pf_ppo_weight_pow", 2.0), + config.pf_ppo.reweight_method, + config.pf_ppo.weight_pow, ) elif adv_estimator == AdvantageEstimator.GRPO: - # TODO: test on more adv estimator type + # Initialize the mask for GRPO calculation grpo_calculation_mask = data.batch["response_mask"] - if multi_turn: - # If multi-turn, replace the mask with the relevant part of loss_mask - response_length = grpo_calculation_mask.size(1) # Get length from the initial response mask - grpo_calculation_mask = data.batch["loss_mask"][:, -response_length:] # This mask is the one intended for GRPO # Call compute_grpo_outcome_advantage with parameters matching its definition advantages, returns = core_algos.compute_grpo_outcome_advantage( token_level_rewards=data.batch["token_level_rewards"], @@ -269,85 +273,32 @@ def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_re ) data.batch["advantages"] = advantages data.batch["returns"] = returns - elif adv_estimator == AdvantageEstimator.GRPO_PASSK: - advantages, returns = core_algos.compute_grpo_passk_outcome_advantage( - token_level_rewards=data.batch["token_level_rewards"], - response_mask=data.batch["response_mask"], - index=data.non_tensor_batch["uid"], - norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, - ) - data.batch["advantages"] = advantages - data.batch["returns"] = returns - elif adv_estimator == AdvantageEstimator.REINFORCE_PLUS_PLUS_BASELINE: - advantages, returns = core_algos.compute_reinforce_plus_plus_baseline_outcome_advantage( - token_level_rewards=data.batch["token_level_rewards"], - response_mask=data.batch["response_mask"], - index=data.non_tensor_batch["uid"], - ) - data.batch["advantages"] = advantages - data.batch["returns"] = returns - elif adv_estimator == AdvantageEstimator.REINFORCE_PLUS_PLUS: - advantages, returns = core_algos.compute_reinforce_plus_plus_outcome_advantage( - token_level_rewards=data.batch["token_level_rewards"], - response_mask=data.batch["response_mask"], - gamma=gamma, - ) - data.batch["advantages"] = advantages - data.batch["returns"] = returns - elif adv_estimator == AdvantageEstimator.REMAX: - advantages, returns = core_algos.compute_remax_outcome_advantage( - token_level_rewards=data.batch["token_level_rewards"], - reward_baselines=data.batch["reward_baselines"], - response_mask=data.batch["response_mask"], - ) + else: + # handle all other adv estimator type other than GAE and GRPO + adv_estimator_fn = core_algos.get_adv_estimator_fn(adv_estimator) + adv_kwargs = { + "token_level_rewards": data.batch["token_level_rewards"], + "response_mask": data.batch["response_mask"], + "config": config, + } + if "uid" in data.non_tensor_batch: # optional + adv_kwargs["index"] = data.non_tensor_batch["uid"] + if "reward_baselines" in data.batch: # optional + adv_kwargs["reward_baselines"] = data.batch["reward_baselines"] + # calculate advantage estimator + advantages, returns = adv_estimator_fn(**adv_kwargs) data.batch["advantages"] = advantages data.batch["returns"] = returns - elif adv_estimator == AdvantageEstimator.RLOO: - advantages, returns = core_algos.compute_rloo_outcome_advantage( - token_level_rewards=data.batch["token_level_rewards"], - response_mask=data.batch["response_mask"], - index=data.non_tensor_batch["uid"], - ) - data.batch["advantages"] = advantages - data.batch["returns"] = returns - elif adv_estimator == AdvantageEstimator.OPO: - advantages, returns = core_algos.compute_opo_outcome_advantage( - token_level_rewards=data.batch["token_level_rewards"], - response_mask=data.batch["response_mask"], - index=data.non_tensor_batch["uid"], - ) - data.batch["advantages"] = advantages - data.batch["returns"] = returns - else: - raise NotImplementedError return data -@contextmanager -def _timer(name: str, timing_raw: Dict[str, float]): - """Context manager for timing code execution. - - This utility function measures the execution time of code within its context - and accumulates the timing information in the provided dictionary. - - Args: - name (str): The name/identifier for this timing measurement. - timing_raw (Dict[str, float]): Dictionary to store timing information. - - Yields: - None: This is a context manager that yields control back to the code block. - """ - with Timer(name=name, logger=None) as timer: - yield - if name not in timing_raw: - timing_raw[name] = 0 - timing_raw[name] += timer.last - - class RayPPOTrainer: - """ - Note that this trainer runs on the driver process on a single CPU/GPU node. + """Distributed PPO trainer using Ray for scalable reinforcement learning. + + This trainer orchestrates distributed PPO training across multiple nodes and GPUs, + managing actor rollouts, critic training, and reward computation with Ray backend. + Supports various model architectures including FSDP, Megatron, and vLLM integration. """ # TODO: support each role have individual ray_worker_group_cls, @@ -368,8 +319,27 @@ def __init__( train_sampler: Optional[Sampler] = None, device_name="cuda", ): - """Initialize distributed PPO trainer with Ray backend.""" + """ + Initialize distributed PPO trainer with Ray backend. + Note that this trainer runs on the driver process on a single CPU/GPU node. + + Args: + config: Configuration object containing training parameters. + tokenizer: Tokenizer used for encoding and decoding text. + role_worker_mapping (dict[Role, WorkerType]): Mapping from roles to worker classes. + resource_pool_manager (ResourcePoolManager): Manager for Ray resource pools. + ray_worker_group_cls (RayWorkerGroup, optional): Class for Ray worker groups. Defaults to RayWorkerGroup. + processor: Optional data processor, used for multimodal data + reward_fn: Function for computing rewards during training. + val_reward_fn: Function for computing rewards during validation. + train_dataset (Optional[Dataset], optional): Training dataset. Defaults to None. + val_dataset (Optional[Dataset], optional): Validation dataset. Defaults to None. + collate_fn: Function to collate data samples into batches. + train_sampler (Optional[Sampler], optional): Sampler for the training dataset. Defaults to None. + device_name (str, optional): Device name for training (e.g., "cuda", "cpu"). Defaults to "cuda". + """ + # Store the tokenizer for text processing self.tokenizer = tokenizer self.processor = processor self.config = config @@ -395,8 +365,8 @@ def __init__( # define in-reward KL control # kl loss control currently not suppoorted - if config.algorithm.use_kl_in_reward: - self.kl_ctrl_in_reward = core_algos.get_kl_controller(config.algorithm.kl_ctrl) + if self.config.algorithm.use_kl_in_reward: + self.kl_ctrl_in_reward = core_algos.get_kl_controller(self.config.algorithm.kl_ctrl) if self.config.algorithm.adv_estimator == AdvantageEstimator.GAE: self.use_critic = True @@ -408,6 +378,7 @@ def __init__( AdvantageEstimator.RLOO, AdvantageEstimator.OPO, AdvantageEstimator.REINFORCE_PLUS_PLUS_BASELINE, + AdvantageEstimator.GPG, ]: self.use_critic = False else: @@ -421,20 +392,46 @@ def _validate_config(self): # number of GPUs total n_gpus = config.trainer.n_gpus_per_node * config.trainer.nnodes if config.actor_rollout_ref.actor.strategy == "megatron": - model_parallel_size = config.actor_rollout_ref.actor.megatron.tensor_model_parallel_size * config.actor_rollout_ref.actor.megatron.pipeline_model_parallel_size - assert n_gpus % (model_parallel_size * config.actor_rollout_ref.actor.megatron.context_parallel_size) == 0, f"n_gpus ({n_gpus}) must be divisible by model_parallel_size ({model_parallel_size}) times context_parallel_size ({config.actor_rollout_ref.actor.megatron.context_parallel_size})" - megatron_dp = n_gpus // (model_parallel_size * config.actor_rollout_ref.actor.megatron.context_parallel_size) + model_parallel_size = ( + config.actor_rollout_ref.actor.megatron.tensor_model_parallel_size + * config.actor_rollout_ref.actor.megatron.pipeline_model_parallel_size + ) + assert ( + n_gpus % (model_parallel_size * config.actor_rollout_ref.actor.megatron.context_parallel_size) == 0 + ), ( + f"n_gpus ({n_gpus}) must be divisible by model_parallel_size ({model_parallel_size}) times " + f"context_parallel_size ({config.actor_rollout_ref.actor.megatron.context_parallel_size})" + ) + megatron_dp = n_gpus // ( + model_parallel_size * config.actor_rollout_ref.actor.megatron.context_parallel_size + ) minimal_bsz = megatron_dp * config.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu else: minimal_bsz = n_gpus # 1. Check total batch size for data correctness real_train_batch_size = config.data.train_batch_size * config.actor_rollout_ref.rollout.n - assert real_train_batch_size % minimal_bsz == 0, f"real_train_batch_size ({real_train_batch_size}) must be divisible by minimal possible batch size ({minimal_bsz})" + assert real_train_batch_size % minimal_bsz == 0, ( + f"real_train_batch_size ({real_train_batch_size}) must be divisible by minimal possible batch size " + f"({minimal_bsz})" + ) # A helper function to check "micro_batch_size" vs "micro_batch_size_per_gpu" # We throw an error if the user sets both. The new convention is "..._micro_batch_size_per_gpu". def check_mutually_exclusive(mbs, mbs_per_gpu, name: str): + """Validate mutually exclusive micro batch size configuration options. + + Ensures that users don't set both deprecated micro_batch_size and + the new micro_batch_size_per_gpu parameters simultaneously. + + Args: + mbs: Deprecated micro batch size parameter value. + mbs_per_gpu: New micro batch size per GPU parameter value. + name (str): Configuration section name for error messages. + + Raises: + ValueError: If both parameters are set or neither is set. + """ settings = { "actor_rollout_ref.actor": "micro_batch_size", "critic": "micro_batch_size", @@ -448,10 +445,15 @@ def check_mutually_exclusive(mbs, mbs_per_gpu, name: str): param_per_gpu = f"{param}_per_gpu" if mbs is None and mbs_per_gpu is None: - raise ValueError(f"[{name}] Please set at least one of '{name}.{param}' or '{name}.{param_per_gpu}'.") + raise ValueError( + f"[{name}] Please set at least one of '{name}.{param}' or '{name}.{param_per_gpu}'." + ) if mbs is not None and mbs_per_gpu is not None: - raise ValueError(f"[{name}] You have set both '{name}.{param}' AND '{name}.{param_per_gpu}'. Please remove '{name}.{param}' because only '*_{param_per_gpu}'" + "is supported (the former is deprecated).") + raise ValueError( + f"[{name}] You have set both '{name}.{param}' AND '{name}.{param_per_gpu}'. Please remove " + f"'{name}.{param}' because only '*_{param_per_gpu}' is supported (the former is deprecated)." + ) if not config.actor_rollout_ref.actor.use_dynamic_bsz: # actor: ppo_micro_batch_size vs. ppo_micro_batch_size_per_gpu @@ -478,11 +480,15 @@ def check_mutually_exclusive(mbs, mbs_per_gpu, name: str): if self.use_critic and not config.critic.use_dynamic_bsz: # Check for critic micro-batch size conflicts - check_mutually_exclusive(config.critic.ppo_micro_batch_size, config.critic.ppo_micro_batch_size_per_gpu, "critic") + check_mutually_exclusive( + config.critic.ppo_micro_batch_size, config.critic.ppo_micro_batch_size_per_gpu, "critic" + ) # Check for reward model micro-batch size conflicts if config.reward_model.enable and not config.reward_model.use_dynamic_bsz: - check_mutually_exclusive(config.reward_model.micro_batch_size, config.reward_model.micro_batch_size_per_gpu, "reward_model") + check_mutually_exclusive( + config.reward_model.micro_batch_size, config.reward_model.micro_batch_size_per_gpu, "reward_model" + ) # Actor # check if train_batch_size is larger than ppo_mini_batch_size @@ -493,7 +499,11 @@ def check_mutually_exclusive(mbs, mbs_per_gpu, name: str): assert config.data.train_batch_size >= config.actor_rollout_ref.actor.ppo_mini_batch_size sp_size = config.actor_rollout_ref.actor.get("ulysses_sequence_parallel_size", 1) if config.actor_rollout_ref.actor.ppo_micro_batch_size is not None: - assert config.actor_rollout_ref.actor.ppo_mini_batch_size % config.actor_rollout_ref.actor.ppo_micro_batch_size == 0 + assert ( + config.actor_rollout_ref.actor.ppo_mini_batch_size + % config.actor_rollout_ref.actor.ppo_micro_batch_size + == 0 + ) assert config.actor_rollout_ref.actor.ppo_micro_batch_size * sp_size >= n_gpus assert config.actor_rollout_ref.actor.loss_agg_mode in [ @@ -503,7 +513,7 @@ def check_mutually_exclusive(mbs, mbs_per_gpu, name: str): "seq-mean-token-sum-norm", ], f"Invalid loss_agg_mode: {config.actor_rollout_ref.actor.loss_agg_mode}" - if config.algorithm.use_kl_in_reward and config.actor_rollout_ref.actor.use_kl_loss: + if self.config.algorithm.use_kl_in_reward and config.actor_rollout_ref.actor.use_kl_loss: print("NOTICE: You have both enabled in-reward kl and kl loss.") # critic @@ -515,28 +525,36 @@ def check_mutually_exclusive(mbs, mbs_per_gpu, name: str): assert config.critic.ppo_micro_batch_size * sp_size >= n_gpus # Check if use_remove_padding is enabled when using sequence parallelism for fsdp - if config.actor_rollout_ref.actor.strategy == "fsdp" and (config.actor_rollout_ref.actor.get("ulysses_sequence_parallel_size", 1) > 1 or config.actor_rollout_ref.ref.get("ulysses_sequence_parallel_size", 1) > 1): - assert config.actor_rollout_ref.model.use_remove_padding, "When using sequence parallelism for actor/ref policy, you must enable `use_remove_padding`." + if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"} and ( + config.actor_rollout_ref.actor.get("ulysses_sequence_parallel_size", 1) > 1 + or config.actor_rollout_ref.ref.get("ulysses_sequence_parallel_size", 1) > 1 + ): + assert config.actor_rollout_ref.model.use_remove_padding, ( + "When using sequence parallelism for actor/ref policy, you must enable `use_remove_padding`." + ) - if self.use_critic and config.critic.strategy == "fsdp": + if self.use_critic and config.critic.strategy in {"fsdp", "fsdp2"}: if config.critic.get("ulysses_sequence_parallel_size", 1) > 1: - assert config.critic.model.use_remove_padding, "When using sequence parallelism for critic, you must enable `use_remove_padding`." + assert config.critic.model.use_remove_padding, ( + "When using sequence parallelism for critic, you must enable `use_remove_padding`." + ) if config.data.get("val_batch_size", None) is not None: - print("WARNING: val_batch_size is deprecated." + " Validation datasets are sent to inference engines as a whole batch," + " which will schedule the memory themselves.") + print( + "WARNING: val_batch_size is deprecated." + + " Validation datasets are sent to inference engines as a whole batch," + + " which will schedule the memory themselves." + ) # check eval config if config.actor_rollout_ref.rollout.val_kwargs.do_sample: - assert config.actor_rollout_ref.rollout.temperature > 0, "validation gen temperature should be greater than 0 when enabling do_sample" - - # check multi_turn with tool config - if config.actor_rollout_ref.rollout.multi_turn.enable: - assert config.actor_rollout_ref.rollout.multi_turn.tool_config_path is not None, "tool_config_path must be set when enabling multi_turn with tool, due to no role-playing support" - assert config.algorithm.adv_estimator in [AdvantageEstimator.GRPO], "only GRPO is tested for multi-turn with tool" + assert config.actor_rollout_ref.rollout.temperature > 0, ( + "validation gen temperature should be greater than 0 when enabling do_sample" + ) print("[validate_config] All configuration checks passed successfully!") - def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler): + def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler: Optional[Sampler]): """ Creates the train and validation dataloaders. """ @@ -544,9 +562,13 @@ def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampl from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler if train_dataset is None: - train_dataset = create_rl_dataset(self.config.data.train_files, self.config.data, self.tokenizer, self.processor) + train_dataset = create_rl_dataset( + self.config.data.train_files, self.config.data, self.tokenizer, self.processor + ) if val_dataset is None: - val_dataset = create_rl_dataset(self.config.data.val_files, self.config.data, self.tokenizer, self.processor) + val_dataset = create_rl_dataset( + self.config.data.val_files, self.config.data, self.tokenizer, self.processor + ) self.train_dataset, self.val_dataset = train_dataset, val_dataset if train_sampler is None: @@ -556,10 +578,12 @@ def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampl collate_fn = default_collate_fn + num_workers = self.config.data["dataloader_num_workers"] + self.train_dataloader = StatefulDataLoader( dataset=self.train_dataset, batch_size=self.config.data.get("gen_batch_size", self.config.data.train_batch_size), - num_workers=self.config.data.get("dataloader_num_workers", 8), + num_workers=num_workers, drop_last=True, collate_fn=collate_fn, sampler=train_sampler, @@ -572,8 +596,8 @@ def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampl self.val_dataloader = StatefulDataLoader( dataset=self.val_dataset, batch_size=val_batch_size, - num_workers=self.config.data.get("dataloader_num_workers", 8), - shuffle=False, + num_workers=num_workers, + shuffle=self.config.data.get("validation_shuffle", True), drop_last=False, collate_fn=collate_fn, ) @@ -581,7 +605,10 @@ def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampl assert len(self.train_dataloader) >= 1, "Train dataloader is empty!" assert len(self.val_dataloader) >= 1, "Validation dataloader is empty!" - print(f"Size of train dataloader: {len(self.train_dataloader)}, Size of val dataloader: {len(self.val_dataloader)}") + print( + f"Size of train dataloader: {len(self.train_dataloader)}, Size of val dataloader: " + f"{len(self.val_dataloader)}" + ) total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs @@ -601,6 +628,15 @@ def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampl except Exception as e: print(f"Warning: Could not set total_training_steps in config. Structure missing? Error: {e}") + self.pre_train_dataloader = StatefulDataLoader( + dataset=self.train_dataset, + batch_size=val_batch_size, + num_workers=self.config.data.get("dataloader_num_workers", 8), + shuffle=False, + drop_last=False, + collate_fn=collate_fn, + ) + def _dump_generations(self, inputs, outputs, scores, reward_extra_infos_dict, dump_path): """Dump rollout/validation samples as JSONL.""" os.makedirs(dump_path, exist_ok=True) @@ -618,10 +654,13 @@ def _dump_generations(self, inputs, outputs, scores, reward_extra_infos_dict, du if len(v) == n: base_data[k] = v + lines = [] + for i in range(n): + entry = {k: v[i] for k, v in base_data.items()} + lines.append(json.dumps(entry, ensure_ascii=False)) + with open(filename, "w") as f: - for i in range(n): - entry = {k: v[i] for k, v in base_data.items()} - f.write(json.dumps(entry, ensure_ascii=False) + "\n") + f.write("\n".join(lines) + "\n") print(f"Dumped generations to {filename}") @@ -636,7 +675,7 @@ def _maybe_log_val_generations(self, inputs, outputs, scores): import numpy as np # Create tuples of (input, output, score) and sort by input text - samples = list(zip(inputs, outputs, scores)) + samples = list(zip(inputs, outputs, scores, strict=True)) samples.sort(key=lambda x: x[0]) # Sort by input text # Use fixed random seed for deterministic shuffling @@ -649,7 +688,7 @@ def _maybe_log_val_generations(self, inputs, outputs, scores): # Log to each configured logger self.validation_generations_logger.log(self.config.trainer.logger, samples, self.global_steps) - def _validate(self): + def _validate_training_data(self): data_source_lst = [] dataset_lst = [] reward_extra_infos_dict: dict[str, list] = defaultdict(list) @@ -658,14 +697,16 @@ def _validate(self): sample_inputs = [] sample_outputs = [] sample_scores = [] + sample_generation_lengths = [] # Add list to collect token-based generation lengths - for test_data in self.val_dataloader: + print(f"Starting to validate training data") + for test_data in self.pre_train_dataloader: test_batch = DataProto.from_single_dict(test_data) # NOTE: print statements in this loop added by Reasoning360 are temporarily disabled # print(f"Shape of test_batch: {test_batch.batch['input_ids'].shape}") - # repeat test batch - test_batch = test_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.val_kwargs.n, interleave=True) + # repeat test batch, use a small number 4 for now, only remove the all correct case + test_batch = test_batch.repeat(repeat_times=4, interleave=True) # we only do validation on rule-based rm if self.config.reward_model.enable and test_batch[0].non_tensor_batch["reward_model"]["style"] == "model": @@ -724,9 +765,22 @@ def _validate(self): result = self.val_reward_fn(test_batch, return_dict=True) reward_tensor = result["reward_tensor"] scores = reward_tensor.sum(-1).cpu().tolist() - # print(f"Shape of reward_tensor: {reward_tensor.shape}") + print(f"Shape of reward_tensor: {reward_tensor.shape}") + print(f"scores (first 100): {scores[:100]}") sample_scores.extend(scores) + + # compute the pass rate for the batch + temp_df = pd.DataFrame({ + "prompt_id": test_batch.non_tensor_batch["prompt_id"], + "on_policy_pass_rate": scores + }) + print(f"temp_df: {temp_df}") + pass_rate_df = temp_df.groupby("prompt_id", as_index=False)["on_policy_pass_rate"].mean().set_index('prompt_id')[['on_policy_pass_rate']] + self.train_dataset.dataframe = self.train_dataset.dataframe.set_index('prompt_id') + self.train_dataset.dataframe.update(pass_rate_df) + self.train_dataset.dataframe = self.train_dataset.dataframe.reset_index() + reward_extra_infos_dict["reward"].extend(scores) if "reward_extra_info" in result: @@ -746,6 +800,13 @@ def _validate(self): data_source_lst.append(test_batch.non_tensor_batch.get("data_source", ["unknown"] * reward_tensor.shape[0])) + # Collect generation lengths + # Calculate token-based generation lengths using response masks + response_length = test_batch.batch["responses"].shape[-1] # Get response length dimension + response_mask = test_batch.batch["attention_mask"][:, -response_length:] # Get response portion of attention mask + generation_lengths = response_mask.sum(dim=-1).cpu().tolist() # Actual token lengths per response + sample_generation_lengths.extend(generation_lengths) + self._maybe_log_val_generations(inputs=sample_inputs, outputs=sample_outputs, scores=sample_scores) # dump generations @@ -774,13 +835,195 @@ def _validate(self): for metric_name, metric_val in metric2val.items(): # NOTE: Added by Reasoning360: Add std to the metric name. if (var_name == core_var) and any(metric_name.startswith(pfx) for pfx in ["mean", "maj", "best", "std"]) and (f"@{n_max}" in metric_name): + metric_sec = "val-training-core" + else: + metric_sec = "val-training-aux" + pfx = f"{metric_sec}/{data_source}/{var_name}/{metric_name}" + metric_dict[pfx] = metric_val + + print(f"Training data validation complete") + + # Calculate the average generation length for each data source + data_source_generation_lengths = {} + generation_lengths = sample_generation_lengths # Use already collected token-based generation lengths + + for i in range(len(generation_lengths)): + data_source = data_sources[i] + if data_source not in data_source_generation_lengths: + data_source_generation_lengths[data_source] = [] + data_source_generation_lengths[data_source].append(generation_lengths[i]) + + # Record the average generation length for each data source + for data_source, lengths in data_source_generation_lengths.items(): + metric_dict[f"val/avg_gen_length/{data_source}"] = np.mean(lengths) + + return + + + def _validate(self): + data_source_lst = [] + dataset_lst = [] + reward_extra_infos_dict: dict[str, list] = defaultdict(list) + + # Lists to collect samples for the table + sample_inputs = [] + sample_outputs = [] + sample_scores = [] + sample_turns = [] + sample_generation_lengths = [] # Add list to collect token-based generation lengths + + for test_data in self.val_dataloader: + test_batch = DataProto.from_single_dict(test_data) + # NOTE: print statements in this loop added by Reasoning360 are temporarily disabled + # print(f"Shape of test_batch: {test_batch.batch['input_ids'].shape}") + + # repeat test batch + test_batch = test_batch.repeat( + repeat_times=self.config.actor_rollout_ref.rollout.val_kwargs.n, interleave=True + ) + + # we only do validation on rule-based rm + if self.config.reward_model.enable and test_batch[0].non_tensor_batch["reward_model"]["style"] == "model": + return {} + + # Store original inputs + input_ids = test_batch.batch["input_ids"] + # TODO: Can we keep special tokens except for padding tokens? + input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids] + sample_inputs.extend(input_texts) + + batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"] + non_tensor_batch_keys_to_pop = ["raw_prompt_ids"] + if "multi_modal_data" in test_batch.non_tensor_batch: + non_tensor_batch_keys_to_pop.append("multi_modal_data") + if "raw_prompt" in test_batch.non_tensor_batch: + non_tensor_batch_keys_to_pop.append("raw_prompt") + if "tools_kwargs" in test_batch.non_tensor_batch: + non_tensor_batch_keys_to_pop.append("tools_kwargs") + if "interaction_kwargs" in test_batch.non_tensor_batch: + non_tensor_batch_keys_to_pop.append("interaction_kwargs") + if "agent_name" in test_batch.non_tensor_batch: + non_tensor_batch_keys_to_pop.append("agent_name") + test_gen_batch = test_batch.pop( + batch_keys=batch_keys_to_pop, + non_tensor_batch_keys=non_tensor_batch_keys_to_pop, + ) + + test_gen_batch.meta_info = { + "eos_token_id": self.tokenizer.eos_token_id, + "pad_token_id": self.tokenizer.pad_token_id, + "recompute_log_prob": False, + "do_sample": self.config.actor_rollout_ref.rollout.val_kwargs.do_sample, + "validate": True, + "global_steps": self.global_steps, + } + print(f"test_gen_batch meta info: {test_gen_batch.meta_info}") + + # pad to be divisible by dp_size + size_divisor = ( + self.actor_rollout_wg.world_size + if not self.async_rollout_mode + else self.config.actor_rollout_ref.rollout.agent.num_workers + ) + test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_gen_batch, size_divisor) + if not self.async_rollout_mode: + test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(test_gen_batch_padded) + else: + test_output_gen_batch_padded = self.async_rollout_manager.generate_sequences(test_gen_batch_padded) + + # unpad + test_output_gen_batch = unpad_dataproto(test_output_gen_batch_padded, pad_size=pad_size) + + print("validation generation end") + + # Store generated outputs + output_ids = test_output_gen_batch.batch["responses"] + output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids] + sample_outputs.extend(output_texts) + + test_batch = test_batch.union(test_output_gen_batch) + test_batch.meta_info["validate"] = True + + # evaluate using reward_function + result = self.val_reward_fn(test_batch, return_dict=True) + reward_tensor = result["reward_tensor"] + scores = reward_tensor.sum(-1).cpu().tolist() + # print(f"Shape of reward_tensor: {reward_tensor.shape}") + + sample_scores.extend(scores) + + reward_extra_infos_dict["reward"].extend(scores) + print(f"len reward_extra_infos_dict['reward']: {len(reward_extra_infos_dict['reward'])}") + if "reward_extra_info" in result: + for key, lst in result["reward_extra_info"].items(): + reward_extra_infos_dict[key].extend(lst) + print(f"len reward_extra_infos_dict['{key}']: {len(reward_extra_infos_dict[key])}") + + # NOTE: Added by Reasoning360: Collect dataset information. TODO: maybe replicated usage with the data_source_lst and can be removed? + datasets = [] + for i in range(reward_tensor.shape[0]): + dataset = "unknown" + if "extra_info" in test_batch.non_tensor_batch: + extra_info = test_batch.non_tensor_batch["extra_info"][i] + if isinstance(extra_info, dict) and "dataset" in extra_info: + dataset = extra_info["dataset"] + datasets.append(dataset) + dataset_lst.append(np.array(datasets)) + + # collect num_turns of each prompt + if "__num_turns__" in test_batch.non_tensor_batch: + sample_turns.append(test_batch.non_tensor_batch["__num_turns__"]) + + data_source_lst.append(test_batch.non_tensor_batch.get("data_source", ["unknown"] * reward_tensor.shape[0])) + + # Collect generation lengths + # Calculate token-based generation lengths using response masks + response_length = test_batch.batch["responses"].shape[-1] # Get response length dimension + response_mask = test_batch.batch["attention_mask"][:, -response_length:] # Get response portion of attention mask + generation_lengths = response_mask.sum(dim=-1).cpu().tolist() # Actual token lengths per response + sample_generation_lengths.extend(generation_lengths) + + self._maybe_log_val_generations(inputs=sample_inputs, outputs=sample_outputs, scores=sample_scores) + + # dump generations + val_data_dir = self.config.trainer.get("validation_data_dir", None) + if val_data_dir: + self._dump_generations( + inputs=sample_inputs, + outputs=sample_outputs, + scores=sample_scores, + reward_extra_infos_dict=reward_extra_infos_dict, + dump_path=val_data_dir, + ) + + for key_info, lst in reward_extra_infos_dict.items(): + assert len(lst) == 0 or len(lst) == len(sample_scores), f"{key_info}: {len(lst)=}, {len(sample_scores)=}" + + # NOTE: Added by Reasoning360: Calculate the mean reward for each data source and dataset + data_sources = np.concatenate(data_source_lst, axis=0) + + datasets = np.concatenate(dataset_lst, axis=0) # Concatenate datasets + + data_src2var2metric2val = process_validation_metrics(data_sources, sample_inputs, reward_extra_infos_dict) + metric_dict = {} + for data_source, var2metric2val in data_src2var2metric2val.items(): + core_var = "acc" if "acc" in var2metric2val else "reward" + for var_name, metric2val in var2metric2val.items(): + n_max = max([int(name.split("@")[-1].split("/")[0]) for name in metric2val.keys()]) + for metric_name, metric_val in metric2val.items(): + # NOTE: Added by Reasoning360: Add std to the metric name. + if ( + (var_name == core_var) + and any(metric_name.startswith(pfx) for pfx in ["mean", "maj", "best", "std"]) + and (f"@{n_max}" in metric_name) + ): metric_sec = "val-core" else: metric_sec = "val-aux" pfx = f"{metric_sec}/{data_source}/{var_name}/{metric_name}" metric_dict[pfx] = metric_val - - # Calculate the mean reward for each data source and dataset + + # NOTE: Added by Reasoning360: Calculate the mean reward for each data source and dataset data_source_dataset_reward = {} for i in range(len(sample_scores)): data_source = data_sources[i] @@ -790,9 +1033,29 @@ def _validate(self): data_source_dataset_reward[key] = [] data_source_dataset_reward[key].append(sample_scores[i]) + if len(sample_turns) > 0: + sample_turns = np.concatenate(sample_turns) + metric_dict["val-aux/num_turns/min"] = sample_turns.min() + metric_dict["val-aux/num_turns/max"] = sample_turns.max() + metric_dict["val-aux/num_turns/mean"] = sample_turns.mean() + # Record the mean reward for each data source and dataset for (data_source, dataset), rewards in data_source_dataset_reward.items(): metric_dict[f"val/test_score/{data_source}/{dataset}"] = np.mean(rewards) + + # Calculate the average generation length for each data source + data_source_generation_lengths = {} + generation_lengths = sample_generation_lengths # Use already collected token-based generation lengths + + for i in range(len(generation_lengths)): + data_source = data_sources[i] + if data_source not in data_source_generation_lengths: + data_source_generation_lengths[data_source] = [] + data_source_generation_lengths[data_source].append(generation_lengths[i]) + + # Record the average generation length for each data source + for data_source, lengths in data_source_generation_lengths.items(): + metric_dict[f"val/avg_gen_length/{data_source}"] = np.mean(lengths) return metric_dict @@ -814,6 +1077,7 @@ def init_workers(self): cls=self.role_worker_mapping[Role.ActorRollout], config=self.config.actor_rollout_ref, role="actor_rollout", + profile_option=self.config.trainer.npu_profile.options, ) self.resource_pool_to_cls[resource_pool]["actor_rollout"] = actor_rollout_cls else: @@ -828,7 +1092,12 @@ def init_workers(self): # create reference policy if needed if self.use_reference_policy: resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy) - ref_policy_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RefPolicy], config=self.config.actor_rollout_ref, role="ref") + ref_policy_cls = RayClassWithInitArgs( + self.role_worker_mapping[Role.RefPolicy], + config=self.config.actor_rollout_ref, + role="ref", + profile_option=self.config.trainer.npu_profile.options, + ) self.resource_pool_to_cls[resource_pool]["ref"] = ref_policy_cls # create a reward model if reward_fn is None @@ -847,10 +1116,23 @@ def init_workers(self): wg_kwargs = {} # Setting up kwargs for RayWorkerGroup if OmegaConf.select(self.config.trainer, "ray_wait_register_center_timeout") is not None: wg_kwargs["ray_wait_register_center_timeout"] = self.config.trainer.ray_wait_register_center_timeout + if OmegaConf.select(self.config.trainer, "profile_steps") is not None: + wg_kwargs["profile_steps"] = OmegaConf.select(self.config.trainer, "profile_steps") + assert OmegaConf.select(self.config.trainer, "worker_nsight_options") is not None, ( + "worker_nsight_options must be set when profile_steps is set" + ) + wg_kwargs["worker_nsight_options"] = OmegaConf.to_container( + OmegaConf.select(self.config.trainer, "worker_nsight_options") + ) for resource_pool, class_dict in self.resource_pool_to_cls.items(): worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) - wg_dict = self.ray_worker_group_cls(resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls, device_name=self.device_name, **wg_kwargs) + wg_dict = self.ray_worker_group_cls( + resource_pool=resource_pool, + ray_cls_with_init=worker_dict_cls, + device_name=self.device_name, + **wg_kwargs, + ) spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys()) all_wg.update(spawn_wg) @@ -873,42 +1155,69 @@ def init_workers(self): # create async rollout manager and request scheduler self.async_rollout_mode = False if self.config.actor_rollout_ref.rollout.mode == "async": + from verl.experimental.agent_loop import AgentLoopManager + self.async_rollout_mode = True - self.async_rollout_manager = AsyncLLMServerManager( - config=self.config.actor_rollout_ref, + self.async_rollout_manager = AgentLoopManager( + config=self.config, worker_group=self.actor_rollout_wg, ) def _save_checkpoint(self): + from verl.utils.fs import local_mkdir_safe + # path: given_path + `/global_step_{global_steps}` + `/actor` - local_global_step_folder = os.path.join(self.config.trainer.default_local_dir, f"global_step_{self.global_steps}") + local_global_step_folder = os.path.join( + self.config.trainer.default_local_dir, f"global_step_{self.global_steps}" + ) print(f"local_global_step_folder: {local_global_step_folder}") actor_local_path = os.path.join(local_global_step_folder, "actor") - actor_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "actor") + actor_remote_path = ( + None + if self.config.trainer.default_hdfs_dir is None + else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "actor") + ) remove_previous_ckpt_in_save = self.config.trainer.get("remove_previous_ckpt_in_save", False) if remove_previous_ckpt_in_save: - print("Warning: remove_previous_ckpt_in_save is deprecated," + " set max_actor_ckpt_to_keep=1 and max_critic_ckpt_to_keep=1 instead") - max_actor_ckpt_to_keep = self.config.trainer.get("max_actor_ckpt_to_keep", None) if not remove_previous_ckpt_in_save else 1 - max_critic_ckpt_to_keep = self.config.trainer.get("max_critic_ckpt_to_keep", None) if not remove_previous_ckpt_in_save else 1 + print( + "Warning: remove_previous_ckpt_in_save is deprecated," + + " set max_actor_ckpt_to_keep=1 and max_critic_ckpt_to_keep=1 instead" + ) + max_actor_ckpt_to_keep = ( + self.config.trainer.get("max_actor_ckpt_to_keep", None) if not remove_previous_ckpt_in_save else 1 + ) + max_critic_ckpt_to_keep = ( + self.config.trainer.get("max_critic_ckpt_to_keep", None) if not remove_previous_ckpt_in_save else 1 + ) - self.actor_rollout_wg.save_checkpoint(actor_local_path, actor_remote_path, self.global_steps, max_ckpt_to_keep=max_actor_ckpt_to_keep) + self.actor_rollout_wg.save_checkpoint( + actor_local_path, actor_remote_path, self.global_steps, max_ckpt_to_keep=max_actor_ckpt_to_keep + ) if self.use_critic: critic_local_path = os.path.join(local_global_step_folder, "critic") - critic_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "critic") - self.critic_wg.save_checkpoint(critic_local_path, critic_remote_path, self.global_steps, max_ckpt_to_keep=max_critic_ckpt_to_keep) + critic_remote_path = ( + None + if self.config.trainer.default_hdfs_dir is None + else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "critic") + ) + self.critic_wg.save_checkpoint( + critic_local_path, critic_remote_path, self.global_steps, max_ckpt_to_keep=max_critic_ckpt_to_keep + ) # save dataloader - BaseCheckpointManager.local_mkdir(local_global_step_folder) + local_mkdir_safe(local_global_step_folder) dataloader_local_path = os.path.join(local_global_step_folder, "data.pt") dataloader_state_dict = self.train_dataloader.state_dict() torch.save(dataloader_state_dict, dataloader_local_path) # latest checkpointed iteration tracker (for atomic usage) - local_latest_checkpointed_iteration = os.path.join(self.config.trainer.default_local_dir, "latest_checkpointed_iteration.txt") + local_latest_checkpointed_iteration = os.path.join( + self.config.trainer.default_local_dir, "latest_checkpointed_iteration.txt" + ) with open(local_latest_checkpointed_iteration, "w") as f: f.write(str(self.global_steps)) @@ -934,7 +1243,9 @@ def _load_checkpoint(self): else: if self.config.trainer.resume_mode == "resume_path": assert isinstance(self.config.trainer.resume_from_path, str), "resume ckpt must be str type" - assert "global_step_" in self.config.trainer.resume_from_path, "resume ckpt must specify the global_steps" + assert "global_step_" in self.config.trainer.resume_from_path, ( + "resume ckpt must specify the global_steps" + ) global_step_folder = self.config.trainer.resume_from_path if not os.path.isabs(global_step_folder): working_dir = os.getcwd() @@ -949,10 +1260,14 @@ def _load_checkpoint(self): actor_path = os.path.join(global_step_folder, "actor") critic_path = os.path.join(global_step_folder, "critic") # load actor - self.actor_rollout_wg.load_checkpoint(actor_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load) + self.actor_rollout_wg.load_checkpoint( + actor_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load + ) # load critic if self.use_critic: - self.critic_wg.load_checkpoint(critic_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load) + self.critic_wg.load_checkpoint( + critic_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load + ) # load dataloader, # TODO: from remote not implemented yet @@ -963,17 +1278,43 @@ def _load_checkpoint(self): else: print(f"Warning: No dataloader state found at {dataloader_local_path}, will start from scratch") + def _start_profiling(self, do_profile: bool) -> None: + """Start profiling for all worker groups if profiling is enabled.""" + if do_profile: + self.actor_rollout_wg.start_profile(role="e2e", profile_step=self.global_steps) + if self.use_reference_policy: + self.ref_policy_wg.start_profile() + if self.use_critic: + self.critic_wg.start_profile() + if self.use_rm: + self.rm_wg.start_profile() + + def _stop_profiling(self, do_profile: bool) -> None: + """Stop profiling for all worker groups if profiling is enabled.""" + if do_profile: + self.actor_rollout_wg.stop_profile() + if self.use_reference_policy: + self.ref_policy_wg.stop_profile() + if self.use_critic: + self.critic_wg.stop_profile() + if self.use_rm: + self.rm_wg.stop_profile() + def _balance_batch(self, batch: DataProto, metrics, logging_prefix="global_seqlen"): """Reorder the data on single controller such that each dp rank gets similar total tokens""" attention_mask = batch.batch["attention_mask"] batch_size = attention_mask.shape[0] global_seqlen_lst = batch.batch["attention_mask"].view(batch_size, -1).sum(-1).tolist() # (train_batch_size,) world_size = self.actor_rollout_wg.world_size - global_partition_lst = get_seqlen_balanced_partitions(global_seqlen_lst, k_partitions=world_size, equal_size=True) + global_partition_lst = get_seqlen_balanced_partitions( + global_seqlen_lst, k_partitions=world_size, equal_size=True + ) # reorder based on index. The data will be automatically equally partitioned by dispatch function global_idx = torch.tensor([j for partition in global_partition_lst for j in partition]) batch.reorder(global_idx) - global_balance_stats = log_seqlen_unbalance(seqlen_list=global_seqlen_lst, partitions=global_partition_lst, prefix=logging_prefix) + global_balance_stats = log_seqlen_unbalance( + seqlen_list=global_seqlen_lst, partitions=global_partition_lst, prefix=logging_prefix + ) metrics.update(global_balance_stats) def fit(self): @@ -1015,11 +1356,21 @@ def fit(self): # we start from step 1 self.global_steps += 1 last_val_metrics = None + self.max_steps_duration = 0 for epoch in range(self.config.trainer.total_epochs): for batch_dict in self.train_dataloader: metrics = {} timing_raw = {} + + do_profile = ( + self.global_steps in self.config.trainer.profile_steps + if self.config.trainer.profile_steps is not None + else False + ) + with marked_timer("start_profile", timing_raw): + self._start_profiling(do_profile) + batch: DataProto = DataProto.from_single_dict(batch_dict) # pop those keys for generation @@ -1031,38 +1382,42 @@ def fit(self): non_tensor_batch_keys_to_pop.append("raw_prompt") if "tools_kwargs" in batch.non_tensor_batch: non_tensor_batch_keys_to_pop.append("tools_kwargs") + if "interaction_kwargs" in batch.non_tensor_batch: + non_tensor_batch_keys_to_pop.append("interaction_kwargs") + if "index" in batch.non_tensor_batch: + non_tensor_batch_keys_to_pop.append("index") + if "agent_name" in batch.non_tensor_batch: + non_tensor_batch_keys_to_pop.append("agent_name") + gen_batch = batch.pop( batch_keys=batch_keys_to_pop, non_tensor_batch_keys=non_tensor_batch_keys_to_pop, ) + # pass global_steps to trace + gen_batch.meta_info["global_steps"] = self.global_steps + gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) + is_last_step = self.global_steps >= self.total_training_steps - with _timer("step", timing_raw): + with marked_timer("step", timing_raw): # generate a batch - with _timer("gen", timing_raw): + with marked_timer("gen", timing_raw, color="red"): if not self.async_rollout_mode: gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) else: - self.async_rollout_manager.wake_up() gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch) - self.async_rollout_manager.sleep() - # NOTE: added by Reasoning360. TODO: fix this metric later. - # vllm_page_metrics = gen_batch_output.non_tensor_batch - # vllm_page_metrics = { - # k.removeprefix("metrics_") : v for k, v in vllm_page_metrics.items() - # if k.startswith("metrics_") - # } - # vllm_page_metrics = reduce_metrics(vllm_page_metrics) - # metrics.update(vllm_page_metrics) + timing_raw.update(gen_batch_output.meta_info["timing"]) + gen_batch_output.meta_info.pop("timing", None) if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX: - # NOTE: this is likely an abandoned branch (the async rollout not supported. not carefully maintained in Reasoning360 as well) - with _timer("gen_max", timing_raw): + with marked_timer("gen_max", timing_raw, color="purple"): gen_baseline_batch = deepcopy(gen_batch) gen_baseline_batch.meta_info["do_sample"] = False - gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch) - + if not self.async_rollout_mode: + gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch) + else: + gen_baseline_output = self.async_rollout_manager.generate_sequences(gen_baseline_batch) batch = batch.union(gen_baseline_output) reward_baseline_tensor = self.reward_fn(batch) reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1) @@ -1073,12 +1428,15 @@ def fit(self): del gen_baseline_batch, gen_baseline_output - batch.non_tensor_batch["uid"] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object) + batch.non_tensor_batch["uid"] = np.array( + [str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object + ) # repeat to align with repeated responses in rollout batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) batch = batch.union(gen_batch_output) - batch.batch["response_mask"] = compute_response_mask(batch) + if "response_mask" not in batch.batch.keys(): + batch.batch["response_mask"] = compute_response_mask(batch) # Balance the number of valid tokens across DP ranks. # NOTE: This usually changes the order of data in the `batch`, # which won't affect the advantage calculation (since it's based on uid), @@ -1090,7 +1448,7 @@ def fit(self): # compute global_valid tokens batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() - with _timer("reward", timing_raw): + with marked_timer("reward", timing_raw, color="yellow"): # compute reward model score if self.use_rm: reward_tensor = self.rm_wg.compute_rm_score(batch) @@ -1102,13 +1460,13 @@ def fit(self): reward_tensor, reward_extra_infos_dict = compute_reward(batch, self.reward_fn) # recompute old_log_probs - with _timer("old_log_prob", timing_raw): + with marked_timer("old_log_prob", timing_raw, color="blue"): old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) entropys = old_log_prob.batch["entropys"] response_masks = batch.batch["response_mask"] loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode - entropy_loss = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode) - old_log_prob_metrics = {"actor/entropy_loss": entropy_loss.detach().item()} + entropy_agg = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode) + old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()} metrics.update(old_log_prob_metrics) old_log_prob.batch.pop("entropys") batch = batch.union(old_log_prob) @@ -1139,7 +1497,7 @@ def fit(self): if self.use_reference_policy: # compute reference log_prob - with _timer("ref", timing_raw): + with marked_timer("ref", timing_raw, color="olive"): if not self.ref_in_actor: ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) else: @@ -1148,31 +1506,34 @@ def fit(self): # compute values if self.use_critic: - with _timer("values", timing_raw): + with marked_timer("values", timing_raw, color="cyan"): values = self.critic_wg.compute_values(batch) batch = batch.union(values) - with _timer("adv", timing_raw): + with marked_timer("adv", timing_raw, color="brown"): # we combine with rule-based rm reward_extra_infos_dict: dict[str, list] if self.config.reward_model.launch_reward_fn_async: reward_tensor, reward_extra_infos_dict = ray.get(future_reward) batch.batch["token_level_scores"] = reward_tensor - print(f"{list(reward_extra_infos_dict.keys())=}") if reward_extra_infos_dict: batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()}) # compute rewards. apply_kl_penalty if available if self.config.algorithm.use_kl_in_reward: - batch, kl_metrics = apply_kl_penalty(batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty) + batch, kl_metrics = apply_kl_penalty( + batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty + ) metrics.update(kl_metrics) else: batch.batch["token_level_rewards"] = batch.batch["token_level_scores"] # compute advantages, executed on the driver process - norm_adv_by_std_in_grpo = self.config.algorithm.get("norm_adv_by_std_in_grpo", True) # GRPO adv normalization factor + norm_adv_by_std_in_grpo = self.config.algorithm.get( + "norm_adv_by_std_in_grpo", True + ) # GRPO adv normalization factor batch = compute_advantage( batch, @@ -1181,15 +1542,12 @@ def fit(self): lam=self.config.algorithm.lam, num_repeat=self.config.actor_rollout_ref.rollout.n, norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, - multi_turn=self.config.actor_rollout_ref.rollout.multi_turn.enable, - use_pf_ppo=self.config.algorithm.use_pf_ppo, - pf_ppo_reweight_method=self.config.algorithm.pf_ppo.reweight_method, - pf_ppo_weight_pow=self.config.algorithm.pf_ppo.weight_pow, + config=self.config.algorithm, ) # update critic if self.use_critic: - with _timer("update_critic", timing_raw): + with marked_timer("update_critic", timing_raw, color="pink"): critic_output = self.critic_wg.update_critic(batch) critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"]) metrics.update(critic_output_metrics) @@ -1197,7 +1555,7 @@ def fit(self): # implement critic warmup if self.config.trainer.critic_warmup <= self.global_steps: # update actor - with _timer("update_actor", timing_raw): + with marked_timer("update_actor", timing_raw, color="red"): batch.meta_info["multi_turn"] = self.config.actor_rollout_ref.rollout.multi_turn.enable actor_output = self.actor_rollout_wg.update_actor(batch) actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) @@ -1206,8 +1564,7 @@ def fit(self): # Log rollout generations if enabled rollout_data_dir = self.config.trainer.get("rollout_data_dir", None) if rollout_data_dir: - with _timer("dump_rollout_generations", timing_raw): - print(batch.batch.keys()) + with marked_timer("dump_rollout_generations", timing_raw, color="green"): inputs = self.tokenizer.batch_decode(batch.batch["prompts"], skip_special_tokens=True) outputs = self.tokenizer.batch_decode(batch.batch["responses"], skip_special_tokens=True) scores = batch.batch["token_level_scores"].sum(-1).cpu().tolist() @@ -1220,17 +1577,45 @@ def fit(self): ) # validate - if self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0): - with _timer("testing", timing_raw): + if ( + self.val_reward_fn is not None + and self.config.trainer.test_freq > 0 + and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0) + ): + with marked_timer("testing", timing_raw, color="green"): val_metrics: dict = self._validate() if is_last_step: last_val_metrics = val_metrics metrics.update(val_metrics) - if self.config.trainer.save_freq > 0 and (is_last_step or self.global_steps % self.config.trainer.save_freq == 0): - with _timer("save_checkpoint", timing_raw): + # Check if the ESI (Elastic Server Instance)/training plan is close to expiration. + esi_close_to_expiration = should_save_ckpt_esi( + max_steps_duration=self.max_steps_duration, + redundant_time=self.config.trainer.esi_redundant_time, + ) + # Check if the conditions for saving a checkpoint are met. + # The conditions include a mandatory condition (1) and + # one of the following optional conditions (2/3/4): + # 1. The save frequency is set to a positive value. + # 2. It's the last training step. + # 3. The current step number is a multiple of the save frequency. + # 4. The ESI(Elastic Server Instance)/training plan is close to expiration. + if self.config.trainer.save_freq > 0 and ( + is_last_step + or self.global_steps % self.config.trainer.save_freq == 0 + or esi_close_to_expiration + ): + if esi_close_to_expiration: + print("Force saving checkpoint: ESI instance expiration approaching.") + with marked_timer("save_checkpoint", timing_raw, color="green"): self._save_checkpoint() + with marked_timer("stop_profile", timing_raw): + self._stop_profiling(do_profile) + + steps_duration = timing_raw["step"] + self.max_steps_duration = max(self.max_steps_duration, steps_duration) + # training metrics metrics.update( { @@ -1246,6 +1631,10 @@ def fit(self): n_gpus = self.resource_pool_manager.get_n_gpus() metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus)) + # this is experimental and may be changed/removed in the future in favor of a general-purpose one + if isinstance(self.train_dataloader.sampler, AbstractCurriculumSampler): + self.train_dataloader.sampler.update(batch=batch) + # TODO: make a canonical logger that supports various backend logger.log(data=metrics, step=self.global_steps) @@ -1255,3 +1644,9 @@ def fit(self): pprint(f"Final validation metrics: {last_val_metrics}") progress_bar.close() return + + # this is experimental and may be changed/removed in the future + # in favor of a general-purpose data buffer pool + if hasattr(self.train_dataset, "on_batch_end"): + # The dataset may be changed after each training batch + self.train_dataset.on_batch_end(batch=batch) diff --git a/verl/trainer/ppo/reward.py b/verl/trainer/ppo/reward.py index 7f6910ef3..143b631bc 100644 --- a/verl/trainer/ppo/reward.py +++ b/verl/trainer/ppo/reward.py @@ -22,7 +22,34 @@ from verl.utils.reward_score import default_compute_score +def _call_with_kwargs(raw_fn, extra_kwargs, *args, **kwargs): + """Calls `raw_fn` by merging `extra_kwargs` into call-time `kwargs`, with `extra_kwargs` taking precedence. + + This function is used to merge additional keyword arguments with the original function's arguments. + """ + merged_kwargs = {**kwargs, **extra_kwargs} + return raw_fn(*args, **merged_kwargs) + + def get_custom_reward_fn(config): + """Load and return a custom reward function from external file. + + Dynamically imports a reward function from a specified file path and wraps + it with additional keyword arguments from the configuration. + + Args: + config (dict): Configuration dictionary containing custom_reward_function + settings with 'path', 'name', and 'reward_kwargs' fields. + + Returns: + callable or None: Wrapped reward function with merged kwargs, or None + if no custom reward function is configured. + + Raises: + FileNotFoundError: If the specified reward function file doesn't exist. + RuntimeError: If there's an error loading the module from file. + AttributeError: If the specified function name isn't found in the module. + """ import importlib.util import sys @@ -51,46 +78,57 @@ def get_custom_reward_fn(config): reward_kwargs = dict(reward_fn_config.get("reward_kwargs", {})) - def wrapped_fn(*args, **kwargs): - return raw_fn(*args, **kwargs, **reward_kwargs) - - return wrapped_fn + return partial(_call_with_kwargs, raw_fn, reward_kwargs) def load_reward_manager(config, tokenizer, num_examine, **reward_kwargs): - reward_manager_name = config.reward_model.get("reward_manager", "naive") - if reward_manager_name == "naive": - from verl.workers.reward_manager import NaiveRewardManager - - reward_manager_cls = NaiveRewardManager - elif reward_manager_name == "prime": - from verl.workers.reward_manager import PrimeRewardManager - - reward_manager_cls = PrimeRewardManager - elif reward_manager_name == "batch": - from verl.workers.reward_manager import BatchRewardManager + """ + Load and initialize a reward manager based on the configuration. - reward_manager_cls = BatchRewardManager - elif reward_manager_name == "dapo": - from verl.workers.reward_manager import DAPORewardManager + Args: + config: PPO trainer configuration object containing reward_model fields. + tokenizer: Tokenizer object used for processing text. + num_examine: Number of samples to examine. + **reward_kwargs: Additional keyword arguments for the reward manager. - reward_manager_cls = DAPORewardManager - else: - raise NotImplementedError + Returns: + An instance of the specified reward manager class. + """ + from verl.workers.reward_manager import get_reward_manager_cls + + # The list of pre-defined reward managers are defined in `verl/workers/reward_manager/`: + # naive: NaiveRewardManager + # prime: PrimeRewardManager + # batch: BatchRewardManager + # dapo: DAPORewardManager + # Note(haibin.lin): For custom reward managers, please make sure they are imported and + # registered via `verl.workers.reward_manager.register` + # By default reward_manager is set to naive (NaiveRewardManager) + reward_manager_name = config.reward_model.get("reward_manager", "naive") + reward_manager_cls = get_reward_manager_cls(reward_manager_name) + # Try to get a custom reward function based on the configuration compute_score = get_custom_reward_fn(config) final_compute_score = compute_score if compute_score is None: sandbox_config = config.reward_model.get("sandbox_fusion") sandbox_url = sandbox_config.get("url") if sandbox_config else None + memory_limit_mb = sandbox_config.get("memory_limit_mb", 1024) if sandbox_url: sandbox_manager = multiprocessing.Manager() + # Create a semaphore to control concurrent access to the sandbox _concurrent_semaphore = sandbox_manager.Semaphore(sandbox_config.get("max_concurrent", 64)) - final_compute_score = partial(default_compute_score, sandbox_fusion_url=sandbox_url, concurrent_semaphore=_concurrent_semaphore) + final_compute_score = partial( + default_compute_score, + sandbox_fusion_url=sandbox_url, + concurrent_semaphore=_concurrent_semaphore, + memory_limit_mb=memory_limit_mb, + ) else: final_compute_score = default_compute_score + # Instantiate and return the reward manager with the specified parameters return reward_manager_cls( tokenizer=tokenizer, num_examine=num_examine, @@ -112,7 +150,7 @@ def compute_reward(data: DataProto, reward_fn): try: reward_result = reward_fn(data, return_dict=True) reward_tensor = reward_result["reward_tensor"] - reward_extra_infos_dict = reward_result["reward_extra_info"] + reward_extra_infos_dict = reward_result.get("reward_extra_info", {}) except Exception as e: print(f"Error in reward_fn: {e}") reward_tensor = reward_fn(data) diff --git a/verl/trainer/runtime_env.yaml b/verl/trainer/runtime_env.yaml index 5aa693cd7..d29f2128b 100644 --- a/verl/trainer/runtime_env.yaml +++ b/verl/trainer/runtime_env.yaml @@ -2,5 +2,3 @@ working_dir: ./ excludes: ["/.git/"] env_vars: TORCH_NCCL_AVOID_RECORD_STREAMS: "1" - # If you are using vllm<=0.6.3, you might need to set the following environment variable to avoid bugs: - # VLLM_ATTENTION_BACKEND: "XFORMERS" \ No newline at end of file diff --git a/verl/utils/__init__.py b/verl/utils/__init__.py index 85621a630..034584945 100644 --- a/verl/utils/__init__.py +++ b/verl/utils/__init__.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from . import tokenizer +from . import config, tokenizer +from .config import omega_conf_to_dataclass from .tokenizer import hf_processor, hf_tokenizer -__all__ = tokenizer.__all__ + ["hf_processor", "hf_tokenizer"] +__all__ = tokenizer.__all__ + config.__all__ + ["hf_processor", "hf_tokenizer", "omega_conf_to_dataclass"] diff --git a/verl/utils/activation_offload.py b/verl/utils/activation_offload.py index e07ee2626..73e2e83eb 100644 --- a/verl/utils/activation_offload.py +++ b/verl/utils/activation_offload.py @@ -25,6 +25,7 @@ import torch from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from verl.utils.device import get_torch_device from verl.utils.fsdp_utils import FSDPModule as FSDP2 logger = logging.getLogger(__file__) @@ -94,11 +95,17 @@ def __init__(self) -> None: def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any: """Tensor push.""" - raise NotImplementedError("`tensor_push is not implented in OffloadHandler class. Inherit this class and implement your custom tensor_push.") + raise NotImplementedError( + "`tensor_push is not implented in OffloadHandler class. Inherit this class and implement your " + "custom tensor_push." + ) def tensor_pop(self, tensor_tag: Any, **kwargs): """Tensor pop.""" - raise NotImplementedError("`tensor_pop is not implented in OffloadHandler class. Inherit this class and implement your custom tensor_pop.") + raise NotImplementedError( + "`tensor_pop is not implented in OffloadHandler class. Inherit this class and implement your " + "custom tensor_pop." + ) class GroupCommitFunction(torch.autograd.Function): @@ -250,16 +257,13 @@ def __init__( self.layer_window_map[i] += constant # allocate streams and events for synchronization - self.d2h_stream = torch.cuda.Stream() - self.h2d_stream = torch.cuda.Stream() + self.d2h_stream = get_torch_device().Stream() + self.h2d_stream = get_torch_device().Stream() def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any: torch_stray_tensor = isinstance( tensor, - ( - torch._subclasses.fake_tensor.FakeTensor, - torch._subclasses.functional_tensor.FunctionalTensor, - ), + torch._subclasses.fake_tensor.FakeTensor | torch._subclasses.functional_tensor.FunctionalTensor, ) need_offload = not torch_stray_tensor need_offload = need_offload and self.tensor_need_offloading_checker(tensor) @@ -295,7 +299,7 @@ def bulk_offload_group(self, group_to_offload): """Bulk offload group.""" offload_mapping = {} offload_size = 0 - with torch.cuda.stream(self.d2h_stream): + with get_torch_device().stream(self.d2h_stream): for tensor_tag, state in self.tensor_tag_to_state.items(): group_id, _ = tensor_tag if group_id == group_to_offload: @@ -318,15 +322,15 @@ def synchronize_on_group_commit_forward(self, current_group): # For the first group, kickstart the offload after we have # the first compute completion if current_group == 0: - self.d2h_stream.wait_stream(torch.cuda.current_stream()) + self.d2h_stream.wait_stream(get_torch_device().current_stream()) self.bulk_offload_group(current_group) # Window map data structure helps us synchronize based on number # of layers offloaded if self.layer_window_map[self.offloaded_group_count] == current_group: # Stream synchronization both ways - self.d2h_stream.wait_stream(torch.cuda.current_stream()) - torch.cuda.current_stream().wait_stream(self.d2h_stream) + self.d2h_stream.wait_stream(get_torch_device().current_stream()) + get_torch_device().current_stream().wait_stream(self.d2h_stream) # Time to free the activation memory after usage for tensor_tag, _ in self.tensor_tag_to_buf.items(): @@ -352,7 +356,7 @@ def bulk_reload_group(self, group_to_reload): """Bulk reload group.""" assert group_to_reload < self.num_offload_group - with torch.cuda.stream(self.h2d_stream): + with get_torch_device().stream(self.h2d_stream): # move back tensors offload_mapping = self.group_offload_mapping.pop(group_to_reload) assert offload_mapping is not None @@ -376,8 +380,8 @@ def on_group_commit_backward(self): # Layer window data structure helps us to reload at right times if self.layer_window_map[self.offloaded_group_count - 1] == self.current_group: # Stream synchronization both ways - self.h2d_stream.wait_stream(torch.cuda.current_stream()) - torch.cuda.current_stream().wait_stream(self.h2d_stream) + self.h2d_stream.wait_stream(get_torch_device().current_stream()) + get_torch_device().current_stream().wait_stream(self.h2d_stream) # Time to reload the next group self.bulk_reload_group(self.offloaded_group_count - 1) @@ -387,11 +391,13 @@ def on_group_commit_backward(self): # Last group computation needs to wait till all the reloads complete if self.current_group == 0: - torch.cuda.current_stream().wait_stream(self.h2d_stream) + get_torch_device().current_stream().wait_stream(self.h2d_stream) self.offloaded_group_count = 0 -def get_activation_offload_context(num_layers: int = 1, model_layers: int = 1, tensor_need_offloading_checker=(lambda t: True)): +def get_activation_offload_context( + num_layers: int = 1, model_layers: int = 1, tensor_need_offloading_checker=(lambda t: True) +): cpu_offload_handler = AsyncDoubleBufferGroupOffloadHandler( num_offload_group=num_layers, num_model_group=model_layers, @@ -442,7 +448,7 @@ def _unpack_kwargs(self, flat_args, kwarg_keys): if len(kwarg_keys) == 0: return flat_args, {} args = flat_args[: -len(kwarg_keys)] - kwargs = dict(zip(kwarg_keys, flat_args[-len(kwarg_keys) :])) + kwargs = dict(zip(kwarg_keys, flat_args[-len(kwarg_keys) :], strict=True)) return args, kwargs def _ckpt_forward(self, forward_method, *args, **kwargs): @@ -517,7 +523,7 @@ def enable_activation_offloading(model, strategy, enable_ckpt=False): def get_layers(module): for name, child in module.named_children(): - if not isinstance(child, (FSDP, FSDP2)): + if not isinstance(child, FSDP | FSDP2): get_layers(child) else: wrapped_module = child @@ -536,7 +542,8 @@ def get_layers(module): tensor_filter = FSDPParameterFilter() context, sync_func = get_activation_offload_context(len(layers) - 1, len(layers), tensor_filter) if enable_ckpt: - # The implementation of activation checkpointing in transformers library is incompatible with activation offloading, + # The implementation of activation checkpointing in transformers library is incompatible with + # activation offloading, # so it will be disabled, but this implementation supports another version of activation checkpointing, so that # these two features can be enabled at the same time. for module in model.modules(): diff --git a/verl/utils/checkpoint/checkpoint_manager.py b/verl/utils/checkpoint/checkpoint_manager.py index 076a319bb..9659b7b89 100644 --- a/verl/utils/checkpoint/checkpoint_manager.py +++ b/verl/utils/checkpoint/checkpoint_manager.py @@ -11,19 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import os import random import shutil import tempfile -from typing import Optional, Union +from filelock import FileLock import numpy as np import torch import torch.distributed -from filelock import FileLock +from omegaconf import DictConfig from transformers import PreTrainedTokenizer, ProcessorMixin -from verl.utils.device import is_cuda_available, is_npu_available +from verl.utils.device import get_device_name, get_torch_device class BaseCheckpointManager: @@ -46,11 +47,16 @@ def __init__( model, optimizer: torch.optim.Optimizer, lr_scheduler: torch.optim.lr_scheduler.LRScheduler = None, - processing_class: Union[PreTrainedTokenizer, ProcessorMixin] = None, - checkpoint_contents: Optional[list] = None, + processing_class: PreTrainedTokenizer | ProcessorMixin = None, + checkpoint_config: DictConfig = None, ): - if checkpoint_contents is None: - checkpoint_contents = ["model", "optimizer", "extra"] + self.checkpoint_config = checkpoint_config + checkpoint_load_contents = checkpoint_config.get("load_contents", None) if checkpoint_config else None + checkpoint_save_contents = checkpoint_config.get("save_contents", None) if checkpoint_config else None + if checkpoint_load_contents is None: + checkpoint_load_contents = ["model", "optimizer", "extra"] + if checkpoint_save_contents is None: + checkpoint_save_contents = ["model", "optimizer", "extra"] self.previous_global_step = None self.previous_saved_paths = [] @@ -58,15 +64,68 @@ def __init__( self.optimizer = optimizer self.lr_scheduler = lr_scheduler self.processing_class = processing_class - self.checkpoint_contents = checkpoint_contents + self.checkpoint_load_contents = checkpoint_load_contents + self.checkpoint_save_contents = checkpoint_save_contents self.rank = torch.distributed.get_rank() self.world_size = torch.distributed.get_world_size() + @property + def should_save_model(self) -> bool: + """ + Returns True if 'model' is in checkpoint_save_contents, indicating the model state should be saved. + """ + return "model" in self.checkpoint_save_contents + + @property + def should_save_optimizer(self) -> bool: + """ + Returns True if 'optimizer' is in checkpoint_save_contents, indicating the optimizer state should be saved. + """ + return "optimizer" in self.checkpoint_save_contents + + @property + def should_save_extra(self) -> bool: + """ + Returns True if 'extra' is in checkpoint_save_contents, indicating the extra state should be saved. + """ + return "extra" in self.checkpoint_save_contents + + @property + def should_save_hf_model(self) -> bool: + """ + Returns True if 'hf_model' is in checkpoint_save_contents, indicating the model should be converted to hf + model and saved. + """ + return "hf_model" in self.checkpoint_save_contents + + @property + def should_load_model(self) -> bool: + """ + Returns True if 'model' is in checkpoint_load_contents, indicating the model state should be loaded. + """ + return "model" in self.checkpoint_load_contents + + @property + def should_load_optimizer(self) -> bool: + """ + Returns True if 'optimizer' is in checkpoint_load_contents, indicating the optimizer state should be loaded. + """ + return "optimizer" in self.checkpoint_load_contents + + @property + def should_load_extra(self) -> bool: + """ + Returns True if 'extra' is in checkpoint_load_contents, indicating the extra state should be loaded. + """ + return "extra" in self.checkpoint_load_contents + def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_after_load: bool = False): raise NotImplementedError - def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: int = 0, max_ckpt_to_keep: int = None): + def save_checkpoint( + self, local_path: str, hdfs_path: str = None, global_step: int = 0, max_ckpt_to_keep: int = None + ): raise NotImplementedError @staticmethod @@ -83,7 +142,7 @@ def remove_previous_save_local_path(self, path): if not os.path.exists(abs_path): continue shutil.rmtree(abs_path, ignore_errors=True) - + @staticmethod def local_mkdir(path): if not os.path.isabs(path): @@ -113,10 +172,8 @@ def get_rng_state(): "random": random.getstate(), } - if is_cuda_available: - rng_state["cuda"] = torch.cuda.get_rng_state() - elif is_npu_available: - rng_state["npu"] = torch.npu.get_rng_state() + if get_device_name() != "cpu": + rng_state[get_device_name()] = get_torch_device().get_rng_state() return rng_state @@ -126,10 +183,8 @@ def load_rng_state(rng_state): np.random.set_state(rng_state["numpy"]) random.setstate(rng_state["random"]) - if is_cuda_available: - torch.cuda.set_rng_state(rng_state["cuda"]) - elif is_npu_available: - torch.npu.set_rng_state(rng_state["npu"]) + if get_device_name() != "cpu": + get_torch_device().set_rng_state(rng_state[get_device_name()]) def find_latest_ckpt_path(path, directory_format="global_step_{}"): @@ -150,7 +205,7 @@ def find_latest_ckpt_path(path, directory_format="global_step_{}"): tracker_file = get_checkpoint_tracker_filename(path) if not os.path.exists(tracker_file): - print("Checkpoint tracker file does not exist: %s", tracker_file) + print(f"Checkpoint tracker file does not exist: {tracker_file}") return None with open(tracker_file, "rb") as f: @@ -169,3 +224,37 @@ def get_checkpoint_tracker_filename(root_path: str): Tracker file rescords the latest chckpoint during training to restart from. """ return os.path.join(root_path, "latest_checkpointed_iteration.txt") + + +def should_save_ckpt_esi(max_steps_duration: float, save_ckpt_duration: float = 60, redundant_time: float = 0) -> bool: + """ + Determine if checkpoint should be saved based on capacity esi expiration. + + Args: + max_steps_duration: Max estimated time (seconds) required to complete one training step + save_ckpt_duration: Estimated time (seconds) required to save checkpoint (default: 60) + redundant_time: Additional buffer time (seconds) for unexpected delays (default: 0) + """ + exp_ts_mlp = os.getenv("MLP_CURRENT_CAPACITY_BLOCK_EXPIRATION_TIMESTAMP") # vemlp + exp_ts_aws = os.getenv("SAGEMAKER_CURRENT_CAPACITY_BLOCK_EXPIRATION_TIMESTAMP") # aws + if exp_ts_mlp: + try: + import time + + remaining = float(exp_ts_mlp) - time.time() + except ValueError: + return False + return ( + remaining > 0 + and max_steps_duration > 0 + and remaining <= save_ckpt_duration + max_steps_duration + redundant_time + ) + elif exp_ts_aws: + from datetime import datetime, timedelta + + expiration_time = datetime.fromtimestamp(int(exp_ts_aws)) + time_difference = expiration_time - datetime.now() + threshold_minutes = (save_ckpt_duration + max_steps_duration + redundant_time) / 60 + return time_difference < timedelta(minutes=threshold_minutes) + else: + return False diff --git a/verl/utils/checkpoint/fsdp_checkpoint_manager.py b/verl/utils/checkpoint/fsdp_checkpoint_manager.py index 99c724e7a..e81aebbd0 100644 --- a/verl/utils/checkpoint/fsdp_checkpoint_manager.py +++ b/verl/utils/checkpoint/fsdp_checkpoint_manager.py @@ -12,23 +12,45 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json +import logging import os import warnings -from typing import Optional, Union +from dataclasses import asdict, dataclass +from typing import Optional import torch import torch.distributed from accelerate import init_empty_weights -from torch.distributed.fsdp import FullStateDictConfig, ShardedOptimStateDictConfig, ShardedStateDictConfig, StateDictType +from omegaconf import DictConfig from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import ShardedOptimStateDictConfig, ShardedStateDictConfig, StateDictType from transformers import GenerationConfig, PreTrainedTokenizer, ProcessorMixin from verl.utils.device import is_cuda_available -from verl.utils.fs import copy_to_local, is_non_local -from verl.utils.fsdp_utils import fsdp_version, get_fsdp_state_ctx +from verl.utils.fs import copy_to_local, is_non_local, local_mkdir_safe +from verl.utils.fsdp_utils import fsdp_version, get_fsdp_full_state_dict, get_fsdp_state_ctx +from verl.utils.logger import log_with_rank from .checkpoint_manager import BaseCheckpointManager +# Setup logging +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "INFO")) + + +@dataclass +class FSDPConfig: + """Configuration for FSDP checkpointing. + + Args: + FSDP_version (int): Version of FSDP being used. + world_size (int): Number of processes in the distributed training setup. + """ + + FSDP_version: int + world_size: int + class FSDPCheckpointManager(BaseCheckpointManager): """ @@ -44,33 +66,33 @@ class FSDPCheckpointManager(BaseCheckpointManager): lr_scheduler (LRScheduler): Learning-rate scheduler. processing_class (PreTrainedTokenizer or ProcessorMixin, optional): Pre-/post-processing artifact handler. - checkpoint_contents (list[str], optional): - Components to include; must contain 'model', 'optimizer', 'extra'. + checkpoint_contents DictConfig: Configuration for checkpoint contents. + - 'load': Components to load; must contain 'model'. Defaults to ['model', 'optimizer', 'extra']. + - 'save': Components to save; must contain 'model'. Defaults to ['model', 'optimizer', 'extra']. """ def __init__( self, model: FSDP, - optimizer: torch.optim.Optimizer, - lr_scheduler: torch.optim.lr_scheduler.LRScheduler, - processing_class: Union[PreTrainedTokenizer, ProcessorMixin] = None, - checkpoint_contents: Optional[list] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + lr_scheduler: Optional[torch.optim.lr_scheduler.LRScheduler] = None, + processing_class: PreTrainedTokenizer | ProcessorMixin = None, + checkpoint_config: DictConfig = None, **kwargs, ): - if checkpoint_contents is None: - checkpoint_contents = ["model", "optimizer", "extra"] if processing_class is None: assert "tokenizer" in kwargs, "tokenizer or processor must be provided" - warnings.warn("`tokenizer` is deprecated. use `processing_class` instead.", DeprecationWarning, stacklevel=2) + warnings.warn( + "`tokenizer` is deprecated. use `processing_class` instead.", DeprecationWarning, stacklevel=2 + ) processing_class = kwargs.pop("tokenizer") - assert "model" in checkpoint_contents and "optimizer" in checkpoint_contents and "extra" in checkpoint_contents, f"FSDPCheckpointManager must include ['model', 'optimizer', 'extra'], got {checkpoint_contents}" super().__init__( model, optimizer, lr_scheduler=lr_scheduler, processing_class=processing_class, - checkpoint_contents=checkpoint_contents, + checkpoint_config=checkpoint_config, ) def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_after_load=False): @@ -89,42 +111,71 @@ def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_afte if local_path is None: return + # check if the checkpoint_load_contents is valid + if self.should_load_model: + assert self.model is not None, "model must be provided when checkpoint_contents.load includes ['model']" + if self.should_load_optimizer: + assert self.optimizer is not None, ( + "optimizer must be provided when checkpoint_contents.load includes ['optimizer']" + ) + # every rank download its own checkpoint - remote_model_path = os.path.join(local_path, f"model_world_size_{self.world_size}_rank_{self.rank}.pt") - remote_optim_path = os.path.join(local_path, f"optim_world_size_{self.world_size}_rank_{self.rank}.pt") - remote_extra_state_path = os.path.join(local_path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt") - print(f"[rank-{self.rank}]: Loading from {remote_model_path} and {remote_optim_path} and {remote_extra_state_path}") - local_model_path = copy_to_local(remote_model_path) - local_optim_path = copy_to_local(remote_optim_path) - local_extra_state_path = copy_to_local(remote_extra_state_path) - - model_state_dict = torch.load(local_model_path, weights_only=False) - optimizer_state_dict = torch.load(local_optim_path, weights_only=False) - extra_state_dict = torch.load(local_extra_state_path, weights_only=False) - - if del_local_after_load: + state_dict_cfg = ( + ShardedStateDictConfig(offload_to_cpu=True if is_cuda_available else False) + if self.should_load_model + else None + ) + optim_cfg = ( + ShardedOptimStateDictConfig(offload_to_cpu=True if is_cuda_available else False) + if self.should_load_optimizer + else None + ) + with get_fsdp_state_ctx(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_cfg, optim_cfg): + if self.should_load_model: + remote_model_path = os.path.join(local_path, f"model_world_size_{self.world_size}_rank_{self.rank}.pt") + local_model_path = copy_to_local(remote_model_path) + model_state_dict = torch.load(local_model_path, weights_only=False) + self.model.load_state_dict(model_state_dict) + log_with_rank(f"Loaded model from {remote_model_path}", rank=self.rank, logger=logger) + + if self.should_load_optimizer: + remote_optim_path = os.path.join(local_path, f"optim_world_size_{self.world_size}_rank_{self.rank}.pt") + local_optim_path = copy_to_local(remote_optim_path) + optimizer_state_dict = torch.load(local_optim_path, weights_only=False) + self.optimizer.load_state_dict(optimizer_state_dict) + log_with_rank(f"Loaded optimizer from {remote_optim_path}", rank=self.rank, logger=logger) + + if self.should_load_extra: + remote_extra_state_path = os.path.join( + local_path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt" + ) + local_extra_state_path = copy_to_local(remote_extra_state_path) + extra_state_dict = torch.load(local_extra_state_path, weights_only=False) + # recover random state + if "rng" in extra_state_dict: + # 'rng' may not exist for backward compatibility + self.load_rng_state(extra_state_dict["rng"]) + log_with_rank(f"Loaded rng from {remote_extra_state_path}", rank=self.rank, logger=logger) + + lr_scheduler_state_dict = extra_state_dict["lr_scheduler"] + if lr_scheduler_state_dict is not None and self.lr_scheduler is not None: + self.lr_scheduler.load_state_dict(lr_scheduler_state_dict) + log_with_rank(f"Loaded lr_scheduler from {remote_extra_state_path}", rank=self.rank, logger=logger) + + if self.rank == 0 and del_local_after_load: try: os.remove(local_model_path) if is_non_local(local_model_path) else None os.remove(local_optim_path) if is_non_local(local_optim_path) else None os.remove(local_extra_state_path) if is_non_local(local_extra_state_path) else None except Exception as e: - print(f"[rank-{self.rank}]: remove local resume ckpt file after loading failed, exception {e} will be ignored") + log_with_rank( + f"remove local resume ckpt file after loading failed, exception {e} will be ignored", + rank=self.rank, + logger=logger, + ) - lr_scheduler_state_dict = extra_state_dict["lr_scheduler"] - - state_dict_cfg = ShardedStateDictConfig(offload_to_cpu=True if is_cuda_available else False) - optim_cfg = ShardedOptimStateDictConfig(offload_to_cpu=True if is_cuda_available else False) - with get_fsdp_state_ctx(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_cfg, optim_cfg): - self.model.load_state_dict(model_state_dict) - if self.optimizer is not None: - self.optimizer.load_state_dict(optimizer_state_dict) - # recover random state - if "rng" in extra_state_dict: - # 'rng' may not exist for backward compatibility - self.load_rng_state(extra_state_dict["rng"]) - - if self.lr_scheduler is not None: - self.lr_scheduler.load_state_dict(lr_scheduler_state_dict) + # wait for everyone to load checkpoints + torch.distributed.barrier() def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: int = 0, max_ckpt_to_keep=None): """ @@ -150,16 +201,29 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i # record the previous global step self.previous_global_step = global_step - # remove previous local_path - if max_ckpt_to_keep and isinstance(max_ckpt_to_keep, int) and max_ckpt_to_keep > 0 and len(self.previous_saved_paths) >= max_ckpt_to_keep: + # remove previous local_path, only rank 0 should do this + if ( + self.rank == 0 + and max_ckpt_to_keep + and isinstance(max_ckpt_to_keep, int) + and max_ckpt_to_keep > 0 + and len(self.previous_saved_paths) >= max_ckpt_to_keep + ): keep_start = len(self.previous_saved_paths) - max_ckpt_to_keep + 1 self.remove_previous_save_local_path(self.previous_saved_paths[:keep_start]) self.previous_saved_paths = self.previous_saved_paths[keep_start:] - if self.rank == 0: # added by Reasoning360: file system got problem on rank0 when co-current making dirs, so we make dirs on rank0 only - local_path = self.local_mkdir(local_path) + # if self.rank == 0: # added by Reasoning360: file system got problem on rank0 when co-current making dirs, so we make dirs on rank0 only + local_path = local_mkdir_safe(local_path) torch.distributed.barrier() - local_path = self.local_mkdir(local_path) # hack fix: to get the local path for non-rank0 + + # check if the checkpoint_save_contents is valid + if self.should_save_model: + assert self.model is not None, "model must be provided when checkpoint_contents.save includes ['model']" + if self.should_save_optimizer: + assert self.optimizer is not None, ( + "optimizer must be provided when checkpoint_contents.save includes ['optimizer']" + ) # every rank will save its own model and optim shard state_dict_cfg = ShardedStateDictConfig(offload_to_cpu=True if is_cuda_available else False) @@ -167,59 +231,82 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i with warnings.catch_warnings(): warnings.simplefilter("ignore") with get_fsdp_state_ctx(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_cfg, optim_cfg): - model_state_dict = self.model.state_dict() - optimizer_state_dict = self.optimizer.state_dict() if self.optimizer is not None else None - lr_scheduler_state_dict = self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None - - extra_state_dict = { - "lr_scheduler": lr_scheduler_state_dict, - "rng": self.get_rng_state(), - } model_path = os.path.join(local_path, f"model_world_size_{self.world_size}_rank_{self.rank}.pt") optim_path = os.path.join(local_path, f"optim_world_size_{self.world_size}_rank_{self.rank}.pt") extra_path = os.path.join(local_path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt") - print(f"[rank-{self.rank}]: Saving model to {os.path.abspath(model_path)}") - print(f"[rank-{self.rank}]: Saving optim to {os.path.abspath(optim_path)}") - print(f"[rank-{self.rank}]: Saving extra_state to {os.path.abspath(extra_path)}") - torch.save(model_state_dict, model_path) - torch.save(optimizer_state_dict, optim_path) # TODO: address optimizer is None - torch.save(extra_state_dict, extra_path) + if self.should_save_model: + model_state_dict = self.model.state_dict() + torch.save(model_state_dict, model_path) + log_with_rank(f"Saved model to {os.path.abspath(model_path)}", rank=self.rank, logger=logger) + + if self.should_save_optimizer: + optimizer_state_dict = self.optimizer.state_dict() + torch.save(optimizer_state_dict, optim_path) + log_with_rank(f"Saved optim to {os.path.abspath(optim_path)}", rank=self.rank, logger=logger) + + if self.should_save_extra: + lr_scheduler_state_dict = self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None + extra_state_dict = { + "lr_scheduler": lr_scheduler_state_dict, + "rng": self.get_rng_state(), + } + torch.save(extra_state_dict, extra_path) + log_with_rank(f"Saved extra_state to {os.path.abspath(extra_path)}", rank=self.rank, logger=logger) if self.rank == 0: + # Save HF tokenizer/processor and model config on rank 0 to huggingface/ directory, no matter whether + # huggingface model is requested to be saved or not. + if fsdp_version(self.model) == 1: unwrap_model = self.model._fsdp_wrapped_module else: unwrap_model = self.model + hf_config_tokenizer_path = os.path.join(local_path, "huggingface") + local_mkdir_safe(hf_config_tokenizer_path) model_config = unwrap_model.config + generation_config = None if unwrap_model.can_generate() and hasattr(model_config, "name_or_path") and model_config.name_or_path: - # Some model's name_or_path is empty if not initialized from pretrained, - # in this cases, we don't save generation config. - generation_config = GenerationConfig.from_pretrained(model_config.name_or_path) - generation_config.save_pretrained(local_path) - else: - generation_config = None - - model_config.save_pretrained(local_path) - self.processing_class.save_pretrained(local_path) + try: + # Some model's name_or_path is empty if not initialized from pretrained, + # in this cases, we don't save generation config. + generation_config = GenerationConfig.from_pretrained(model_config.name_or_path) + generation_config.save_pretrained(hf_config_tokenizer_path) + except Exception: + # if the generation config isn't available, we don't save it + pass + + model_config.save_pretrained(hf_config_tokenizer_path) + self.processing_class.save_pretrained(hf_config_tokenizer_path) + log_with_rank( + f"Saved model config and tokenizer class to {os.path.abspath(hf_config_tokenizer_path)}", + rank=self.rank, + logger=logger, + log_only_rank_0=True, + ) + + # Also save runtime FSDP config + fsdp_config_path = os.path.join(local_path, "fsdp_config.json") + fsdp_config = FSDPConfig( + FSDP_version=fsdp_version(self.model), + world_size=self.world_size, + ) + with open(fsdp_config_path, "w") as f: + json.dump(asdict(fsdp_config), f, indent=4) # wait for everyone to dump to local torch.distributed.barrier() - if "hf_model" in self.checkpoint_contents: - hf_local_path = os.path.join(local_path, "huggingface") - if self.rank == 0: - os.makedirs(hf_local_path, exist_ok=True) - torch.distributed.barrier() - + if self.should_save_hf_model: # Only rank 0 will save hf model and, # offload to cpu to save LLMs which may be too large to fit in one GPU - state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) - with get_fsdp_state_ctx(self.model, StateDictType.FULL_STATE_DICT, state_dict_config, None): - state_dict = self.model.state_dict() + state_dict = get_fsdp_full_state_dict(self.model, offload_to_cpu=True, rank0_only=True) if self.rank == 0: + hf_local_path = os.path.join(local_path, "huggingface") + os.makedirs(hf_local_path, exist_ok=True) + if "ForTokenClassification" in model_config.architectures[0]: from transformers import AutoModelForTokenClassification @@ -243,10 +330,18 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i if generation_config is not None: save_model.generation_config = generation_config else: - print(f"Warning: {self.__class__.__name__}.save_checkpoint: Generation config file not found in, using a generation config created from the model config when saving hf_model.") + print( + f"Warning: {self.__class__.__name__}.save_checkpoint: Generation config file not found " + f"in, using a generation config created from the model config when saving hf_model." + ) save_model.save_pretrained(hf_local_path, state_dict=state_dict) - self.processing_class.save_pretrained(hf_local_path) + log_with_rank( + f"Saved hf_model to {os.path.abspath(hf_local_path)}", + rank=self.rank, + logger=logger, + log_only_rank_0=True, + ) del state_dict del save_model diff --git a/verl/utils/checkpoint/megatron_checkpoint_manager.py b/verl/utils/checkpoint/megatron_checkpoint_manager.py index 8eb6a2524..b9fcc551b 100644 --- a/verl/utils/checkpoint/megatron_checkpoint_manager.py +++ b/verl/utils/checkpoint/megatron_checkpoint_manager.py @@ -12,72 +12,124 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json +import logging import os import random -from typing import Optional +from collections.abc import Callable +from dataclasses import asdict import numpy as np import torch import torch.distributed from megatron.core import mpu, tensor_parallel from megatron.core.dist_checkpointing.mapping import ShardedObject +from megatron.core.transformer.enums import AttnBackend from transformers import GenerationConfig from verl.models.weight_loader_registry import get_weight_saver -from verl.utils.fs import is_non_local +from verl.utils.device import get_device_name, get_torch_device +from verl.utils.fs import is_non_local, local_mkdir_safe +from verl.utils.logger import log_with_rank +from verl.utils.megatron.dist_checkpointing import load_dist_checkpointing, save_dist_checkpointing from verl.utils.megatron_utils import ( - get_hf_config_and_tokenizer_checkpoint_path, + get_dist_checkpoint_path, get_hf_model_checkpoint_path, - get_model_checkpoint_path, - get_optimizer_checkpoint_path, - get_rng_states_checkpoint_path, + get_transformer_config_checkpoint_path, ) from .checkpoint_manager import BaseCheckpointManager +# Setup logging +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "INFO")) + class MegatronCheckpointManager(BaseCheckpointManager): """ - A checkpoint manager that saves and loads - - model - - optimizer - - lr_scheduler - - extra_states - in a SPMD way. - - We save - - sharded model states and optimizer states - - full lr_scheduler states - - huggingface tokenizer/processor and config for ckpt merge + Checkpoint manager for Megatron-LM distributed training. + + This class manages the saving and loading of model checkpoints in a Megatron-LM + distributed training environment. It handles various aspects of checkpointing + including model states, optimizer states, learning rate schedulers, and random + number generator states, ensuring compatibility with HuggingFace formats. + + Key features: + - Distributed checkpoint saving and loading using Megatron's dist_checkpointing + - Support for tensor parallel, pipeline parallel, and data parallel configurations + - Automatic handling of model state dictionaries across multiple pipeline stages + - Integration with HuggingFace model configurations and tokenizers + - Random number generator state management for reproducibility + - Support for both synchronous and asynchronous checkpoint operations + + The manager automatically handles: + - Directory structure creation based on global steps and process ranks + - Model configuration and tokenizer saving in HuggingFace format + - Optimizer and scheduler state persistence + - CUDA RNG state management for deterministic training + - Checkpoint cleanup and retention policies + + Args: + model: The Megatron model instance to checkpoint + optimizer: The optimizer instance (optional) + lr_scheduler: The learning rate scheduler instance (optional) + + Attributes: + model: Reference to the Megatron model being checkpointed + optimizer: Reference to the optimizer (if provided) + lr_scheduler: Reference to the learning rate scheduler (if provided) + rank: Current process rank in the distributed setup + + Example: + ```python + checkpoint_manager = MegatronCheckpointManager( + model=megatron_model, + optimizer=optimizer, + lr_scheduler=scheduler + ) + + checkpoint_manager.save_checkpoint( + local_path="checkpoints/step_1000", + global_step=1000 + ) + + checkpoint_manager.load_checkpoint( + local_path="checkpoints/step_1000" + ) + ``` """ def __init__( self, config, + checkpoint_config, model_config, + transformer_config, role, model: torch.nn.ModuleList, arch: str, hf_config, param_dtype: torch.dtype, share_embeddings_and_output_weights: bool, - tokenizer, + processing_class, optimizer, + optimizer_scheduler, use_distributed_optimizer: bool, - checkpoint_contents: Optional[list] = None, + use_checkpoint_opt_param_scheduler: bool = False, + use_dist_checkpointing: bool = True, + bridge=None, **kwargs, ): - if checkpoint_contents is None: - checkpoint_contents = ["model", "optimizer", "extra"] super().__init__( model, optimizer=optimizer, - lr_scheduler=None, - processing_class=tokenizer, - checkpoint_contents=checkpoint_contents, + lr_scheduler=optimizer_scheduler, + processing_class=processing_class, + checkpoint_config=checkpoint_config, ) self.arch = arch self.config = config + self.transformer_config = transformer_config self.role = role self.is_value_model = False if self.role in ["reward", "critic"]: @@ -88,21 +140,26 @@ def __init__( self.share_embeddings_and_output_weights = share_embeddings_and_output_weights self.model_path = self.config.model.path self.use_distributed_optimizer = use_distributed_optimizer - + self.use_checkpoint_opt_param_scheduler = use_checkpoint_opt_param_scheduler + self.bridge = bridge self.rank = torch.distributed.get_rank() + self.use_dist_checkpointing = use_dist_checkpointing or not self.bridge or self.is_value_model + self.use_hf_checkpoint = not self.use_dist_checkpointing self.weight_saver = get_weight_saver(self.arch) - def get_rng_state(self, use_dist_ckpt: bool = False, data_parallel_random_init: bool = False): + def get_rng_state(self, use_dist_ckpt: bool = True, data_parallel_random_init: bool = False): """collect rng state across data parallel ranks""" rng_state = { "random_rng_state": random.getstate(), "np_rng_state": np.random.get_state(), "torch_rng_state": torch.get_rng_state(), - "cuda_rng_state": torch.cuda.get_rng_state(), "rng_tracker_states": tensor_parallel.get_cuda_rng_tracker().get_states(), } + if get_device_name() != "cpu": + rng_state[f"{get_device_name()}_rng_state"] = get_torch_device().get_rng_state() + rng_state_list = None if torch.distributed.is_initialized() and mpu.get_data_parallel_world_size() > 1 and data_parallel_random_init: rng_state_list = [None for i in range(mpu.get_data_parallel_world_size())] @@ -115,13 +172,11 @@ def get_rng_state(self, use_dist_ckpt: bool = False, data_parallel_random_init: pp_size = mpu.get_pipeline_model_parallel_world_size() tp_rank = mpu.get_tensor_model_parallel_rank() tp_size = mpu.get_tensor_model_parallel_world_size() - cp_rank = mpu.get_context_parallel_rank() - cp_size = mpu.get_context_parallel_world_size() rng_state_list = ShardedObject( "rng_state", rng_state_list, - (pp_size, tp_size, cp_size), - (pp_rank, tp_rank, cp_rank), + (pp_size, tp_size), + (pp_rank, tp_rank), replica_id=mpu.get_data_parallel_rank(with_context_parallel=True), ) @@ -167,123 +222,238 @@ def get_checkpoint_name( if expert_parallel: common_path = common_path + f"_{expert_rank:03d}" - # replace os.makedirs by local_mkdir - self.local_mkdir(common_path) + # NOTE: Added by Reasoning360: replace os.makedirs by local_mkdir + local_mkdir_safe(common_path) if return_base_dir: return common_path return os.path.join(common_path, basename) - def load_optimizer(self, ckpt_path): - # TODO: Check Optimizer format and distributed optimizer - optimizer_path = get_optimizer_checkpoint_path(ckpt_path) - print(f"Loading optimizer from {optimizer_path}") - self.optimizer.load_parameter_state(optimizer_path) + def generate_state_dict(self): + # For save dist checkpointing + state_dict = {} + + # All ranks Save Model to reduce memory pressure + if self.should_save_model or self.should_load_model: + # Get sharded state dict, notice that state_dict will collect among dp groups, causing memory pressure + for vpp_rank, model in enumerate(self.model): + if len(self.model) > 1: + mpu.set_virtual_pipeline_model_parallel_rank(vpp_rank) + key = f"model{vpp_rank}" if len(self.model) > 1 else "model" + else: + key = "model" + if hasattr(model, "module"): + model = model.module + state_dict[key] = model.sharded_state_dict() + + # Optimizer State Dict + if self.should_save_optimizer or self.should_load_optimizer: + torch.distributed.barrier() + optimizer_sharded_states = self.optimizer.sharded_state_dict(state_dict) + state_dict["optimizer"] = optimizer_sharded_states - def load_rng_states(self, ckpt_path, data_parallel_random_init=False, use_dist_ckpt=False): - rng_state_path = get_rng_states_checkpoint_path(ckpt_path, only_rank0_save=False) - print(f"Loading rng states from {rng_state_path}") - rng_state = torch.load(rng_state_path, weights_only=False) + if self.lr_scheduler is not None: + lr_state_dict = self.lr_scheduler.state_dict() + state_dict["lr_scheduler"] = lr_state_dict + + # RNG States State Dict + if self.should_save_extra or self.should_load_extra: + torch.distributed.barrier() + rng_state = self.get_rng_state() + state_dict["rng_state"] = rng_state + + return state_dict + + def load_rng_states(self, rng_states, data_parallel_random_init=False, use_dist_ckpt=True): # access rng_state for data parallel rank - if not use_dist_ckpt: - rng_state = rng_state[mpu.get_data_parallel_rank()] if data_parallel_random_init else rng_state[0] - random.setstate(rng_state["random_rng_state"]) - np.random.set_state(rng_state["np_rng_state"]) - torch.set_rng_state(rng_state["torch_rng_state"]) - torch.cuda.set_rng_state(rng_state["cuda_rng_state"]) + if data_parallel_random_init: + rng_states = rng_states[mpu.get_data_parallel_rank()] + else: + rng_states = rng_states[0] + random.setstate(rng_states["random_rng_state"]) + np.random.set_state(rng_states["np_rng_state"]) + torch.set_rng_state(rng_states["torch_rng_state"]) + + if get_device_name() != "cpu": + get_torch_device().set_rng_state(rng_states[f"{get_device_name()}_rng_state"]) + # Check for empty states array - if not rng_state["rng_tracker_states"]: + if not rng_states["rng_tracker_states"]: raise KeyError - tensor_parallel.get_cuda_rng_tracker().set_states(rng_state["rng_tracker_states"]) + tensor_parallel.get_cuda_rng_tracker().set_states(rng_states["rng_tracker_states"]) def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_after_load=False): - if local_path is None: - return - - if "model" in self.checkpoint_contents: - model_path = get_model_checkpoint_path(local_path) - ckpt_name = self.get_checkpoint_name(model_path, return_base_dir=False) - state_dicts = torch.load(os.path.join(ckpt_name), weights_only=False) - assert len(state_dicts) == len(self.model), f"state_dicts length: {len(state_dicts)} mismatch with model length: {len(self.model)}" - for vpp_rank, (state_dict, model) in enumerate(zip(state_dicts, self.model)): - model.load_state_dict(state_dict) - print(f"Loaded sharded model checkpoint from {model_path}") - - if "optimizer" in self.checkpoint_contents: - self.load_optimizer(local_path) + if local_path is not None: + assert os.path.exists(local_path), f"Checkpoint path {local_path} does not exist." + + dist_checkpoint_path = get_dist_checkpoint_path(local_path) + + # Get State Dict for loading + sharded_state_dict = self.generate_state_dict() + log_with_rank(f"Generated state dict for saving: {sharded_state_dict.keys()}", rank=self.rank, logger=logger) + for vpp_rank, model in enumerate(self.model): + if len(self.model) > 1: + model_i_keys = sharded_state_dict[f"model{vpp_rank}"].keys() + log_with_rank(f"Generated state dict for saving: {model_i_keys}", rank=self.rank, logger=logger) + else: + log_with_rank( + f"Generated state dict for saving: {sharded_state_dict['model'].keys()}", + rank=self.rank, + logger=logger, + ) + + # Load Dist Checkpointing + state_dict = load_dist_checkpointing( + sharded_state_dict=sharded_state_dict, + ckpt_dir=dist_checkpoint_path, + ) - if "extra" in self.checkpoint_contents: - self.load_rng_states(local_path) + if self.should_load_model and self.use_dist_checkpointing: + assert "model" in state_dict or any( + f"model{vpp_rank}" in state_dict for vpp_rank in range(len(self.model)) + ), f"Model state dict not found in {state_dict.keys()}. Please check the checkpoint file {local_path}." + for vpp_rank, model in enumerate(self.model): + if len(self.model) == 1: + model_state_dict = state_dict["model"] + else: + assert f"model{vpp_rank}" in state_dict, f"model{vpp_rank} not found in state_dict" + model_state_dict = state_dict[f"model{vpp_rank}"] + mpu.set_virtual_pipeline_model_parallel_rank(vpp_rank) + self.model[vpp_rank].load_state_dict(model_state_dict) + log_with_rank(f"Loaded sharded model checkpoint from {local_path}", rank=self.rank, logger=logger) + elif self.should_load_model and self.use_hf_checkpoint: + hf_model_path = get_hf_model_checkpoint_path(local_path) + self.bridge.load_weights(self.model, hf_model_path) + log_with_rank(f"Loaded HF model checkpoint from {hf_model_path} with bridge", rank=self.rank, logger=logger) + + if self.should_load_optimizer: + assert "optimizer" in state_dict, ( + f"Optimizer state dict not found in {state_dict.keys()}. Please check the checkpoint file {local_path}." + ) + optimizer_state_dict = state_dict["optimizer"] + self.optimizer.load_state_dict(optimizer_state_dict) + log_with_rank(f"Loaded optimizer checkpoint from {local_path}", rank=self.rank, logger=logger) + if self.use_checkpoint_opt_param_scheduler: + assert "lr_scheduler" in state_dict, ( + f"LR scheduler state dict not found in {state_dict.keys()}. Please check the checkpoint file " + f"{local_path}." + ) + lr_scheduler_state_dict = state_dict["lr_scheduler"] + if self.lr_scheduler is not None: + self.lr_scheduler.load_state_dict(lr_scheduler_state_dict) + log_with_rank(f"Loaded LR scheduler checkpoint from {local_path}", rank=self.rank, logger=logger) + + if self.should_load_extra: + assert "rng_state" in state_dict, ( + f"RNG state dict not found in {state_dict.keys()}. Please check the checkpoint file {local_path}." + ) + rng_state = state_dict["rng_state"] + self.load_rng_states(rng_state) + log_with_rank(f"Loaded RNG states from {local_path}", rank=self.rank, logger=logger) if del_local_after_load: try: os.remove(local_path) if is_non_local(local_path) else None except Exception as e: - print(f"[rank-{self.rank}]: remove local resume ckpt file after loading failed, exception {e} will be ignored") + log_with_rank( + f"remove local resume ckpt file after loading failed, exception {e} will be ignored", + rank=self.rank, + logger=logger, + ) def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: int = 0, max_ckpt_to_keep=None): # record the previous global step self.previous_global_step = global_step # remove previous local_path - if max_ckpt_to_keep and isinstance(max_ckpt_to_keep, int) and max_ckpt_to_keep > 0 and len(self.previous_saved_paths) >= max_ckpt_to_keep: + if ( + max_ckpt_to_keep + and isinstance(max_ckpt_to_keep, int) + and max_ckpt_to_keep > 0 + and len(self.previous_saved_paths) >= max_ckpt_to_keep + ): keep_start = len(self.previous_saved_paths) - max_ckpt_to_keep + 1 self.remove_previous_save_local_path(self.previous_saved_paths[:keep_start]) self.previous_saved_paths = self.previous_saved_paths[keep_start:] - if self.rank == 0: - # NOTE: bug fix by Reasoning360: avoid multiple nodes creating the same directory - local_path = self.local_mkdir(local_path) - if "model" in self.checkpoint_contents: - model_ckpt_path = get_model_checkpoint_path(local_path) - model_ckpt_path = self.local_mkdir(model_ckpt_path) - if "optimizer" in self.checkpoint_contents: - optimizer_path = os.path.join(local_path, "optim") - optimizer_path = self.local_mkdir(optimizer_path) - if "extra" in self.checkpoint_contents: - rng_state_path = os.path.join(local_path, "rng_states") - rng_state_path = self.local_mkdir(rng_state_path) - - torch.distributed.barrier() - local_path = self.local_mkdir(local_path) - - # Save Model - # NOTE: bug fix by Reasoning360: only save one copy for the CP group - if ("model" in self.checkpoint_contents and mpu.get_data_parallel_rank() == 0 and - (mpu.get_context_parallel_world_size() <= 1 or mpu.get_context_parallel_rank() == 0) - ): - state_dicts = [] + local_path = local_mkdir_safe(local_path) + dist_checkpoint_path = get_dist_checkpoint_path(local_path) + if self.use_dist_checkpointing: + # Generate state dict for saving + state_dict = self.generate_state_dict() + log_with_rank(f"Generated state dict for saving: {state_dict.keys()}", rank=self.rank, logger=logger) for vpp_rank, model in enumerate(self.model): - state_dict = model.state_dict() - state_dicts.append(state_dict) - - print(f"Saving sharded model checkpoint to {local_path}") - model_ckpt_path = get_model_checkpoint_path(local_path) - hf_config_and_tokenizer_path = get_hf_config_and_tokenizer_checkpoint_path(local_path) - ckpt_name = self.get_checkpoint_name(model_ckpt_path, return_base_dir=False) - torch.save(state_dicts, os.path.join(ckpt_name)) + if len(self.model) > 1: + model_i_keys = state_dict[f"model{vpp_rank}"].keys() + log_with_rank(f"Generated state dict for saving: {model_i_keys}", rank=self.rank, logger=logger) + else: + log_with_rank( + f"Generated state dict for saving: {state_dict['model'].keys()}", rank=self.rank, logger=logger + ) + # Start Async save if enabled + async_save_request = save_dist_checkpointing( + sharded_state_dict=state_dict, + ckpt_path=dist_checkpoint_path, + async_save=self.checkpoint_config.async_save, + ) - print(f"Saved checkpoint to {model_ckpt_path}") + # Synchronize all async save requests + if not self.checkpoint_config.async_save: + assert async_save_request is None, "Async save request should be None when not using async save." + torch.distributed.barrier() + else: + assert self.use_hf_checkpoint, "use_hf_checkpoint should be True when not using dist checkpointing" + log_with_rank(f"Saving HF model checkpoint to {local_path} with bridge", rank=self.rank, logger=logger) + hf_ckpt_path = get_hf_model_checkpoint_path(local_path) + self.bridge.save_weights(self.model, hf_ckpt_path) + log_with_rank(f"Saved bridge checkpoint to {hf_ckpt_path}", rank=self.rank, logger=logger) + + if self.should_save_model: + # Only rank 0 saves the hf config and tokenizer to huggingface path + # No matter whether we save hf model or not if self.rank == 0: - self.processing_class.save_pretrained(hf_config_and_tokenizer_path) - self.hf_config.save_pretrained(hf_config_and_tokenizer_path) + # Save tokenizer + hf_config_tokenizer_path = get_hf_model_checkpoint_path(local_path) + self.processing_class.save_pretrained(hf_config_tokenizer_path) + # Save huggingface config + self.hf_config.save_pretrained(hf_config_tokenizer_path) if hasattr(self.hf_config, "name_or_path") and self.hf_config.name_or_path: try: generation_config = GenerationConfig.from_pretrained(self.hf_config.name_or_path) - generation_config.save_pretrained(hf_config_and_tokenizer_path) + generation_config.save_pretrained(hf_config_tokenizer_path) except Exception: # if the generation config isn't available, we don't save it pass - if hdfs_path is not None: - print(f"Uploading checkpoint to {hdfs_path}") - from verl.utils import hdfs_io - - hdfs_io.makedirs(hdfs_path, exist_ok=True) - hdfs_io.copy(src=model_ckpt_path, dst=hdfs_path, dirs_exist_ok=True) - hdfs_io.copy(src=hf_config_and_tokenizer_path, dst=hdfs_path, dirs_exist_ok=True) - - if "hf_model" in self.checkpoint_contents: + log_with_rank( + f"Saved Huggingface config and tokenizer to {hf_config_tokenizer_path}", + rank=self.rank, + logger=logger, + log_only_rank_0=True, + ) + + if self.should_save_extra: + if self.rank == 0: + # Save transformer config + print(self.transformer_config) + transformer_config_dict = asdict(self.transformer_config) + to_convert_types = {torch.dtype: str, AttnBackend: str} + ignore_types = [Callable] + pop_keys = [] + for key, value in transformer_config_dict.items(): + if type(value) in to_convert_types: + transformer_config_dict[key] = to_convert_types[type(value)](value) + if type(value) in ignore_types: + pop_keys.append(key) + if callable(value): + pop_keys.append(key) + for key in pop_keys: + transformer_config_dict.pop(key) + transformer_config_path = get_transformer_config_checkpoint_path(local_path) + with open(transformer_config_path, "w") as f: + json.dump(transformer_config_dict, f, indent=2) + + if self.should_save_hf_model: # wait for everyone to dump to local state_dict = self.weight_saver( self.model, @@ -295,9 +465,6 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i torch.distributed.barrier() if self.rank == 0: - print(f"self.param_dtype: {self.param_dtype}") - for key in state_dict.keys(): - print(f"state_dict[key].dtype: {key} {state_dict[key].dtype}") hf_model_ckpt_path = get_hf_model_checkpoint_path(local_path) import warnings @@ -308,43 +475,52 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i if "mistral7b-rm" in self.config.model.path: from transformers import MistralForSequenceClassification - model = MistralForSequenceClassification.from_pretrained(self.config.model.path) # use score head instead of lm_head + model = MistralForSequenceClassification.from_pretrained( + self.config.model.path + ) # use score head instead of lm_head state_dict["score.weight"] = state_dict["score.weight"] else: from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained(self.config.model.path, torch_dtype="auto") model.save_pretrained(hf_model_ckpt_path, state_dict=state_dict) - self.processing_class.save_pretrained(hf_model_ckpt_path) + log_with_rank( + f"Saved Huggingface config and tokenizer to {hf_model_ckpt_path}", + rank=self.rank, + logger=logger, + log_only_rank_0=True, + ) if hdfs_path is not None: - print(f"Uploading checkpoint to {hdfs_path}") + log_with_rank( + f"Uploading checkpoint to {hdfs_path}", rank=self.rank, logger=logger, log_only_rank_0=True + ) from verl.utils import hdfs_io hdfs_io.makedirs(hdfs_path, exist_ok=True) hdfs_io.copy(src=hf_model_ckpt_path, dst=hdfs_path, dirs_exist_ok=True) - - # Save Optimizer - if "optimizer" in self.checkpoint_contents: - torch.distributed.barrier() - - optimizer_path = get_optimizer_checkpoint_path(local_path) - self.optimizer.save_parameter_state(optimizer_path) + log_with_rank( + f"HDFS checkpoint uploaded to {hdfs_path}", rank=self.rank, logger=logger, log_only_rank_0=True + ) + + def finalize_save_fn(): + # Rank 0 uploads checkpoint to HDFS if hdfs_path is provided + log_with_rank( + f"Dist checkpointing save completed for {dist_checkpoint_path}", rank=self.rank, logger=logger + ) if self.rank == 0: - print(f"saving optimizer state to {optimizer_path}") + if hdfs_path is not None: + log_with_rank(f"Uploading checkpoint to {hdfs_path}", rank=self.rank, logger=logger) + from verl.utils import hdfs_io - # Save RNG States - if "extra" in self.checkpoint_contents: - torch.distributed.barrier() - # NOTE: bug saving by Reasoning360: here we implicitly create a local path and multiple nodes - # may try to access it together. - if self.rank == 0: - rng_state_parent_path = self.local_mkdir(os.path.join(local_path, "rng_states")) - torch.distributed.barrier() + hdfs_io.makedirs(hdfs_path, exist_ok=True) + hdfs_io.copy(src=dist_checkpoint_path, dst=hdfs_path, dirs_exist_ok=True) + hdfs_io.copy(src=hf_config_tokenizer_path, dst=hdfs_path, dirs_exist_ok=True) - rng_state_path = get_rng_states_checkpoint_path(local_path, only_rank0_save=False) - rng_state = self.get_rng_state() - torch.save(rng_state, rng_state_path) - print(f"Rank {self.rank} saving rng states to {rng_state_path}") + if self.checkpoint_config.async_save: + assert async_save_request is not None, "Async save request should not be None when using async save." + async_save_request.add_finalize_fn(finalize_save_fn) + else: + finalize_save_fn() self.previous_saved_paths.append(local_path) diff --git a/verl/utils/config.py b/verl/utils/config.py index 5c9298c42..f1c301f24 100644 --- a/verl/utils/config.py +++ b/verl/utils/config.py @@ -12,12 +12,54 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict +from dataclasses import is_dataclass +from typing import Any, Optional -from omegaconf import DictConfig +from omegaconf import DictConfig, ListConfig, OmegaConf +__all__ = ["omega_conf_to_dataclass"] -def update_dict_with_config(dictionary: Dict, config: DictConfig): + +def omega_conf_to_dataclass(config: DictConfig | dict, dataclass_type: Optional[type[Any]] = None) -> Any: + """ + Convert an OmegaConf DictConfig to a dataclass. + + Args: + config: The OmegaConf DictConfig or dict to convert. + dataclass_type: The dataclass type to convert to. When dataclass_type is None, + the DictConfig must contain _target_ to be instantiated via hydra.instantiate API. + + Returns: + The dataclass instance. + """ + # Got an empty config + if not config: + return dataclass_type if dataclass_type is None else dataclass_type() + # Got an object + if not isinstance(config, DictConfig | ListConfig | dict | list): + return config + + if dataclass_type is None: + assert "_target_" in config, ( + "When dataclass_type is not provided, config must contain _target_." + "See trainer/config/ppo_trainer.yaml algorithm section for an example." + ) + from hydra.utils import instantiate + + return instantiate(config, _convert_="partial") + + if not is_dataclass(dataclass_type): + raise ValueError(f"{dataclass_type} must be a dataclass") + cfg = OmegaConf.create(config) # in case it's a dict + cfg_from_dataclass = OmegaConf.structured(dataclass_type) + # let cfg override the existing vals in `cfg_from_dataclass` + cfg_merged = OmegaConf.merge(cfg_from_dataclass, cfg) + # now convert to `dataclass_type` + config_object = OmegaConf.to_object(cfg_merged) + return config_object + + +def update_dict_with_config(dictionary: dict, config: DictConfig): for key in dictionary: if hasattr(config, key): dictionary[key] = getattr(config, key) diff --git a/verl/utils/dataset/multiturn_sft_dataset.py b/verl/utils/dataset/multiturn_sft_dataset.py index d6ea9af3a..e3eed0fd6 100644 --- a/verl/utils/dataset/multiturn_sft_dataset.py +++ b/verl/utils/dataset/multiturn_sft_dataset.py @@ -1,4 +1,5 @@ # Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2025 ModelBest Inc. and/or its affiliates # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,8 +16,11 @@ Multi-turn SFT dataset that supports training on conversation data with multiple turns """ -from typing import List, Union +import json +import logging +from typing import Any, Optional +import numpy as np import pandas as pd import torch from torch.utils.data import Dataset @@ -26,12 +30,25 @@ from verl.utils.fs import copy_local_path_from_hdfs +def convert_nested_value_to_list_recursive(data_item): + if isinstance(data_item, dict): + return {k: convert_nested_value_to_list_recursive(v) for k, v in data_item.items()} + elif isinstance(data_item, list): + return [convert_nested_value_to_list_recursive(elem) for elem in data_item] + elif isinstance(data_item, np.ndarray): + # Convert to list, then recursively process the elements of the new list + return convert_nested_value_to_list_recursive(data_item.tolist()) + else: + # Base case: item is already a primitive type (int, str, float, bool, etc.) + return data_item + + class MultiTurnSFTDataset(Dataset): """ Dataset for multi-turn conversations where each assistant response should be trained """ - def __init__(self, parquet_files: Union[str, List[str]], tokenizer, config=None): + def __init__(self, parquet_files: str | list[str], tokenizer, config=None): # Set defaults and extract parameters from config if provided config = config or {} self.truncation = config.get("truncation", "error") @@ -39,10 +56,11 @@ def __init__(self, parquet_files: Union[str, List[str]], tokenizer, config=None) # Get messages_key from the new multiturn config structure multiturn_config = config.get("multiturn", {}) self.messages_key = multiturn_config.get("messages_key", "messages") - + self.tools_key = multiturn_config.get("tools_key", "tools") + self.enable_thinking_key = multiturn_config.get("enable_thinking_key", "enable_thinking") assert self.truncation in ["error", "left", "right"] - if not isinstance(parquet_files, List): + if not isinstance(parquet_files, list): parquet_files = [parquet_files] self.parquet_files = parquet_files @@ -62,7 +80,7 @@ def series_to_item(ls): import numpy import pandas - while isinstance(ls, (pandas.core.series.Series, numpy.ndarray)) and len(ls) == 1: + while isinstance(ls, pandas.core.series.Series | numpy.ndarray) and len(ls) == 1: ls = ls[0] return ls @@ -75,46 +93,222 @@ def series_to_item(ls): # Extract messages list from dataframe self.messages = self.dataframe[self.messages_key].apply(series_to_item).tolist() + # Extract tools list from dataframe + if self.tools_key in self.dataframe.columns: + self.tools = self.dataframe[self.tools_key].apply(convert_nested_value_to_list_recursive).tolist() + else: + self.tools = None + # Extract enable_thinking list from dataframe + if self.enable_thinking_key in self.dataframe.columns: + self.enable_thinking = self.dataframe[self.enable_thinking_key].tolist() + else: + self.enable_thinking = None + def __len__(self): return len(self.messages) + def _process_message_tokens( + self, + messages: list[dict[str, Any]], + start_idx: int, + end_idx: int, + is_assistant: bool = False, + enable_thinking: Optional[bool] = None, + tools: Optional[list[dict[str, Any]]] = None, + ) -> tuple[list[int], list[int], list[int]]: + """ + Process tokens for a single message or a group of messages. + + Args: + messages: List of message dictionaries + start_idx: Start index in messages list + end_idx: End index in messages list + is_assistant: Whether this is an assistant message + enable_thinking: Whether to enable thinking mode + + Returns: + Tuple of (tokens, loss_mask, attention_mask) + """ + if start_idx > 0: + prev_applied_text = self.tokenizer.apply_chat_template( + messages[:start_idx], + tokenize=False, + add_generation_prompt=False, + enable_thinking=enable_thinking, + tools=tools, + ) + if is_assistant: + prev_applied_text_w_generation_prompt = self.tokenizer.apply_chat_template( + messages[:start_idx], + tokenize=False, + add_generation_prompt=True, + enable_thinking=enable_thinking, + tools=tools, + ) + + else: + prev_applied_text = "" + + cur_applied_text = self.tokenizer.apply_chat_template( + messages[:end_idx], + tokenize=False, + add_generation_prompt=False, + enable_thinking=enable_thinking, + tools=tools, + ) + # Get tokens for the current message only + if is_assistant: + generation_prompt_text = prev_applied_text_w_generation_prompt[len(prev_applied_text) :] + generation_prompt_tokens = self.tokenizer.encode( + generation_prompt_text, + add_special_tokens=False, + ) + _message_tokens = self.tokenizer.encode( + cur_applied_text[len(prev_applied_text_w_generation_prompt) :], + add_special_tokens=False, + ) + message_tokens = generation_prompt_tokens + _message_tokens + loss_mask = [0] * (len(generation_prompt_tokens)) + [1] * ( + len(message_tokens) - len(generation_prompt_tokens) + ) + else: + message_tokens = self.tokenizer.encode( + cur_applied_text[len(prev_applied_text) :], + add_special_tokens=False, + ) + loss_mask = [0] * len(message_tokens) + + attention_mask = [1] * len(message_tokens) + + return message_tokens, loss_mask, attention_mask + + def _validate_and_convert_tokens( + self, + full_tokens: torch.Tensor, + concat_tokens: list[int], + concat_loss_mask: list[int], + concat_attention_mask: list[int], + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Validate tokenization and convert to tensors. + + Args: + full_tokens: Full conversation tokens + concat_tokens: Concatenated tokens + concat_loss_mask: Concatenated loss mask + concat_attention_mask: Concatenated attention mask + + Returns: + Tuple of (input_ids, loss_mask, attention_mask) as tensors + """ + full_tokens_list = full_tokens.tolist() + + if len(concat_tokens) != len(full_tokens_list) or not all( + a == b for a, b in zip(concat_tokens, full_tokens_list, strict=True) + ): + logging.warning( + f"Token mismatch detected! Full tokenization length: {len(full_tokens_list)}, Concatenated tokens " + f"length: {len(concat_tokens)}. Using concatenated version." + # f"full tokens text: {self.tokenizer.decode(full_tokens_list)}" + # f"concat tokens text: {self.tokenizer.decode(concat_tokens)}" + ) + return ( + torch.tensor(concat_tokens, dtype=torch.long), + torch.tensor(concat_loss_mask, dtype=torch.long), + torch.tensor(concat_attention_mask, dtype=torch.long), + ) + + return ( + full_tokens, + torch.tensor(concat_loss_mask, dtype=torch.long), + torch.tensor(concat_attention_mask, dtype=torch.long), + ) + def __getitem__(self, item): tokenizer = self.tokenizer messages = self.messages[item] + tools = self.tools[item] if self.tools is not None else None + enable_thinking = self.enable_thinking[item] if self.enable_thinking is not None else None - # First, get the full conversation tokens - full_tokens = tokenizer.apply_chat_template(messages, tokenize=True, return_tensors="pt", add_generation_prompt=False) - input_ids = full_tokens[0] # The output is already a tensor - attention_mask = torch.ones_like(input_ids) - - # Create loss mask by identifying assistant responses - loss_mask = torch.zeros_like(input_ids, dtype=torch.long) + if self.tools is not None: + tools = json.loads(self.tools[item]) + else: + tools = None - # Process each message to find assistant responses - for i, msg in enumerate(messages): - # Get tokens for messages up to this point to find the start position - prefix_messages = messages[: i + 1] - prefix_tokens = tokenizer.apply_chat_template(prefix_messages, tokenize=True, return_tensors="pt", add_generation_prompt=False) + # First, get the full conversation tokens + try: + full_tokens = tokenizer.apply_chat_template( + messages, + tools=tools, + tokenize=True, + return_tensors="pt", + add_generation_prompt=False, + enable_thinking=enable_thinking, + ) + except Exception as e: + logging.error( + f"Error applying chat template: {e}\nMessages: {messages}\nTools: {tools}\nEnable thinking: " + f"{enable_thinking}" + ) + raise - # Get tokens for messages up to previous point - prev_tokens = tokenizer.apply_chat_template(messages[:i], tokenize=True, return_tensors="pt", add_generation_prompt=False) if i > 0 else None + # Track concatenated tokens for validation + concat_tokens = [] + concat_loss_mask = [] + concat_attention_mask = [] - # Calculate start and end positions - start_pos = prev_tokens[0].shape[0] if prev_tokens is not None else 0 - end_pos = prefix_tokens[0].shape[0] + i = 0 + while i < len(messages): + cur_messages = messages[i] + if cur_messages["role"] == "assistant": + # Process assistant message + tokens, loss_mask, attention_mask = self._process_message_tokens( + messages, i, i + 1, is_assistant=True, enable_thinking=enable_thinking, tools=tools + ) + concat_tokens.extend(tokens) + concat_loss_mask.extend(loss_mask) + concat_attention_mask.extend(attention_mask) + i += 1 + elif cur_messages["role"] == "tool": + # Process consecutive tool messages + st = i + ed = i + 1 + while ed < len(messages) and messages[ed]["role"] == "tool": + ed += 1 + tokens, loss_mask, attention_mask = self._process_message_tokens( + messages, st, ed, enable_thinking=enable_thinking, tools=tools + ) + concat_tokens.extend(tokens) + concat_loss_mask.extend(loss_mask) + concat_attention_mask.extend(attention_mask) + i = ed + elif cur_messages["role"] in ["user", "system"]: + # Process user or system message + if cur_messages["role"] == "system" and i != 0: + raise ValueError("System message should be the first message") + tokens, loss_mask, attention_mask = self._process_message_tokens( + messages, i, i + 1, enable_thinking=enable_thinking, tools=tools + ) + concat_tokens.extend(tokens) + concat_loss_mask.extend(loss_mask) + concat_attention_mask.extend(attention_mask) + i += 1 + else: + raise ValueError(f"Unknown role: {cur_messages['role']}") - # If this is an assistant message, set loss mask - if msg["role"] == "assistant": - loss_mask[start_pos:end_pos] = 1 + # Validate and convert tokens + input_ids, loss_mask, attention_mask = self._validate_and_convert_tokens( + full_tokens[0], concat_tokens, concat_loss_mask, concat_attention_mask + ) # Handle sequence length sequence_length = input_ids.shape[0] if sequence_length < self.max_length: # Pad sequences pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else 0 - padded_input_ids = torch.ones(size=(self.max_length - sequence_length,), dtype=input_ids.dtype) * pad_token_id - padded_attention_mask = torch.zeros(size=(self.max_length - sequence_length,), dtype=attention_mask.dtype) - padded_loss_mask = torch.zeros(size=(self.max_length - sequence_length,), dtype=loss_mask.dtype) + padded_input_ids = torch.full((self.max_length - sequence_length,), pad_token_id, dtype=input_ids.dtype) + padded_attention_mask = torch.zeros((self.max_length - sequence_length,), dtype=attention_mask.dtype) + padded_loss_mask = torch.zeros((self.max_length - sequence_length,), dtype=loss_mask.dtype) input_ids = torch.cat((input_ids, padded_input_ids)) attention_mask = torch.cat((attention_mask, padded_attention_mask)) diff --git a/verl/utils/dataset/rl_dataset.py b/verl/utils/dataset/rl_dataset.py index 93d765efe..87036f37e 100644 --- a/verl/utils/dataset/rl_dataset.py +++ b/verl/utils/dataset/rl_dataset.py @@ -12,12 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -# The file is temporarily reverted by Reasoning360 to use `dataframe` rather than `Dataset`, to support heterogeneous keys of multi-domain data +# The file is temporarily reverted by Reasoning360 to use `dataframe` rather than `Dataset`, +# to support heterogeneous keys of multi-domain data import copy import os from collections import defaultdict -from typing import List, Optional, Union +from typing import Optional import numpy as np import pandas as pd @@ -82,7 +83,7 @@ class RLHFDataset(Dataset): def __init__( self, - data_files: Union[str, List[str]], + data_files: str | list[str], tokenizer: PreTrainedTokenizer, processor: Optional[ProcessorMixin] = None, prompt_key="prompt", @@ -132,7 +133,7 @@ def __init__( self.filter_prompts = filter_prompts parquet_files = data_files - if not isinstance(parquet_files, (List, ListConfig)): + if not isinstance(parquet_files, list | ListConfig): parquet_files = [parquet_files] self.parquet_files = copy.deepcopy(parquet_files) self.original_parquet_files = copy.deepcopy(parquet_files) # use for resume @@ -182,7 +183,12 @@ def _read_files_and_tokenize(self): prompt_key = self.prompt_key self.dataframe = self.dataframe[ self.dataframe.apply( - lambda doc: len(tokenizer.apply_chat_template(doc[prompt_key], add_generation_prompt=True) if doc["apply_chat_template"] else tokenizer.encode(doc["raw_prompt"])) <= self.max_prompt_length, + lambda doc: len( + tokenizer.apply_chat_template(doc[prompt_key], add_generation_prompt=True) + if doc["apply_chat_template"] + else tokenizer.encode(doc["raw_prompt"]) + ) + <= self.max_prompt_length, axis=1, ) ] @@ -209,7 +215,11 @@ def __getitem__(self, item): chat = row_dict.pop(self.prompt_key) - prompt_with_chat_template = self.tokenizer.apply_chat_template(chat, add_generation_prompt=True, tokenize=False) if row_dict["apply_chat_template"] else row_dict["raw_prompt"] + prompt_with_chat_template = ( + self.tokenizer.apply_chat_template(chat, add_generation_prompt=True, tokenize=False) + if row_dict["apply_chat_template"] + else row_dict["raw_prompt"] + ) is_multi_modal = self.image_key in row_dict if is_multi_modal: # expand image token @@ -225,12 +235,16 @@ def __getitem__(self, item): while "" in prompt_with_chat_template: prompt_with_chat_template = prompt_with_chat_template.replace( "", - "<|vision_start|>" + "<|placeholder|>" * (image_grid_thw[index].prod() // merge_length) + "<|vision_end|>", + "<|vision_start|>" + + "<|placeholder|>" * (image_grid_thw[index].prod() // merge_length) + + "<|vision_end|>", 1, ) index += 1 - prompt_with_chat_template = prompt_with_chat_template.replace("<|placeholder|>", self.processor.image_token) + prompt_with_chat_template = prompt_with_chat_template.replace( + "<|placeholder|>", self.processor.image_token + ) else: raw_prompt = prompt_with_chat_template diff --git a/verl/utils/dataset/rm_dataset.py b/verl/utils/dataset/rm_dataset.py index 01f83af41..7af792343 100644 --- a/verl/utils/dataset/rm_dataset.py +++ b/verl/utils/dataset/rm_dataset.py @@ -13,7 +13,6 @@ # limitations under the License. import os -from typing import List, Union import pandas as pd import torch @@ -39,7 +38,7 @@ def download_files_distributed(download_fn): class RMDataset(Dataset): def __init__( self, - parquet_files: Union[str, List[str]], + parquet_files: str | list[str], tokenizer, prompt_key="prompt", chosen_key="chosen", @@ -48,7 +47,7 @@ def __init__( add_eos=True, cache_dir="~/.cache/verl/rm", ): - if not isinstance(parquet_files, List): + if not isinstance(parquet_files, list): parquet_files = [parquet_files] self.parquet_files = parquet_files @@ -100,8 +99,12 @@ def _pad_to_length(self, input_ids, attention_mask): curr_length = input_ids.shape[-1] if curr_length < self.max_length: - input_ids = torch.cat((input_ids, torch.zeros(size=(self.max_length - curr_length,), dtype=input_ids.dtype)), dim=-1) - attention_mask = torch.cat((attention_mask, torch.zeros(size=(self.max_length - curr_length,), dtype=attention_mask.dtype)), dim=-1) + input_ids = torch.cat( + (input_ids, torch.zeros(size=(self.max_length - curr_length,), dtype=input_ids.dtype)), dim=-1 + ) + attention_mask = torch.cat( + (attention_mask, torch.zeros(size=(self.max_length - curr_length,), dtype=attention_mask.dtype)), dim=-1 + ) elif curr_length > self.max_length: input_ids = input_ids[: self.max_length] attention_mask = attention_mask[: self.max_length] @@ -119,7 +122,9 @@ def __getitem__(self, item): if self.add_eos: chosen_response_ids = torch.cat((chosen_response_ids, torch.tensor([self.tokenizer.eos_token_id])), dim=-1) - rejected_response_ids = torch.cat((rejected_response_ids, torch.tensor([self.tokenizer.eos_token_id])), dim=-1) + rejected_response_ids = torch.cat( + (rejected_response_ids, torch.tensor([self.tokenizer.eos_token_id])), dim=-1 + ) chosen_input_ids = torch.cat((prompt_ids, chosen_response_ids), dim=-1) chosen_attention_mask = torch.ones_like(chosen_input_ids) diff --git a/verl/utils/dataset/sft_dataset.py b/verl/utils/dataset/sft_dataset.py index d01b174f0..fbbe8b304 100644 --- a/verl/utils/dataset/sft_dataset.py +++ b/verl/utils/dataset/sft_dataset.py @@ -18,10 +18,9 @@ Each parquet file contains """ -from typing import List, Union - import pandas as pd import torch +from omegaconf.listconfig import ListConfig from torch.utils.data import Dataset from transformers import PreTrainedTokenizer @@ -38,20 +37,20 @@ class SFTDataset(Dataset): config (OmegaConf): the data config """ - def __init__(self, parquet_files: Union[str, List[str]], tokenizer, config): + def __init__(self, parquet_files: str | ListConfig, tokenizer, config): prompt_key = config.get("prompt_key", "prompt") prompt_dict_keys = config.get("prompt_dict_keys", None) response_key = config.get("response_key", "response") response_dict_keys = config.get("response_dict_keys", None) max_length = config.get("max_length", 1024) truncation = config.get("truncation", "error") - use_shm = config.get('use_shm', False) + use_shm = config.get("use_shm", False) assert truncation in ["error", "left", "right"] self.truncation = truncation self.use_shm = use_shm - if not isinstance(parquet_files, List): + if not isinstance(parquet_files, ListConfig): parquet_files = [parquet_files] self.parquet_files = parquet_files @@ -59,8 +58,8 @@ def __init__(self, parquet_files: Union[str, List[str]], tokenizer, config): tokenizer = hf_tokenizer(tokenizer) self.tokenizer: PreTrainedTokenizer = tokenizer - self.prompt_key = prompt_key if isinstance(prompt_key, (tuple, list)) else [prompt_key] - self.response_key = response_key if isinstance(response_key, (tuple, list)) else [response_key] + self.prompt_key = prompt_key if isinstance(prompt_key, tuple | list) else [prompt_key] + self.response_key = response_key if isinstance(response_key, tuple | list) else [response_key] self.prompt_dict_keys = prompt_dict_keys if prompt_dict_keys else [] self.response_dict_keys = response_dict_keys if response_dict_keys else [] @@ -78,7 +77,7 @@ def series_to_item(ls): import numpy import pandas - while isinstance(ls, (pandas.core.series.Series, numpy.ndarray)) and len(ls) == 1: + while isinstance(ls, pandas.core.series.Series | numpy.ndarray) and len(ls) == 1: ls = ls[0] return ls @@ -152,7 +151,10 @@ def __getitem__(self, item): # padding to max length sequence_length = input_ids.shape[0] if sequence_length < self.max_length: - padded_input_ids = torch.ones(size=(self.max_length - sequence_length,), dtype=input_ids.dtype) * self.tokenizer.pad_token_id + padded_input_ids = ( + torch.ones(size=(self.max_length - sequence_length,), dtype=input_ids.dtype) + * self.tokenizer.pad_token_id + ) padded_attention_mask = torch.zeros(size=(self.max_length - sequence_length,), dtype=attention_mask.dtype) input_ids = torch.cat((input_ids, padded_input_ids)) diff --git a/verl/utils/dataset/vision_utils.py b/verl/utils/dataset/vision_utils.py index 832abde91..75cce7f6a 100644 --- a/verl/utils/dataset/vision_utils.py +++ b/verl/utils/dataset/vision_utils.py @@ -13,14 +13,14 @@ # limitations under the License. from io import BytesIO -from typing import Optional, Union +from typing import Optional import torch from PIL import Image from qwen_vl_utils import fetch_image, fetch_video -def process_image(image: Union[dict, Image.Image]) -> Image.Image: +def process_image(image: dict | Image.Image) -> Image.Image: if isinstance(image, Image.Image): return image.convert("RGB") @@ -90,3 +90,28 @@ def process_video( video["max_frames"] = fps_max_frames return fetch_video(video) + + +def process_multi_modal_inputs_for_minicpmo(input_ids, attention_mask, position_ids, cu_seqlens, multi_modal_inputs): + # Adjust image bounds based on left padding and cumulative sequence lengths + # This is necessary for MiniCPM-o's vision-language alignment + left_padding_length = torch.argmax(attention_mask, dim=1) + image_bounds = [] + for i in range(len(multi_modal_inputs["image_bound"])): + image_bound = ( + multi_modal_inputs["image_bound"][i].to(left_padding_length.device) - left_padding_length[i] + cu_seqlens[i] + ) + image_bounds.append(image_bound) + + # Flatten pixel values list for MiniCPM-o processing + pixel_values = [] + for i in range(len(multi_modal_inputs["pixel_values"])): + pixel_values.extend([p for p in multi_modal_inputs["pixel_values"][i]]) + + multi_modal_inputs["pixel_values"] = [pixel_values] + multi_modal_inputs["image_bound"] = [torch.vstack(image_bounds)] + multi_modal_inputs["tgt_sizes"] = [torch.vstack(multi_modal_inputs["tgt_sizes"])] + multi_modal_inputs["input_ids"] = input_ids + multi_modal_inputs["attention_mask"] = attention_mask + multi_modal_inputs["position_ids"] = position_ids + return {"data": multi_modal_inputs} diff --git a/verl/utils/debug/__init__.py b/verl/utils/debug/__init__.py index 3718037c1..eb67df1b7 100644 --- a/verl/utils/debug/__init__.py +++ b/verl/utils/debug/__init__.py @@ -12,6 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .performance import GPUMemoryLogger, log_gpu_memory_usage, log_print - -__all__ = ["GPUMemoryLogger", "log_gpu_memory_usage"] +# APIs kept for backward compatibility purpose +# For new features please develop in verl/utils/profiler/ +from ..profiler import * # noqa diff --git a/verl/utils/debug/performance.py b/verl/utils/debug/performance.py index 607c852e5..9186e125a 100644 --- a/verl/utils/debug/performance.py +++ b/verl/utils/debug/performance.py @@ -12,89 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import datetime -import inspect -import logging -from typing import Any, Tuple - -import torch.distributed as dist - -from verl.utils.device import get_torch_device -from verl.utils.logger.aggregate_logger import DecoratorLoggerBase - - -def _get_current_mem_info(unit: str = "GB", precision: int = 2) -> Tuple[str]: - """Get current memory usage.""" - assert unit in ["GB", "MB", "KB"] - divisor = 1024**3 if unit == "GB" else 1024**2 if unit == "MB" else 1024 - mem_allocated = get_torch_device().memory_allocated() - mem_reserved = get_torch_device().memory_reserved() - # use get_torch_device().mem_get_info to profile device memory - # since vllm's sleep mode works below pytorch - # see https://github.com/vllm-project/vllm/pull/11743#issuecomment-2754338119 - mem_free, mem_total = get_torch_device().mem_get_info() - mem_used = mem_total - mem_free - mem_allocated = f"{mem_allocated / divisor:.{precision}f}" - mem_reserved = f"{mem_reserved / divisor:.{precision}f}" - mem_used = f"{mem_used / divisor:.{precision}f}" - mem_total = f"{mem_total / divisor:.{precision}f}" - return mem_allocated, mem_reserved, mem_used, mem_total - - -def log_gpu_memory_usage(head: str, logger: logging.Logger = None, level=logging.DEBUG, rank: int = 0): - if (not dist.is_initialized()) or (rank is None) or (dist.get_rank() == rank): - mem_allocated, mem_reserved, mem_used, mem_total = _get_current_mem_info() - message = f"{head}, memory allocated (GB): {mem_allocated}, memory reserved (GB): {mem_reserved}, device memory used/total (GB): {mem_used}/{mem_total}" - - if logger is None: - print(message) - else: - logger.log(msg=message, level=level) - - -class GPUMemoryLogger(DecoratorLoggerBase): - """A decorator class to log GPU memory usage. - - Example: - >>> from verl.utils.debug.performance import GPUMemoryLogger - >>> @GPUMemoryLogger(role="actor") - >>> def update_actor(self, batch): - ... # real actor update logics - ... return - """ - - def __init__(self, role: str, logger: logging.Logger = None, level=logging.DEBUG, log_only_rank_0: bool = True): - if dist.is_initialized() and dist.get_world_size() > 1: - rank = dist.get_rank() - else: - rank = 0 - super().__init__(role, logger, level, rank, log_only_rank_0) - - def __call__(self, decorated_function: callable): - def f(*args, **kwargs): - return self.log(decorated_function, *args, **kwargs) - - return f - - def log(self, func, *args, **kwargs): - name = func.__name__ - mem_allocated, mem_reserved, mem_used, mem_total = _get_current_mem_info() - message = f"Before {name}, memory allocated (GB): {mem_allocated}, memory reserved (GB): {mem_reserved}, device memory used/total (GB): {mem_used}/{mem_total}" - self.logging_function(message) - - output = func(*args, **kwargs) - - mem_allocated, mem_reserved, mem_used, mem_total = _get_current_mem_info() - message = f"After {name}, memory allocated (GB): {mem_allocated}, memory reserved (GB): {mem_reserved}, device memory used/total (GB): {mem_used}/{mem_total}" - - self.logging_function(message) - return output - -def log_print(ctn: Any): - current_time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') - - frame = inspect.currentframe().f_back - function_name = frame.f_code.co_name - line_number = frame.f_lineno - file_name = frame.f_code.co_filename.split('/')[-1] - print(f"[{file_name}:{line_number}:{function_name}]: {ctn}") \ No newline at end of file +# APIs kept for backward compatibility purpose +# This file is deprecated, for new features please develop in profiler/performance.py +from verl.utils.profiler.performance import simple_timer, reduce_timing # noqa diff --git a/verl/utils/debug/profile.py b/verl/utils/debug/profile.py deleted file mode 100644 index e06c24cde..000000000 --- a/verl/utils/debug/profile.py +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os - -import torch -import torch.distributed - - -class Profiler: - def __init__(self, config): - # note : if we do not set use_profile, it will be set as None, so that all function will be skip - self.config = config - self.skip_prof = False - self.saved = False - self.prof = None - self.rank = torch.distributed.get_rank() - # we need to validate the config before using the profiler - self._validate() - if config.use_profile and self.rank in self.config.profile_ranks: - print(f"[Profiler] Profiler init for rank {self.rank}") - - self.prof = torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA, - ], - schedule=torch.profiler.schedule( - wait=max(self.config.step_start - 1, 0), - warmup=1 if self.config.step_start > 0 else 0, - active=self.config.step_end - self.config.step_start, - repeat=1, - ), - record_shapes=True, - with_stack=True, - ) - - def _validate(self): - if self.config.use_profile: - if self.config.profile_ranks is None: - print("[WARNING] Profile ranks is not set, default to rank 0") - self.config.profile_ranks = [0] - assert self.config.step_start >= 0, "[ERROR] Profile step start must be greater than 0" - assert self.config.step_end >= 0, "[ERROR] Profile step end must be greater than 0" - assert self.config.step_start < self.config.step_end, "[ERROR] Profile step start must be less than step end" - - def check(self): - return self.prof is not None and not self.skip_prof - - def start(self): - if self.check(): - print(f"[Profiler] started for rank {self.rank}") - self.prof.start() - - def step(self): - if self.check(): - self.prof.step() - - def stop(self): - if self.check(): - print(f"[Profiler] stopped for rank {self.rank}") - self.prof.stop() - - def save(self): - if self.prof is not None and not self.saved: - if not os.path.exists(self.config.save_path): - os.makedirs(self.config.save_path) - save_file_name = f"/prof_start_{self.config.step_start}_end_{self.config.step_end}_rank_{self.rank}.json" - print(f"[Profiler] Saving trace to {self.config.save_path + save_file_name}") - self.prof.export_chrome_trace(self.config.save_path + save_file_name) - self.skip_prof = True - self.saved = True - - def stop_and_save(self): - if self.check(): - self.stop() - self.save() - - def stop_trace(self): - if self.check(): - print(f"[Profiler] Trace stopped for rank {self.rank}") - self.skip_prof = True diff --git a/verl/utils/debug/trajectory_tracker.py b/verl/utils/debug/trajectory_tracker.py index fe6b44fe1..73afb8540 100644 --- a/verl/utils/debug/trajectory_tracker.py +++ b/verl/utils/debug/trajectory_tracker.py @@ -80,7 +80,9 @@ def get_trajectory_tracker(): hdfs_dir = os.getenv("VERL_TRACKER_HDFS_DIR", default=None) verbose = os.getenv("VERL_TRACKER_VERBOSE", default="0") == "1" assert hdfs_dir is not None - tracker = TrajectoryTracker.options(name="global_tracker", get_if_exists=True, lifetime="detached").remote(hdfs_dir, verbose) + tracker = TrajectoryTracker.options(name="global_tracker", get_if_exists=True, lifetime="detached").remote( + hdfs_dir, verbose + ) return tracker diff --git a/verl/utils/device.py b/verl/utils/device.py index ee9e279d2..ed85b0d5b 100644 --- a/verl/utils/device.py +++ b/verl/utils/device.py @@ -29,6 +29,14 @@ def is_torch_npu_available() -> bool: is_npu_available = is_torch_npu_available() +def get_visible_devices_keyword() -> str: + """Function that gets visible devices keyword name. + Returns: + 'CUDA_VISIBLE_DEVICES' or `ASCEND_RT_VISIBLE_DEVICES` + """ + return "CUDA_VISIBLE_DEVICES" if is_cuda_available else "ASCEND_RT_VISIBLE_DEVICES" + + def get_device_name() -> str: """Function that gets the torch.device based on the current machine. This currently only supports CPU, CUDA, NPU. @@ -55,3 +63,24 @@ def get_torch_device() -> any: except AttributeError: logger.warning(f"Device namespace '{device_name}' not found in torch, try to load torch.cuda.") return torch.cuda + + +def get_device_id() -> int: + """Return current device id based on the device type. + Returns: + device index + """ + return get_torch_device().current_device() + + +def get_nccl_backend() -> str: + """Return nccl backend type based on the device type. + Returns: + nccl backend type string. + """ + if is_cuda_available: + return "nccl" + elif is_npu_available: + return "hccl" + else: + raise RuntimeError(f"No available nccl backend found on device type {get_device_name()}.") diff --git a/verl/utils/distributed.py b/verl/utils/distributed.py index 1d10638ea..610b5d4c9 100644 --- a/verl/utils/distributed.py +++ b/verl/utils/distributed.py @@ -17,13 +17,17 @@ import torch.distributed -from verl.utils.device import get_torch_device, is_cuda_available +from verl.utils.device import get_nccl_backend, get_torch_device def initialize_global_process_group(timeout_second=36000): from datetime import timedelta - torch.distributed.init_process_group("nccl" if is_cuda_available else "hccl", timeout=timedelta(seconds=timeout_second)) + torch.distributed.init_process_group( + get_nccl_backend(), + timeout=timedelta(seconds=timeout_second), + init_method=os.environ.get("DIST_INIT_METHOD", None), + ) local_rank = int(os.environ["LOCAL_RANK"]) rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) diff --git a/verl/utils/experimental/torch_functional.py b/verl/utils/experimental/torch_functional.py index 9c225a60b..0b4ce5c61 100644 --- a/verl/utils/experimental/torch_functional.py +++ b/verl/utils/experimental/torch_functional.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple +from typing import Optional import torch @@ -21,8 +21,8 @@ def _fused_linear_for_ppo_fwd( hidden_states: torch.FloatTensor, vocab_weights: torch.FloatTensor, input_ids: torch.LongTensor, - temperature: float = 1.0 -) -> Tuple[torch.FloatTensor, torch.FloatTensor]: + temperature: float = 1.0, +) -> tuple[torch.FloatTensor, torch.FloatTensor]: logits = (hidden_states @ vocab_weights.t()) / temperature orig_dtype = logits.dtype logits = logits.to(torch.float32) @@ -44,7 +44,7 @@ def _fused_linear_for_ppo_bwd( vocab_weights: torch.FloatTensor, input_ids: torch.LongTensor, temperature: float = 1.0, -) -> Tuple[torch.FloatTensor, torch.FloatTensor]: +) -> tuple[torch.FloatTensor, torch.FloatTensor]: logits = (hidden_states @ vocab_weights.t()) / temperature orig_dtype = logits.dtype logits = logits.to(torch.float32) @@ -67,13 +67,12 @@ def _fused_linear_for_ppo_bwd( dlogits = dlogits.to(orig_dtype) / temperature dhidden_states = dlogits @ vocab_weights - dvocab_weights = (dlogits.t() @ hidden_states) + dvocab_weights = dlogits.t() @ hidden_states return dhidden_states, dvocab_weights class FusedLinearForPPOFunction(torch.autograd.Function): - @staticmethod def forward( ctx, @@ -82,7 +81,7 @@ def forward( input_ids: torch.LongTensor, temperature: float = 1.0, chunk_size: int = 512, - ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: + ) -> tuple[torch.FloatTensor, torch.FloatTensor]: ctx.set_materialize_grads(False) # Cast to a 2D tensor of the shape [T, D] for ease of working @@ -195,7 +194,6 @@ def backward(ctx, dlog_probs: Optional[torch.FloatTensor], dentropy: Optional[to class FusedLinearForPPO(torch.nn.Module): - def __init__(self, chunk_size: int = 512): super().__init__() @@ -207,7 +205,7 @@ def forward( vocab_weights: torch.FloatTensor, input_ids: torch.LongTensor, temperature: float = 1.0, - ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: + ) -> tuple[torch.FloatTensor, torch.FloatTensor]: input_ids = input_ids.to(torch.int64) return FusedLinearForPPOFunction.apply( hidden_states, diff --git a/verl/utils/flops_counter.py b/verl/utils/flops_counter.py index 59d6b108d..1bed92902 100644 --- a/verl/utils/flops_counter.py +++ b/verl/utils/flops_counter.py @@ -16,7 +16,17 @@ from verl.utils.device import get_torch_device -VALID_CONFIG_TYPE = {"llama", "qwen2", "qwen2_vl", "qwen2_5_vl", "qwen3", "qwen3_moe", "deepseek_v3"} +VALID_CONFIG_TYPE = { + "llama", + "qwen2", + "qwen2_vl", + "qwen2_5_vl", + "qwen3", + "qwen3_moe", + "deepseek_v3", + "minicpmv", + "minicpmo", +} def get_device_flops(unit="T"): @@ -35,7 +45,7 @@ def unit_convert(number, level): if "MI300X" in device_name: flops = 1336e12 - elif "H100" in device_name or "H800" in device_name: + elif "H100" in device_name or "H800" in device_name or "H200" in device_name: flops = 989e12 elif "A100" in device_name or "A800" in device_name: flops = 312e12 @@ -47,6 +57,8 @@ def unit_convert(number, level): flops = 148e12 elif "910B" in device_name: flops = 354e12 + elif "RTX 3070 Ti" in device_name: + flops = 21.75e12 flops_unit = unit_convert(flops, unit) return flops_unit @@ -63,16 +75,22 @@ class FlopsCounter: def __init__(self, config: PretrainedConfig): if config.model_type not in VALID_CONFIG_TYPE: - print(f"Only support config type of {VALID_CONFIG_TYPE}, but got {config.model_type}. MFU will always be zero.") + print( + f"Only support config type of {VALID_CONFIG_TYPE}, but got {config.model_type}. MFU will always be " + f"zero." + ) self.estimate_func = { "qwen2": self._estimate_qwen2_flops, "llama": self._estimate_qwen2_flops, + "qwen2_moe": self._estimate_qwen2_moe_flops, "qwen2_vl": self._estimate_qwen2_flops, "qwen2_5_vl": self._estimate_qwen2_flops, "qwen3": self._estimate_qwen2_flops, - "qwen3_moe": self._estimate_qwen3_moe_flops, + "qwen3_moe": self._estimate_qwen2_moe_flops, "deepseek_v3": self._estimate_deepseek_v3_flops, + "minicpmv": self._estimate_qwen2_flops, + "minicpmo": self._estimate_qwen2_flops, } self.config = config @@ -139,11 +157,19 @@ def _estimate_deepseek_v3_flops(self, tokens_sum, batch_seqlens, delta_time): attn_linear_N += num_query_heads * q_head_dim * self.config.q_lora_rank attn_linear_N += hidden_size * (self.config.kv_lora_rank + self.config.qk_rope_head_dim) - attn_linear_N += num_query_heads * (q_head_dim - self.config.qk_rope_head_dim + self.config.v_head_dim) * self.config.kv_lora_rank + attn_linear_N += ( + num_query_heads + * (q_head_dim - self.config.qk_rope_head_dim + self.config.v_head_dim) + * self.config.kv_lora_rank + ) attn_linear_N += num_query_heads * self.config.v_head_dim * hidden_size emd_and_lm_head_N = vocab_size * hidden_size * 2 # non-attn all_layer parm - moe_N = (moe_gata_N + moe_expertmlp_N + attn_linear_N) * (num_hidden_layers - first_k_dense_replace) + (hidden_size * self.config.intermediate_size * 3 + attn_linear_N) * first_k_dense_replace + emd_and_lm_head_N + moe_N = ( + (moe_gata_N + moe_expertmlp_N + attn_linear_N) * (num_hidden_layers - first_k_dense_replace) + + (hidden_size * self.config.intermediate_size * 3 + attn_linear_N) * first_k_dense_replace + + emd_and_lm_head_N + ) # non-attn all_layer & all_token fwd & bwd flops dense_N_flops = 6 * moe_N * tokens_sum @@ -159,13 +185,13 @@ def _estimate_deepseek_v3_flops(self, tokens_sum, batch_seqlens, delta_time): return flops_achieved - def _estimate_qwen3_moe_flops(self, tokens_sum, batch_seqlens, delta_time): + def _estimate_qwen2_moe_flops(self, tokens_sum, batch_seqlens, delta_time): hidden_size = self.config.hidden_size vocab_size = self.config.vocab_size num_hidden_layers = self.config.num_hidden_layers num_key_value_heads = self.config.num_key_value_heads num_attention_heads = self.config.num_attention_heads - moe_intermediate_size = self.config.moe_intermediate_size + moe_intermediate_size = self.config.moe_intermediate_size moe_topk = self.config.num_experts_per_tok num_experts = self.config.num_experts @@ -195,13 +221,13 @@ def _estimate_qwen3_moe_flops(self, tokens_sum, batch_seqlens, delta_time): flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12 return flops_achieved - def estimate_flops(self, batch_seqlens, delta_time): """ Estimate the FLOPS based on the number of valid tokens in the current batch and the time taken. Args: - batch_seqlens (List[int]): A list where each element represents the number of valid tokens in the current batch. + batch_seqlens (List[int]): A list where each element represents the number of valid tokens in the + current batch. delta_time (float): The time taken to process the batch, in seconds. Returns: diff --git a/verl/utils/fs.py b/verl/utils/fs.py index 5a7b80651..7cc11300f 100644 --- a/verl/utils/fs.py +++ b/verl/utils/fs.py @@ -78,6 +78,7 @@ def get_local_temp_path(hdfs_path: str, cache_dir: str) -> str: dst = os.path.join(temp_dir, os.path.basename(hdfs_path)) return dst + def verify_copy(src: str, dest: str) -> bool: """ verify the copy of src to dest by comparing their sizes and file structures. @@ -105,7 +106,7 @@ def verify_copy(src: str, dest: str) -> bool: for root, dirs, files in os.walk(src): rel_path = os.path.relpath(root, src) - dest_root = os.path.join(dest, rel_path) if rel_path != '.' else dest + dest_root = os.path.join(dest, rel_path) if rel_path != "." else dest if not os.path.exists(dest_root): return False @@ -137,18 +138,21 @@ def verify_copy(src: str, dest: str) -> bool: return True -def copy_to_shm(src:str): +def copy_to_shm(src: str): """ - Load the model into /dev/shm to make the process of loading the model multiple times more efficient. + Load the model into /dev/shm to make the process of loading the model multiple times more efficient. """ - shm_model_root = '/dev/shm/verl-cache/' + shm_model_root = "/dev/shm/verl-cache/" src_abs = os.path.abspath(os.path.normpath(src)) - dest = os.path.join(shm_model_root, hashlib.md5(src_abs.encode('utf-8')).hexdigest()) + dest = os.path.join(shm_model_root, hashlib.md5(src_abs.encode("utf-8")).hexdigest()) os.makedirs(dest, exist_ok=True) dest = os.path.join(dest, os.path.basename(src_abs)) if os.path.exists(dest) and verify_copy(src, dest): # inform user and depends on him - print(f"[WARNING]: The memory model path {dest} already exists. If it is not you want, please clear it and restart the task.") + print( + f"[WARNING]: The memory model path {dest} already exists. If it is not you want, please clear it and " + f"restart the task." + ) else: if os.path.isdir(src): shutil.copytree(src, dest, symlinks=False, dirs_exist_ok=True) @@ -156,6 +160,7 @@ def copy_to_shm(src:str): shutil.copy2(src, dest) return dest + def _record_directory_structure(folder_path): record_file = os.path.join(folder_path, ".directory_record.txt") with open(record_file, "w") as f: @@ -187,7 +192,9 @@ def _check_directory_structure(folder_path, record_file): return existing_entries == recorded_entries -def copy_to_local(src: str, cache_dir=None, filelock=".file.lock", verbose=False, always_recopy=False, use_shm:bool=False) -> str: +def copy_to_local( + src: str, cache_dir=None, filelock=".file.lock", verbose=False, always_recopy=False, use_shm: bool = False +) -> str: """Copy files/directories from HDFS to local cache with validation. Args: @@ -208,7 +215,10 @@ def copy_to_local(src: str, cache_dir=None, filelock=".file.lock", verbose=False return copy_to_shm(local_path) return local_path -def copy_local_path_from_hdfs(src: str, cache_dir=None, filelock=".file.lock", verbose=False, always_recopy=False) -> str: + +def copy_local_path_from_hdfs( + src: str, cache_dir=None, filelock=".file.lock", verbose=False, always_recopy=False +) -> str: """Deprecated. Please use copy_to_local instead.""" from filelock import FileLock @@ -249,3 +259,34 @@ def copy_local_path_from_hdfs(src: str, cache_dir=None, filelock=".file.lock", v return local_path else: return src + + +def local_mkdir_safe(path): + """_summary_ + Thread-safe directory creation function that ensures the directory is created + even if multiple processes attempt to create it simultaneously. + + Args: + path (str): The path to create a directory at. + """ + + from filelock import FileLock + + if not os.path.isabs(path): + working_dir = os.getcwd() + path = os.path.join(working_dir, path) + + # Using hash value of path as lock file name to avoid long file name + lock_filename = f"ckpt_{hash(path) & 0xFFFFFFFF:08x}.lock" + lock_path = os.path.join(tempfile.gettempdir(), lock_filename) + + try: + with FileLock(lock_path, timeout=60): # Add timeout + # make a new dir + os.makedirs(path, exist_ok=True) + except Exception as e: + print(f"Warning: Failed to acquire lock for {path}: {e}") + # Even if the lock is not acquired, try to create the directory + os.makedirs(path, exist_ok=True) + + return path diff --git a/verl/utils/fsdp_utils.py b/verl/utils/fsdp_utils.py index e5210d1c2..7465b400e 100644 --- a/verl/utils/fsdp_utils.py +++ b/verl/utils/fsdp_utils.py @@ -19,7 +19,6 @@ import os from collections import OrderedDict from contextlib import contextmanager, nullcontext -from typing import Dict import torch import torch.distributed as dist @@ -31,7 +30,7 @@ from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy from transformers.trainer_pt_utils import get_module_class_from_name -from verl.utils.device import get_device_name, get_torch_device +from verl.utils.device import get_device_id, get_device_name, get_torch_device if version.parse(torch.__version__) >= version.parse("2.6"): from torch.distributed.fsdp import CPUOffloadPolicy, FSDPModule, MixedPrecisionPolicy, fully_shard @@ -43,7 +42,7 @@ def init_fn(x: torch.nn.Module): if torch.distributed.get_rank() != 0: - x = x.to_empty(device=get_torch_device().current_device(), recurse=False) + x = x.to_empty(device=get_device_id(), recurse=False) get_torch_device().empty_cache() return x @@ -75,7 +74,8 @@ def get_fsdp_wrap_policy(module, config=None, is_lora=False): if config is None: config = {} - # NOTE: This is a temporary workaround to be compatible with the OmegaConf & dataclass. We will remove this once we have make all config in verl from OmegaConf to data class. + # NOTE: This is a temporary workaround to be compatible with the OmegaConf & dataclass. We will remove this + # once we have make all config in verl from OmegaConf to data class. def _get_attr(attr_name, default_value=None): if hasattr(config, "get"): return config.get(attr_name, default_value) @@ -86,7 +86,9 @@ def _get_attr(attr_name, default_value=None): return None default_transformer_cls_names_to_wrap = getattr(module, "_no_split_modules", None) - fsdp_transformer_layer_cls_to_wrap = _get_attr("transformer_layer_cls_to_wrap", default_transformer_cls_names_to_wrap) + fsdp_transformer_layer_cls_to_wrap = _get_attr( + "transformer_layer_cls_to_wrap", default_transformer_cls_names_to_wrap + ) min_num_params = _get_attr("min_num_params", 0) auto_wrap_policy = None @@ -98,7 +100,11 @@ def _get_attr(attr_name, default_value=None): if is_lora: def lambda_policy_fn(module): - return bool(len(list(module.named_children())) == 0 and getattr(module, "weight", None) is not None and module.weight.requires_grad) + return bool( + len(list(module.named_children())) == 0 + and getattr(module, "weight", None) is not None + and module.weight.requires_grad + ) lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn) policies.append(lambda_policy) @@ -141,7 +147,11 @@ def offload_fsdp_model_to_cpu(model: FSDP, empty_cache: bool = True): if handle._offload_params: continue flat_param = handle.flat_param - assert flat_param.data.data_ptr() == flat_param._local_shard.data_ptr() and id(flat_param.data) != id(flat_param._local_shard) and flat_param.data.size() == flat_param._local_shard.size() + assert ( + flat_param.data.data_ptr() == flat_param._local_shard.data_ptr() + and id(flat_param.data) != id(flat_param._local_shard) + and flat_param.data.size() == flat_param._local_shard.size() + ) handle.flat_param_to(torch.device("cpu"), non_blocking=True) # the following still keeps id(._local_shard) != id(.data) flat_param._local_shard = flat_param.data @@ -168,7 +178,7 @@ def load_fsdp_model_to_gpu(model: FSDP): # lazy init FSDP model _lazy_init(model, model) assert model._is_root, "Only support root model loading to GPU" - device_id = get_torch_device().current_device() + device_id = get_device_id() for handle in model._all_handles: if handle._offload_params: continue @@ -180,7 +190,7 @@ def load_fsdp_model_to_gpu(model: FSDP): @torch.no_grad() def load_fsdp2_model_to_gpu(model): - device = torch.cuda.current_device() + device = get_device_id() for param in model.parameters(): param.data = param.data.to(device, non_blocking=True) @@ -282,7 +292,7 @@ def parallel_load_safetensors(filepath): ckpt_chunks = [ckpt_chunks[rank * size : rank * size + size] for rank in range(world_size)] shard_states = {} - device = get_torch_device().current_device() + device = get_device_id() for rank, files in enumerate(ckpt_chunks): if rank == dist.get_rank(): for file in files: @@ -297,7 +307,7 @@ def parallel_load_safetensors(filepath): return shard_states -def parallel_init_module_fn(module: torch.nn.Module, shard_states: Dict[str, torch.nn.Parameter]): +def parallel_init_module_fn(module: torch.nn.Module, shard_states: dict[str, torch.nn.Parameter]): """ Generate a function to initialize sub-modules in the `module` with `shard_states` from huggingface checkpoint. @@ -311,7 +321,9 @@ def parallel_init_module_fn(module: torch.nn.Module, shard_states: Dict[str, tor """ state2fqn = {} - for name, state in itertools.chain(module.named_parameters(remove_duplicate=False), module.named_buffers(remove_duplicate=False)): + for name, state in itertools.chain( + module.named_parameters(remove_duplicate=False), module.named_buffers(remove_duplicate=False) + ): state2fqn.setdefault(state, []).append(name) # remove standalone parameters and buffers shared = {s for s, names in state2fqn.items() if len(names) > 1} @@ -320,13 +332,13 @@ def parallel_init_module_fn(module: torch.nn.Module, shard_states: Dict[str, tor @torch.no_grad() def create_and_sync_state(param_name, state, is_param): assert param_name in shard_states, f"{param_name} not loaded" - device = get_torch_device().current_device() + device = get_device_id() if is_param: param = torch.nn.Parameter(torch.empty_like(state.data, device=device), requires_grad=state.requires_grad) else: # buffer param = torch.empty_like(state.data, device=device) loaded = shard_states[param_name] - if isinstance(loaded, (torch.nn.Parameter, torch.Tensor)): + if isinstance(loaded, torch.nn.Parameter | torch.Tensor): # NOTE: loaded.dtype can be different with param.dtype param.data.copy_(loaded.data) dist.broadcast(param.data, src=dist.get_rank()) @@ -348,7 +360,10 @@ def init_fn(sub_mod: torch.nn.Module, recurse: bool = True): # non-persistent buffers will not be saved in state dict, we can safely skip it if (not is_param) and fqn not in shard_states: if state.is_meta: - raise RuntimeError(f"find a non-persistent buffer ({fqn}) initiated with device meta. Such buffer is not saved in checkpoint and user should guarantee to init in CPU / GPU device.") + raise RuntimeError( + f"find a non-persistent buffer ({fqn}) initiated with device meta. Such buffer is not saved " + f"in checkpoint and user should guarantee to init in CPU / GPU device." + ) continue # for shared parameter, we get it from the first time it is created if state in shared: @@ -392,6 +407,42 @@ def get_fsdp_state_ctx(model, state_type, state_cfg, optim_cfg): return nullcontext() +def get_fsdp_full_state_dict(model: torch.nn.Module, offload_to_cpu: bool = True, rank0_only: bool = True): + """ + Get the full state dict from an FSDP model. + + Args: + model (torch.nn.Module): The FSDP model to get state dict from + offload_to_cpu (bool, optional): Whether to offload the state dict to CPU. Defaults to True. + rank0_only (bool, optional): Whether to only get state dict on rank 0. Defaults to True. + + Returns: + dict: The full state dict of the model + + Raises: + NotImplementedError: If the FSDP version is unknown + """ + if fsdp_version(model) == 1: + from torch.distributed.fsdp import FullStateDictConfig, StateDictType + + state_dict_config = FullStateDictConfig(offload_to_cpu=offload_to_cpu, rank0_only=rank0_only) + with get_fsdp_state_ctx( + model, state_type=StateDictType.FULL_STATE_DICT, state_cfg=state_dict_config, optim_cfg=None + ): + state_dict = model.state_dict() + return state_dict + elif fsdp_version(model) == 2: + from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict + + state_dict_config = StateDictOptions( + full_state_dict=True, cpu_offload=offload_to_cpu, broadcast_from_rank0=not rank0_only + ) + state_dict = get_model_state_dict(model, options=state_dict_config) + return state_dict + else: + raise NotImplementedError(f"Unknown FSDP version {fsdp_version}") + + def fsdp2_load_full_state_dict(model: torch.nn.Module, full_state: dict, device_mesh=None, cpu_offload=None): """ Loads the full state dict (could be only on rank 0) into the sharded model. This is done by broadcasting the @@ -405,9 +456,9 @@ def fsdp2_load_full_state_dict(model: torch.nn.Module, full_state: dict, device_ # To broadcast, it needs to be instantiated in the GPU. if dist.get_rank() == 0: - model = model.to(device=torch.cuda.current_device(), non_blocking=True) + model = model.to(device=get_device_id(), non_blocking=True) else: - model = model.to_empty(device=torch.cuda.current_device()) + model = model.to_empty(device=get_device_id()) cpu_offload = cpu_offload is not None options = StateDictOptions(full_state_dict=True, cpu_offload=cpu_offload, broadcast_from_rank0=True) @@ -420,7 +471,7 @@ def fsdp2_load_full_state_dict(model: torch.nn.Module, full_state: dict, device_ if cpu_offload: model.to("cpu", non_blocking=True) for buf in model.buffers(): - buf.data = buf.data.to(torch.cuda.current_device()) + buf.data = buf.data.to(get_device_id()) def apply_fsdp2(model, fsdp_kwargs, config): @@ -428,7 +479,9 @@ def apply_fsdp2(model, fsdp_kwargs, config): assert CPUOffloadPolicy is not None, "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)" default_transformer_cls_names_to_wrap = getattr(model, "_no_split_modules", None) - fsdp_transformer_layer_cls_to_wrap = config.get("wrap_policy", {}).get("transformer_layer_cls_to_wrap", default_transformer_cls_names_to_wrap) + fsdp_transformer_layer_cls_to_wrap = config.get("wrap_policy", {}).get( + "transformer_layer_cls_to_wrap", default_transformer_cls_names_to_wrap + ) if isinstance(fsdp_transformer_layer_cls_to_wrap, str): fsdp_transformer_layer_cls_to_wrap = [fsdp_transformer_layer_cls_to_wrap] @@ -437,7 +490,9 @@ def apply_fsdp2(model, fsdp_kwargs, config): modules = [] for name, module in model.named_modules(): - if module.__class__.__name__ in fsdp_transformer_layer_cls_to_wrap or (isinstance(module, nn.Embedding) and not model.config.tie_word_embeddings): + if module.__class__.__name__ in fsdp_transformer_layer_cls_to_wrap or ( + isinstance(module, nn.Embedding) and not model.config.tie_word_embeddings + ): modules.append(module) for idx, module in enumerate(modules): @@ -456,7 +511,7 @@ def fsdp2_clip_grad_norm_(parameters, max_norm, norm_type=2.0, error_if_nonfinit parameters = list(parameters) grads = [p.grad for p in parameters if p.grad is not None] total_norm = _get_total_norm(grads, norm_type, error_if_nonfinite, foreach) - total_norm = total_norm.to(torch.cuda.current_device(), non_blocking=True) + total_norm = total_norm.to(get_device_id(), non_blocking=True) _clip_grads_with_norm_(parameters, max_norm, total_norm, foreach) return total_norm @@ -489,8 +544,13 @@ def __prefix_submodules(module, prefix): if fsdp_version(submodule) > 0: with FSDP.summon_full_params(submodule, writeback=False): sub_lora_params = get_peft_model_state_dict(peft_model, state_dict=submodule.state_dict()) - sub_lora_params = {f"{prefix}.{name}": param.full_tensor().detach().cpu() if hasattr(param, "full_tensor") else param.detach().cpu() for name, param in sub_lora_params.items()} + sub_lora_params = { + f"{prefix}.{name}": param.full_tensor().detach().cpu() + if hasattr(param, "full_tensor") + else param.detach().cpu() + for name, param in sub_lora_params.items() + } lora_params.update(sub_lora_params) submodule._is_root = False - torch.cuda.empty_cache() + get_torch_device().empty_cache() return lora_params diff --git a/verl/utils/import_utils.py b/verl/utils/import_utils.py index 6d62fd86b..fc75541e6 100644 --- a/verl/utils/import_utils.py +++ b/verl/utils/import_utils.py @@ -16,9 +16,12 @@ We assume package availability won't change during runtime. """ +import importlib import importlib.util -from functools import cache -from typing import List, Optional +import os +import warnings +from functools import cache, wraps +from typing import Optional @cache @@ -48,10 +51,28 @@ def is_sglang_available(): return sglang_spec is not None +@cache +def is_nvtx_available(): + try: + nvtx_spec = importlib.util.find_spec("nvtx") + except ModuleNotFoundError: + nvtx_spec = None + return nvtx_spec is not None + + +@cache +def is_trl_available(): + try: + trl_spec = importlib.util.find_spec("trl") + except ModuleNotFoundError: + trl_spec = None + return trl_spec is not None + + def import_external_libs(external_libs=None): if external_libs is None: return - if not isinstance(external_libs, List): + if not isinstance(external_libs, list): external_libs = [external_libs] import importlib @@ -59,23 +80,33 @@ def import_external_libs(external_libs=None): importlib.import_module(external_lib) -def load_extern_type(file_path: Optional[str], type_name: Optional[str]): +def load_extern_type(file_path: Optional[str], type_name: Optional[str]) -> type: """Load a external data type based on the file path and type name""" - import importlib.util - import os - if not file_path: return None - if not os.path.exists(file_path): - raise FileNotFoundError(f"Custom type file '{file_path}' not found.") - - spec = importlib.util.spec_from_file_location("custom_module", file_path) - module = importlib.util.module_from_spec(spec) - try: - spec.loader.exec_module(module) - except Exception as e: - raise RuntimeError(f"Error loading module from '{file_path}'") from e + if file_path.startswith("pkg://"): + # pkg://verl.utils.dataset.rl_dataset + # pkg://verl/utils/dataset/rl_dataset + module_name = file_path[6:].replace("/", ".") + module = importlib.import_module(module_name) + + else: + # file://verl/utils/dataset/rl_dataset + # file:///path/to/verl/utils/dataset/rl_dataset.py + # or without file:// prefix + if file_path.startswith("file://"): + file_path = file_path[7:] + + if not os.path.exists(file_path): + raise FileNotFoundError(f"Custom type file '{file_path}' not found.") + + spec = importlib.util.spec_from_file_location("custom_module", file_path) + module = importlib.util.module_from_spec(spec) + try: + spec.loader.exec_module(module) + except Exception as e: + raise RuntimeError(f"Error loading module from '{file_path}'") from e if not hasattr(module, type_name): raise AttributeError(f"Custom type '{type_name}' not found in '{file_path}'.") @@ -89,19 +120,37 @@ def _get_qualified_name(func): qualname = func.__qualname__ return f"{module}.{qualname}" + def deprecated(replacement: str = ""): - """Decorator to mark APIs as deprecated.""" - import functools - import warnings - - def decorator(func): - qualified_name = _get_qualified_name(func) - @functools.wraps(func) - def wrapped(*args, **kwargs): - msg = f"Warning: API '{qualified_name}' is deprecated." - if replacement: - msg += f" Please use '{replacement}' instead." - warnings.warn(msg, category=DeprecationWarning, stacklevel=2) - return func(*args, **kwargs) - return wrapped - return decorator \ No newline at end of file + """Decorator to mark functions or classes as deprecated.""" + + def decorator(obj): + qualified_name = _get_qualified_name(obj) + + if isinstance(obj, type): + original_init = obj.__init__ + + @wraps(original_init) + def wrapped_init(self, *args, **kwargs): + msg = f"Warning: Class '{qualified_name}' is deprecated." + if replacement: + msg += f" Please use '{replacement}' instead." + warnings.warn(msg, category=FutureWarning, stacklevel=2) + return original_init(self, *args, **kwargs) + + obj.__init__ = wrapped_init + return obj + + else: + + @wraps(obj) + def wrapped(*args, **kwargs): + msg = f"Warning: Function '{qualified_name}' is deprecated." + if replacement: + msg += f" Please use '{replacement}' instead." + warnings.warn(msg, category=FutureWarning, stacklevel=2) + return obj(*args, **kwargs) + + return wrapped + + return decorator diff --git a/verl/utils/kernel/__init__.py b/verl/utils/kernel/__init__.py new file mode 100644 index 000000000..e32d583d3 --- /dev/null +++ b/verl/utils/kernel/__init__.py @@ -0,0 +1,31 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/verl/utils/kernel/kernels.py b/verl/utils/kernel/kernels.py new file mode 100644 index 000000000..a125bacda --- /dev/null +++ b/verl/utils/kernel/kernels.py @@ -0,0 +1,1553 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Implementations of the linear cross entropy with token entropy kernel. +""" + +import typing +from dataclasses import dataclass + +import torch +import torch.distributed as dist +import triton +import triton.language as tl + +from verl.utils.device import get_torch_device + + +@dataclass +class EntropyReductionEnum: + """ + Enum for the reduction method of cross entropy. + """ + + _None = 0 + _Sum = 1 + _Mean = 2 + + +def get_entropy_reduction_enum_number(reduction: str) -> int: + """ + Get the enum number for the reduction method of cross entropy. + """ + _enum = EntropyReductionEnum._None + if reduction == "none": + _enum = EntropyReductionEnum._None + elif reduction == "sum": + _enum = EntropyReductionEnum._Sum + elif reduction == "mean": + _enum = EntropyReductionEnum._Mean + else: + raise ValueError(f"Invalid reduction: {reduction}") + return _enum + + +def get_entropy_reduction_enum(ce_reduction: int) -> EntropyReductionEnum: + """ + Get the enum for the reduction method of cross entropy. + """ + _enum = EntropyReductionEnum._None + if ce_reduction == 0: + _enum = EntropyReductionEnum._None + elif ce_reduction == 1: + _enum = EntropyReductionEnum._Sum + elif ce_reduction == 2: + _enum = EntropyReductionEnum._Mean + else: + raise ValueError(f"Invalid ce_reduction: {ce_reduction}") + return _enum + + +@dataclass +class BackwardEnum: + """ + Enum for the backward method. + """ + + _Total_Fuse_MN = ( + 0 # Fuse d_logits & d_hidden & d_weight, no intermediate storage, requires fp32 for d_hidden & d_weight + ) + _Total_Separate = 1 # Store d_logits, no special requirements for d_hidden & d_weight + _Split_Dlogits_N = 2 # split d_logits along its N dimension, aka. vocab_size + _Split_Dlogits_M = 3 # split d_logits along its M dimension, aka. num_tokens + + +@dataclass +class Config: + _backward: BackwardEnum = BackwardEnum._Split_Dlogits_N + _use_triton: bool = True + + +_config = Config() + + +def set_backward_method(backward_method: BackwardEnum): + """ + Set the backward method. + """ + global _config + _config._backward = backward_method + + +@triton.autotune( + configs=[triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32}, num_stages=3, num_warps=8)], + key=["num_tokens", "hidden_size", "vocab_size"], +) +@triton.jit +def efficient_entropy_kernel_general_mainloop( + rank, + hidden_ptr, + weight_ptr, + labels_ptr, + num_tokens, + hidden_size, + vocab_size, + vocab_per_split, + stride_hidden_m: tl.int64, + stride_hidden_k: tl.int64, + stride_weight_n: tl.int64, + stride_weight_k: tl.int64, + max_ptr, + stride_max_m: tl.int64, + stride_max_n: tl.int64, + accu_ptr, + stride_accu_m: tl.int64, + stride_accu_n: tl.int64, + entropy_b_ptr, + stride_entropy_b_m: tl.int64, + stride_entropy_b_n: tl.int64, + global_logprobs_ptr, + stride_global_logprobs: tl.int64, + global_logprobs_scalar_ptr, + rcp_temperature: tl.float32, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + """ + forward mainloop + """ + pid = tl.program_id(axis=0) + num_splits = (vocab_size + vocab_per_split - 1) // vocab_per_split + num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(vocab_per_split, BLOCK_SIZE_N) + pid_m = pid % num_pid_m + pid_n = pid // num_pid_m + + if pid_m == 0 and pid_n == 0: + tl.store(global_logprobs_scalar_ptr, 0.0) + + # create pointers for the first blocks of hidden + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_k = tl.arange(0, BLOCK_SIZE_K) + hidden_ptrs = hidden_ptr + (offs_am[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k) + + # load labels for this block + labels = tl.load(labels_ptr + offs_am, mask=offs_am < num_tokens) + + # traverse over N dimension + # _max = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + _max = tl.full((BLOCK_SIZE_M,), -float("inf"), dtype=tl.float32) + _accu = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + _entropy_b = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + _logprobs = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + for n in range(0, num_pid_n): + offs_bn = pid_n * vocab_per_split + n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + # weight_ptrs = weight_ptr + (offs_k[:, None] * stride_weight_k + offs_bn[None, :] * stride_weight_n) + weight_ptrs = weight_ptr + (offs_bn[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k) + + # iterate over K dimension + logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): + # load the next block of hidden and weight + _hidden = tl.load( + hidden_ptrs, + mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens), + other=0.0, + ) + # _weight = tl.load(weight_ptrs, + # mask=(offs_k[:, None] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[None, :] < (min( + # (pid_n + 1) * vocab_per_split, vocab_size))), + # other=0.0) + + _weight = tl.load( + weight_ptrs, + mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) + & (offs_bn[:, None] < (min((pid_n + 1) * vocab_per_split, vocab_size))), + other=0.0, + ) + + # GEMM + logits = tl.dot(_hidden, _weight.trans(), logits) + + # advance the ptrs to the next K block + hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k + weight_ptrs += BLOCK_SIZE_K * stride_weight_k + # reset hidden_ptrs for next iteration + hidden_ptrs -= hidden_size * stride_hidden_k + + # scale logits by temperature + logits *= rcp_temperature + + # update global maximum + _max_old = _max + m_pid_n = tl.max(logits, axis=1) + _max = tl.maximum(_max_old, m_pid_n) + + exp_logits = tl.exp(logits - _max[:, None]) + coeff = tl.exp(_max_old - _max) + _accu = coeff * _accu + tl.sum(exp_logits, axis=1) + + _entropy_b = _entropy_b * coeff + tl.sum(logits * exp_logits, axis=1) + + label_mask = (offs_bn + rank * vocab_size)[None, :] == labels[:, None] + _logprobs += tl.sum(logits * label_mask, axis=1) + + # store maximum + offs_max_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_max_n = pid_n + maximum_ptrs = max_ptr + offs_max_n * stride_max_n + offs_max_m * stride_max_m + tl.store(maximum_ptrs, _max, mask=(offs_max_m < num_tokens) & (offs_max_n < num_splits)) + + # store entropy + accu_ptrs = accu_ptr + offs_max_n * stride_accu_n + offs_max_m * stride_accu_m + tl.store(accu_ptrs, _accu, mask=(offs_max_m < num_tokens) & (offs_max_n[None] < num_splits)) + entropy_b_ptrs = entropy_b_ptr + offs_max_n * stride_entropy_b_n + offs_max_m * stride_entropy_b_m + tl.store(entropy_b_ptrs, _entropy_b, mask=(offs_max_m < num_tokens) & (offs_max_n < num_splits)) + + # store logprobs + vocab_left_idx = pid_n * vocab_per_split + rank * vocab_size + vocab_right_idx = min((pid_n + 1) * vocab_per_split, vocab_size) + rank * vocab_size + mask = (labels >= vocab_left_idx) & (labels < vocab_right_idx) + mask &= offs_am < num_tokens + global_logprobs_ptrs = global_logprobs_ptr + offs_am * stride_global_logprobs + # tl.atomic_add(global_logprobs_ptrs, _logprobs, mask=mask) + tl.store(global_logprobs_ptrs, _logprobs, mask=mask) + + +@triton.autotune(configs=[triton.Config({"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64})], key=["num_tokens", "num_splits"]) +@triton.jit +def efficient_entropy_triton_kernel_epilogue( + max_ptr, + stride_max_m: tl.int64, + stride_max_n: tl.int64, + num_tokens, + num_splits, + global_max_ptr, + stride_global_max: tl.int64, + accu_ptr, + stride_accu_m: tl.int64, + stride_accu_n: tl.int64, + global_accu_ptr, + stride_global_accu: tl.int64, + entropy_b_ptr, + stride_entropy_b_m: tl.int64, + stride_entropy_b_n: tl.int64, + global_entropy_b_ptr, + stride_global_entropy_b: tl.int64, + global_entropy_ptr, + stride_global_entropy: tl.int64, + global_logprobs_ptr, + stride_global_logprobs: tl.int64, + global_logprobs_scalar_ptr, + reduction: int, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + """ + foward epilogue + """ + pid_m = tl.program_id(axis=0) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + global_max = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + global_accu = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + global_entropy_b = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + for pid_n in range(0, tl.cdiv(num_splits, BLOCK_SIZE_N)): + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + max_ptrs = max_ptr + offs_m[:, None] * stride_max_m + offs_n[None, :] * stride_max_n + + _max = tl.load(max_ptrs, mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), other=0.0) + + accu_ptrs = accu_ptr + offs_m[:, None] * stride_accu_m + offs_n[None, :] * stride_accu_n + _accu = tl.load(accu_ptrs, mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), other=0.0) + + entropy_b_ptrs = entropy_b_ptr + offs_m[:, None] * stride_entropy_b_m + offs_n[None, :] * stride_entropy_b_n + _entropy_b = tl.load( + entropy_b_ptrs, mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), other=0.0 + ) + + # local reduction + _max_old = global_max + _local_max = tl.max(_max, axis=1) + global_max = tl.maximum(global_max, _local_max) + + _scale = tl.exp(_max - global_max[:, None]) + _coeff = tl.exp(_max_old - global_max) + global_accu = _coeff * global_accu + tl.sum(_scale * _accu, axis=1) + global_entropy_b = _coeff * global_entropy_b + tl.sum(_scale * _entropy_b, axis=1) + + # store + maximum_ptrs = global_max_ptr + offs_m * stride_global_max + tl.store(maximum_ptrs, global_max, mask=offs_m < num_tokens) + + # store entropy_b + global_entropy_b = tl.fdiv(global_entropy_b, global_accu) # entropy_b + tl.store(global_entropy_b_ptr + offs_m * stride_global_entropy_b, global_entropy_b, mask=offs_m < num_tokens) + + # store entropy + global_accu_ptrs = global_accu_ptr + offs_m * stride_global_accu + tl.store(global_accu_ptrs, global_accu, mask=offs_m < num_tokens) + global_entropy = tl.log(global_accu) + global_max - global_entropy_b # entropy_a + global_entropy_ptrs = global_entropy_ptr + offs_m * stride_global_entropy + tl.store(global_entropy_ptrs, global_entropy, mask=offs_m < num_tokens) + # update logprobs + global_logprobs_ptrs = global_logprobs_ptr + offs_m * stride_global_logprobs + global_logprobs = tl.load(global_logprobs_ptrs, mask=offs_m < num_tokens) + global_logprobs = global_max + tl.log(global_accu) - global_logprobs + + global_logprobs = -1 * global_logprobs + if reduction == 0: + tl.store(global_logprobs_ptrs, global_logprobs, mask=offs_m < num_tokens) + elif reduction == 1: + global_logprobs_scalar = tl.sum(global_logprobs, axis=0) + tl.atomic_add(global_logprobs_scalar_ptr, global_logprobs_scalar) + elif reduction == 2: + global_logprobs_scalar = tl.sum(global_logprobs, axis=0) / num_tokens.to(tl.float32) + tl.atomic_add(global_logprobs_scalar_ptr, global_logprobs_scalar) + + +@triton.autotune(configs=[triton.Config({"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64})], key=["num_tokens", "num_splits"]) +@triton.jit +def efficient_entropy_triton_kernel_epilogue_tp( + num_tokens, + num_splits, + reduced_max_ptr, + stride_reduced_max_m: tl.int64, + stride_reduced_max_n: tl.int64, + original_max_ptr, + stride_original_max_m: tl.int64, + stride_original_max_n: tl.int64, + accu_ptr, + stride_accu_m: tl.int64, + stride_accu_n: tl.int64, + entropy_b_ptr, + stride_entropy_b_m: tl.int64, + stride_entropy_b_n: tl.int64, + global_max_ptr, + stride_global_max: tl.int64, + global_accu_ptr, + stride_global_accu: tl.int64, + global_entropy_b_ptr, + stride_global_entropy_b: tl.int64, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + + global_max = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + global_accu = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + global_entropy_b = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + for pid_n in range(0, tl.cdiv(num_splits, BLOCK_SIZE_N)): + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + _reduced_max = tl.load( + reduced_max_ptr + offs_m[:, None] * stride_reduced_max_m + offs_n[None, :] * stride_reduced_max_n, + mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), + other=0.0, + ) + _original_max = tl.load( + original_max_ptr + offs_m[:, None] * stride_original_max_m + offs_n[None, :] * stride_original_max_n, + mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), + other=0.0, + ) + _accu = tl.load( + accu_ptr + offs_m[:, None] * stride_accu_m + offs_n[None, :] * stride_accu_n, + mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), + other=0.0, + ) + + # local reduce-max + _max_old = global_max + _local_max = tl.max(_reduced_max, axis=1) + global_max = tl.maximum(global_max, _local_max) + + # update accumulate + _coeff = tl.exp(_max_old - global_max) + _scale = tl.exp(_original_max - global_max[:, None]) + global_accu = _coeff * global_accu + tl.sum(_scale * _accu, axis=1) + + # update entropy_b + _entropy_b = tl.load( + entropy_b_ptr + offs_m[:, None] * stride_entropy_b_m + offs_n[None, :] * stride_entropy_b_n, + mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), + other=0.0, + ) + global_entropy_b = _coeff * global_entropy_b + tl.sum(_scale * _entropy_b, axis=1) + + # store + tl.store(global_max_ptr + offs_m * stride_global_max, global_max, mask=offs_m < num_tokens) + tl.store(global_accu_ptr + offs_m * stride_global_accu, global_accu, mask=offs_m < num_tokens) + tl.store(global_entropy_b_ptr + offs_m * stride_global_entropy_b, global_entropy_b, mask=offs_m < num_tokens) + + +@triton.autotune(configs=[triton.Config({"BLOCK_SIZE_M": 16})], key=["num_tokens"]) +@triton.jit +def efficient_entropy_triton_epilogue_tp_update( + num_tokens, + logprobs_ptr, + stride_logprobs: tl.int64, + maximum_ptr, + stride_maximum: tl.int64, + accumulate_ptr, + stride_accumulate: tl.int64, + entropy_b_ptr, + stride_entropy_b: tl.int64, + entropy_ptr, + stride_entropy: tl.int64, + logprobs_scalar_ptr, + reduction: int, + BLOCK_SIZE_M: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + + maximum = tl.load(maximum_ptr + offs_m * stride_maximum, mask=offs_m < num_tokens) + accumulate = tl.load(accumulate_ptr + offs_m * stride_accumulate, mask=offs_m < num_tokens) + + entropy_b = tl.load(entropy_b_ptr + offs_m * stride_entropy_b, mask=offs_m < num_tokens) + entropy_b = tl.fdiv(entropy_b, accumulate) + tl.store(entropy_b_ptr + offs_m * stride_entropy_b, entropy_b, mask=offs_m < num_tokens) + + entropy = tl.log(accumulate) + maximum - entropy_b + tl.store(entropy_ptr + offs_m * stride_entropy, entropy, mask=offs_m < num_tokens) + + logprobs = tl.load(logprobs_ptr + offs_m * stride_logprobs, mask=offs_m < num_tokens) + logprobs = maximum + tl.log(accumulate) - logprobs + + logprobs = -1 * logprobs + if reduction == 0: + tl.store(logprobs_ptr + offs_m * stride_logprobs, logprobs, mask=offs_m < num_tokens) + elif reduction == 1: + logprobs_scalar = tl.sum(logprobs, axis=0) + tl.atomic_add(logprobs_scalar_ptr, logprobs_scalar) + elif reduction == 2: + logprobs_scalar = tl.sum(logprobs, axis=0) / num_tokens.to(tl.float32) + tl.atomic_add(logprobs_scalar_ptr, logprobs_scalar) + + +_dedicated_stream, _dedicated_events = None, None + + +def efficient_entropy_forward( + hidden: torch.Tensor, + weight: torch.Tensor, + labels: torch.Tensor, + reduction: typing.Optional[int] = 2, + temperature: typing.Optional[float] = 1.0, + dist_process_group: typing.Optional[dist.ProcessGroup] = None, +) -> list[torch.Tensor]: + """ + forward host function + """ + assert hidden.is_cuda and weight.is_cuda and labels.is_cuda + assert weight.device == hidden.device and labels.device == hidden.device + assert hidden.dim() == 2 and weight.dim() == 2 and labels.dim() == 1 + assert hidden.is_contiguous() and weight.is_contiguous() and labels.is_contiguous() + + assert hidden.shape[0] == labels.shape[0] and hidden.shape[1] == weight.shape[1] + + _rank = 0 if dist_process_group is None else dist.get_rank(dist_process_group) + _world_size = 1 if dist_process_group is None else dist.get_world_size(dist_process_group) + + if dist_process_group is not None and not hasattr(efficient_entropy_forward, "_initialized"): + global _dedicated_stream, _dedicated_events + _dedicated_stream = get_torch_device().Stream(hidden.device) + _dedicated_events = [get_torch_device().Event() for _ in range(2)] + efficient_entropy_forward._initialized = True + + num_tokens, hidden_size = hidden.shape + num_tokens = labels.shape[0] + vocab_size, hidden_size = weight.shape + assert hidden_size % 128 == 0 + + REDUCTION = get_entropy_reduction_enum(reduction) + + if REDUCTION == EntropyReductionEnum._None: + if dist_process_group is None: + logprobs = torch.empty((num_tokens,), device=hidden.device, dtype=torch.float32) + else: + logprobs = torch.zeros((num_tokens,), device=hidden.device, dtype=torch.float32) + elif REDUCTION in (EntropyReductionEnum._Sum, EntropyReductionEnum._Mean): + logprobs = torch.empty((), device=hidden.device, dtype=torch.float32) + else: + raise ValueError(f"Invalid reduction: {reduction}") + + entropy = torch.empty((num_tokens,), device=hidden.device, dtype=torch.float32) + assert logprobs.is_contiguous() and entropy.is_contiguous() + + maximum = torch.empty_like(entropy) + accumulate_and_entropy_b = torch.empty((num_tokens * 2,), device=hidden.device, dtype=torch.float32) + accumulate_and_entropy_b_view = accumulate_and_entropy_b.view(2, num_tokens) + accumulate = accumulate_and_entropy_b_view[0, :] + entropy_b = accumulate_and_entropy_b_view[1, :] + assert maximum.is_contiguous() and accumulate.is_contiguous() and entropy_b.is_contiguous() + + vocab_per_split = 1024 + assert vocab_per_split % 128 == 0 + num_splits = (vocab_size + vocab_per_split - 1) // vocab_per_split + + _max = torch.empty((num_tokens, num_splits), device=hidden.device, dtype=torch.float32) + _accu = torch.empty((num_tokens, num_splits), device=hidden.device, dtype=torch.float32) + _entropy_b = torch.empty((num_tokens, num_splits), device=hidden.device, dtype=torch.float32) + + if REDUCTION == EntropyReductionEnum._None: + _logprobs = logprobs + else: + _logprobs = torch.empty((num_tokens,), device=hidden.device, dtype=torch.float32) + + assert _accu.is_contiguous() and _entropy_b.is_contiguous() and _max.is_contiguous() + assert _accu.is_cuda and _entropy_b.is_cuda and _max.is_cuda + + if _config._use_triton: + # 1D kernel launch, then split the tile + def mainloop_grid(meta): + return (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]) * num_splits,) + + efficient_entropy_kernel_general_mainloop[mainloop_grid]( + _rank, + hidden, + weight, + labels, + num_tokens, + hidden_size, + vocab_size, + vocab_per_split, + hidden.stride(0), + hidden.stride(1), + weight.stride(0), + weight.stride(1), + _max, + _max.stride(0), + _max.stride(1), + _accu, + _accu.stride(0), + _accu.stride(1), + _entropy_b, + _entropy_b.stride(0), + _entropy_b.stride(1), + _logprobs, + _logprobs.stride(0), + logprobs, + 1.0 / temperature, + ) + else: + raise AssertionError("Triton is required for efficient entropy kernel") + + # reduction on maximum and maximum_indices + def epilogue_grid(meta): + return (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]),) + + if dist_process_group is None: + efficient_entropy_triton_kernel_epilogue[epilogue_grid]( + _max, + _max.stride(0), + _max.stride(1), + num_tokens, + num_splits, + maximum, + maximum.stride(0), + _accu, + _accu.stride(0), + _accu.stride(1), + accumulate, + accumulate.stride(0), + _entropy_b, + _entropy_b.stride(0), + _entropy_b.stride(1), + entropy_b, + entropy_b.stride(0), + entropy, + entropy.stride(0), + _logprobs, + _logprobs.stride(0), + logprobs, + REDUCTION, + ) + else: + # tensor-parallel + _max_backup = _max.clone() + dist.all_reduce(_max, op=dist.ReduceOp.MAX, group=dist_process_group) + + get_torch_device().current_stream().record_event(_dedicated_events[0]) + with get_torch_device().stream(_dedicated_stream): + _dedicated_stream.wait_event(_dedicated_events[0]) + dist.all_reduce(_logprobs, op=dist.ReduceOp.SUM, group=dist_process_group) + _dedicated_stream.record_event(_dedicated_events[1]) + + efficient_entropy_triton_kernel_epilogue_tp[epilogue_grid]( + num_tokens, + num_splits, + _max, + _max.stride(0), + _max.stride(1), + _max_backup, + _max_backup.stride(0), + _max_backup.stride(1), + _accu, + _accu.stride(0), + _accu.stride(1), + _entropy_b, + _entropy_b.stride(0), + _entropy_b.stride(1), + maximum, + maximum.stride(0), + accumulate, + accumulate.stride(0), + entropy_b, + entropy_b.stride(0), + ) + get_torch_device().current_stream().wait_event(_dedicated_events[1]) + + dist.all_reduce(accumulate_and_entropy_b, op=dist.ReduceOp.SUM, group=dist_process_group) + + # update logprobs & entropy + efficient_entropy_triton_epilogue_tp_update[epilogue_grid]( + num_tokens, + _logprobs, + _logprobs.stride(0), + maximum, + maximum.stride(0), + accumulate, + accumulate.stride(0), + entropy_b, + entropy_b.stride(0), + entropy, + entropy.stride(0), + logprobs, + REDUCTION, + ) + + return (logprobs, entropy, maximum, accumulate, entropy_b) + + +# NOTE: merge d_weight & d_hidden here, split along M & N +@triton.autotune( + configs=[ + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 16}, + num_stages=3, + num_warps=8, + ) + ], + key=["num_tokens", "hidden_size", "vocab_size"], +) +@triton.jit +def efficient_entropy_backward_kernel_general_mainloop_MN( + num_tokens: int, + hidden_size: int, + vocab_size: int, + rank: int, + hidden_ptr, + stride_hidden_m: tl.int64, + stride_hidden_k: tl.int64, + weight_ptr, + stride_weight_n: tl.int64, + stride_weight_k: tl.int64, + labels_ptr, + stride_labels: tl.int64, + maximum_ptr, + stride_maximum: tl.int64, + accu_ptr, + stride_accu: tl.int64, + d_entropy_ptr, + stride_d_entropy: tl.int64, + d_logprobs_ptr, + stride_d_logprobs: tl.int64, + reduction: int, + entropy_b_ptr, + stride_entropy_b: tl.int64, + d_hidden_ptr, + stride_d_hidden_m: tl.int64, + stride_d_hidden_k: tl.int64, + d_weight_ptr, + stride_d_weight_n: tl.int64, + stride_d_weight_k: tl.int64, + rcp_temperature: tl.float32, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + """ + backward mainloop, where d_logits & d_hidden & d_weight are fused + """ + # block swizzling + # pid = tl.program_id(axis=0) + # num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M) + # pid_m = pid % num_pid_m + # pid_n = pid // num_pid_m + + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(vocab_size, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + maximum_ptrs = maximum_ptr + offs_am * stride_maximum + maximum = tl.load(maximum_ptrs, mask=offs_am < num_tokens, other=0.0) + accu_ptrs = accu_ptr + offs_am * stride_accu + accu = tl.load(accu_ptrs, mask=offs_am < num_tokens, other=1e-6) # epsilon to avoid division by zero + accu_rcp = tl.fdiv(1.0, accu) + + d_entropy_ptrs = d_entropy_ptr + offs_am * stride_d_entropy + d_entropy = tl.load(d_entropy_ptrs, mask=offs_am < num_tokens, other=0.0) + if reduction == 0: # none + d_logprobs_ptrs = d_logprobs_ptr + offs_am * stride_d_logprobs + d_logprobs = tl.load(d_logprobs_ptrs, mask=offs_am < num_tokens, other=0.0) + elif reduction == 1: # sum + d_logprobs = tl.load(d_logprobs_ptr) + d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) + else: # mean + d_logprobs = tl.fdiv(tl.load(d_logprobs_ptr), num_tokens.to(tl.float32)) + d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) + d_logprobs = -1 * d_logprobs + + entropy_b_ptrs = entropy_b_ptr + offs_am * stride_entropy_b + entropy_b = tl.load(entropy_b_ptrs, mask=offs_am < num_tokens, other=0.0) + + hidden_ptrs = hidden_ptr + (offs_am[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k) + # weight_ptrs = weight_ptr + (offs_k[:, None] * stride_weight_k + offs_bn[None, :] * stride_weight_n) + weight_ptrs = weight_ptr + (offs_bn[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k) + labels_ptrs = labels_ptr + offs_am * stride_labels + labels = tl.load(labels_ptrs, mask=offs_am < num_tokens, other=0) + + d_hidden_ptrs = d_hidden_ptr + offs_am[:, None] * stride_d_hidden_m + offs_k[None, :] * stride_d_hidden_k + # d_weight_ptrs = d_weight_ptr + offs_k[:, None] * stride_d_weight_k + offs_bn[None, :] * stride_d_weight_n + d_weight_ptrs = d_weight_ptr + offs_bn[:, None] * stride_d_weight_n + offs_k[None, :] * stride_d_weight_k + + logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): + _hidden = tl.load( + hidden_ptrs, + mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens), + other=0.0, + ) + # _weight = tl.load(weight_ptrs, + # mask=(offs_k[:, None] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[None, :] < vocab_size), + # other=0.0) + _weight = tl.load( + weight_ptrs, + mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[:, None] < vocab_size), + other=0.0, + ) + + logits = tl.dot(_hidden, _weight.trans(), logits) + + hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k + weight_ptrs += BLOCK_SIZE_K * stride_weight_k + hidden_ptrs -= hidden_size * stride_hidden_k + weight_ptrs -= hidden_size * stride_weight_k + + # scale logits by temperature + logits *= rcp_temperature + + exp_logits = tl.exp(logits - maximum[:, None]) + + mask = (offs_bn + rank * vocab_size)[None, :] == labels[:, None] + d_logits = d_logprobs[:, None] * (exp_logits * accu_rcp[:, None] - mask) + d_logits += d_entropy[:, None] * (-exp_logits * accu_rcp[:, None]) * (logits - entropy_b[:, None]) + + # scale d_logits by temperature + d_logits *= rcp_temperature + + # loop for d_weight & d_hidden + for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): + _hidden = tl.load( + hidden_ptrs, + mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens), + other=0.0, + ) + # _d_weight = tl.dot(tl.trans(_hidden).to(tl.float32), d_logits) + # tl.atomic_add(d_weight_ptrs, + # _d_weight, + # mask=(offs_k[:, None] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[None, :] < vocab_size)) + _d_weight = tl.dot(d_logits.trans(), _hidden.to(tl.float32)) + tl.atomic_add( + d_weight_ptrs, + _d_weight, + mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[:, None] < vocab_size), + ) + + # _weight = tl.load(weight_ptrs, + # mask=(offs_k[:, None] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[None, :] < vocab_size), + # other=0.0) + # _d_hidden = tl.dot(d_logits, tl.trans(_weight).to(tl.float32)) + _weight = tl.load( + weight_ptrs, + mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[:, None] < vocab_size), + other=0.0, + ) + _d_hidden = tl.dot(d_logits, _weight.to(tl.float32)) + tl.atomic_add( + d_hidden_ptrs, + _d_hidden, + mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens), + ) + + hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k + weight_ptrs += BLOCK_SIZE_K * stride_weight_k + d_hidden_ptrs += BLOCK_SIZE_K * stride_d_hidden_k + d_weight_ptrs += BLOCK_SIZE_K * stride_d_weight_k + + +@triton.autotune( + configs=[ + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 16}, + num_stages=3, + num_warps=8, + ), + ], + key=["num_tokens", "hidden_size", "vocab_size"], +) +@triton.jit +def efficient_entropy_backward_kernel_d_hidden( + num_tokens: int, + hidden_size: int, + vocab_size: int, + rank: int, + hidden_ptr, + stride_hidden_m: tl.int64, + stride_hidden_k: tl.int64, + weight_ptr, + stride_weight_n: tl.int64, + stride_weight_k: tl.int64, + labels_ptr, + stride_labels: tl.int64, + maximum_ptr, + stride_maximum: tl.int64, + accu_ptr, + stride_accu: tl.int64, + d_entropy_ptr, + stride_d_entropy: tl.int64, + d_logprobs_ptr, + stride_d_logprobs: tl.int64, + reduction: int, + entropy_b_ptr, + stride_entropy_b: tl.int64, + d_hidden_ptr, + stride_d_hidden_m: tl.int64, + stride_d_hidden_k: tl.int64, + rcp_temperature: tl.float32, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + """ + backward d_hidden + """ + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M) + pid_m = pid % num_pid_m + pid_k = pid // num_pid_m + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_k = tl.arange(0, BLOCK_SIZE_K) + result_offs_k = pid_k * BLOCK_SIZE_K + offs_k + + maximum = tl.load(maximum_ptr + offs_m * stride_maximum, mask=offs_m < num_tokens, other=0.0) + accu = tl.load(accu_ptr + offs_m * stride_accu, mask=offs_m < num_tokens, other=1e-6) + accu_rcp = tl.fdiv(1.0, accu) + d_entropy = tl.load(d_entropy_ptr + offs_m * stride_d_entropy, mask=offs_m < num_tokens, other=0.0) + if reduction == 0: + d_logprobs = tl.load(d_logprobs_ptr + offs_m * stride_d_logprobs, mask=offs_m < num_tokens, other=0.0) + elif reduction == 1: + d_logprobs = tl.load(d_logprobs_ptr) + d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) + else: + d_logprobs = tl.fdiv(tl.load(d_logprobs_ptr), num_tokens.to(tl.float32)) + d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) + d_logprobs = -1 * d_logprobs + + entropy_b = tl.load(entropy_b_ptr + offs_m * stride_entropy_b, mask=offs_m < num_tokens, other=0.0) + labels = tl.load(labels_ptr + offs_m * stride_labels, mask=offs_m < num_tokens, other=0) + + # iterate over vocab_size + d_hidden = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32) + for n in range(0, tl.cdiv(vocab_size, BLOCK_SIZE_N)): + offs_n = n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + hidden_ptrs = hidden_ptr + (offs_m[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k) + weight_ptrs = weight_ptr + (offs_n[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k) + + # iterate over hidden_size to get logits + logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): + _hidden = tl.load( + hidden_ptrs, + mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_m[:, None] < num_tokens), + other=0.0, + ) + _weight = tl.load( + weight_ptrs, + mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_n[:, None] < vocab_size), + other=0.0, + ) + + logits = tl.dot(_hidden, _weight.trans(), logits) + + hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k + weight_ptrs += BLOCK_SIZE_K * stride_weight_k + + # scale logits by temperature + logits *= rcp_temperature + + exp_logits = tl.exp(logits - maximum[:, None]) + + mask = (offs_n + rank * vocab_size)[None, :] == labels[:, None] + d_logits = d_logprobs[:, None] * (exp_logits * accu_rcp[:, None] - mask) + d_logits += d_entropy[:, None] * (-exp_logits * accu_rcp[:, None]) * (logits - entropy_b[:, None]) + + # scale d_logits + d_logits *= rcp_temperature + + # calculate d_hidden + weight_ptrs = weight_ptr + (offs_n[:, None] * stride_weight_n + result_offs_k[None, :] * stride_weight_k) + _weight = tl.load( + weight_ptrs, mask=(result_offs_k[None, :] < hidden_size) & (offs_n[:, None] < vocab_size), other=0.0 + ) + d_hidden = tl.dot(d_logits.to(weight_ptr.dtype.element_ty), _weight, d_hidden) + + # write back + tl.store( + d_hidden_ptr + offs_m[:, None] * stride_d_hidden_m + result_offs_k[None, :] * stride_d_hidden_k, + d_hidden, + mask=(offs_m[:, None] < num_tokens) & (result_offs_k[None, :] < hidden_size), + ) + + +@triton.autotune( + configs=[ + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 16}, + num_stages=3, + num_warps=8, + ), + ], + key=["num_tokens", "hidden_size", "vocab_size"], +) +@triton.jit +def efficient_entropy_backward_kernel_d_weight( + num_tokens: int, + hidden_size: int, + vocab_size: int, + rank: int, + hidden_ptr, + stride_hidden_m: tl.int64, + stride_hidden_k: tl.int64, + weight_ptr, + stride_weight_n: tl.int64, + stride_weight_k: tl.int64, + labels_ptr, + stride_labels: tl.int64, + maximum_ptr, + stride_maximum: tl.int64, + accu_ptr, + stride_accu: tl.int64, + d_entropy_ptr, + stride_d_entropy: tl.int64, + d_logprobs_ptr, + stride_d_logprobs: tl.int64, + reduction: int, + entropy_b_ptr, + stride_entropy_b: tl.int64, + d_weight_ptr, + stride_d_weight_n: tl.int64, + stride_d_weight_k: tl.int64, + rcp_temperature: tl.float32, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + pid = tl.program_id(axis=0) + num_pid_n = tl.cdiv(vocab_size, BLOCK_SIZE_N) + pid_n = pid % num_pid_n + pid_k = pid // num_pid_n + + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + result_offs_k = pid_k * BLOCK_SIZE_K + offs_k + + d_weight = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_K), dtype=tl.float32) + for m in range(0, tl.cdiv(num_tokens, BLOCK_SIZE_M)): + offs_m = m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + + maximum = tl.load(maximum_ptr + offs_m * stride_maximum, mask=offs_m < num_tokens, other=0.0) + accu = tl.load(accu_ptr + offs_m * stride_accu, mask=offs_m < num_tokens, other=1e-6) + accu_rcp = tl.fdiv(1.0, accu) + d_entropy = tl.load(d_entropy_ptr + offs_m * stride_d_entropy, mask=offs_m < num_tokens, other=0.0) + if reduction == 0: + d_logprobs = tl.load(d_logprobs_ptr + offs_m * stride_d_logprobs, mask=offs_m < num_tokens, other=0.0) + elif reduction == 1: + d_logprobs = tl.load(d_logprobs_ptr) + d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) + else: + d_logprobs = tl.fdiv(tl.load(d_logprobs_ptr), num_tokens.to(tl.float32)) + d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) + d_logprobs = -1 * d_logprobs + + entropy_b = tl.load(entropy_b_ptr + offs_m * stride_entropy_b, mask=offs_m < num_tokens, other=0.0) + labels = tl.load(labels_ptr + offs_m * stride_labels, mask=offs_m < num_tokens, other=0) + + hidden_ptrs = hidden_ptr + (offs_m[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k) + weight_ptrs = weight_ptr + (offs_n[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k) + + logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): + _hidden = tl.load( + hidden_ptrs, + mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_m[:, None] < num_tokens), + other=0.0, + ) + _weight = tl.load( + weight_ptrs, + mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_n[:, None] < vocab_size), + other=0.0, + ) + + logits = tl.dot(_hidden, _weight.trans(), logits) + + hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k + weight_ptrs += BLOCK_SIZE_K * stride_weight_k + + logits *= rcp_temperature + + exp_logits = tl.exp(logits - maximum[:, None]) + + mask = (offs_n + rank * vocab_size)[None, :] == labels[:, None] + d_logits = d_logprobs[:, None] * (exp_logits * accu_rcp[:, None] - mask) + d_logits += d_entropy[:, None] * (-exp_logits * accu_rcp[:, None]) * (logits - entropy_b[:, None]) + + d_logits *= rcp_temperature + + hidden_ptrs = hidden_ptr + (offs_m[:, None] * stride_hidden_m + result_offs_k[None, :] * stride_hidden_k) + _hidden = tl.load( + hidden_ptrs, mask=(result_offs_k[None, :] < hidden_size) & (offs_m[:, None] < num_tokens), other=0.0 + ) + d_weight = tl.dot(d_logits.to(d_weight_ptr.dtype.element_ty).trans(), _hidden, d_weight) + + # write back + tl.store( + d_weight_ptr + offs_n[:, None] * stride_d_weight_n + result_offs_k[None, :] * stride_d_weight_k, + d_weight, + mask=(offs_n[:, None] < vocab_size) & (result_offs_k[None, :] < hidden_size), + ) + + +# NOTE: split tile from d_logits' perspective +@triton.autotune( + configs=[ + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 16}, + num_stages=3, + num_warps=8, + ), + ], + key=["num_tokens", "hidden_size", "vocab_size"], +) +@triton.jit +def efficient_entropy_backward_kernel_general_d_logits( + num_tokens: int, + hidden_size: int, + vocab_size: int, + rank: int, + hidden_ptr, + stride_hidden_m: tl.int64, + stride_hidden_k: tl.int64, + weight_ptr, + stride_weight_n: tl.int64, + stride_weight_k: tl.int64, + labels_ptr, + stride_labels: tl.int64, + maximum_ptr, + stride_maximum: tl.int64, + accu_ptr, + stride_accu: tl.int64, + d_entropy_ptr, + stride_d_entropy: tl.int64, + d_logprobs_ptr, + stride_d_logprobs: tl.int64, + reduction: int, + entropy_b_ptr, + stride_entropy_b, + d_logits_ptr, + stride_d_logits_m: tl.int64, + stride_d_logits_n: tl.int64, + rcp_temperature: tl.float32, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + """ + backward d_logits + """ + # block swizzling + # pid = tl.program_id(axis=0) + # num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M) + # pid_m = pid % num_pid_m + # pid_n = pid // num_pid_m + + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(vocab_size, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + maximum_ptrs = maximum_ptr + offs_am * stride_maximum + maximum = tl.load(maximum_ptrs, mask=offs_am < num_tokens, other=0.0) + accu_ptrs = accu_ptr + offs_am * stride_accu + accu = tl.load(accu_ptrs, mask=offs_am < num_tokens, other=1e-6) # epsilon to avoid division by zero + accu_rcp = tl.fdiv(1.0, accu) + + d_entropy_ptrs = d_entropy_ptr + offs_am * stride_d_entropy + d_entropy = tl.load(d_entropy_ptrs, mask=offs_am < num_tokens, other=0.0) + if reduction == 0: # none + d_logprobs_ptrs = d_logprobs_ptr + offs_am * stride_d_logprobs + d_logprobs = tl.load(d_logprobs_ptrs, mask=offs_am < num_tokens, other=0.0) + elif reduction == 1: # sum + d_logprobs = tl.load(d_logprobs_ptr) + d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) + else: # mean + d_logprobs = tl.fdiv(tl.load(d_logprobs_ptr), num_tokens.to(tl.float32)) + d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) + d_logprobs = -1 * d_logprobs + + entropy_b_ptrs = entropy_b_ptr + offs_am * stride_entropy_b + entropy_b = tl.load(entropy_b_ptrs, mask=offs_am < num_tokens, other=0.0) + + hidden_ptrs = hidden_ptr + (offs_am[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k) + # weight_ptrs = weight_ptr + (offs_k[:, None] * stride_weight_k + offs_bn[None, :] * stride_weight_n) + weight_ptrs = weight_ptr + (offs_bn[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k) + labels_ptrs = labels_ptr + offs_am * stride_labels + labels = tl.load(labels_ptrs, mask=offs_am < num_tokens, other=0) + + logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): + _hidden = tl.load( + hidden_ptrs, + mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens), + other=0.0, + ) + # _weight = tl.load(weight_ptrs, + # mask=(offs_k[:, None] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[None, :] < vocab_size), + # other=0.0) + _weight = tl.load( + weight_ptrs, + mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[:, None] < vocab_size), + other=0.0, + ) + + logits = tl.dot(_hidden, _weight.trans(), logits) + + hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k + weight_ptrs += BLOCK_SIZE_K * stride_weight_k + hidden_ptrs -= hidden_size * stride_hidden_k + weight_ptrs -= hidden_size * stride_weight_k + + # scale logits by temperature + logits *= rcp_temperature + + exp_logits = tl.exp(logits - maximum[:, None]) + + mask = (offs_bn + rank * vocab_size)[None, :] == labels[:, None] + d_logits = d_logprobs[:, None] * (exp_logits * accu_rcp[:, None] - mask) + d_logits += d_entropy[:, None] * (-exp_logits * accu_rcp[:, None]) * (logits - entropy_b[:, None]) + + # scale d_logits by temperature + d_logits *= rcp_temperature + + # store d_logits + d_logits_ptrs = d_logits_ptr + offs_am[:, None] * stride_d_logits_m + offs_bn[None, :] * stride_d_logits_n + tl.store( + d_logits_ptrs, + d_logits, # will be implicitly converted to d_logits_ptrs.dtype.element_ty + mask=(offs_am[:, None] < num_tokens) & (offs_bn[None, :] < vocab_size), + ) + + +@triton.autotune( + configs=[ + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 16}, + num_stages=3, + num_warps=8, + ), + ], + key=["num_tokens", "hidden_size", "vocab_size"], +) +@triton.jit +def efficient_entropy_backward_kernel_general_d_logits_split_N( + split_idx: int, + num_tokens: int, + hidden_size: int, + vocab_size: int, + vocab_per_split: int, + rank: int, + hidden_ptr, + stride_hidden_m: tl.int64, + stride_hidden_k: tl.int64, + weight_ptr, + stride_weight_n: tl.int64, + stride_weight_k: tl.int64, + labels_ptr, + stride_labels: tl.int64, + maximum_ptr, + stride_maximum: tl.int64, + accu_ptr, + stride_accu: tl.int64, + d_entropy_ptr, + stride_d_entropy: tl.int64, + d_logprobs_ptr, + stride_d_logprobs: tl.int64, + reduction: int, + entropy_b_ptr, + stride_entropy_b, + d_logits_ptr, + stride_d_logits_m: tl.int64, + stride_d_logits_n: tl.int64, + rcp_temperature: tl.float32, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(vocab_per_split, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = split_idx * vocab_per_split + pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + maximum = tl.load(maximum_ptr + offs_am * stride_maximum, mask=offs_am < num_tokens, other=0.0) + accu = tl.load(accu_ptr + offs_am * stride_accu, mask=offs_am < num_tokens, other=1e-6) + accu_rcp = tl.fdiv(1.0, accu) + d_entropy = tl.load(d_entropy_ptr + offs_am * stride_d_entropy, mask=offs_am < num_tokens, other=0.0) + if reduction == 0: + d_logprobs = tl.load(d_logprobs_ptr + offs_am * stride_d_logprobs, mask=offs_am < num_tokens, other=0.0) + elif reduction == 1: + d_logprobs = tl.load(d_logprobs_ptr) + d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) + else: + d_logprobs = tl.fdiv(tl.load(d_logprobs_ptr), num_tokens.to(tl.float32)) + d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) + d_logprobs = -1 * d_logprobs + entropy_b = tl.load(entropy_b_ptr + offs_am * stride_entropy_b, mask=offs_am < num_tokens, other=0.0) + labels = tl.load(labels_ptr + offs_am * stride_labels, mask=offs_am < num_tokens, other=0) + + hidden_ptrs = hidden_ptr + (offs_am[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k) + weight_ptrs = weight_ptr + (offs_bn[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k) + + vocab_right_bound = min((split_idx + 1) * vocab_per_split, vocab_size) + logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): + _hidden = tl.load( + hidden_ptrs, + mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens), + other=0.0, + ) + _weight = tl.load( + weight_ptrs, + mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[:, None] < vocab_right_bound), + other=0.0, + ) + logits = tl.dot(_hidden, _weight.trans(), logits) + + hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k + weight_ptrs += BLOCK_SIZE_K * stride_weight_k + + logits *= rcp_temperature + exp_logits = tl.exp(logits - maximum[:, None]) + + mask = (offs_bn + rank * vocab_size)[None, :] == labels[:, None] + d_logits = d_logprobs[:, None] * (exp_logits * accu_rcp[:, None] - mask) + d_logits += d_entropy[:, None] * (-exp_logits * accu_rcp[:, None]) * (logits - entropy_b[:, None]) + + d_logits *= rcp_temperature + + # filter d_logits with mask + result_offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + mask = (offs_am[:, None] < num_tokens) & (result_offs_n[None, :] < vocab_per_split) + + tl.store( + d_logits_ptr + offs_am[:, None] * stride_d_logits_m + result_offs_n[None, :] * stride_d_logits_n, d_logits, mask + ) + + +def efficient_entropy_backward( + dlogprobs: torch.Tensor, + dentropy: torch.Tensor, + hidden: torch.Tensor, + weight: torch.Tensor, + labels: torch.Tensor, + maximum: torch.Tensor, + acc: torch.Tensor, + entropy_b: torch.Tensor, + reduction: typing.Optional[int] = 2, + should_return_fp32_grad: bool = False, + temperature: typing.Optional[float] = 1.0, + dist_process_group: typing.Optional[dist.ProcessGroup] = None, +) -> list[torch.Tensor]: + """ + backward host function + """ + assert hidden.is_cuda and weight.is_cuda and labels.is_cuda + assert weight.device == hidden.device and labels.device == hidden.device + assert hidden.dim() == 2 and weight.dim() == 2 and labels.dim() == 1 + assert hidden.is_contiguous() and weight.is_contiguous() and labels.is_contiguous() + assert hidden.shape[0] == labels.shape[0] and hidden.shape[1] == weight.shape[1] + + _rank = 0 if dist_process_group is None else dist.get_rank(dist_process_group) + _world_size = 1 if dist_process_group is None else dist.get_world_size(dist_process_group) + + num_tokens, hidden_size = hidden.shape + num_tokens = labels.shape[0] + vocab_size, hidden_size = weight.shape + assert hidden_size % 128 == 0 + + REDUCTION = get_entropy_reduction_enum(reduction) + + if REDUCTION == EntropyReductionEnum._None: + assert dlogprobs.shape == (num_tokens,) + else: + assert dlogprobs.dim() == 0 + + assert dlogprobs.is_contiguous() and dentropy.is_contiguous() + assert dlogprobs.is_cuda and dentropy.is_cuda + assert dlogprobs.device == hidden.device and dlogprobs.device == dentropy.device + assert dentropy.shape == (num_tokens,) + + d_hidden, d_weight = None, None + if _config._backward == BackwardEnum._Total_Fuse_MN or should_return_fp32_grad: + d_hidden = torch.zeros_like(hidden, dtype=torch.float32, device=hidden.device) + d_weight = torch.zeros_like(weight, dtype=torch.float32, device=weight.device) + else: + d_hidden = torch.empty_like(hidden, dtype=hidden.dtype, device=hidden.device) + d_weight = torch.empty_like(weight, dtype=hidden.dtype, device=weight.device) + assert d_hidden.is_contiguous() and d_weight.is_contiguous() + + assert maximum.is_contiguous() and acc.is_contiguous() + assert maximum.device == hidden.device and acc.device == hidden.device + assert maximum.shape == labels.shape == acc.shape + assert maximum.is_cuda and acc.is_cuda + + vocab_per_split = 1024 + assert vocab_per_split % 128 == 0 + num_splits = (vocab_size + vocab_per_split - 1) // vocab_per_split + + assert entropy_b.is_contiguous() and entropy_b.is_cuda + assert entropy_b.shape == (num_tokens,) + + if _config._backward == BackwardEnum._Total_Fuse_MN: + # --- Triton doesn't materialize d_logits at all. Split tiles at the perspective of d_logits. + def mainloop_grid(meta): + return (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]) * triton.cdiv(vocab_size, meta["BLOCK_SIZE_N"]),) + + efficient_entropy_backward_kernel_general_mainloop_MN[mainloop_grid]( + num_tokens, + hidden_size, + vocab_size, + _rank, + hidden, + hidden.stride(0), + hidden.stride(1), + weight, + weight.stride(0), + weight.stride(1), + labels, + labels.stride(0), + maximum, + maximum.stride(0), + acc, + acc.stride(0), + dentropy, + dentropy.stride(0), + dlogprobs, + dlogprobs.stride(0) if REDUCTION == EntropyReductionEnum._None else 0, + REDUCTION, + entropy_b, + entropy_b.stride(0), + d_hidden, + d_hidden.stride(0), + d_hidden.stride(1), + d_weight, + d_weight.stride(0), + d_weight.stride(1), + 1.0 / temperature, + ) + + elif _config._backward == BackwardEnum._Total_Separate: + _d_logits = torch.empty((num_tokens, vocab_size), device=hidden.device, dtype=hidden.dtype).contiguous() + assert _d_logits.is_contiguous() + + if _config._use_triton: + + def d_logits_grid(meta): + return (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]) * triton.cdiv(vocab_size, meta["BLOCK_SIZE_N"]),) + + efficient_entropy_backward_kernel_general_d_logits[d_logits_grid]( + num_tokens, + hidden_size, + vocab_size, + _rank, + hidden, + hidden.stride(0), + hidden.stride(1), + weight, + weight.stride(0), + weight.stride(1), + labels, + labels.stride(0), + maximum, + maximum.stride(0), + acc, + acc.stride(0), + dentropy, + dentropy.stride(0), + dlogprobs, + dlogprobs.stride(0) if REDUCTION == EntropyReductionEnum._None else 0, + REDUCTION, + entropy_b, + entropy_b.stride(0), + _d_logits, + _d_logits.stride(0), + _d_logits.stride(1), + 1.0 / temperature, + ) + + torch.matmul(_d_logits, weight, out=d_hidden) + torch.matmul(_d_logits.T, hidden, out=d_weight) + else: + raise AssertionError("Triton is required for efficient entropy kernel") + + elif _config._backward == BackwardEnum._Split_Dlogits_N: + vocab_per_split = 9504 + num_splits = (vocab_size + vocab_per_split - 1) // vocab_per_split + + _d_logits = torch.empty((num_tokens, vocab_per_split), device=hidden.device, dtype=hidden.dtype).contiguous() + assert _d_logits.is_contiguous() + + def d_logits_grid(meta): + return (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]) * triton.cdiv(vocab_per_split, meta["BLOCK_SIZE_N"]),) + + for split_idx in range(num_splits): + efficient_entropy_backward_kernel_general_d_logits_split_N[d_logits_grid]( + split_idx, + num_tokens, + hidden_size, + vocab_size, + vocab_per_split, + _rank, + hidden, + hidden.stride(0), + hidden.stride(1), + weight, + weight.stride(0), + weight.stride(1), + labels, + labels.stride(0), + maximum, + maximum.stride(0), + acc, + acc.stride(0), + dentropy, + dentropy.stride(0), + dlogprobs, + dlogprobs.stride(0) if REDUCTION == EntropyReductionEnum._None else 0, + REDUCTION, + entropy_b, + entropy_b.stride(0), + _d_logits, + _d_logits.stride(0), + _d_logits.stride(1), + 1.0 / temperature, + ) + + if split_idx == (num_splits - 1): + vocab_right_bound = min((split_idx + 1) * vocab_per_split, vocab_size) - split_idx * vocab_per_split + _d_logits = _d_logits[:, :vocab_right_bound].contiguous() + + if split_idx == 0: + torch.matmul( + _d_logits, weight[split_idx * vocab_per_split : (split_idx + 1) * vocab_per_split, :], out=d_hidden + ) + else: + d_hidden += torch.matmul( + _d_logits, weight[split_idx * vocab_per_split : (split_idx + 1) * vocab_per_split, :] + ) + torch.matmul( + _d_logits.T, hidden, out=d_weight[split_idx * vocab_per_split : (split_idx + 1) * vocab_per_split, :] + ) + + elif _config._backward == BackwardEnum._Split_Dlogits_M: + raise NotImplementedError("BackwardEnum._Split_Dlogits_M is not implemented yet") + + return d_hidden, d_weight diff --git a/verl/utils/kernel/linear_cross_entropy.py b/verl/utils/kernel/linear_cross_entropy.py new file mode 100644 index 000000000..733a8152a --- /dev/null +++ b/verl/utils/kernel/linear_cross_entropy.py @@ -0,0 +1,117 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing + +import torch +import torch.distributed as dist + +from . import kernels + + +class LinearCrossEntropy(torch.autograd.Function): + @staticmethod + def forward( + ctx, + hidden: torch.Tensor, + weight: torch.Tensor, + labels: torch.Tensor, + temperature: typing.Optional[float] = 1.0, + reduction: typing.Optional[str] = "none", + dist_process_group: typing.Optional[dist.ProcessGroup] = None, + ) -> list[torch.Tensor]: + """_summary_ + + Args: + ctx (_type_): _description_ + hidden (torch.Tensor): (batch_size, num_tokens, hidden_size) -> (batch_size * num_tokens, hidden_size) + weight (torch.Tensor): (vocab_size, hidden_size) + labels (torch.Tensor): (batch_size, num_tokens) -> (batch_size * num_tokens, ) + temperature (typing.Optional[float], optional): _description_. Defaults to 1.0. + reduction (typing.Optional[str], optional): _description_. Defaults to "none". + dist_process_group (typing.Optional[dist.ProcessGroup], optional): _description_. Defaults to None. + + Returns: + typing.List[torch.Tensor]: _description_ + """ + + assert isinstance(temperature, float), f"temperature must be a float, but got {type(temperature)}" + assert isinstance(reduction, str), f"reduction must be a str, but got {type(reduction)}" + with torch.cuda.nvtx.range("LinearCrossEntropy-forward"): + REDUCTION = kernels.get_entropy_reduction_enum_number(reduction.lower()) + + original_hidden_shape = hidden.shape + if len(hidden.shape) != 2: + hidden = hidden.view(-1, hidden.shape[-1]) # (batch_size * num_tokens, hidden_size) + if len(labels.shape) != 1: + labels = labels.view(-1) + + logprobs, entropy, _maximum, _accumulate, _entropy_b = kernels.efficient_entropy_forward( + hidden, weight, labels, REDUCTION, temperature, dist_process_group + ) + + ctx.save_for_backward(hidden, weight, labels, _maximum, _accumulate, _entropy_b) + ctx.original_hidden_shape = original_hidden_shape + ctx.REDUCTION = REDUCTION + ctx.dist_process_group = dist_process_group + ctx.should_return_fp32_grad = False + ctx.temperature = temperature + return logprobs, entropy + + @staticmethod + def backward(ctx, dlogprobs: torch.Tensor, dentropy: torch.Tensor) -> list[torch.Tensor]: + with torch.cuda.nvtx.range("LinearCrossEntropy-backward"): + (hidden, weight, labels, _maximum, _accumulate, _entropy_b) = ctx.saved_tensors + REDUCTION = ctx.REDUCTION + dist_process_group = ctx.dist_process_group + should_return_fp32_grad = ctx.should_return_fp32_grad + temperature = ctx.temperature + + d_hidden, d_weight = kernels.efficient_entropy_backward( + dlogprobs, + dentropy, + hidden, + weight, + labels, + _maximum, + _accumulate, + _entropy_b, + REDUCTION, + should_return_fp32_grad, + temperature, + dist_process_group, + ) + d_hidden = d_hidden.view(ctx.original_hidden_shape) + + return (d_hidden, d_weight, None, None, None, None) + + +linear_cross_entropy = LinearCrossEntropy.apply diff --git a/verl/utils/logger/__init__.py b/verl/utils/logger/__init__.py index 1ce90c5eb..e3184368b 100644 --- a/verl/utils/logger/__init__.py +++ b/verl/utils/logger/__init__.py @@ -11,3 +11,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + + +from .aggregate_logger import ( + DecoratorLoggerBase, + LocalLogger, + log_with_rank, + print_rank_0, + print_with_rank, + print_with_rank_and_timer, +) + +__all__ = [ + "LocalLogger", + "DecoratorLoggerBase", + "print_rank_0", + "print_with_rank", + "print_with_rank_and_timer", + "log_with_rank", +] diff --git a/verl/utils/logger/aggregate_logger.py b/verl/utils/logger/aggregate_logger.py index 47d9945cd..d29698acc 100644 --- a/verl/utils/logger/aggregate_logger.py +++ b/verl/utils/logger/aggregate_logger.py @@ -15,25 +15,33 @@ A Ray logger will receive logging info from different processes. """ +import datetime import logging import numbers -from typing import Dict +import pprint +import torch -def concat_dict_to_str(dict: Dict, step): + +def concat_dict_to_str(dict: dict, step): output = [f"step:{step}"] for k, v in dict.items(): if isinstance(v, numbers.Number): - output.append(f"{k}:{v:.3f}") + output.append(f"{k}:{pprint.pformat(v)}") output_str = " - ".join(output) return output_str class LocalLogger: - def __init__(self, remote_logger=None, enable_wandb=False, print_to_console=False): + """ + A local logger that logs messages to the console. + + Args: + print_to_console (bool): Whether to print to the console. + """ + + def __init__(self, print_to_console=True): self.print_to_console = print_to_console - if print_to_console: - print("Using LocalLogger is deprecated. The constructor API will change ") def flush(self): pass @@ -44,7 +52,20 @@ def log(self, data, step): class DecoratorLoggerBase: - def __init__(self, role: str, logger: logging.Logger = None, level=logging.DEBUG, rank: int = 0, log_only_rank_0: bool = True): + """ + Base class for all decorators that log messages. + + Args: + role (str): The role (the name) of the logger. + logger (logging.Logger): The logger instance to use for logging. + level (int): The logging level. + rank (int): The rank of the process. + log_only_rank_0 (bool): If True, only log for rank 0. + """ + + def __init__( + self, role: str, logger: logging.Logger = None, level=logging.DEBUG, rank: int = 0, log_only_rank_0: bool = True + ): self.role = role self.logger = logger self.level = level @@ -63,3 +84,57 @@ def log_by_logging(self, log_str): raise ValueError("Logger is not initialized") if not self.log_only_rank_0 or self.rank == 0: self.logger.log(self.level, f"{self.role} {log_str}") + + +def print_rank_0(message): + """If distributed is initialized, print only on rank 0.""" + if torch.distributed.is_initialized(): + if torch.distributed.get_rank() == 0: + print(message, flush=True) + else: + print(message, flush=True) + + +def print_with_rank(message: str, rank: int = 0, log_only_rank_0: bool = False): + """_summary_ + Print a message with rank information. + This function prints the message only if `log_only_rank_0` is False or if the rank is 0. + + Args: + message (str): _description_ + rank (int, optional): _description_. Defaults to 0. + log_only_rank_0 (bool, optional): _description_. Defaults to False. + """ + if not log_only_rank_0 or rank == 0: + print(f"[Rank {rank}] {message}", flush=True) + + +def print_with_rank_and_timer(message: str, rank: int = 0, log_only_rank_0: bool = False): + """_summary_ + Print a message with rank information and a timestamp. + This function prints the message only if `log_only_rank_0` is False or if the rank is 0. + + Args: + message (str): _description_ + rank (int, optional): _description_. Defaults to 0. + log_only_rank_0 (bool, optional): _description_. Defaults to False. + """ + now = datetime.datetime.now() + message = f"[{now.strftime('%Y-%m-%d %H:%M:%S')}] [Rank {rank}] {message}" + if not log_only_rank_0 or rank == 0: + print(message, flush=True) + + +def log_with_rank(message: str, rank, logger: logging.Logger, level=logging.INFO, log_only_rank_0: bool = False): + """_summary_ + Log a message with rank information using a logger. + This function logs the message only if `log_only_rank_0` is False or if the rank is 0. + Args: + message (str): The message to log. + rank (int): The rank of the process. + logger (logging.Logger): The logger instance to use for logging. + level (int, optional): The logging level. Defaults to logging.INFO. + log_only_rank_0 (bool, optional): If True, only log for rank 0. Defaults to False. + """ + if not log_only_rank_0 or rank == 0: + logger.log(level, f"[Rank {rank}] {message}") diff --git a/verl/utils/megatron/dist_checkpointing.py b/verl/utils/megatron/dist_checkpointing.py new file mode 100644 index 000000000..d95752a45 --- /dev/null +++ b/verl/utils/megatron/dist_checkpointing.py @@ -0,0 +1,56 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from megatron.core import dist_checkpointing, mpu +from megatron.core.dist_checkpointing.serialization import ( + get_default_load_sharded_strategy, + get_default_save_sharded_strategy, +) +from megatron.core.dist_checkpointing.strategies.fully_parallel import ( + FullyParallelLoadStrategyWrapper, + FullyParallelSaveStrategyWrapper, +) + + +def save_dist_checkpointing(sharded_state_dict, ckpt_path, async_save=False): + validate_sharding_integrity = True + # Get checkpointing strategies + save_strategy = get_default_save_sharded_strategy("torch_dist") + save_strategy = FullyParallelSaveStrategyWrapper( + save_strategy, mpu.get_data_parallel_group(with_context_parallel=True) + ) + + # Save model sharded state dicts + async_save_request = dist_checkpointing.save( + sharded_state_dict, + ckpt_path, + sharded_strategy=save_strategy, + async_sharded_save=async_save, + validate_access_integrity=validate_sharding_integrity, + ) + + return async_save_request + + +def load_dist_checkpointing(sharded_state_dict, ckpt_dir): + # Get checkpointing strategies + load_strategy = get_default_load_sharded_strategy(ckpt_dir) + load_strategy = FullyParallelLoadStrategyWrapper( + load_strategy, mpu.get_data_parallel_group(with_context_parallel=True) + ) + + # Load model sharded state dicts + state_dict = dist_checkpointing.load(sharded_state_dict, ckpt_dir, sharded_strategy=load_strategy) + + return state_dict diff --git a/verl/utils/megatron/memory.py b/verl/utils/megatron/memory.py index 17a8ee1cf..bc62d427e 100644 --- a/verl/utils/megatron/memory.py +++ b/verl/utils/megatron/memory.py @@ -14,13 +14,15 @@ import torch +from verl.utils.device import get_device_id + class MemoryBuffer: def __init__(self, numel, numel_padded, dtype): self.numel = numel self.numel_padded = numel_padded self.dtype = dtype - self.data = torch.zeros(self.numel_padded, dtype=self.dtype, device=torch.cuda.current_device(), requires_grad=False) + self.data = torch.zeros(self.numel_padded, dtype=self.dtype, device=get_device_id(), requires_grad=False) def zero(self): """Reset the buffer to zero.""" diff --git a/verl/utils/megatron/optimizer.py b/verl/utils/megatron/optimizer.py index 30ebf6cc9..100c161a5 100644 --- a/verl/utils/megatron/optimizer.py +++ b/verl/utils/megatron/optimizer.py @@ -15,6 +15,7 @@ from megatron.core.optimizer import OptimizerConfig from megatron.core.optimizer import get_megatron_optimizer as get_megatron_optimizer_native +from megatron.core.optimizer_param_scheduler import OptimizerParamScheduler def get_megatron_optimizer( @@ -34,4 +35,46 @@ def get_megatron_optimizer( ) -# TODO: add get_optimizer_param_scheduler(optimizer) to implement lr scheuler. +def get_megatron_optimizer_param_scheduler( + optimizer, + config, +): + """ + Get the optimizer parameter scheduler for Megatron. + """ + if config.get("lr_decay_steps", None) is None: + config.lr_decay_steps = config.total_training_steps + wsd_decay_steps = None + if config.get("lr_wsd_decay_steps", None) is not None: + wsd_decay_steps = config.lr_wsd_decay_steps + if config.get("lr_warmup_steps_ratio", None) is not None and ( + config.get("lr_warmup_steps", None) is None or config.lr_warmup_steps <= 0 + ): + config.lr_warmup_steps = int(config.lr_warmup_steps_ratio * config.lr_decay_steps) + + opt_param_scheduler = OptimizerParamScheduler( + optimizer, + init_lr=config.lr_warmup_init, + max_lr=config.lr, + min_lr=config.min_lr, + lr_warmup_steps=config.lr_warmup_steps, + lr_decay_steps=config.lr_decay_steps, + lr_decay_style=config.lr_decay_style, + start_wd=config.weight_decay, + end_wd=config.weight_decay, + wd_incr_steps=config.total_training_steps, + wd_incr_style=config.weight_decay_incr_style, + use_checkpoint_opt_param_scheduler=config.use_checkpoint_opt_param_scheduler, + override_opt_param_scheduler=(not config.use_checkpoint_opt_param_scheduler), + wsd_decay_steps=wsd_decay_steps, + lr_wsd_decay_style=config.lr_wsd_decay_style, + ) + + return opt_param_scheduler + + +def get_megatron_last_lr(optimizer): + """ + Get the last learning rate from the optimizer parameter scheduler. + """ + return optimizer.param_groups[0]["lr"] diff --git a/verl/utils/megatron/tensor_parallel.py b/verl/utils/megatron/tensor_parallel.py index 3462e761c..d4a99b9d8 100644 --- a/verl/utils/megatron/tensor_parallel.py +++ b/verl/utils/megatron/tensor_parallel.py @@ -16,7 +16,7 @@ Utilities for using tensor_parallel in megatron """ -from typing import TYPE_CHECKING, Dict +from typing import TYPE_CHECKING import torch import torch.distributed as dist @@ -27,7 +27,7 @@ from megatron.core import ModelParallelConfig -def update_kwargs_with_config(dictionary: Dict, config: "ModelParallelConfig"): +def update_kwargs_with_config(dictionary: dict, config: "ModelParallelConfig"): dictionary["config"] = config return dictionary @@ -159,7 +159,8 @@ def vocab_parallel_log_probs_from_logits(logits, labels): def vocab_parallel_log_probs_from_logits_response_rmpad(input_ids, attention_mask, logits_rmpad, response_length): - """Similar to log_probs_from_logits_response_rmpad, but the logits_rmpad is now spliited across tensor parallel region. + """Similar to log_probs_from_logits_response_rmpad, but the logits_rmpad is now spliited across tensor parallel + region. This will further reduce the peak memory usage during training Args: @@ -175,7 +176,11 @@ def vocab_parallel_log_probs_from_logits_response_rmpad(input_ids, attention_mas input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), attention_mask=attention_mask) input_ids_rmpad = input_ids_rmpad.squeeze(-1) input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=0) - full_log_probs_rmpad = vocab_parallel_log_probs_from_logits(logits=logits_rmpad, labels=input_ids_rmpad_rolled) # (total_nnz,) - full_output = pad_input(hidden_states=full_log_probs_rmpad.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen) + full_log_probs_rmpad = vocab_parallel_log_probs_from_logits( + logits=logits_rmpad, labels=input_ids_rmpad_rolled + ) # (total_nnz,) + full_output = pad_input( + hidden_states=full_log_probs_rmpad.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen + ) output = full_output.squeeze(-1)[:, -response_length - 1 : -1] # [batch_size, response_length] return output diff --git a/verl/utils/megatron_utils.py b/verl/utils/megatron_utils.py index 17f3c8211..e59d1ed36 100644 --- a/verl/utils/megatron_utils.py +++ b/verl/utils/megatron_utils.py @@ -19,7 +19,7 @@ import gc import os import warnings -from typing import Any, Dict +from typing import Any import torch import torch.nn.functional as F @@ -34,6 +34,8 @@ from transformers import PretrainedConfig import verl.utils.megatron.tensor_parallel as tp_utils +from verl.utils.device import get_device_id, get_device_name, get_torch_device +from verl.utils.fs import local_mkdir_safe from verl.utils.model import normalize_model_name from verl.utils.torch_dtypes import PrecisionType @@ -48,11 +50,17 @@ def get_model( wrap_with_ddp=True, use_distributed_optimizer=True, transformer_config=None, + override_ddp_config=None, ): """Build the model.""" # Build model. - if mpu.get_pipeline_model_parallel_world_size() > 1 and mpu.get_virtual_pipeline_model_parallel_world_size() is not None: - assert model_type != ModelType.encoder_and_decoder, "Interleaved schedule not supported for model with both encoder and decoder" + if ( + mpu.get_pipeline_model_parallel_world_size() > 1 + and mpu.get_virtual_pipeline_model_parallel_world_size() is not None + ): + assert model_type != ModelType.encoder_and_decoder, ( + "Interleaved schedule not supported for model with both encoder and decoder" + ) model = [] for i in range(mpu.get_virtual_pipeline_model_parallel_world_size()): mpu.set_virtual_pipeline_model_parallel_rank(i) @@ -62,6 +70,7 @@ def get_model( this_model = model_provider_func(pre_process=pre_process, post_process=post_process) this_model.model_type = model_type model.append(this_model) + mpu.set_virtual_pipeline_model_parallel_rank(0) else: pre_process = mpu.is_pipeline_first_stage() post_process = mpu.is_pipeline_last_stage() @@ -69,7 +78,9 @@ def get_model( add_decoder = True if model_type == ModelType.encoder_and_decoder: if mpu.get_pipeline_model_parallel_world_size() > 1: - assert mpu.get_pipeline_model_parallel_split_rank() is not None, "Split rank needs to be specified for model with both encoder and decoder" + assert mpu.get_pipeline_model_parallel_split_rank() is not None, ( + "Split rank needs to be specified for model with both encoder and decoder" + ) rank = mpu.get_pipeline_model_parallel_rank() split_rank = mpu.get_pipeline_model_parallel_split_rank() world_size = mpu.get_pipeline_model_parallel_world_size() @@ -77,7 +88,9 @@ def get_model( post_process = (rank == (split_rank - 1)) or (rank == (world_size - 1)) add_encoder = mpu.is_pipeline_stage_before_split() add_decoder = mpu.is_pipeline_stage_after_split() - model = model_provider_func(pre_process=pre_process, post_process=post_process, add_encoder=add_encoder, add_decoder=add_decoder) + model = model_provider_func( + pre_process=pre_process, post_process=post_process, add_encoder=add_encoder, add_decoder=add_decoder + ) else: model = model_provider_func(pre_process=pre_process, post_process=post_process) model.model_type = model_type @@ -107,7 +120,7 @@ def get_model( # GPU allocation. if transformer_config is None or (not transformer_config.use_cpu_initialization): for model_module in model: - model_module.cuda(torch.cuda.current_device()) + model_module.to(f"{get_device_name()}:{get_device_id()}") # Fp16 conversion. config: TransformerConfig = get_model_config(model[0]) @@ -118,16 +131,20 @@ def get_model( if wrap_with_ddp: ddp_models = [] + ddp_config_dict = { + "use_distributed_optimizer": use_distributed_optimizer, + "grad_reduce_in_fp32": True, + "overlap_grad_reduce": False, + } + if override_ddp_config is not None: + ddp_config_dict.update(override_ddp_config) + ddp_config = DistributedDataParallelConfig(**ddp_config_dict) for model_chunk_idx, model_chunk in enumerate(model): ddp_model = DDP( config=tfconfig, module=model_chunk, disable_bucketing=(model_chunk_idx > 0), - ddp_config=DistributedDataParallelConfig( - overlap_grad_reduce=False, - use_distributed_optimizer=use_distributed_optimizer, - grad_reduce_in_fp32=True, # [old] accumulate_allreduce_grads_in_fp32=True, - ), + ddp_config=ddp_config, ) ddp_models.append(ddp_model) model = ddp_models @@ -161,7 +178,10 @@ def convert_config(hf_config: PretrainedConfig, megatron_config) -> TransformerC dt = PrecisionType.to_dtype(megatron_config.params_dtype) print(f"pipeline_dtype=megatron_config {dt}") qkv_bias = True if "Qwen2ForCausalLM" in hf_config.architectures else getattr(hf_config, "attention_bias", False) - overlap_p2p_comm = mpu.get_virtual_pipeline_model_parallel_world_size() is not None and mpu.get_virtual_pipeline_model_parallel_world_size() > 1 + overlap_p2p_comm = ( + mpu.get_virtual_pipeline_model_parallel_world_size() is not None + and mpu.get_virtual_pipeline_model_parallel_world_size() > 1 + ) batch_p2p_comm = False transformer_config = TransformerConfig( num_layers=hf_config.num_hidden_layers, @@ -198,12 +218,13 @@ def convert_config(hf_config: PretrainedConfig, megatron_config) -> TransformerC return transformer_config -def init_megatron_optim_config(optim_config: Dict) -> OptimizerConfig: +def init_megatron_optim_config(optim_config: dict) -> OptimizerConfig: config = OptimizerConfig( - optimizer="adam", + optimizer=optim_config.get("optimizer", "adam"), lr=optim_config.get("lr"), - clip_grad=optim_config.get("clip_grad"), - weight_decay=optim_config.get("weight_decay"), + min_lr=optim_config.get("min_lr", None), + clip_grad=optim_config.get("clip_grad", 1.0), + weight_decay=optim_config.get("weight_decay", 0.01), bf16=True, params_dtype=torch.bfloat16, use_distributed_optimizer=True, @@ -218,7 +239,8 @@ def mcore_model_parallel_config( # WARNING: Code should not reach this point. This function is deprecated and will be removed. # Please use hf_to_mcore_config_dense() from verl.models.mcore.config_converter instead. warnings.warn( - "Code should not reach this point. This function is deprecated and will be removed. Please use hf_to_mcore_config_dense() from verl.models.mcore.config_converter instead.", + "Code should not reach this point. This function is deprecated and will be removed. Please use " + "hf_to_mcore_config_dense() from verl.models.mcore.config_converter instead.", DeprecationWarning, stacklevel=2, ) @@ -269,7 +291,7 @@ def offload_megatron_model_to_cpu(models): if param.grad is not None: param.grad = param.grad.to("cpu", non_blocking=True) gc.collect() - torch.cuda.empty_cache() + get_torch_device().empty_cache() @torch.no_grad() @@ -290,13 +312,13 @@ def load_megatron_model_to_gpu(models, load_grad=True): buffer.param_data.copy_(buffer.param_data.cpu_data, non_blocking=True) else: # we need this for ref module - device_id = torch.cuda.current_device() + device_id = get_device_id() for _, param in model_chunk.named_parameters(): param.data = param.data.to(device_id, non_blocking=True) if param.grad is not None: param.grad = param.grad.to(device_id, non_blocking=True) gc.collect() - torch.cuda.empty_cache() + get_torch_device().empty_cache() @torch.no_grad() @@ -357,7 +379,7 @@ def _iter_opts(opt): def load_tensor_to_gpu(tensor): if tensor is None: return - device_id = torch.cuda.current_device() + device_id = get_device_id() tensor.data = tensor.data.to(device_id, non_blocking=True) def load_group_to_gpu(group): @@ -397,7 +419,7 @@ def _iter_opts(opt): if "exp_avg_sq" in v: v["exp_avg_sq"] = v["exp_avg_sq"].to("cpu", non_blocking=True) gc.collect() - torch.cuda.empty_cache() + get_torch_device().empty_cache() @torch.no_grad() @@ -412,59 +434,28 @@ def _iter_opts(opt): opt_state_dict_values = _opt.optimizer.state.values() for v in opt_state_dict_values: if "exp_avg" in v: - v["exp_avg"] = v["exp_avg"].to(torch.cuda.current_device(), non_blocking=True) + v["exp_avg"] = v["exp_avg"].to(get_device_id(), non_blocking=True) if "exp_avg_sq" in v: - v["exp_avg_sq"] = v["exp_avg_sq"].to(torch.cuda.current_device(), non_blocking=True) + v["exp_avg_sq"] = v["exp_avg_sq"].to(get_device_id(), non_blocking=True) gc.collect() - torch.cuda.empty_cache() + get_torch_device().empty_cache() -def print_rank_0(message): - """If distributed is initialized, print only on rank 0.""" - if torch.distributed.is_initialized(): - if torch.distributed.get_rank() == 0: - print(message, flush=True) - else: - print(message, flush=True) - - -def get_model_checkpoint_path(checkpoint_path): - os.makedirs(checkpoint_path, exist_ok=True) - return os.path.join(checkpoint_path, "model") +def get_dist_checkpoint_path(checkpoint_path): + local_mkdir_safe(checkpoint_path) + local_mkdir_safe(os.path.join(checkpoint_path, "dist_ckpt")) + return os.path.join(checkpoint_path, "dist_ckpt") def get_hf_model_checkpoint_path(checkpoint_path): - os.makedirs(checkpoint_path, exist_ok=True) + local_mkdir_safe(checkpoint_path) + local_mkdir_safe(os.path.join(checkpoint_path, "huggingface")) return os.path.join(checkpoint_path, "huggingface") -def get_hf_config_and_tokenizer_checkpoint_path(checkpoint_path): +def get_transformer_config_checkpoint_path(checkpoint_path): os.makedirs(checkpoint_path, exist_ok=True) - return os.path.join(checkpoint_path, "hf_config_and_tokenizer") - - -def get_optimizer_checkpoint_path(checkpoint_path, use_distributed_optimizer=True): - os.makedirs(os.path.join(checkpoint_path, "optim"), exist_ok=True) - if not use_distributed_optimizer: - return os.path.join(checkpoint_path, "optim", "optim.pt") - pp_rank = mpu.get_pipeline_model_parallel_rank() - tp_rank = mpu.get_tensor_model_parallel_rank() - cp_rank = mpu.get_context_parallel_rank() - dp_rank = mpu.get_data_parallel_rank() - # TODO: support ep - return os.path.join(checkpoint_path, "optim", f"distrib_optim_pp{pp_rank}_tp{tp_rank}_cp{cp_rank}_dp{dp_rank}.pt") - - -def get_rng_states_checkpoint_path(checkpoint_path, only_rank0_save=True): - # save rng states cause interrupts - os.makedirs(os.path.join(checkpoint_path, "rng_states"), exist_ok=True) - if only_rank0_save: - return os.path.join(checkpoint_path, "rng_states", "rng_states.pt") - dp_rank = mpu.get_data_parallel_rank() - pp_rank = mpu.get_pipeline_model_parallel_rank() - tp_rank = mpu.get_tensor_model_parallel_rank() - cp_rank = mpu.get_context_parallel_rank() - return os.path.join(checkpoint_path, "rng_states", f"rng_states_pp{pp_rank}_tp{tp_rank}_cp{cp_rank}_dp{dp_rank}.pt") + return os.path.join(checkpoint_path, "transformer_config.json") def convert_megatron_model_to_transformers_model( @@ -618,7 +609,9 @@ def broadcast_from_megatron_pp(tensor: torch.Tensor): else: tensor_spec = None tensor_spec_output = [None] * mpu.get_pipeline_model_parallel_world_size() - torch.distributed.all_gather_object(object_list=tensor_spec_output, obj=tensor_spec, group=mpu.get_pipeline_model_parallel_group()) + torch.distributed.all_gather_object( + object_list=tensor_spec_output, obj=tensor_spec, group=mpu.get_pipeline_model_parallel_group() + ) # find the src rank target_tensor_spec = None src_rank = None @@ -631,7 +624,7 @@ def broadcast_from_megatron_pp(tensor: torch.Tensor): src_rank = rank assert target_tensor_spec is not None if tensor is None: - tensor = torch.empty(size=target_tensor_spec[0], dtype=target_tensor_spec[1], device=torch.cuda.current_device()) + tensor = torch.empty(size=target_tensor_spec[0], dtype=target_tensor_spec[1], device=get_device_id()) if target_tensor_spec[2] is not None: tensor.tensor_model_parallel = target_tensor_spec[2] if target_tensor_spec[3] is not None: @@ -661,12 +654,22 @@ def broadcast_str_from_megatron_pp(obj: Any): obj_output = [None] * torch.distributed.get_world_size(group=mpu.get_pipeline_model_parallel_group()) obj_output[0] = target_obj - torch.distributed.broadcast_object_list(object_list=obj_output, src=global_rank, group=mpu.get_pipeline_model_parallel_group()) + torch.distributed.broadcast_object_list( + object_list=obj_output, src=global_rank, group=mpu.get_pipeline_model_parallel_group() + ) return obj_output[0] -def default_tp_concat_fn(layer_name_mapping, name, train_params, infer_params, model_config, convert_qkv_gate_up_by_simple_split=False): +def default_tp_concat_fn( + layer_name_mapping, + name, + train_params, + infer_params, + model_config, + hf_config=None, + convert_qkv_gate_up_by_simple_split=False, +): """ name: name of the parameter train_params: training parameters @@ -678,21 +681,33 @@ def default_tp_concat_fn(layer_name_mapping, name, train_params, infer_params, m """ from megatron.core import mpu + train_tp_size = mpu.get_tensor_model_parallel_world_size() if layer_name_mapping.get("qkv_layer_name") in name and "layer_norm" not in name: # if the tensor is qkv, for each param on tp, split into q, k, v # concat q, k, v separately. q_lst = [] k_lst = [] v_lst = [] - assert model_config.num_attention_heads % model_config.num_key_value_heads == 0 - num_q_per_kv = model_config.num_attention_heads // model_config.num_key_value_heads - assert infer_params[0].shape[0] % (num_q_per_kv + 2) == 0, f"param '{name}' shape '{infer_params[0].shape}' dim0 is not divisible by {num_q_per_kv + 2}" + num_attention_heads = model_config.num_attention_heads + num_key_value_heads = model_config.num_key_value_heads + if "vision_model" in name: + num_attention_heads = hf_config.vision_config.num_heads + num_key_value_heads = hf_config.vision_config.num_heads + assert num_attention_heads % num_key_value_heads == 0 + num_q_per_kv = num_attention_heads // num_key_value_heads + assert infer_params[0].shape[0] % (num_q_per_kv + 2) == 0, ( + f"param '{name}' shape '{infer_params[0].shape}' dim0 is not divisible by {num_q_per_kv + 2}" + ) kv_size_per_tp = infer_params[0].shape[0] // (num_q_per_kv + 2) split_size = [kv_size_per_tp * num_q_per_kv, kv_size_per_tp, kv_size_per_tp] for infer_param in infer_params: - num_query_groups_per_partition = model_config.num_key_value_heads // mpu.get_tensor_model_parallel_world_size() + num_query_groups_per_partition = num_key_value_heads // train_tp_size for chunk in infer_param.chunk(num_query_groups_per_partition): - split_size = [kv_size_per_tp * num_q_per_kv // num_query_groups_per_partition, kv_size_per_tp // num_query_groups_per_partition, kv_size_per_tp // num_query_groups_per_partition] + split_size = [ + kv_size_per_tp * num_q_per_kv // num_query_groups_per_partition, + kv_size_per_tp // num_query_groups_per_partition, + kv_size_per_tp // num_query_groups_per_partition, + ] q, k, v = chunk.split(split_size) q_lst.append(q) k_lst.append(k) @@ -702,7 +717,11 @@ def default_tp_concat_fn(layer_name_mapping, name, train_params, infer_params, m v = torch.cat(v_lst, dim=0) infer_params = torch.cat((q, k, v), dim=0) if not convert_qkv_gate_up_by_simple_split else [q, k, v] - elif layer_name_mapping.get("gate_proj_layer_name") in name: + elif ( + layer_name_mapping.get("gate_proj_layer_name") in name + and "layer_norm" not in name + and "vision_model.projection" not in name + ): # if the tensor is gate and proj gate_lst = [] up_lst = [] @@ -724,7 +743,14 @@ def default_tp_concat_fn(layer_name_mapping, name, train_params, infer_params, m return infer_params -def per_tensor_generator(actor_module, model_config, weight_converter, transformer_config, layer_name_mapping, convert_qkv_gate_up_by_simple_split=True): +def per_tensor_generator( + actor_module, + model_config, + weight_converter, + transformer_config, + layer_name_mapping, + convert_qkv_gate_up_by_simple_split=True, +): from megatron.core import parallel_state as mpu pp_rank = mpu.get_pipeline_model_parallel_rank() @@ -745,11 +771,11 @@ def tensor_generator(): yield name, param # note # there is a bug in megatron GPTModel - # decoder.layers[n].mlp.router.expert_bias" in GPTModel is not registered in named_parameter, but in state_dict(). - # for now we patch it by adding those keys to extra_keys. + # decoder.layers[n].mlp.router.expert_bias" in GPTModel is not registered in named_parameter, but in + # state_dict(). for now we patch it by adding those keys to extra_keys. extra_keys = [x for x in model.state_dict().keys() if "_extra_state" not in x and x not in existing_keys] for name in extra_keys: - yield name, model.state_dict()[name].to(torch.cuda.current_device()) + yield name, model.state_dict()[name].to(get_device_id()) # we need first make all rank get full model information meta_info = [] @@ -764,7 +790,9 @@ def tensor_generator(): meta_info.append((pp_rank, scan_vpp_idx, idx, name)) obj_spec_output = [None] * mpu.get_pipeline_model_parallel_world_size() - torch.distributed.all_gather_object(object_list=obj_spec_output, obj=meta_info, group=mpu.get_pipeline_model_parallel_group()) + torch.distributed.all_gather_object( + object_list=obj_spec_output, obj=meta_info, group=mpu.get_pipeline_model_parallel_group() + ) layer_list_meta = [item for sublist in obj_spec_output for item in sublist] gen_func = tensor_generator() @@ -774,7 +802,9 @@ def tensor_generator(): if model_config.tie_word_embeddings and ("output_layers" in name): import warnings - warnings.warn("Current model sharing word and embedding weights, skip output layer conversion", stacklevel=2) + warnings.warn( + "Current model sharing word and embedding weights, skip output layer conversion", stacklevel=2 + ) continue if cur_pp_rank == pp_rank: @@ -806,7 +836,7 @@ def tensor_generator(): global_expert_ids = [num_experts_per_rank * ep_rank + local_expert_id for ep_rank in range(ep_size)] global_expert_names = [f"{name_prefix}.weight{expert_id}" for expert_id in global_expert_ids] - for name, param in zip(global_expert_names, infer_params): + for name, param in zip(global_expert_names, infer_params, strict=True): if etp_size > 1: # gather etp etp_params = [torch.empty_like(param) for _ in range(etp_size)] @@ -815,12 +845,20 @@ def tensor_generator(): else: params = [param] - merge_params = default_tp_concat_fn(layer_name_mapping, name, broad_pp_tensor, params, model_config, convert_qkv_gate_up_by_simple_split) + merge_params = default_tp_concat_fn( + layer_name_mapping, + name, + broad_pp_tensor, + params, + model_config, + weight_converter.hf_config, + convert_qkv_gate_up_by_simple_split, + ) if not isinstance(merge_params, list): merge_params = [merge_params] converted_names, converted_params = weight_converter.convert_param(name, merge_params) - yield from zip(converted_names, converted_params) + yield from zip(converted_names, converted_params, strict=True) continue # tp all gather @@ -831,7 +869,15 @@ def tensor_generator(): else: infer_params = [torch.empty_like(broad_pp_tensor) for _ in range(all_gather_group_size)] torch.distributed.all_gather(infer_params, broad_pp_tensor, group=mpu.get_tensor_model_parallel_group()) - infer_params = default_tp_concat_fn(layer_name_mapping, cur_name, broad_pp_tensor, infer_params, model_config, convert_qkv_gate_up_by_simple_split) + infer_params = default_tp_concat_fn( + layer_name_mapping, + cur_name, + broad_pp_tensor, + infer_params, + model_config, + weight_converter.hf_config, + convert_qkv_gate_up_by_simple_split, + ) else: infer_params = broad_pp_tensor @@ -839,7 +885,7 @@ def tensor_generator(): infer_params = [infer_params] converted_names, converted_params = weight_converter.convert_param(cur_name, infer_params) - yield from zip(converted_names, converted_params) + yield from zip(converted_names, converted_params, strict=True) def get_transformer_layer_offset(pipeline_rank, vp_rank, config: TransformerConfig): @@ -853,7 +899,10 @@ def get_transformer_layer_offset(pipeline_rank, vp_rank, config: TransformerConf Extension to https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/transformer_layer.py::get_transformer_layer_offset""" ''' if config.pipeline_model_parallel_size > 1: - if config.num_layers_in_first_pipeline_stage is not None or config.num_layers_in_last_pipeline_stage is not None: + if ( + config.num_layers_in_first_pipeline_stage is not None + or config.num_layers_in_last_pipeline_stage is not None + ): # Calculate number of pipeline stages to distribute the remaining Transformer # layers after deducting the Transformer layers in the first or the last stages middle_pipeline_stages = config.pipeline_model_parallel_size @@ -871,10 +920,16 @@ def get_transformer_layer_offset(pipeline_rank, vp_rank, config: TransformerConf # num_layers_in_first_pipeline_stage and num_layers_in_last_pipeline_stage # are not set, we will not enable uneven pipeline. All layers will be treated # as middle layers. - num_layers_in_first_pipeline_stage = 0 if config.num_layers_in_first_pipeline_stage is None else config.num_layers_in_first_pipeline_stage - num_layers_in_last_pipeline_stage = 0 if config.num_layers_in_last_pipeline_stage is None else config.num_layers_in_last_pipeline_stage + num_layers_in_first_pipeline_stage = ( + 0 if config.num_layers_in_first_pipeline_stage is None else config.num_layers_in_first_pipeline_stage + ) + num_layers_in_last_pipeline_stage = ( + 0 if config.num_layers_in_last_pipeline_stage is None else config.num_layers_in_last_pipeline_stage + ) - middle_num_layers = config.num_layers - num_layers_in_first_pipeline_stage - num_layers_in_last_pipeline_stage + middle_num_layers = ( + config.num_layers - num_layers_in_first_pipeline_stage - num_layers_in_last_pipeline_stage + ) if mpu.get_virtual_pipeline_model_parallel_world_size() is not None: vp_size = mpu.get_virtual_pipeline_model_parallel_world_size() @@ -883,27 +938,46 @@ def get_transformer_layer_offset(pipeline_rank, vp_rank, config: TransformerConf # If the num_layers_in_first_pipeline_stage and # num_layers_in_last_pipeline_stage are not set, all pipeline stages # will be treated as middle pipeline stages in the calculation - num_layers_per_virtual_model_chunk_in_first_pipeline_stage = 0 if config.num_layers_in_first_pipeline_stage is None else config.num_layers_in_first_pipeline_stage // vp_size - - num_layers_per_virtual_model_chunk_in_last_pipeline_stage = 0 if config.num_layers_in_last_pipeline_stage is None else config.num_layers_in_last_pipeline_stage // vp_size + num_layers_per_virtual_model_chunk_in_first_pipeline_stage = ( + 0 + if config.num_layers_in_first_pipeline_stage is None + else config.num_layers_in_first_pipeline_stage // vp_size + ) + + num_layers_per_virtual_model_chunk_in_last_pipeline_stage = ( + 0 + if config.num_layers_in_last_pipeline_stage is None + else config.num_layers_in_last_pipeline_stage // vp_size + ) num_layers_per_vritual_model_chunk_in_middle_pipeline_stage = middle_num_layers // vp_size # First stage + middle stage + last stage - total_virtual_chunks = num_layers_per_virtual_model_chunk_in_first_pipeline_stage + num_layers_per_vritual_model_chunk_in_middle_pipeline_stage + num_layers_per_virtual_model_chunk_in_last_pipeline_stage + total_virtual_chunks = ( + num_layers_per_virtual_model_chunk_in_first_pipeline_stage + + num_layers_per_vritual_model_chunk_in_middle_pipeline_stage + + num_layers_per_virtual_model_chunk_in_last_pipeline_stage + ) # Calculate the layer offset with interleaved uneven pipeline parallelism if pipeline_rank == 0: offset = vp_rank * total_virtual_chunks else: - offset = vp_rank * total_virtual_chunks + num_layers_per_virtual_model_chunk_in_first_pipeline_stage + (pipeline_rank - 1) * (num_layers_per_vritual_model_chunk_in_middle_pipeline_stage // middle_pipeline_stages) + offset = ( + vp_rank * total_virtual_chunks + + num_layers_per_virtual_model_chunk_in_first_pipeline_stage + + (pipeline_rank - 1) + * (num_layers_per_vritual_model_chunk_in_middle_pipeline_stage // middle_pipeline_stages) + ) else: if middle_pipeline_stages > 0: num_layers_per_pipeline_rank = middle_num_layers // middle_pipeline_stages else: num_layers_per_pipeline_rank = 0 - middle_pipeline_rank = pipeline_rank if config.num_layers_in_first_pipeline_stage is None else pipeline_rank - 1 + middle_pipeline_rank = ( + pipeline_rank if config.num_layers_in_first_pipeline_stage is None else pipeline_rank - 1 + ) if pipeline_rank == 0: offset = 0 diff --git a/verl/utils/memory_buffer.py b/verl/utils/memory_buffer.py index 9396b41a6..9386f0d88 100644 --- a/verl/utils/memory_buffer.py +++ b/verl/utils/memory_buffer.py @@ -15,11 +15,13 @@ This file contains utilities to manipulate torch memory buffers """ -from typing import Dict, List, Optional +from typing import Optional import torch from torch import nn +from verl.utils.device import get_device_name + class MemoryBuffer: """ @@ -34,7 +36,7 @@ def __init__(self, numel: int, numel_padded: int, dtype: torch.dtype, source: Op if source is not None: self.data = source else: - self.data = torch.zeros(self.numel_padded, dtype=self.dtype, device="cuda", requires_grad=False) + self.data = torch.zeros(self.numel_padded, dtype=self.dtype, device=get_device_name(), requires_grad=False) def zero(self): """Reset the buffer to zero.""" @@ -57,7 +59,7 @@ def calc_padded_numel(shape: torch.Size, dtype: torch.dtype): return (numel + align_numel - 1) // align_numel * align_numel -def get_weight_buffer_meta_from_module(module: nn.Module) -> Dict[str, Dict]: +def get_weight_buffer_meta_from_module(module: nn.Module) -> dict[str, dict]: """ Return a dictionary containing name to a shape and dtype. """ @@ -67,7 +69,7 @@ def get_weight_buffer_meta_from_module(module: nn.Module) -> Dict[str, Dict]: return weight_buffer_meta -def build_memory_buffer(weight_buffer_meta: Dict[str, Dict]) -> Dict[torch.dtype, MemoryBuffer]: +def build_memory_buffer(weight_buffer_meta: dict[str, dict]) -> dict[torch.dtype, MemoryBuffer]: """Build the memory buffer given weight_buffer_meta Args: @@ -96,7 +98,9 @@ def build_memory_buffer(weight_buffer_meta: Dict[str, Dict]) -> Dict[torch.dtype return memory_buffers -def build_memory_reference_from_module(module: torch.nn.Module, memory_buffers: Dict[torch.dtype, MemoryBuffer], maintain_weight=True): +def build_memory_reference_from_module( + module: torch.nn.Module, memory_buffers: dict[torch.dtype, MemoryBuffer], maintain_weight=True +): start_index = {} for dtype in memory_buffers: start_index[dtype] = 0 @@ -110,7 +114,7 @@ def build_memory_reference_from_module(module: torch.nn.Module, memory_buffers: param.data = buffer -def build_memory_reference(weight_buffer_meta: Dict[str, Dict], memory_buffers: Dict[torch.dtype, MemoryBuffer]): +def build_memory_reference(weight_buffer_meta: dict[str, dict], memory_buffers: dict[torch.dtype, MemoryBuffer]): """Build the memory references. The memory buffers are built using the build_memory_buffer API. This API will allocate a weight buffer pointer to the memory buffer according to the weight_buffer_meta. @@ -178,7 +182,7 @@ def __init__(self, transform_memory_param_fn): self._named_parameters = {} self.transform_memory_param_fn = transform_memory_param_fn - def initialize_weight_buffer(self, weight_buffer_meta_pp: List[Dict[str, Dict]]): + def initialize_weight_buffer(self, weight_buffer_meta_pp: list[dict[str, dict]]): """ Initialize the weight buffer. The weight buffer is obtained according to the actor. We will construct a large buffer for each dtype in the weight_buffer. diff --git a/verl/utils/metric/utils.py b/verl/utils/metric/utils.py index f3281b3c4..f9e7cd511 100644 --- a/verl/utils/metric/utils.py +++ b/verl/utils/metric/utils.py @@ -15,12 +15,12 @@ Metrics utils. """ -from typing import Any, Dict, List +from typing import Any import numpy as np -def reduce_metrics(metrics: Dict[str, List[Any]]) -> Dict[str, Any]: +def reduce_metrics(metrics: dict[str, list[Any]]) -> dict[str, Any]: """ Reduces a dictionary of metric lists by computing the mean, max, or min of each list. The reduce operation is determined by the key name: diff --git a/verl/utils/model.py b/verl/utils/model.py index 3074ea0c0..04cc34fe5 100644 --- a/verl/utils/model.py +++ b/verl/utils/model.py @@ -16,8 +16,10 @@ """ import os +import re import warnings -from typing import Dict, Optional, Type +from dataclasses import dataclass +from typing import Optional import numpy as np import torch @@ -28,9 +30,12 @@ GenerationConfig, MistralForSequenceClassification, PretrainedConfig, + PreTrainedModel, ) +from transformers.modeling_outputs import CausalLMOutputWithPast from verl.models.registry import ModelRegistry +from verl.utils.import_utils import is_trl_available class LambdaLayer(nn.Module): @@ -59,10 +64,12 @@ def update_model_config(module_config, override_config_kwargs): setattr(module_config, key, val) -def get_huggingface_actor_config(model_name: str, override_config_kwargs=None, trust_remote_code=False) -> Dict: +def get_huggingface_actor_config(model_name: str, override_config_kwargs=None, trust_remote_code=False) -> dict: if override_config_kwargs is None: override_config_kwargs = {} - assert isinstance(override_config_kwargs, Dict), f"override_config_kwargs must be a dict, got {type(override_config_kwargs)}" + assert isinstance(override_config_kwargs, dict), ( + f"override_config_kwargs must be a dict, got {type(override_config_kwargs)}" + ) module_config = AutoConfig.from_pretrained(model_name, trust_remote_code=trust_remote_code) update_model_config(module_config, override_config_kwargs) @@ -100,8 +107,12 @@ def create_huggingface_actor(model_name: str, override_config_kwargs=None, autom override_config_kwargs = {} if automodel_kwargs is None: automodel_kwargs = {} - assert isinstance(override_config_kwargs, Dict), f"override_config_kwargs must be a dict, got {type(override_config_kwargs)}" - module_config = get_huggingface_actor_config(model_name, override_config_kwargs, trust_remote_code=automodel_kwargs.get("trust_remote_code", False)) + assert isinstance(override_config_kwargs, dict), ( + f"override_config_kwargs must be a dict, got {type(override_config_kwargs)}" + ) + module_config = get_huggingface_actor_config( + model_name, override_config_kwargs, trust_remote_code=automodel_kwargs.get("trust_remote_code", False) + ) module: nn.Module = AutoModelForCausalLM.from_config(module_config, **automodel_kwargs) return module @@ -116,11 +127,15 @@ def create_huggingface_critic(model_name: str, override_config_kwargs=None, auto Returns: """ - critic_module: nn.Module = create_huggingface_actor(model_name, override_config_kwargs=override_config_kwargs, automodel_kwargs=automodel_kwargs) + critic_module: nn.Module = create_huggingface_actor( + model_name, override_config_kwargs=override_config_kwargs, automodel_kwargs=automodel_kwargs + ) if automodel_kwargs is None: automodel_kwargs = {} torch_dtype = automodel_kwargs.get("torch_dtype", torch.float32) - critic_module.lm_head = nn.Sequential(nn.Linear(critic_module.config.hidden_size, 1, dtype=torch_dtype), LambdaLayer(fn=squeeze)) + critic_module.lm_head = nn.Sequential( + nn.Linear(critic_module.config.hidden_size, 1, dtype=torch_dtype), LambdaLayer(fn=squeeze) + ) return critic_module @@ -205,6 +220,101 @@ def compute_position_id_with_mask(mask): return torch.clip(torch.cumsum(mask, dim=-1) - 1, min=0, max=None) +def convert_weight_keys(state_dict: dict[str, torch.Tensor], model: PreTrainedModel): + # convert state dict keys: https://github.com/huggingface/transformers/pull/38385 + if not hasattr(model, "_checkpoint_conversion_mapping"): + return state_dict + + reverse_key_mapping = {v: k for k, v in model._checkpoint_conversion_mapping.items()} + original_weights = {} + for key, value in state_dict.items(): + for pattern, replacement in reverse_key_mapping.items(): + replacement = replacement.lstrip("^") # strip off un-needed chars and patterns + replacement = re.sub(r"\(.*\)", "", replacement) + key, n_replace = re.subn(pattern, replacement, key) + # Early exit of the loop + if n_replace > 0: + break + + original_weights[key] = value + + return original_weights + + +def check_exclude_modules(config, key: str) -> bool: + """ + A helper method to check if the passed module's key name matches any of the exclude modules in the adapter_config. + Adapted from https://github.com/huggingface/peft/blob/main/src/peft/tuners/tuners_utils.py + + Args: + config (`LoraConfig` | `LycorisConfig`): A config to match exclude modules from + key (`str`): A key to search any matches in config + + Returns: + True of match object if key matches any exclude modules from config, False if no match found + """ + if hasattr(config, "exclude_modules") and config.exclude_modules: + if isinstance(config.exclude_modules, str): + if re.fullmatch(config.exclude_modules, key): + return True + elif key in config.exclude_modules: + return True + elif any(key.endswith(f".{exclude_key}") for exclude_key in config.exclude_modules): + return True + return False + + +def check_target_modules(config, key: str) -> bool: + """ + A helper method to check if the passed module's key name matches any of the target modules in the adapter_config. + Adapted from https://github.com/huggingface/peft/blob/main/src/peft/tuners/tuners_utils.py + + Args: + config (`LoraConfig` | `LycorisConfig`): A config to match target modules from + key (`str`): A key to search any matches in config + + Returns: + True of match object if key matches any target modules from config, False if no match found + """ + if isinstance(config.target_modules, str): + target_module_found = re.fullmatch(config.target_modules, key) + elif key in config.target_modules: + # this module is specified directly in target_modules + target_module_found = True + else: + target_module_found = any(key.endswith(f".{target_key}") for target_key in config.target_modules) + + layer_indexes = getattr(config, "layers_to_transform", None) + layers_pattern = getattr(config, "layers_pattern", None) + + is_using_layer_indexes = layer_indexes is not None and ( + len(layer_indexes) != 0 if isinstance(layer_indexes, list) else True + ) + if is_using_layer_indexes and target_module_found: + layer_index = None + # TODO: It's still unclear how empty layers_pattern (None, [], or "") should behave + # For now, empty layers_pattern means any layer pattern is ok + if layers_pattern is None or len(layers_pattern) == 0: + layer_index = re.match(r".*\.[^.]*\.(\d+)\.", key) + else: + layers_pattern = [layers_pattern] if isinstance(layers_pattern, str) else layers_pattern + for pattern in layers_pattern: + layer_index = re.match(rf".*\.{pattern}\.(\d+)\.", key) + if layer_index is not None: + break + + if layer_index is None: + target_module_found = False + else: + layer_index = int(layer_index.group(1)) + if isinstance(layer_indexes, int): + target_module_found = layer_index == layer_indexes + else: + target_module_found = layer_index in layer_indexes + + return target_module_found + + def normalize_model_name(name, pp_rank, vpp_rank, transformer_config, layer_name="layers"): """ Transform the model name in each model_chunk in each pp stage into the name in inference engine @@ -244,11 +354,15 @@ def normalize_pp_vpp_params(params, num_hidden_layers, layer_name="layers"): vpp_size = len(params[pp_rank]) for vpp_rank in range(vpp_size): for name, param in params[pp_rank][vpp_rank].items(): - normalized_name = normalize_model_name(name, pp_rank, vpp_rank, pp_size, vpp_size, num_hidden_layers, layer_name=layer_name) + normalized_name = normalize_model_name( + name, pp_rank, vpp_rank, pp_size, vpp_size, num_hidden_layers, layer_name=layer_name + ) yield normalized_name, param -def get_parallel_model_from_config(config, megatron_config, pre_process=None, post_process=None, share_embeddings_and_output_weights=False, value=False): +def get_parallel_model_from_config( + config, megatron_config, pre_process=None, post_process=None, share_embeddings_and_output_weights=False, value=False +): from megatron.core import ModelParallelConfig assert isinstance(megatron_config, ModelParallelConfig) @@ -264,14 +378,17 @@ def get_parallel_model_from_config(config, megatron_config, pre_process=None, po return model -def _get_parallel_model_architecture_from_config(config: PretrainedConfig, value=False) -> Type[nn.Module]: +def _get_parallel_model_architecture_from_config(config: PretrainedConfig, value=False) -> type[nn.Module]: architectures = getattr(config, "architectures", []) for arch in architectures: model_cls = ModelRegistry.load_model_cls(arch, value) print("after load model cls") if model_cls is not None: return model_cls - raise ValueError(f"Model architectures {architectures} are not supported for now. Supported architectures: {ModelRegistry.get_supported_archs()}") + raise ValueError( + f"Model architectures {architectures} are not supported for now. Supported architectures: " + f"{ModelRegistry.get_supported_archs()}" + ) def _load_hf_model(config, model_config, is_value_model, local_cache_path): @@ -289,7 +406,9 @@ def _load_hf_model(config, model_config, is_value_model, local_cache_path): from verl.utils.fs import copy_to_local print(f"start download from {config.model.path}") - local_model_path = copy_to_local(src=config.model.path, cache_dir=local_cache_path, use_shm=config.model.get('use_shm', False)) + local_model_path = copy_to_local( + src=config.model.path, cache_dir=local_cache_path, use_shm=config.model.get("use_shm", False) + ) print("finish download") else: local_model_path = config.model.path @@ -310,7 +429,9 @@ def _load_hf_model(config, model_config, is_value_model, local_cache_path): ) # use score head instead of lm_head state_dict = model.state_dict() state_dict["lm_head.weight"] = state_dict["score.weight"] - state_dict["model.embed_tokens.weight"] = state_dict["model.embed_tokens.weight"][:32000] # workaround, 32001 -> 32000 + state_dict["model.embed_tokens.weight"] = state_dict["model.embed_tokens.weight"][ + :32000 + ] # workaround, 32001 -> 32000 is_value_model = True else: model = AutoModelForCausalLM.from_pretrained( @@ -324,9 +445,26 @@ def _load_hf_model(config, model_config, is_value_model, local_cache_path): return architectures, model, state_dict, is_value_model -def load_megatron_model_weights(config, model_config, parallel_model, params_dtype, is_value_model=False, local_cache_path="~/.cache/verl/rlhf"): +def get_hf_model_path(config, local_cache_path="~/.cache/verl/rlhf"): + local_cache_path = os.path.expanduser(local_cache_path) + if config.model.path.startswith("hdfs:"): + from verl.utils.fs import copy_to_local + + local_model_path = copy_to_local( + src=config.model.path, cache_dir=local_cache_path, use_shm=config.model.get("use_shm", False) + ) + else: + local_model_path = config.model.path + return local_model_path + + +def load_megatron_model_weights( + config, model_config, parallel_model, params_dtype, is_value_model=False, local_cache_path="~/.cache/verl/rlhf" +): """Load weights for verl customized model.""" - architectures, model, state_dict, is_value_model = _load_hf_model(config, model_config, is_value_model, local_cache_path) + architectures, model, state_dict, is_value_model = _load_hf_model( + config, model_config, is_value_model, local_cache_path + ) from verl.models.weight_loader_registry import get_weight_loader @@ -345,7 +483,9 @@ def load_megatron_model_weights(config, model_config, parallel_model, params_dty return model.config -def load_megatron_gptmodel_weights(config, model_config, parallel_model, params_dtype, is_value_model=False, local_cache_path="~/.cache/verl/rlhf"): +def load_megatron_gptmodel_weights( + config, model_config, parallel_model, params_dtype, is_value_model=False, local_cache_path="~/.cache/verl/rlhf" +): """Load weights for mcore GPT model.""" _, model, state_dict, is_value_model = _load_hf_model(config, model_config, is_value_model, local_cache_path) @@ -398,15 +538,13 @@ def pad_packed_inputs(unpad_tokens: torch.Tensor, cu_seqlens, max_seqlen_in_batc def load_mcore_dist_weights(parallel_model, dist_weight_path, is_value_model=False): from megatron.core import dist_checkpointing from megatron.core.dist_checkpointing.serialization import StrictHandling - from megatron.core.models.gpt.gpt_model import GPTModel + + from verl.utils.megatron_utils import unwrap_model # strict = StrictHandling.IGNORE_ALL if is_value_model else StrictHandling.ASSUME_OK_UNEXPECTED strict = StrictHandling.ASSUME_OK_UNEXPECTED for model in parallel_model: - if isinstance(model.module, GPTModel): - ssd = model.module.sharded_state_dict() - else: - ssd = model.module.module.sharded_state_dict() + ssd = unwrap_model(model).sharded_state_dict() if is_value_model: for k in list(ssd.keys()): if "output_layer" in k: @@ -416,7 +554,9 @@ def load_mcore_dist_weights(parallel_model, dist_weight_path, is_value_model=Fal return -def get_parallel_gptmodel_from_config(tfconfig, hf_config, pre_process=None, post_process=None, share_embeddings_and_output_weights=False, value=False): +def get_parallel_gptmodel_from_config( + tfconfig, hf_config, pre_process=None, post_process=None, share_embeddings_and_output_weights=False, value=False +): from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec from megatron.core.models.gpt.gpt_model import GPTModel @@ -444,5 +584,81 @@ def get_parallel_gptmodel_from_config(tfconfig, hf_config, pre_process=None, pos if post_process and value: from verl.models.llama.megatron.layers.parallel_linear import LinearForLastLayer - parallel_model.output_layer = LinearForLastLayer(input_size=tfconfig.hidden_size, output_size=1, config=tfconfig) + parallel_model.output_layer = LinearForLastLayer( + input_size=tfconfig.hidden_size, output_size=1, config=tfconfig + ) return parallel_model + + +def patch_valuehead_model(model) -> None: + from types import MethodType + + from transformers import PreTrainedModel + from trl import AutoModelForCausalLMWithValueHead + + def tie_weights(self: "AutoModelForCausalLMWithValueHead") -> None: + if isinstance(self.pretrained_model, PreTrainedModel): + self.pretrained_model.tie_weights() + + def get_input_embeddings(self: "AutoModelForCausalLMWithValueHead") -> torch.nn.Module: + if isinstance(self.pretrained_model, PreTrainedModel): + return self.pretrained_model.get_input_embeddings() + + def get_output_embeddings(self: "AutoModelForCausalLMWithValueHead") -> torch.nn.Module: + if isinstance(self.pretrained_model, PreTrainedModel): + return self.pretrained_model.get_output_embeddings() + + def can_generate(self): + return False + + ignore_modules = [name for name, _ in model.named_parameters() if "pretrained_model" in name] + model._keys_to_ignore_on_save = ignore_modules + model.tie_weights = MethodType(tie_weights, model) + model.get_input_embeddings = MethodType(get_input_embeddings, model) + model.get_output_embeddings = MethodType(get_output_embeddings, model) + model.can_generate = MethodType(can_generate, model) + model._no_split_modules = getattr(model.pretrained_model, "_no_split_modules", []) + + +def load_valuehead_model(local_path, torch_dtype, model_config, trust_remote_code): + from transformers import AutoModelForCausalLM, AutoModelForTokenClassification, AutoModelForVision2Seq + + try: + model = AutoModelForTokenClassification.from_pretrained( + pretrained_model_name_or_path=local_path, + torch_dtype=torch_dtype, + config=model_config, + attn_implementation="flash_attention_2", + trust_remote_code=trust_remote_code, + ) + return model + except BaseException as e: + if not is_trl_available(): + raise RuntimeError( + f"model({local_path}) is not a value head model, please install trl to make it valid" + ) from e + + assert is_trl_available() + + from trl import AutoModelForCausalLMWithValueHead + + if type(model_config) in AutoModelForVision2Seq._model_mapping.keys(): + module_class = AutoModelForVision2Seq + else: + module_class = AutoModelForCausalLM + ori_model = module_class.from_pretrained( + pretrained_model_name_or_path=local_path, + torch_dtype=torch_dtype, + config=model_config, + attn_implementation="flash_attention_2", + trust_remote_code=trust_remote_code, + ) + model = AutoModelForCausalLMWithValueHead.from_pretrained(ori_model) + patch_valuehead_model(model) + return model + + +@dataclass +class CausalLMOutputForPPO(CausalLMOutputWithPast): + log_probs: Optional[torch.FloatTensor] = None + entropy: Optional[torch.FloatTensor] = None diff --git a/verl/utils/profiler/__init__.py b/verl/utils/profiler/__init__.py new file mode 100644 index 000000000..2242c24fe --- /dev/null +++ b/verl/utils/profiler/__init__.py @@ -0,0 +1,41 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ..device import is_npu_available +from ..import_utils import is_nvtx_available +from .performance import GPUMemoryLogger, log_gpu_memory_usage, simple_timer +from .profile import DistProfilerExtension, ProfilerConfig + +if is_nvtx_available(): + from .nvtx_profile import NsightSystemsProfiler as DistProfiler + from .nvtx_profile import mark_annotate, mark_end_range, mark_start_range, marked_timer +elif is_npu_available: + from .mstx_profile import NPUProfiler as DistProfiler + from .mstx_profile import mark_annotate, mark_end_range, mark_start_range, marked_timer +else: + from .performance import marked_timer + from .profile import DistProfiler, mark_annotate, mark_end_range, mark_start_range + +__all__ = [ + "GPUMemoryLogger", + "log_gpu_memory_usage", + "mark_start_range", + "mark_end_range", + "mark_annotate", + "DistProfiler", + "DistProfilerExtension", + "ProfilerConfig", + "simple_timer", + "marked_timer", +] diff --git a/verl/utils/profiler/config.py b/verl/utils/profiler/config.py new file mode 100644 index 000000000..d4fb53650 --- /dev/null +++ b/verl/utils/profiler/config.py @@ -0,0 +1,61 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import ClassVar + +from verl.base_config import BaseConfig + + +@dataclass +class ProfilerConfig(BaseConfig): + """Worker profiler config. Currently only support Nsight system profiler. + + The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config. + + Args: + discrete (bool): True for each task has its own database, False for all tasks in one training step + share one database. + all_ranks (bool): Whether to profile all ranks. + ranks (list[int]): The ranks that will be profiled. Defaults to []. + """ + + # the fields expected to be frozen + _frozen_fields: ClassVar[set[str]] = {"discrete", "all_ranks", "ranks"} + + discrete: bool = False + + all_ranks: bool = False + + ranks: list[int] = field(default_factory=list) + + def union(self, other: "ProfilerConfig") -> "ProfilerConfig": + return ProfilerConfig( + all_ranks=self.all_ranks or other.all_ranks, + ranks=list(set(self.ranks or []) | set(other.ranks or [])), + discrete=self.discrete or other.discrete, + ) + + def intersect(self, other: "ProfilerConfig") -> "ProfilerConfig": + return ProfilerConfig( + all_ranks=self.all_ranks and other.all_ranks, + ranks=list(set(self.ranks or []) & set(other.ranks or [])), + discrete=self.discrete and other.discrete, + ) + + def __post_init__(self) -> None: + """config validation logics go here""" + assert isinstance(self.ranks, set | list | tuple), ( + f"Profiler ranks must be of type list, got {type(self.ranks)}" + ) diff --git a/verl/utils/profiler/empty_annotations.py b/verl/utils/profiler/empty_annotations.py new file mode 100644 index 000000000..ed18dd359 --- /dev/null +++ b/verl/utils/profiler/empty_annotations.py @@ -0,0 +1,40 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Optional + + +def mark_start_range( + message: Optional[str] = None, + color: Optional[str] = None, + domain: Optional[str] = None, + category: Optional[str] = None, +) -> None: + pass + + +def mark_end_range(range_id: str) -> None: + pass + + +def mark_annotate( + message: Optional[str] = None, + color: Optional[str] = None, + domain: Optional[str] = None, + category: Optional[str] = None, +) -> Callable: + def decorator(func): + return func + + return decorator diff --git a/verl/utils/profiler/mstx_profile.py b/verl/utils/profiler/mstx_profile.py new file mode 100644 index 000000000..c5c35cec0 --- /dev/null +++ b/verl/utils/profiler/mstx_profile.py @@ -0,0 +1,219 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Inspired from https://gitee.com/ascend/MindSpeed-RL/blob/master/mindspeed_rl/utils/utils.py +import functools +import os +from contextlib import contextmanager +from typing import Callable, Optional + +import torch_npu +from omegaconf import DictConfig +from torch_npu.npu import mstx + +from .profile import DistProfiler, ProfilerConfig + + +def mark_start_range(message: Optional[str] = None) -> None: + """Start a mark range in the profiler. + + Args: + message (str, optional): + The message to be displayed in the profiler. Defaults to None. + """ + return mstx.range_start(message=message) + + +def mark_end_range(range_id: str) -> None: + """End a mark range in the profiler. + + Args: + range_id (str): + The id of the mark range to end. + """ + return mstx.range_end(range_id) + + +def mark_annotate(message: Optional[str] = None) -> Callable: + """Decorate a function to annotate a mark range along with the function life cycle. + + Args: + message (str, optional): + The message to be displayed in the profiler. Defaults to None. + """ + + def decorator(func): + profile_message = message or func.__name__ + return mstx.mstx_range(profile_message)(func) + + return decorator + + +@contextmanager +def marked_timer(name: str, timing_raw: dict[str, float], **kwargs): + """Context manager for timing with MSTX markers. + + This utility function measures the execution time of code within its context, + accumulates the timing information, and adds MSTX markers for profiling. + + Args: + name (str): The name/identifier for this timing measurement. + timing_raw (Dict[str, float]): Dictionary to store timing information. + + Yields: + None: This is a context manager that yields control back to the code block. + """ + mark_range = mark_start_range(message=name) + from .performance import _timer + + yield from _timer(name, timing_raw) + mark_end_range(mark_range) + + +def get_npu_profiler(option: DictConfig, role: Optional[str] = None, profile_step: Optional[str] = None): + """Generate and return an NPU profiler object. + + Args: + option (DictConfig): + The options to control npu profiler. + role (str, optional): + The role of the current data collection. Defaults to None. + profile_step(str, optional): + The current training step. Defaults to None. + """ + if option.level == "level_none": + profile_level = torch_npu.profiler.ProfilerLevel.Level_none + elif option.level == "level0": + profile_level = torch_npu.profiler.ProfilerLevel.Level0 + elif option.level == "level1": + profile_level = torch_npu.profiler.ProfilerLevel.Level1 + elif option.level == "level2": + profile_level = torch_npu.profiler.ProfilerLevel.Level2 + else: + raise ValueError(f"level only supports level0, 1, 2, and level_none, but gets {option.level}") + + profile_save_path = option.save_path + if profile_step: + profile_save_path = os.path.join(profile_save_path, profile_step) + if role: + profile_save_path = os.path.join(profile_save_path, role) + + experimental_config = torch_npu.profiler._ExperimentalConfig( + aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization, + profiler_level=profile_level, + export_type=torch_npu.profiler.ExportType.Text, + data_simplification=True, + msprof_tx=True, + ) + + activites = [] + if option.with_npu: + activites.append(torch_npu.profiler.ProfilerActivity.NPU) + if option.with_cpu: + activites.append(torch_npu.profiler.ProfilerActivity.CPU) + + prof = torch_npu.profiler.profile( + with_modules=option.with_module, + with_stack=option.with_stack, + record_shapes=option.record_shapes, + profile_memory=option.with_memory, + activities=activites, + on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(profile_save_path, analyse_flag=option.analysis), + experimental_config=experimental_config, + ) + return prof + + +class NPUProfiler(DistProfiler): + """ + NPU profiler. Initialized in a worker to control the NPU profiler. + """ + + _define_count = 0 + + def __init__(self, rank: int, config: ProfilerConfig, **kwargs): + """Initialize the NsightSystemsProfiler. + + Args: + rank (int): The rank of the current process. + config (Optional[ProfilerConfig]): Configuration for the profiler. If None, a default configuration is used. + """ + if not config: + config = ProfilerConfig(ranks=[]) + self.this_step: bool = False + self.discrete: bool = config.discrete + self.this_rank: bool = False + self.profile_npu = None + self.profile_option = kwargs.get("option", None) + if config.all_ranks: + self.this_rank = True + elif config.ranks: + self.this_rank = rank in config.ranks + + def start(self, **kwargs): + role, profile_step = kwargs.get("role", None), kwargs.get("profile_step", None) + profile_step = str(profile_step) if profile_step is not None else None + if self.this_rank and self.profile_option is not None: + self.this_step = True + if not self.discrete and NPUProfiler._define_count == 0: + self.profile_npu = get_npu_profiler(option=self.profile_option, role=role, profile_step=profile_step) + self.profile_npu.start() + NPUProfiler._define_count += 1 + + def stop(self): + if self.this_rank and self.profile_option is not None: + self.this_step = False + if not self.discrete and NPUProfiler._define_count == 1: + self.profile_npu.step() + self.profile_npu.stop() + NPUProfiler._define_count -= 1 + + @staticmethod + def annotate(message: Optional[str] = None, role: Optional[str] = None, **kwargs) -> Callable: + """Decorate a Worker member function to profile the current rank in the current training step. + + Requires the target function to be a member function of a Worker, + which has a member field `profiler` with NPUProfiler type. + + Args: + message (str, optional): + The message to be displayed in the profiler. Defaults to None. + role (str, optional): + The role of the current data collection. Defaults to None. + """ + + def decorator(func): + @functools.wraps(func) + def wrapper(self, *args, **kwargs): + profile_name = message or func.__name__ + + if self.profiler.this_step and self.profile_option is not None: + if self.profiler.discrete: + profile_npu = get_npu_profiler(option=self.profile_option, role=role) + profile_npu.start() + mark_range = mark_start_range(message=profile_name) + + result = func(self, *args, **kwargs) + + if self.profiler.this_step and self.profile_option is not None: + mark_end_range(mark_range) + if self.profiler.discrete: + profile_npu.step() + profile_npu.stop() + + return result + + return wrapper + + return decorator diff --git a/verl/utils/profiler/nvtx_profile.py b/verl/utils/profiler/nvtx_profile.py new file mode 100644 index 000000000..9ebce374f --- /dev/null +++ b/verl/utils/profiler/nvtx_profile.py @@ -0,0 +1,191 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +from contextlib import contextmanager +from typing import Callable, Optional + +import nvtx +import torch + +from .profile import DistProfiler, ProfilerConfig + + +def mark_start_range( + message: Optional[str] = None, + color: Optional[str] = None, + domain: Optional[str] = None, + category: Optional[str] = None, +) -> None: + """Start a mark range in the profiler. + + Args: + message (str, optional): + The message to be displayed in the profiler. Defaults to None. + color (str, optional): + The color of the range. Defaults to None. + domain (str, optional): + The domain of the range. Defaults to None. + category (str, optional): + The category of the range. Defaults to None. + """ + return nvtx.start_range(message=message, color=color, domain=domain, category=category) + + +def mark_end_range(range_id: str) -> None: + """End a mark range in the profiler. + + Args: + range_id (str): + The id of the mark range to end. + """ + return nvtx.end_range(range_id) + + +def mark_annotate( + message: Optional[str] = None, + color: Optional[str] = None, + domain: Optional[str] = None, + category: Optional[str] = None, +) -> Callable: + """Decorate a function to annotate a mark range along with the function life cycle. + + Args: + message (str, optional): + The message to be displayed in the profiler. Defaults to None. + color (str, optional): + The color of the range. Defaults to None. + domain (str, optional): + The domain of the range. Defaults to None. + category (str, optional): + The category of the range. Defaults to None. + """ + + def decorator(func): + profile_message = message or func.__name__ + return nvtx.annotate(profile_message, color=color, domain=domain, category=category)(func) + + return decorator + + +@contextmanager +def marked_timer( + name: str, + timing_raw: dict[str, float], + color: str = None, + domain: Optional[str] = None, + category: Optional[str] = None, +): + """Context manager for timing with NVTX markers. + + This utility function measures the execution time of code within its context, + accumulates the timing information, and adds NVTX markers for profiling. + + Args: + name (str): The name/identifier for this timing measurement. + timing_raw (Dict[str, float]): Dictionary to store timing information. + color (Optional[str]): Color for the NVTX marker. Defaults to None. + domain (Optional[str]): Domain for the NVTX marker. Defaults to None. + category (Optional[str]): Category for the NVTX marker. Defaults to None. + + Yields: + None: This is a context manager that yields control back to the code block. + """ + mark_range = mark_start_range(message=name, color=color, domain=domain, category=category) + from .performance import _timer + + yield from _timer(name, timing_raw) + mark_end_range(mark_range) + + +class NsightSystemsProfiler(DistProfiler): + """Nsight system profiler. Installed in a worker to control the Nsight system profiler.""" + + def __init__(self, rank: int, config: Optional[ProfilerConfig], **kwargs): + """Initialize the NsightSystemsProfiler. + + Args: + rank (int): The rank of the current process. + config (Optional[ProfilerConfig]): Configuration for the profiler. If None, a default configuration is used. + """ + # If no configuration is provided, create a default ProfilerConfig with an empty list of ranks + if not config: + config = ProfilerConfig(ranks=[]) + self.this_step: bool = False + self.discrete: bool = config.discrete + self.this_rank: bool = False + if config.all_ranks: + self.this_rank = True + elif config.ranks: + self.this_rank = rank in config.ranks + + def start(self, **kwargs): + if self.this_rank: + self.this_step = True + if not self.discrete: + torch.cuda.profiler.start() + + def stop(self): + if self.this_rank: + self.this_step = False + if not self.discrete: + torch.cuda.profiler.stop() + + @staticmethod + def annotate( + message: Optional[str] = None, + color: Optional[str] = None, + domain: Optional[str] = None, + category: Optional[str] = None, + **kwargs, + ) -> Callable: + """Decorate a Worker member function to profile the current rank in the current training step. + + Requires the target function to be a member function of a Worker, which has a member field `profiler` with + NightSystemsProfiler type. + + Args: + message (str, optional): + The message to be displayed in the profiler. Defaults to None. + color (str, optional): + The color of the range. Defaults to None. + domain (str, optional): + The domain of the range. Defaults to None. + category (str, optional): + The category of the range. Defaults to None. + """ + + def decorator(func): + @functools.wraps(func) + def wrapper(self, *args, **kwargs): + profile_name = message or func.__name__ + + if self.profiler.this_step: + if self.profiler.discrete: + torch.cuda.profiler.start() + mark_range = mark_start_range(message=profile_name, color=color, domain=domain, category=category) + + result = func(self, *args, **kwargs) + + if self.profiler.this_step: + mark_end_range(mark_range) + if self.profiler.discrete: + torch.cuda.profiler.stop() + + return result + + return wrapper + + return decorator diff --git a/verl/utils/profiler/performance.py b/verl/utils/profiler/performance.py new file mode 100644 index 000000000..8991896a2 --- /dev/null +++ b/verl/utils/profiler/performance.py @@ -0,0 +1,205 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import inspect +import logging +from contextlib import contextmanager +from typing import Any, Optional + +import torch +import torch.distributed as dist +from codetiming import Timer + +from verl.utils.device import get_device_id, get_torch_device +from verl.utils.logger import DecoratorLoggerBase + + +def _get_current_mem_info(unit: str = "GB", precision: int = 2) -> tuple[str]: + """Get current memory usage.""" + assert unit in ["GB", "MB", "KB"] + divisor = 1024**3 if unit == "GB" else 1024**2 if unit == "MB" else 1024 + mem_allocated = get_torch_device().memory_allocated() + mem_reserved = get_torch_device().memory_reserved() + # use get_torch_device().mem_get_info to profile device memory + # since vllm's sleep mode works below pytorch + # see https://github.com/vllm-project/vllm/pull/11743#issuecomment-2754338119 + mem_free, mem_total = get_torch_device().mem_get_info() + mem_used = mem_total - mem_free + mem_allocated = f"{mem_allocated / divisor:.{precision}f}" + mem_reserved = f"{mem_reserved / divisor:.{precision}f}" + mem_used = f"{mem_used / divisor:.{precision}f}" + mem_total = f"{mem_total / divisor:.{precision}f}" + return mem_allocated, mem_reserved, mem_used, mem_total + + +def log_gpu_memory_usage(head: str, logger: logging.Logger = None, level=logging.DEBUG, rank: int = 0): + """Log GPU memory usage information. + + Args: + head (str): A descriptive header for the memory usage log message. + logger (logging.Logger, optional): Logger instance to use for logging. If None, prints to stdout. + level: Logging level to use. Defaults to logging.DEBUG. + rank (int): The rank of the process to log memory for. Defaults to 0. + """ + if (not dist.is_initialized()) or (rank is None) or (dist.get_rank() == rank): + mem_allocated, mem_reserved, mem_used, mem_total = _get_current_mem_info() + message = ( + f"{head}, memory allocated (GB): {mem_allocated}, memory reserved (GB): {mem_reserved}, " + f"device memory used/total (GB): {mem_used}/{mem_total}" + ) + + if logger is None: + print(message) + else: + logger.log(msg=message, level=level) + + +class GPUMemoryLogger(DecoratorLoggerBase): + """A decorator class to log GPU memory usage. + + Example: + >>> from verl.utils.profiler.performance import GPUMemoryLogger + >>> @GPUMemoryLogger(role="actor") + >>> def update_actor(self, batch): + ... # real actor update logics + ... return + """ + + def __init__(self, role: str, logger: logging.Logger = None, level=logging.DEBUG, log_only_rank_0: bool = True): + if dist.is_initialized() and dist.get_world_size() > 1: + rank = dist.get_rank() + else: + rank = 0 + super().__init__(role, logger, level, rank, log_only_rank_0) + + def __call__(self, decorated_function: callable): + def f(*args, **kwargs): + return self.log(decorated_function, *args, **kwargs) + + return f + + def log(self, func, *args, **kwargs): + name = func.__name__ + mem_allocated, mem_reserved, mem_used, mem_total = _get_current_mem_info() + message = ( + f"Before {name}, memory allocated (GB): {mem_allocated}, memory reserved (GB): {mem_reserved}, " + f"device memory used/total (GB): {mem_used}/{mem_total}" + ) + self.logging_function(message) + + output = func(*args, **kwargs) + + mem_allocated, mem_reserved, mem_used, mem_total = _get_current_mem_info() + message = ( + f"After {name}, memory allocated (GB): {mem_allocated}, memory reserved (GB): {mem_reserved}, " + f"device memory used/total (GB): {mem_used}/{mem_total}" + ) + + self.logging_function(message) + return output + + +def log_print(ctn: Any): + current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + frame = inspect.currentframe().f_back + function_name = frame.f_code.co_name + line_number = frame.f_lineno + file_name = frame.f_code.co_filename.split("/")[-1] + print(f"[{current_time}-{file_name}:{line_number}:{function_name}]: {ctn}") + + +def _timer(name: str, timing_raw: dict[str, float]): + """Inner function that handles the core timing logic. + + Args: + name (str): The name/identifier for this timing measurement. + timing_raw (Dict[str, float]): Dictionary to store timing information. + """ + with Timer(name=name, logger=None) as timer: + yield + if name not in timing_raw: + timing_raw[name] = 0 + timing_raw[name] += timer.last + + +@contextmanager +def simple_timer(name: str, timing_raw: dict[str, float]): + """Context manager for basic timing without NVTX markers. + + This utility function measures the execution time of code within its context + and accumulates the timing information in the provided dictionary. + + Args: + name (str): The name/identifier for this timing measurement. + timing_raw (Dict[str, float]): Dictionary to store timing information. + + Yields: + None: This is a context manager that yields control back to the code block. + """ + yield from _timer(name, timing_raw) + + +@contextmanager +def marked_timer( + name: str, + timing_raw: dict[str, float], + color: str = None, + domain: Optional[str] = None, + category: Optional[str] = None, +): + """Context manager for timing with platform markers. + + This utility function measures the execution time of code within its context, + accumulates the timing information, and adds platform markers for profiling. + This function is a default implementation when hardware profiler is not available. + + Args: + name (str): The name/identifier for this timing measurement. + timing_raw (Dict[str, float]): Dictionary to store timing information. + color (Optional[str]): Color for the marker. Defaults to None. + domain (Optional[str]): Domain for the marker. Defaults to None. + category (Optional[str]): Category for the marker. Defaults to None. + + Yields: + None: This is a context manager that yields control back to the code block. + """ + yield from _timer(name, timing_raw) + + +def reduce_timing(timing_raw: dict[str, float]) -> dict[str, float]: + """Reduce timing information across all processes. + + This function uses distributed communication to gather and sum the timing + information from all processes in a distributed environment. + + Args: + timing_raw (Dict[str, float]): Dictionary containing timing information. + + Returns: + Dict[str, float]: Reduced timing information. + """ + if not dist.is_initialized(): + return timing_raw + + key_list, timing_list = [], [] + for key in sorted(timing_raw.keys()): + key_list.append(key) + timing_list.append(timing_raw[key]) + timing_list = torch.tensor(timing_list, dtype=torch.float32, device=get_device_id()) + torch.distributed.all_reduce(timing_list, op=torch.distributed.ReduceOp.AVG) + timing_list = [tensor.item() for tensor in timing_list.to("cpu")] + timing_generate = {key_list[i]: timing_list[i] for i in range(len(key_list))} + return timing_generate diff --git a/verl/utils/profiler/profile.py b/verl/utils/profiler/profile.py new file mode 100644 index 000000000..4e7ce4fd3 --- /dev/null +++ b/verl/utils/profiler/profile.py @@ -0,0 +1,227 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Callable, Optional + +import torch +import torch.distributed + +from .config import ProfilerConfig + + +class Profiler: + """A PyTorch profiler wrapper class for collecting performance metrics. + + TODO(haibin.lin): this should implement the DistProfiler interface, and the config should be unified. + + This profiler provides a convenient interface for profiling PyTorch operations, + with support for: + + - CPU and CUDA activity profiling + - Configurable profiling schedule (wait/warmup/active steps) + - Multi-rank profiling support + - Chrome trace export + + Args: + config: Configuration object containing profiling parameters + """ + + def __init__(self, config): + # note : if we do not set use_profile, it will be set as None, so that all function will be skip + self.config = config + self.skip_prof = False + self.saved = False + self.prof = None + self.rank = torch.distributed.get_rank() + # we need to validate the config before using the profiler + self._validate() + if config.use_profile and self.rank in self.config.profile_ranks: + print(f"[Profiler] Profiler init for rank {self.rank}") + + self.prof = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + schedule=torch.profiler.schedule( + wait=max(self.config.step_start - 1, 0), + warmup=1 if self.config.step_start > 0 else 0, + active=self.config.step_end - self.config.step_start, + repeat=1, + ), + record_shapes=True, + with_stack=True, + ) + + def _validate(self): + if self.config.use_profile: + if self.config.profile_ranks is None: + print("[WARNING] Profile ranks is not set, default to rank 0") + self.config.profile_ranks = [0] + assert self.config.step_start >= 0, "[ERROR] Profile step start must be greater than 0" + assert self.config.step_end >= 0, "[ERROR] Profile step end must be greater than 0" + assert self.config.step_start < self.config.step_end, ( + "[ERROR] Profile step start must be less than step end" + ) + + def check(self): + return self.prof is not None and not self.skip_prof + + def start(self): + if self.check(): + print(f"[Profiler] started for rank {self.rank}") + self.prof.start() + + def step(self): + if self.check(): + self.prof.step() + + def stop(self): + if self.check(): + print(f"[Profiler] stopped for rank {self.rank}") + self.prof.stop() + + def save(self): + if self.prof is not None and not self.saved: + if not os.path.exists(self.config.save_path): + os.makedirs(self.config.save_path) + save_file_name = f"/prof_start_{self.config.step_start}_end_{self.config.step_end}_rank_{self.rank}.json" + print(f"[Profiler] Saving trace to {self.config.save_path + save_file_name}") + self.prof.export_chrome_trace(self.config.save_path + save_file_name) + self.skip_prof = True + self.saved = True + + def stop_and_save(self): + if self.check(): + self.stop() + self.save() + + def stop_trace(self): + if self.check(): + print(f"[Profiler] Trace stopped for rank {self.rank}") + self.skip_prof = True + + +def mark_start_range( + message: Optional[str] = None, + color: Optional[str] = None, + domain: Optional[str] = None, + category: Optional[str] = None, +) -> None: + """Start a profiling range marker (no-op implementation). + + Args: + message (Optional[str]): Message to associate with the range marker. + color (Optional[str]): Color for the marker visualization. + domain (Optional[str]): Domain for the marker. + category (Optional[str]): Category for the marker. + """ + pass + + +def mark_end_range(range_id: str) -> None: + """End a profiling range marker (no-op implementation). + + Args: + range_id (str): Identifier of the range to end. + """ + pass + + +def mark_annotate( + message: Optional[str] = None, + color: Optional[str] = None, + domain: Optional[str] = None, + category: Optional[str] = None, +) -> Callable: + """Decorator to annotate a function with profiling markers (no-op implementation). + + Args: + message (Optional[str]): Message to associate with the annotation. + color (Optional[str]): Color for the marker visualization. + domain (Optional[str]): Domain for the marker. + category (Optional[str]): Category for the marker. + + Returns: + Callable: Decorator function that returns the original function unchanged. + """ + + def decorator(func): + return func + + return decorator + + +class DistProfiler: + """A distributed profiler class for collecting performance metrics across multiple ranks. + + This profiler is designed to work in distributed training environments, allowing selective + profiling of specific ranks or all ranks. It provides basic start/stop functionality and + supports annotation of code sections for detailed profiling. + + Args: + rank (int): The rank of the current process + config (ProfilerConfig, optional): Configuration for the profiler. + """ + + def __init__(self, rank: int, config: Optional[ProfilerConfig] = None, **kwargs): + pass + + def start(self, **kwargs): + pass + + def stop(self): + pass + + @staticmethod + def annotate( + message: Optional[str] = None, + color: Optional[str] = None, + domain: Optional[str] = None, + category: Optional[str] = None, + **kwargs, + ) -> Callable: + def decorator(func): + return func + + return decorator + + +class DistProfilerExtension: + """An extension class for DistProfiler that provides distributed profiling capabilities. + + It is intended for workers in verl that single controller invokes. + + This class wraps a DistProfiler instance and provides methods to start/stop profiling + that can be dispatched across multiple ranks in a distributed training environment. + + Args: + profiler (DistProfiler): The base distributed profiler instance to extend + """ + + def __init__(self, profiler: DistProfiler): + self.profiler = profiler + + from verl.single_controller.base.decorator import Dispatch, register + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def start_profile(self, **kwargs) -> None: + """Start profiling for the current rank in the current training step.""" + self.profiler.start(**kwargs) + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def stop_profile(self) -> None: + """Stop profiling for the current rank in the current training step.""" + self.profiler.stop() diff --git a/verl/utils/py_functional.py b/verl/utils/py_functional.py index 9f68a9c08..affe4ed9a 100644 --- a/verl/utils/py_functional.py +++ b/verl/utils/py_functional.py @@ -20,14 +20,15 @@ import os import queue # Import the queue module for exception type hint import signal +from contextlib import contextmanager from functools import wraps from types import SimpleNamespace -from typing import Any, Callable, Dict, Iterator, Optional, Tuple +from typing import Any, Callable, Iterator, Optional # --- Top-level helper for multiprocessing timeout --- # This function MUST be defined at the top level to be pickleable -def _mp_target_wrapper(target_func: Callable, mp_queue: multiprocessing.Queue, args: Tuple, kwargs: Dict[str, Any]): +def _mp_target_wrapper(target_func: Callable, mp_queue: multiprocessing.Queue, args: tuple, kwargs: dict[str, Any]): """ Internal wrapper function executed in the child process. Calls the original target function and puts the result or exception into the queue. @@ -124,11 +125,16 @@ def wrapper_mp(*args, **kwargs): except queue.Empty as err: exitcode = process.exitcode if exitcode is not None and exitcode != 0: - raise RuntimeError(f"Child process exited with error (exitcode: {exitcode}) before returning result.") from err + raise RuntimeError( + f"Child process exited with error (exitcode: {exitcode}) before returning result." + ) from err else: # Should have timed out if queue is empty after join unless process died unexpectedly # Update function name in error message if needed (optional but good practice) - raise TimeoutError(f"Operation timed out or process finished unexpectedly without result (exitcode: {exitcode}).") from err + raise TimeoutError( + f"Operation timed out or process finished unexpectedly without result " + f"(exitcode: {exitcode})." + ) from err finally: q.close() q.join_thread() @@ -138,7 +144,7 @@ def wrapper_mp(*args, **kwargs): return decorator -def union_two_dict(dict1: Dict, dict2: Dict): +def union_two_dict(dict1: dict, dict2: dict): """Union two dict. Will throw an error if there is an item not the same object with the same key. Args: @@ -156,7 +162,7 @@ def union_two_dict(dict1: Dict, dict2: Dict): return dict1 -def append_to_dict(data: Dict, new_data: Dict): +def append_to_dict(data: dict, new_data: dict): """Append values from new_data to lists in data. For each key in new_data, this function appends the corresponding value to a list @@ -225,7 +231,7 @@ def values(cls): class DynamicEnum(metaclass=DynamicEnumMeta): - _registry: Dict[str, "DynamicEnum"] = {} + _registry: dict[str, "DynamicEnum"] = {} _next_value: int = 0 def __init__(self, name: str, value: int): @@ -266,13 +272,46 @@ def remove(cls, name: str): def from_name(cls, name: str) -> Optional["DynamicEnum"]: return cls._registry.get(name.upper()) + +@contextmanager +def temp_env_var(key: str, value: str): + """Context manager for temporarily setting an environment variable. + + This context manager ensures that environment variables are properly set and restored, + even if an exception occurs during the execution of the code block. + + Args: + key: Environment variable name to set + value: Value to set the environment variable to + + Yields: + None + + Example: + >>> with temp_env_var("MY_VAR", "test_value"): + ... # MY_VAR is set to "test_value" + ... do_something() + ... # MY_VAR is restored to its original value or removed if it didn't exist + """ + original = os.environ.get(key) + os.environ[key] = value + try: + yield + finally: + if original is None: + os.environ.pop(key, None) + else: + os.environ[key] = original + + def convert_to_regular_types(obj): """Convert Hydra configs and other special types to regular Python types.""" from omegaconf import DictConfig, ListConfig - if isinstance(obj, (ListConfig, DictConfig)): + + if isinstance(obj, ListConfig | DictConfig): return {k: convert_to_regular_types(v) for k, v in obj.items()} if isinstance(obj, DictConfig) else list(obj) - elif isinstance(obj, (list, tuple)): + elif isinstance(obj, list | tuple): return [convert_to_regular_types(x) for x in obj] elif isinstance(obj, dict): return {k: convert_to_regular_types(v) for k, v in obj.items()} - return obj \ No newline at end of file + return obj diff --git a/verl/utils/ray_utils.py b/verl/utils/ray_utils.py index 10a875afc..a738c0f3d 100644 --- a/verl/utils/ray_utils.py +++ b/verl/utils/ray_utils.py @@ -17,7 +17,7 @@ import concurrent.futures import os -from typing import Any, List, Optional +from typing import Any, Optional import ray @@ -45,7 +45,7 @@ def ray_noset_visible_devices(env_vars=os.environ): return any(env_vars.get(env_var) for env_var in NOSET_VISIBLE_DEVICES_ENV_VARS_LIST) -def parallel_put(data_list: List[Any], max_workers: Optional[int] = None): +def parallel_put(data_list: list[Any], max_workers: Optional[int] = None): """ Puts a list of data into the Ray object store in parallel using a thread pool. diff --git a/verl/utils/rendezvous/ray_backend.py b/verl/utils/rendezvous/ray_backend.py index 123f73463..d9911815d 100644 --- a/verl/utils/rendezvous/ray_backend.py +++ b/verl/utils/rendezvous/ray_backend.py @@ -42,7 +42,9 @@ def get_nccl_id_store_by_name(name): return None -def create_nccl_communicator_in_ray(rank: int, world_size: int, group_name: str, max_retries: int = 100, interval_s: int = 5): +def create_nccl_communicator_in_ray( + rank: int, world_size: int, group_name: str, max_retries: int = 100, interval_s: int = 5 +): if rank == 0: nccl_id = get_unique_id() nccl_id_store = NCCLIDStore.options(name=group_name).remote(nccl_id) diff --git a/verl/utils/reward_score/__init__.py b/verl/utils/reward_score/__init__.py index a78154308..4cb31dead 100644 --- a/verl/utils/reward_score/__init__.py +++ b/verl/utils/reward_score/__init__.py @@ -16,7 +16,15 @@ from verl.utils.import_utils import deprecated -def default_compute_score(data_source, solution_str, ground_truth, extra_info=None, sandbox_fusion_url=None, concurrent_semaphore=None): +def default_compute_score( + data_source, + solution_str, + ground_truth, + extra_info=None, + sandbox_fusion_url=None, + concurrent_semaphore=None, + memory_limit_mb=None, +): """Compute the score for a given solution based on the data source. Args: @@ -32,7 +40,10 @@ def default_compute_score(data_source, solution_str, ground_truth, extra_info=No Raises: NotImplementedError: If the reward function is not implemented for the given data source. """ - reward_metric = extra_info.get("reward_metric", None) + # Handle extra_info format robustly + reward_metric = None + if extra_info and isinstance(extra_info, dict): + reward_metric = extra_info.get("reward_metric", None) # math if data_source.startswith("math"): @@ -90,6 +101,9 @@ def default_compute_score(data_source, solution_str, ground_truth, extra_info=No elif data_source.startswith('stem_web'): from . import stem_llm_judge res = stem_llm_judge.compute_score(data_source=data_source, model_output=solution_str, ground_truth=ground_truth, extra_info=extra_info) + elif data_source in ["reasoning_gym"]: + from . import reasoning_gym + res = reasoning_gym.compute_score(solution_str, ground_truth, extra_info=extra_info) elif data_source in ["ood__ifeval"]: from . import ifeval res = ifeval.compute_score(solution_str, ground_truth, extra_info=extra_info) @@ -99,10 +113,16 @@ def default_compute_score(data_source, solution_str, ground_truth, extra_info=No elif data_source in ["ood__ifbench"]: from . import ifbench res = ifbench.compute_score(solution_str, ground_truth, extra_info=extra_info) + elif data_source in ["deepmath", "DeepMath", "zwhe99/DeepMath-103K"]: + from . import deepmath + res = deepmath.compute_score(solution_str, ground_truth, extra_info=extra_info) + elif data_source in ["stem_nemotron", "nemotron_stem"]: + from . import nemotron_stem + res = nemotron_stem.compute_score(solution_str, ground_truth, extra_info=extra_info) + # NOTE: above is added by Reasoning360 elif data_source == "openai/gsm8k": from . import gsm8k - res = gsm8k.compute_score(solution_str, ground_truth) elif data_source in ["lighteval/MATH", "DigitalLearningGmbH/MATH-lighteval"]: from . import math @@ -136,7 +156,9 @@ def default_compute_score(data_source, solution_str, ground_truth, extra_info=No from . import sandbox_fusion # Pass the URL directly, ground_truth likely contains test cases here - res = sandbox_fusion.compute_score(sandbox_fusion_url, concurrent_semaphore, solution_str, ground_truth, continuous=True) + res = sandbox_fusion.compute_score( + sandbox_fusion_url, concurrent_semaphore, memory_limit_mb, solution_str, ground_truth, continuous=True + ) else: # If no sandbox URL is provided, fall back to prime_code or raise error from . import prime_code @@ -147,27 +169,59 @@ def default_compute_score(data_source, solution_str, ground_truth, extra_info=No from . import geo3k res = geo3k.compute_score(solution_str, ground_truth) - elif data_source in ["searchR1_nq", "searchR1_triviaqa", "searchR1_popqa", "searchR1_hotpotqa", "searchR1_2wikimultihopqa", "searchR1_musique", "searchR1_bamboogle"]: + elif data_source in [ + "searchR1_nq", + "searchR1_triviaqa", + "searchR1_popqa", + "searchR1_hotpotqa", + "searchR1_2wikimultihopqa", + "searchR1_musique", + "searchR1_bamboogle", + ]: from . import search_r1_like_qa_em res = search_r1_like_qa_em.compute_score(solution_str, ground_truth) + + elif data_source.startswith("synlogic"): + from .synlogic.synlogic import verifier_classes + from .synlogic.data import Data + + form_solution = solution_str.strip().split('
')[-1].strip() + data = Data.from_json_str(extra_info["game_data_str"]) + verifier = verifier_classes[data_source.replace("synlogic_", "")]() + res = verifier.verify(data, form_solution) + if res: + res = 1.0 + else: + res = 0.0 + else: raise NotImplementedError(f"Reward function is not implemented for {data_source=}") if isinstance(res, dict): return res - elif isinstance(res, (int, float, bool)): + elif isinstance(res, int | float | bool): return float(res) else: return float(res[0]) @deprecated("verl.utils.reward_score.default_compute_score") -def _default_compute_score(data_source, solution_str, ground_truth, extra_info=None, sandbox_fusion_url=None, concurrent_semaphore=None): +def _default_compute_score( + data_source, + solution_str, + ground_truth, + extra_info=None, + sandbox_fusion_url=None, + concurrent_semaphore=None, + memory_limit_mb=None, +): """ Legacy function API to be deprecated. Please use `default_compute_score` instead. """ - return default_compute_score(data_source, solution_str, ground_truth, extra_info, sandbox_fusion_url, concurrent_semaphore) + return default_compute_score( + data_source, solution_str, ground_truth, extra_info, sandbox_fusion_url, concurrent_semaphore, memory_limit_mb + ) __all__ = ["default_compute_score"] diff --git a/verl/utils/reward_score/deepmath.py b/verl/utils/reward_score/deepmath.py new file mode 100644 index 000000000..3b6a5f32a --- /dev/null +++ b/verl/utils/reward_score/deepmath.py @@ -0,0 +1,225 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from typing import Optional, Union + + +def compute_score(solution_str: str, ground_truth: str, extra_info: Optional[dict] = None) -> float: + """Compute the score for DeepMath dataset solutions. + + Args: + solution_str: The model's solution/answer + ground_truth: The correct answer from the dataset + extra_info: Optional additional information (e.g., difficulty, topic) + + Returns: + float: 1.0 if correct, 0.0 otherwise + """ + try: + # Extract answer from solution if it's in boxed format + extracted_answer = extract_boxed_answer(solution_str) + if extracted_answer is None: + # Try to extract from common answer patterns + extracted_answer = extract_answer_patterns(solution_str) + + if extracted_answer is None: + # Use the full solution string as last resort + extracted_answer = solution_str.strip() + + # Normalize both answers for comparison + normalized_solution = normalize_math_answer(extracted_answer) + normalized_ground_truth = normalize_math_answer(ground_truth) + + # Check if answers are equivalent + if is_equivalent(normalized_solution, normalized_ground_truth): + return 1.0 + + # Additional check for numerical equivalence + if is_numerically_equivalent(normalized_solution, normalized_ground_truth): + return 1.0 + + return 0.0 + except Exception as e: + print(f"Error in DeepMath scoring: {e}") + return 0.0 + + +def extract_boxed_answer(text: str) -> Optional[str]: + """Extract answer from \\boxed{...} format.""" + # Look for the last boxed expression + pattern = r"\\boxed\{([^{}]*(?:\{[^{}]*\}[^{}]*)*)\}" + matches = re.findall(pattern, text) + if matches: + return matches[-1] + + # Also check for \boxed without braces + pattern2 = r"\\boxed\s+([^\s]+)" + matches2 = re.findall(pattern2, text) + if matches2: + return matches2[-1] + + return None + + +def extract_answer_patterns(text: str) -> Optional[str]: + """Extract answer from common answer patterns.""" + patterns = [ + r"(?:final answer|answer)[\s:]*(?:is)?[\s:]*([^\n.]+)", + r"(?:evaluates to|equals to|is equal to)[\s:]*([^\n.]+)", + r"therefore[\s,]+([^\n.]+)", + r"thus[\s,]+([^\n.]+)", + r"hence[\s,]+([^\n.]+)", + r"=\s*([^\n]+)$", # Last equals sign + r"(?:limit|integral|sum|product)[\s\w]*(?:evaluates to|is|equals)[\s:]*([^\n.]+)", + ] + + for pattern in patterns: + matches = re.findall(pattern, text, re.IGNORECASE) + if matches: + # Clean the extracted answer + answer = matches[-1].strip() + # Remove trailing punctuation but keep mathematical symbols + answer = answer.rstrip('.,;:') + return answer + + # Try to find any number at the end of the text + number_pattern = r"(?:^|\s)([-+]?\d*\.?\d+(?:[eE][-+]?\d+)?|\d+/\d+)(?:\s*$|\s*[.,;]?\s*$)" + matches = re.findall(number_pattern, text) + if matches: + return matches[-1].strip() + + return None + + +def normalize_math_answer(answer: str) -> str: + """Normalize mathematical expressions for comparison.""" + # Remove whitespace + answer = answer.strip() + answer = re.sub(r'\s+', '', answer) + + # Remove dollar signs + answer = answer.replace('$', '') + + # Normalize LaTeX commands + answer = answer.replace('\\left', '') + answer = answer.replace('\\right', '') + answer = answer.replace('\\Big', '') + answer = answer.replace('\\big', '') + answer = answer.replace('\\cdot', '*') + answer = answer.replace('\\times', '*') + answer = answer.replace('\\div', '/') + + # Handle fractions + answer = normalize_fractions(answer) + + # Remove trailing punctuation + answer = answer.rstrip('.,;:') + + return answer + + +def normalize_fractions(text: str) -> str: + """Normalize fraction representations.""" + # Convert \frac{a}{b} to a/b for simple cases + frac_pattern = r"\\frac\{([^{}]+)\}\{([^{}]+)\}" + + def frac_replacer(match): + num, den = match.groups() + # For simple numeric fractions, compute the value + try: + num_val = float(eval(num)) + den_val = float(eval(den)) + if den_val != 0: + result = num_val / den_val + # Return as integer if it's a whole number + if result == int(result): + return str(int(result)) + return str(result) + except: + pass + return f"({num})/({den})" + + text = re.sub(frac_pattern, frac_replacer, text) + + # Also handle tfrac and dfrac + text = text.replace('\\tfrac', '\\frac') + text = text.replace('\\dfrac', '\\frac') + + return text + + +def is_equivalent(answer1: str, answer2: str) -> bool: + """Check if two normalized answers are equivalent.""" + # Direct string comparison + if answer1 == answer2: + return True + + # Case-insensitive comparison for text answers + if answer1.lower() == answer2.lower(): + return True + + # Check common mathematical equivalences + equivalences = [ + ('infinity', '\\infty'), + ('inf', '\\infty'), + ('undefined', 'dne'), + ('doesnotexist', 'dne'), + ('none', 'dne'), + ] + + a1_lower = answer1.lower() + a2_lower = answer2.lower() + + for eq1, eq2 in equivalences: + if (eq1 in a1_lower and eq2 in a2_lower) or (eq2 in a1_lower and eq1 in a2_lower): + return True + + return False + + +def is_numerically_equivalent(answer1: str, answer2: str, tolerance: float = 1e-9) -> bool: + """Check if two answers are numerically equivalent.""" + try: + # Try to evaluate as numerical expressions + val1 = evaluate_expression(answer1) + val2 = evaluate_expression(answer2) + + if val1 is not None and val2 is not None: + return abs(val1 - val2) < tolerance + except: + pass + + return False + + +def evaluate_expression(expr: str) -> Optional[float]: + """Safely evaluate a mathematical expression.""" + try: + # Remove common LaTeX commands that might remain + expr = expr.replace('\\pi', '3.141592653589793') + expr = expr.replace('\\e', '2.718281828459045') + expr = expr.replace('^', '**') + + # Only allow safe operations + allowed_names = { + 'abs': abs, + 'min': min, + 'max': max, + } + + # Evaluate the expression safely + result = eval(expr, {"__builtins__": {}}, allowed_names) + return float(result) + except: + return None \ No newline at end of file diff --git a/verl/utils/reward_score/deepmath_test.py b/verl/utils/reward_score/deepmath_test.py new file mode 100644 index 000000000..7fde18db5 --- /dev/null +++ b/verl/utils/reward_score/deepmath_test.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python3 +""" +Test script for DeepMath integration +""" + +import sys +sys.path.append('/mnt/weka/home/jianshu.she/IFM/Reasoning360') + +from verl.utils.reward_score import default_compute_score + +def test_deepmath_scoring(): + """Test DeepMath scoring functionality""" + + print("Testing DeepMath scoring integration...") + + # Test cases + test_cases = [ + { + "solution": "\\boxed{0}", + "ground_truth": "0", + "expected": 1.0, + "description": "Exact match with boxed answer" + }, + { + "solution": "The limit evaluates to 0", + "ground_truth": "0", + "expected": 1.0, + "description": "Text extraction" + }, + { + "solution": "\\boxed{\\frac{2}{3}}", + "ground_truth": "2/3", + "expected": 1.0, + "description": "Fraction equivalence" + }, + { + "solution": "\\boxed{42}", + "ground_truth": "24", + "expected": 0.0, + "description": "Wrong answer" + }, + { + "solution": "The answer is \\infty", + "ground_truth": "infinity", + "expected": 1.0, + "description": "Infinity equivalence" + } + ] + + print("\nRunning test cases:") + print("=" * 60) + + all_passed = True + for i, test in enumerate(test_cases, 1): + try: + # Test with different data source identifiers + for data_source in ["deepmath", "DeepMath", "zwhe99/DeepMath-103K"]: + score = default_compute_score( + data_source=data_source, + solution_str=test["solution"], + ground_truth=test["ground_truth"] + ) + + passed = abs(score - test["expected"]) < 0.001 + + if not passed: + print(f"❌ Test {i} FAILED ({data_source}): {test['description']}") + print(f" Expected: {test['expected']}, Got: {score}") + all_passed = False + break + else: + print(f"✅ Test {i} PASSED: {test['description']}") + + except Exception as e: + print(f"❌ Test {i} ERROR: {test['description']}") + print(f" Error: {e}") + all_passed = False + + print("=" * 60) + if all_passed: + print("✅ All tests passed!") + else: + print("❌ Some tests failed") + + return all_passed + + +if __name__ == "__main__": + success = test_deepmath_scoring() + sys.exit(0 if success else 1) \ No newline at end of file diff --git a/verl/utils/reward_score/geo3k.py b/verl/utils/reward_score/geo3k.py index 699445cd7..8a8508758 100644 --- a/verl/utils/reward_score/geo3k.py +++ b/verl/utils/reward_score/geo3k.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import re from mathruler.grader import extract_boxed_content, grade_answer @@ -23,10 +22,15 @@ def format_reward(predict_str: str) -> float: return 1.0 if match_result else 0.0 -def acc_reward(predict_str: str, ground_truth: str) -> float: - answer = extract_boxed_content(predict_str) +def acc_reward(predict_str: str, ground_truth: str, use_boxed: bool = True) -> float: + if use_boxed: + answer = extract_boxed_content(predict_str) + else: + answer = predict_str return 1.0 if grade_answer(answer, ground_truth) else 0.0 -def compute_score(predict_str: str, ground_truth: str) -> float: - return 0.9 * acc_reward(predict_str, ground_truth) + 0.1 * format_reward(predict_str) +def compute_score(predict_str: str, ground_truth: str, use_boxed: bool = True, format_score: float = 0.1) -> float: + return (1.0 - format_score) * acc_reward(predict_str, ground_truth, use_boxed) + format_score * format_reward( + predict_str + ) diff --git a/verl/utils/reward_score/gsm8k.py b/verl/utils/reward_score/gsm8k.py index f5d4c1585..98a8c24dc 100644 --- a/verl/utils/reward_score/gsm8k.py +++ b/verl/utils/reward_score/gsm8k.py @@ -14,18 +14,26 @@ import re +_SOLUTION_CLIP_CHARS = 300 + def extract_solution(solution_str, method="strict"): assert method in ["strict", "flexible"] + # Optimization: Regular expression matching on very long strings can be slow. + # For math problems, the final answer is usually at the end. + # We only match on the last 300 characters, which is a safe approximation for 300 tokens. + if len(solution_str) > _SOLUTION_CLIP_CHARS: + solution_str = solution_str[-_SOLUTION_CLIP_CHARS:] + if method == "strict": # this also tests the formatting of the model - solution = re.search("#### (\\-?[0-9\\.\\,]+)", solution_str) - if solution is None: + solutions = re.findall("#### (\\-?[0-9\\.\\,]+)", solution_str) + if len(solutions) == 0: final_answer = None else: - final_answer = solution.group(0) - final_answer = final_answer.split("#### ")[1].replace(",", "").replace("$", "") + # take the last solution + final_answer = solutions[-1].replace(",", "").replace("$", "") elif method == "flexible": answer = re.findall("(\\-?[0-9\\.\\,]+)", solution_str) final_answer = None @@ -44,7 +52,8 @@ def extract_solution(solution_str, method="strict"): def compute_score(solution_str, ground_truth, method="strict", format_score=0.0, score=1.0): """The scoring function for GSM8k. - Reference: Trung, Luong, et al. "Reft: Reasoning with reinforced fine-tuning." Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers). 2024. + Reference: Trung, Luong, et al. "Reft: Reasoning with reinforced fine-tuning." Proceedings of the 62nd Annual + Meeting of the Association for Computational Linguistics (Volume 1: Long Papers). 2024. Args: solution_str: the solution text diff --git a/verl/utils/reward_score/ifbench/check_ifbench_data.py b/verl/utils/reward_score/ifbench/check_ifbench_data.py new file mode 100644 index 000000000..8719f28c4 --- /dev/null +++ b/verl/utils/reward_score/ifbench/check_ifbench_data.py @@ -0,0 +1,111 @@ +#!/usr/bin/env python3 +""" +Check and fix IFBench data format. +""" + +import pandas as pd +import json + +def check_ifbench_data(file_path): + """Check the format of IFBench data file.""" + print(f"Checking file: {file_path}") + + try: + df = pd.read_parquet(file_path) + print(f"Total rows: {len(df)}") + print(f"Columns: {list(df.columns)}") + + # Check data structure + print("\nSample data:") + print(df.head(2)) + + # Check for None values in extra_info + if 'extra_info' in df.columns: + none_count = df['extra_info'].isna().sum() + print(f"\nNone values in extra_info: {none_count}") + + # Show sample of None values + if none_count > 0: + print("\nSample rows with None extra_info:") + print(df[df['extra_info'].isna()].head(2)) + + # Check prompt structure + if 'prompt' in df.columns: + print(f"\nPrompt column type: {df['prompt'].dtype}") + print("Sample prompt:") + print(df['prompt'].iloc[0]) + + # Check reward_model structure + if 'reward_model' in df.columns: + print(f"\nReward model column type: {df['reward_model'].dtype}") + print("Sample reward_model:") + print(df['reward_model'].iloc[0]) + + except Exception as e: + print(f"Error reading file: {e}") + +def fix_ifbench_data(input_file, output_file): + """Fix IFBench data by ensuring extra_info is not None.""" + print(f"Fixing data: {input_file} -> {output_file}") + + df = pd.read_parquet(input_file) + + # Fix None values in extra_info + if 'extra_info' in df.columns: + # Replace None with dict containing default fields + def fix_extra_info(info): + if info is None: + return {"split": "train", "instruction_id_list": []} + elif isinstance(info, str): + try: + parsed = json.loads(info) + if not parsed: # Empty dict + return {"split": "train", "instruction_id_list": []} + return parsed + except: + return {"split": "train", "instruction_id_list": []} + elif isinstance(info, dict): + if not info: # Empty dict + return {"split": "train", "instruction_id_list": []} + return info + else: + return {"split": "train", "instruction_id_list": []} + + df['extra_info'] = df['extra_info'].apply(fix_extra_info) + + # Ensure data_source exists + if 'data_source' not in df.columns: + df['data_source'] = 'ood__ifbench' + + # Save fixed data + df.to_parquet(output_file, index=False) + print(f"Fixed data saved to: {output_file}") + + # Verify the fix + check_ifbench_data(output_file) + +def main(): + # Check original data + original_file = "/mnt/sharefs/users/jianshu.she/ood__ifbench_95.1k.parquet" + print("=== Checking original data ===") + check_ifbench_data(original_file) + + # Check split data if exists + train_file = "/mnt/sharefs/users/jianshu.she/ifbench_split/ifbench_train.parquet" + test_file = "/mnt/sharefs/users/jianshu.she/ifbench_split/ifbench_test.parquet" + + import os + if os.path.exists(train_file): + print("\n=== Checking train data ===") + check_ifbench_data(train_file) + + if os.path.exists(test_file): + print("\n=== Checking test data ===") + check_ifbench_data(test_file) + + # Fix the data + print("\n=== Fixing data ===") + fix_ifbench_data(original_file, "/mnt/sharefs/users/jianshu.she/ood__ifbench_95.1k_fixed.parquet") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/verl/utils/reward_score/ifbench/instructions.py b/verl/utils/reward_score/ifbench/instructions.py index 145fb5be0..2fa118d60 100644 --- a/verl/utils/reward_score/ifbench/instructions.py +++ b/verl/utils/reward_score/ifbench/instructions.py @@ -15,14 +15,13 @@ """Library of instructions.""" - import collections import json import random import re import string from typing import Dict, Optional, Sequence, Union - +import unicodedata import langdetect import logging @@ -343,7 +342,7 @@ def check_following(self, value): requirement. """ bullet_lists = re.findall(r"^\s*\*[^\*].*$", value, flags=re.MULTILINE) - bullet_lists_2 = re.findall(r"^\s*-.*$", value, flags=re.MULTILINE) + bullet_lists_2 = re.findall(r"^\s*-\s+.+$", value, flags=re.MULTILINE) num_bullet_lists = len(bullet_lists) + len(bullet_lists_2) return num_bullet_lists == self._num_bullets @@ -376,11 +375,11 @@ def check_following(self, value): True if the actual response contains one of the options in the constrained responses; otherwise False. """ - value = value.strip() - for constrained_response in self._constrained_responses: - if constrained_response in value: - return True - return False + v = value.strip() + return any( + v == opt + for opt in self._constrained_responses + ) class ConstrainedStartChecker(Instruction): @@ -422,7 +421,7 @@ def check_following(self, value): True if the response starts with the given phrase or keyword that is contained in `instruction_args`; otherwise, False. """ - response_pattern = r"^\s*" + self._starter + r".*$" + response_pattern = r"^\s*" + re.escape(self._starter) + r".*$" response_with_constrained_start = re.search(response_pattern, value, flags=re.MULTILINE) return True if response_with_constrained_start else False @@ -477,7 +476,8 @@ def check_following(self, value): if highlight.strip("*").strip(): num_highlights += 1 for highlight in double_highlights: - if highlight.removeprefix("**").removesuffix("**").strip(): + core = highlight[2:-2].strip() # avoid Python 3.9+ only API + if core: num_highlights += 1 return num_highlights >= self._num_highlights @@ -537,10 +537,10 @@ def check_following(self, value): True if the number of sections in the response is greater than or equal to the minimum number of sections; otherwise, False. """ - section_splitter_patten = r"\s?" + self._section_spliter + r"\s?\d+\s?" + section_splitter_patten = r"\s?" + re.escape(self._section_spliter) + r"\s?\d+\s?" sections = re.split(section_splitter_patten, value) num_sections = len(sections) - 1 - return num_sections >= self._num_sections + return num_sections == self._num_sections class ParagraphChecker(Instruction): @@ -647,7 +647,7 @@ def check_following(self, value): elif self._postscript_marker == "P.S.": postscript_pattern = r"\s*p\.\s?s\..*$" else: - postscript_pattern = r"\s*" + self._postscript_marker.lower() + r".*$" + postscript_pattern = r"\s*" + re.escape(self._postscript_marker.lower()) + r".*$" postscript = re.findall(postscript_pattern, value, flags=re.MULTILINE) return True if postscript else False @@ -707,12 +707,12 @@ def check_following(self, value): return response_without_changes == reference_without_changes def is_change(self, response): - """Check if there is change in the response in the form of *change me*.""" - return re.search(r"\*.*\*", response) + """Check if there is change in the response in the form of *change me* (non-greedy).""" + return re.search(r"\*.*?\*", response) def strip_changes(self, response): - """Strips off the changes.""" - return re.sub(r"\*.*\*", "", response) + """Strips off the changes (non-greedy).""" + return re.sub(r"\*.*?\*", "", response) class KeywordChecker(Instruction): @@ -750,7 +750,8 @@ def get_instruction_args_keys(self): def check_following(self, value): """Check if the response contain the expected keywords.""" for keyword in self._keywords: - if not re.search(keyword, value, flags=re.IGNORECASE): + pattern = r"\b" + re.escape(keyword) + r"\b" + if not re.search(pattern, value, flags=re.IGNORECASE): return False return True @@ -810,7 +811,8 @@ def get_instruction_args_keys(self): def check_following(self, value): """Checks if the response contain the keyword with required frequency.""" - actual_occurrences = len(re.findall(self._keyword, value, flags=re.IGNORECASE)) + pattern = r"\b" + re.escape(self._keyword) + r"\b" + actual_occurrences = len(re.findall(pattern, value, flags=re.IGNORECASE)) if self._comparison_relation == _COMPARISON_RELATION[0]: return actual_occurrences < self._frequency @@ -890,17 +892,13 @@ def get_instruction_args_keys(self): return [] def check_following(self, value): - value = ( - value.strip() - .removeprefix("```json") - .removeprefix("```Json") - .removeprefix("```JSON") - .removeprefix("```") - .removesuffix("```") - .strip() - ) + # More robust fence stripping: supports ```json / ```JSON with optional newline/spaces + v = value.strip() + v = re.sub(r"^```(?:\s*[jJ][sS][oO][nN])?\s*\n?", "", v) + v = re.sub(r"\n?```\s*$", "", v) + v = v.strip() try: - json.loads(value) + json.loads(v) except ValueError: return False return True @@ -929,7 +927,7 @@ def build_description(self, num_paragraphs=None, nth_paragraph=None, first_word= self._nth_paragraph = nth_paragraph if self._nth_paragraph is None or self._nth_paragraph <= 0 or self._nth_paragraph > self._num_paragraphs: - self._nth_paragraph = random.randint(1, self._num_paragraphs + 1) + self._nth_paragraph = random.randint(1, self._num_paragraphs) self._first_word = first_word if self._first_word is None: @@ -1089,7 +1087,7 @@ def get_instruction_args_keys(self): def check_following(self, value): """Check if the response does not contain the expected keywords.""" for word in self._forbidden_words: - if re.search(r"\b" + word + r"\b", value, flags=re.IGNORECASE): + if re.search(r"\b" + re.escape(word) + r"\b", value, flags=re.IGNORECASE): return False return True @@ -1530,7 +1528,10 @@ def build_description(self, phrase=None, small_n=None): else: self._small_n = small_n - self._description_pattern = "Repeat the phrase {phrase} exactly {small_n} times, transforming it slightly each time by replacing only one word in the center of the phrase." + self._description_pattern = ( + "Repeat the phrase {phrase} exactly {small_n} times, transforming it slightly each time " + "by replacing only one word in the center of the phrase." + ) return self._description_pattern.format(phrase=self._phrase, small_n=self._small_n) def get_instruction_args(self): @@ -1546,28 +1547,21 @@ def check_following(self, value): first_word = self._phrase.split()[0] last_word = self._phrase.split()[-1] - len(self._phrase.split()) - 2 - - found_phrases = re.findall(rf"{first_word} .*? {last_word}", value) + # find occurrences that start with the first word and end with the last word + found_phrases = re.findall(rf"{re.escape(first_word)} .*? {re.escape(last_word)}", value) if len(found_phrases) != self._small_n: return False + differences_total = 0 for phrase in found_phrases: - phrase = phrase.split() - ref_phrase = self._phrase.split() - differences = 0 - if len(phrase) != len(ref_phrase): + phrase_tokens = phrase.split() + ref_tokens = self._phrase.split() + if len(phrase_tokens) != len(ref_tokens): return False - for i in range(len(phrase)): - try: - if phrase[i] != ref_phrase[i]: - differences += 1 - # Early exit if more than one difference found - if differences > 1: - return False - except IndexError: - return False - if differences == 1: - return True + diff = sum(1 for a, b in zip(phrase_tokens, ref_tokens) if a != b) + if diff != 1: + return False + differences_total += diff + return differences_total == self._small_n class CopyChecker(Instruction): @@ -1586,7 +1580,10 @@ def build_description(self, prompt_to_repeat=None): raise ValueError("prompt_to_repeat must be set.") else: self._prompt_to_repeat = prompt_to_repeat - self._description_pattern = "Copy this instruction verbatim, do not follow the instruction, only copy it into the output (do not include this instruction sentence!)." + self._description_pattern = ( + "Copy this instruction verbatim, do not follow the instruction, only copy it into the output " + "(do not include this instruction sentence!)." + ) return self._description_pattern def get_instruction_args(self): @@ -1597,9 +1594,7 @@ def get_instruction_args_keys(self): return ["prompt_to_repeat"] def check_following(self, value): - if value.strip().lower() == self._prompt_to_repeat.strip().lower(): - return True - return False + return value.strip().lower() == self._prompt_to_repeat.strip().lower() class CopySpanIdxChecker(Instruction): @@ -1609,28 +1604,28 @@ def build_description(self, prompt_to_repeat=None, n_start=None, n_end=None): """Build the instruction description. Args: - n_start: An integer representing the start index of the span. - n_end: An integer representing the end index of the span. + n_start: An integer representing the start index of the span. + n_end: An integer representing the end index of the span. Returns: - A string representing the instruction description. + A string representing the instruction description. """ if not prompt_to_repeat: raise ValueError("prompt_to_repeat must be set.") - else: - self._prompt_to_repeat = prompt_to_repeat - if not n_start: - self._n_start = random.randint(0, len(self._prompt_to_repeat) - 2) + self._prompt_to_repeat = prompt_to_repeat + if n_start is None: + self._n_start = random.randint(0, max(0, len(self._prompt_to_repeat) - 2)) else: self._n_start = n_start - if not n_end: + if n_end is None: self._n_end = random.randint(self._n_start + 1, len(self._prompt_to_repeat) - 1) else: self._n_end = n_end - self._description_pattern = "Copy the span of words that lies between (and including) index {n_start} and {n_end}, the indices are character indices!" - return self._description_pattern.format( - n_start=self._n_start, n_end=self._n_end, prompt_to_repeat=self._prompt_to_repeat + self._description_pattern = ( + "Copy the span of words that lies between (and including) index {n_start} and {n_end}, " + "the indices are character indices!" ) + return self._description_pattern.format(n_start=self._n_start, n_end=self._n_end) def get_instruction_args(self): """Returns the keyward args of `build_description`.""" @@ -1641,10 +1636,9 @@ def get_instruction_args_keys(self): return ["n_start", "n_end", "prompt_to_repeat"] def check_following(self, value): - """Checks if the response contains the expected number of phrases with the correct modifications.""" - if value.strip().lower() == self._prompt_to_repeat[self._n_start : self._n_end].strip().lower(): - return True - return False + """Checks if the response equals the requested character span (inclusive).""" + expected = self._prompt_to_repeat[self._n_start : self._n_end + 1] + return value.strip().lower() == expected.strip().lower() class SentenceHyphenChecker(Instruction): @@ -1664,24 +1658,15 @@ def get_instruction_args_keys(self): def check_following(self, value): """Checks if all sentences are connected using hyphens, with no spaces between them.""" - # 检查是否包含连字符 if "-" not in value: return False - - # 按连字符分割 - words = value.split("-") - - # 检查每个片段是否有空格(不应该有) - for word in words: - if word.strip() != word: - return False - if " " in word: - return False - - # 检查是否至少有两个片段 - if len(words) < 2: + parts = value.split("-") + if len(parts) < 2: return False - + for p in parts: + # no leading/trailing spaces and no internal spaces + if p.strip() != p or " " in p: + return False return True @@ -1704,11 +1689,11 @@ def check_following(self, value): """Checks if no two adjacent words start with consecutive letters of the alphabet.""" words = value.split() for i in range(len(words) - 1): - first_letter = words[i][0].lower() - second_letter = words[i + 1][0].lower() - if len(first_letter) != 1 or len(second_letter) != 1: + a = words[i][0].lower() + b = words[i + 1][0].lower() + if len(a) != 1 or len(b) != 1: return False - if ord(second_letter) - ord(first_letter) == 1: + if ord(b) - ord(a) == 1: return False return True @@ -1731,10 +1716,7 @@ def get_instruction_args_keys(self): def check_following(self, value): """Checks if every word in the response is enclosed within square brackets.""" words = value.split() - for word in words: - if not (word.startswith("[") and word.endswith("]")): - return False - return True + return all(w.startswith("[") and w.endswith("]") for w in words) class KeywordFrequencyOnceChecker(Instruction): @@ -1745,16 +1727,6 @@ def build_description(self, *, keyword=None): Args: keyword: A string representing a keyword that is expected in the response. - frequency: An integer specifying the number of times `keyword` is expected - to appear in the response. - relation: A string in (`less than`, `at least`), defining the relational - operator for comparison. - Two relational comparisons are supported for now: - if 'less than', the actual number of occurrences < frequency; - if 'at least', the actual number of occurrences >= frequency. - - Returns: - A string representing the instruction description. """ if not keyword: self._keyword = instructions_util.generate_keywords(num_keywords=1)[0] @@ -1762,10 +1734,8 @@ def build_description(self, *, keyword=None): self._keyword = keyword.strip() self._frequency = 1 - self._description_pattern = "Include keyword {keyword} in your response." - - return self._description_pattern.format(keyword=self._keyword, frequency=self._frequency) + return self._description_pattern.format(keyword=self._keyword) def get_instruction_args(self): """Returns the keyward args of `build_description`.""" @@ -1776,13 +1746,9 @@ def get_instruction_args_keys(self): return ["keyword"] def check_following(self, value): - """Checks if the response contain the keyword with required frequency.""" - actual_occurrences = len(re.findall(self._keyword, value, flags=re.IGNORECASE)) - - if actual_occurrences == 1: - return True - else: - return False + """Checks if the response contain the keyword exactly once.""" + pattern = r"\b" + re.escape(self._keyword) + r"\b" + return len(re.findall(pattern, value, flags=re.IGNORECASE)) == 1 class KeywordFrequencyCheckerDifferent(Instruction): @@ -1793,40 +1759,25 @@ def build_description(self, *, keyword=None, frequency=None, relation=None): Args: keyword: A string representing a keyword that is expected in the response. - frequency: An integer specifying the number of times `keyword` is expected - to appear in the response. - relation: A string in (`less than`, `at least`), defining the relational - operator for comparison. - Two relational comparisons are supported for now: - if 'less than', the actual number of occurrences < frequency; - if 'at least', the actual number of occurrences >= frequency. - - Returns: - A string representing the instruction description. + frequency: An integer specifying the number of times `keyword` is expected to appear. + relation: One of ('less than', 'at least'). """ if not keyword: self._keyword = instructions_util.generate_keywords(num_keywords=1)[0] else: self._keyword = keyword.strip() - self._frequency = frequency - if self._frequency is None or self._frequency < 0: - self._frequency = random.randint(1, _KEYWORD_FREQUENCY) + self._frequency = frequency if (frequency is not None and frequency >= 0) else random.randint(1, _KEYWORD_FREQUENCY) if relation is None: self._comparison_relation = random.choice(_COMPARISON_RELATION) elif relation not in _COMPARISON_RELATION: - raise ValueError( - f"The supported relation for comparison must be in {_COMPARISON_RELATION}, but {relation} is given." - ) + raise ValueError(f"The supported relation for comparison must be in {_COMPARISON_RELATION}, but {relation} is given.") else: self._comparison_relation = relation self._description_pattern = "In your response, the word {keyword} should appear {frequency} times." - - return self._description_pattern.format( - keyword=self._keyword, relation=self._comparison_relation, frequency=self._frequency - ) + return self._description_pattern.format(keyword=self._keyword, frequency=self._frequency) def get_instruction_args(self): """Returns the keyward args of `build_description`.""" @@ -1837,27 +1788,23 @@ def get_instruction_args_keys(self): return ["keyword", "frequency", "relation"] def check_following(self, value): - """Checks if the response contain the keyword with required frequency.""" - actual_occurrences = len(re.findall(self._keyword, value, flags=re.IGNORECASE)) - + """Checks if the response contain the keyword with required frequency under the relation.""" + pattern = r"\b" + re.escape(self._keyword) + r"\b" + actual = len(re.findall(pattern, value, flags=re.IGNORECASE)) if self._comparison_relation == _COMPARISON_RELATION[0]: - return actual_occurrences < self._frequency - elif self._comparison_relation == _COMPARISON_RELATION[1]: - return actual_occurrences >= self._frequency # pytype: disable=bad-return-type + return actual < self._frequency + else: + return actual >= self._frequency class ExcludeWordHarderChecker(Instruction): - """Checks that specified words are not used in response.""" + """Checks that a specified word is not used in response.""" def build_description(self, keyword=None, instruction=None): """Build the instruction description. Args: - forbidden_words: A sequences of strings respresenting words that are not - allowed in the response. - - Returns: - A string representing the instruction description. + keyword: word to exclude. If None, pick a random token from `instruction`. """ if not keyword: self._keyword = random.choice(instruction.split()) @@ -1865,7 +1812,6 @@ def build_description(self, keyword=None, instruction=None): self._keyword = keyword.strip() self._description_pattern = "Do not include keyword {keyword} in the response." - return self._description_pattern.format(keyword=self._keyword) def get_instruction_args(self): @@ -1877,59 +1823,33 @@ def get_instruction_args_keys(self): return ["keyword"] def check_following(self, value): - """Check if the response does not contain the expected keywords.""" - if " " + self._keyword + " " in value: - return False - return True + pattern = r"\b" + re.escape(self._keyword) + r"\b" + return re.search(pattern, value, flags=re.IGNORECASE) is None class ParagraphBasicChecker(Instruction): """Checks the paragraphs.""" def build_description(self): - """Build the instruction description. - - Args: - num_paragraphs: An integer specifying the number of paragraphs. - - Returns: - A string representing the instruction description. - """ - self._description_pattern = ( - "There should be 2 paragraphs. " + "Paragraphs are separated with the markdown divider: ***" - ) - + """Build the instruction description.""" + self._description_pattern = "There should be 2 paragraphs. Paragraphs are separated with the markdown divider: ***" return self._description_pattern def get_instruction_args(self): - """Returns the keyward args of `build_description`.""" return {} def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" return [] def check_following(self, value): - """Checks the response contains required number of paragraphs. - - Args: - value: A string representing the response. The response may contain - paragraphs that are separated by the markdown divider: `***`. - - Returns: - True if the actual number of paragraphs is the same as required; - otherwise, False. - """ paragraphs = re.split(r"\s?\*\*\*\s?", value) num_paragraphs = len(paragraphs) - - for index, paragraph in enumerate(paragraphs): - if not paragraph.strip(): - if index == 0 or index == len(paragraphs) - 1: + for i, p in enumerate(paragraphs): + if not p.strip(): + if i in (0, len(paragraphs) - 1): num_paragraphs -= 1 else: return False - return num_paragraphs == 2 @@ -1937,47 +1857,27 @@ class ParagraphBasicChecker2(Instruction): """Checks the paragraphs.""" def build_description(self): - """Build the instruction description. - - Args: - num_paragraphs: An integer specifying the number of paragraphs. - - Returns: - A string representing the instruction description. - """ - self._description_pattern = "There should be 2 paragraphs. Paragraphs and only paragraphs are separated with each other by two line breaks. " - - return self._description_pattern.format() + """Build the instruction description.""" + self._description_pattern = ( + "There should be 2 paragraphs. Paragraphs and only paragraphs are separated with each other by two line breaks. " + ) + return self._description_pattern def get_instruction_args(self): - """Returns the keyward args of `build_description`.""" return {} def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" return [] def check_following(self, value): - """Checks the response contains required number of paragraphs. - - Args: - value: A string representing the response. The response may contain - paragraphs that are separated by the markdown divider: `***`. - - Returns: - True if the actual number of paragraphs is the same as required; - otherwise, False. - """ paragraphs = re.split(r"\n\n", value) num_paragraphs = len(paragraphs) - - for index, paragraph in enumerate(paragraphs): - if not paragraph.strip(): - if index == 0 or index == len(paragraphs) - 1: + for i, p in enumerate(paragraphs): + if not p.strip(): + if i in (0, len(paragraphs) - 1): num_paragraphs -= 1 else: return False - return num_paragraphs == 2 @@ -1985,115 +1885,64 @@ class FirstWordSentChecker(Instruction): """The first word of each sentence should be the word {first_word}.""" def build_description(self, first_word=None): - """Build the instruction description. - - Args: - first_word: A string representing the first word of each sentence. - - Returns: - A string representing the instruction description. - """ + """Build the instruction description.""" if not first_word: self._first_word = instructions_util.generate_keywords(num_keywords=1)[0] else: if not isinstance(first_word, str): - self._first_word == first_word[0].strip() + self._first_word = first_word[0].strip() else: self._first_word = first_word.strip() - self._description_pattern = "The first word of each sentence should be the word {first_word}." - return self._description_pattern.format(first_word=self._first_word) def get_instruction_args(self): - """Returns the keyward args of `build_description`.""" return {"first_word": self._first_word} def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" return ["first_word"] def check_following(self, value): - """Checks if the first word of each sentence is the expected word. - - Args: - value: A string representing the response. - - Returns: - True if the first word of each sentence is the expected word; - otherwise, False. - """ sentences = instructions_util.split_into_sentences(value) - - # Check if the first word of each sentence matches the expected word - for sentence in sentences: - if not sentence.strip(): + for s in sentences: + if not s.strip(): return False - first_word = sentence.split()[0].strip() - if first_word.lower() != self._first_word.lower(): + fw = s.split()[0].strip() + if fw.lower() != self._first_word.lower(): return False return True class FirstWordAnswerChecker(Instruction): - """The first word of each sentence should be the word {first_word}.""" + """The first word of your response should be the word {first_word}.""" def build_description(self, first_word=None): - """Build the instruction description. - - Args: - first_word: A string representing the first word of each sentence. - - Returns: - A string representing the instruction description. - """ + """Build the instruction description.""" if not first_word: self._first_word = instructions_util.generate_keywords(num_keywords=1)[0] else: self._first_word = first_word.strip() - self._description_pattern = "The first word of your response should be the word {first_word}." - return self._description_pattern.format(first_word=self._first_word) def get_instruction_args(self): - """Returns the keyward args of `build_description`.""" return {"first_word": self._first_word} def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" return ["first_word"] def check_following(self, value): - """Checks if the first word of each sentence is the expected word. - - Args: - value: A string representing the response. - - Returns: - True if the first word of each sentence is the expected word; - otherwise, False. - """ - if not value.strip() or len(value.split()) == 0: + if not value.strip(): return False - first_word = value.split()[0].strip() - if first_word.lower() != self._first_word.lower(): - return False - return True + fw = value.split()[0].strip() + return fw.lower() == self._first_word.lower() class LastWordSentChecker(Instruction): """The last word of each sentence should be the word {last_word}.""" def build_description(self, last_word=None): - """Build the instruction description. - - Args: - first_word: A string representing the last word of each sentence. - - Returns: - A string representing the instruction description. - """ + """Build the instruction description.""" if not last_word: self._last_word = instructions_util.generate_keywords(num_keywords=1)[0] else: @@ -2105,35 +1954,20 @@ def build_description(self, last_word=None): self._description_pattern = ( "The last word of each sentence, before punctuation, should be the word {last_word}." ) - return self._description_pattern.format(last_word=self._last_word) def get_instruction_args(self): - """Returns the keyward args of `build_description`.""" return {"last_word": self._last_word} def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" return ["last_word"] def check_following(self, value): - """Checks if the first word of each sentence is the expected word. - - Args: - value: A string representing the response. - - Returns: - True if the first word of each sentence is the expected word; - otherwise, False. - """ sentences = instructions_util.split_into_sentences(value) - - # Check if the first word of each sentence matches the expected word - for sentence in sentences: - if not sentence.strip(): + for s in sentences: + if not s.strip(): return False - last_word = sentence.split()[-1].strip() - # remove any punctuation from last_word + last_word = s.split()[-1].strip() last_word = re.sub(r"[^\w\s]", "", last_word) if last_word.lower() != self._last_word.lower(): return False @@ -2144,51 +1978,36 @@ class LastWordAnswerChecker(Instruction): """The last word of your response should be the word {last_word}.""" def build_description(self, last_word=None): - """Build the instruction description. - - Args: - first_word: A string representing the last word of each sentence. - - Returns: - A string representing the instruction description. - """ + """Build the instruction description.""" if not last_word: self._last_word = instructions_util.generate_keywords(num_keywords=1)[0] else: self._last_word = last_word.strip() self._description_pattern = "The last word of your response should be the word {last_word}." - return self._description_pattern.format(last_word=self._last_word) def get_instruction_args(self): - """Returns the keyward args of `build_description`.""" return {"last_word": self._last_word} def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" return ["last_word"] def check_following(self, value): - """Checks if the first word of each sentence is the expected word. - - Args: - value: A string representing the response. + """Checks if the last word of the entire response equals the expected word. - Returns: - True if the first word of each sentence is the expected word; - otherwise, False. + More robust: normalize the string using NFKC; split using non-alphanumeric delimiters. """ - last_word = value.split()[-1].strip() - # remove any punctuation from last_word - last_word = re.sub(r"[^\w\s]", "", last_word) - if last_word.lower() != self._last_word.lower(): + norm = unicodedata.normalize("NFKC", value).strip() + tokens = re.split(r"[^A-Za-z0-9]+", norm) + tokens = [t for t in tokens if t] + if not tokens: return False - return True + return tokens[-1].lower() == self._last_word.lower() class BiGramWrappingChecker(Instruction): - "Wrap every word bigram in double angular brackets, such as <> <> <> <>." + """Wrap every word bigram in double angular brackets, such as <> <> <> <>.""" def build_description(self): """Build the instruction description.""" @@ -2201,119 +2020,98 @@ def get_instruction_args(self): return None def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" return [] def check_following(self, value): """Checks if every word bigram is enclosed within double angular brackets.""" words = value.split() - for i in range(0, len(words) - 1, 2): - if i + 1 < len(words): - if not (words[i].startswith("<<") and words[i + 1].endswith(">>")): - return False + if len(words) % 2 != 0: + return False + # Expect pairs like ["<>", "<>", ...] + for i in range(0, len(words), 2): + left, right = words[i], words[i + 1] + if not (left.startswith("<<") and right.endswith(">>")): + return False + # Ensure only the boundary tokens carry the brackets + if left.endswith(">>") or right.startswith("<<"): + return False return True class CopyingSimpleChecker(Instruction): - "Repeat the request without change (do not say anything before repeating the request; the request you need to repeat does not include this sentence) and do not answer the actual request!" + """Repeat the request without change and do not answer the actual request.""" def build_description(self, prompt_to_repeat=None): - """Build the instruction description. - - Args: - prompt_to_repeat: The prompt that is meant to be repeated. - - Returns: - A string representing the instruction description. - """ + """Build the instruction description.""" if not prompt_to_repeat: raise ValueError("prompt_to_repeat must be set.") - else: - self._prompt_to_repeat = prompt_to_repeat - self._description_pattern = "Repeat the request without change (do not say anything before repeating the request; the request you need to repeat does not include this sentence) and do not answer the actual request!" + self._prompt_to_repeat = prompt_to_repeat + self._description_pattern = ( + "Repeat the request without change (do not say anything before repeating the request; " + "the request you need to repeat does not include this sentence) and do not answer the actual request!" + ) return self._description_pattern def get_instruction_args(self): return {"prompt_to_repeat": self._prompt_to_repeat} def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" return ["prompt_to_repeat"] def check_following(self, value): - if value.strip().lower() == self._prompt_to_repeat.strip().lower(): - return True - return False + return value.strip().lower() == self._prompt_to_repeat.strip().lower() class CopyingMultipleChecker(Instruction): - "Repeat the request without change {N} times, separated by 6 asterisk symbols (do not say anything before repeating the request; the request you need to repeat does not include this sentence) and do not answer the actual request!" + """Repeat the request without change N times, separated by 6 asterisks.""" def build_description(self, prompt_to_repeat=None, N=None): - """Build the instruction description. - - Args: - prompt_to_repeat: The prompt that is meant to be repeated. - N: An integer representing the number of times to repeat the phrase. - - Returns: - A string representing the instruction description. - """ + """Build the instruction description.""" if not prompt_to_repeat: raise ValueError("prompt_to_repeat must be set.") - else: - self._prompt_to_repeat = prompt_to_repeat - if not N: - self._N = random.randint(2, 3) - else: - self._N = N - self._description_pattern = "Repeat the request without change {N} times, separated by 6 asterisk symbols (do not say anything before repeating the request; the request you need to repeat does not include this sentence) and do not answer the actual request!" + self._prompt_to_repeat = prompt_to_repeat + self._N = N if N else random.randint(2, 3) + self._description_pattern = ( + "Repeat the request without change {N} times, separated by 6 asterisk symbols " + "(do not say anything before repeating the request; the request you need to repeat does not include this sentence) " + "and do not answer the actual request!" + ) return self._description_pattern.format(N=self._N) def get_instruction_args(self): return {"prompt_to_repeat": self._prompt_to_repeat, "N": self._N} def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" return ["prompt_to_repeat", "N"] def check_following(self, value): - prompts = value.split("******") - if len(prompts) != self._N: + parts = value.split("******") + if len(parts) != self._N: return False - for prompt in prompts: - if prompt.strip().lower() != self._prompt_to_repeat.strip().lower(): - return False - return True + return all(p.strip().lower() == self._prompt_to_repeat.strip().lower() for p in parts) class PunctuationDotChecker(Instruction): - "In your entire response, refrain from the use of . (i.e. dots) as punctuation and in general." + """In your entire response, refrain from the use of . (dots).""" def build_description(self): - """Build the instruction description.""" - self._description_pattern = ( - "In your entire response, refrain from the use of . (i.e. dots) as punctuation and in general." - ) + self._description_pattern = "In your entire response, refrain from the use of . (i.e. dots) as punctuation and in general." return self._description_pattern def get_instruction_args(self): return None def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" return [] def check_following(self, value): - """Checks that the response does not contain dots.""" - return not re.search(r"\.", value) + return "." not in value class PunctuationExclamationChecker(Instruction): - "In your entire response, refrain from the use of ! (i.e. exclamation marks) as punctuation and in general." + """In your entire response, refrain from the use of ! (exclamation marks).""" def build_description(self): - """Build the instruction description.""" self._description_pattern = "In your entire response, refrain from the use of ! (i.e. exclamation marks) as punctuation and in general." return self._description_pattern @@ -2321,30 +2119,17 @@ def get_instruction_args(self): return None def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" return [] def check_following(self, value): - """Checks that the response does not contain exclamation marks.""" - return not re.search(r"\!", value) + return "!" not in value class LowercaseCountingChecker(Instruction): - "In your response, all lowercase words should appear at most {N} times." + """In your response, all lowercase words should appear at most N times.""" def build_description(self, N=None): - """Build the instruction description. - - Args: - N: An integer representing the maximum number of lowercase words allowed. - - Returns: - A string representing the instruction description. - """ - if not N: - self._N = random.randint(2, 3) - else: - self._N = N + self._N = N if N is not None else random.randint(2, 3) self._description_pattern = "In your response, all lowercase words should appear at most {N} times." return self._description_pattern.format(N=self._N) @@ -2352,38 +2137,19 @@ def get_instruction_args(self): return {"N": self._N} def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" return ["N"] def check_following(self, value): - """Checks that the response does not contain lowercase words more than N times.""" lowercase_words = re.findall(r"\b[a-z]+\b", value) - if len(lowercase_words) <= self._N: - return True - else: - return False + return len(lowercase_words) <= self._N class LetterCountingChecker(Instruction): - "Answer with {relation} {N} letters." + """Answer with {relation} {N} letters.""" def build_description(self, N=None, relation=None): - """Build the instruction description. - - Args: - N: An integer representing the maximum number of letters allowed. - - Returns: - A string representing the instruction description. - """ - if not N: - self._N = random.randint(2, 3) - else: - self._N = N - if not relation: - self._relation = random.choice(_COMPARISON_RELATION) - else: - self._relation = relation + self._N = N if N is not None else random.randint(2, 3) + self._relation = relation if relation else random.choice(_COMPARISON_RELATION) self._description_pattern = "Answer with {relation} {N} letters." return self._description_pattern.format(N=self._N, relation=self._relation) @@ -2391,95 +2157,57 @@ def get_instruction_args(self): return {"N": self._N, "relation": self._relation} def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" return ["N", "relation"] def check_following(self, value): - """Checks that the response does not contain lowercase words more than N times.""" - letters = re.findall(r"[a-zA-Z]", value) + letters = re.findall(r"[A-Za-z]", value) if self._relation == "at least": - if len(letters) >= self._N: - return True - else: - return False - elif self._relation == "less than": - if len(letters) < self._N: - return True - else: - return False + return len(letters) >= self._N + else: + return len(letters) < self._N class CountingCompositionChecker(Instruction): - "Write 3 paragraphs, delimited by the markdown divider: * * *, with exactly {n_sent} sentences each, with exactly {n_words} words in each sentence." + """Write 3 paragraphs, delimited by ***, with exactly n_sent sentences each, and n_words words per sentence.""" def build_description(self, n_sent=None, n_words=None): - """Build the instruction description. - - Args: - n_sent: An integer representing the number of sentences in each paragraph. - n_words: An integer representing the number of words in each sentence. - - Returns: - A string representing the instruction description. - """ - if not n_sent: - self._n_sent = random.randint(2, 3) - else: - self._n_sent = n_sent - if not n_words: - self._n_words = random.randint(2, 3) - else: - self._n_words = n_words - self._description_pattern = "Write 3 paragraphs, delimited by the markdown divider: * * *, with exactly {n_sent} sentences each, with exactly {n_words} words in each sentence." + self._n_sent = n_sent if n_sent is not None else random.randint(2, 3) + self._n_words = n_words if n_words is not None else random.randint(2, 3) + self._description_pattern = ( + "Write 3 paragraphs, delimited by the markdown divider: * * *, with exactly {n_sent} sentences each, " + "with exactly {n_words} words in each sentence." + ) return self._description_pattern.format(n_sent=self._n_sent, n_words=self._n_words) def get_instruction_args(self): return {"n_sent": self._n_sent, "n_words": self._n_words} def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" return ["n_sent", "n_words"] def check_following(self, value): - """Checks that the response contains the expected number of paragraphs, sentences, and words. - - Args: - value: A string representing the response. - - Returns: - True if the response meets the requirements; otherwise, False. - """ paragraphs = re.split(r"\s?\*\*\*\s?", value) num_paragraphs = len(paragraphs) - - for index, paragraph in enumerate(paragraphs): - if not paragraph.strip(): - if index == 0 or index == len(paragraphs) - 1: + for idx, para in enumerate(paragraphs): + if not para.strip(): + if idx in (0, len(paragraphs) - 1): num_paragraphs -= 1 else: return False - - sentences = instructions_util.split_into_sentences(paragraph) - num_sentences = len(sentences) - - if num_sentences != self._n_sent: + sentences = instructions_util.split_into_sentences(para) + if len(sentences) != self._n_sent: return False - - for sentence in sentences: - words = instructions_util.nltk.word_tokenize(sentence) - num_words = len(words) - - if num_words != self._n_words: + for s in sentences: + words = instructions_util.nltk.word_tokenize(s) + if len(words) != self._n_words: return False - return num_paragraphs == 3 class CountUniqueChecker(Instruction): - "Only use unique words in your response, no word should be repeated!" + """Only use unique words in your response, no word should be repeated!""" def build_description(self): - """Build the instruction description.""" self._description_pattern = "Only use unique words in your response, no word should be repeated!" return self._description_pattern @@ -2487,76 +2215,46 @@ def get_instruction_args(self): return None def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" return [] def check_following(self, value): - """Checks that the response contains unique words.""" words = instructions_util.nltk.word_tokenize(value) - unique_words = set(words) - return len(words) == len(unique_words) + return len(words) == len(set(words)) class CountIncrementWordChecker(Instruction): - "Include keyword {keyword1} once in your response, keyword {keyword2} twice in your response." + """Include keyword {keyword1} once, and {keyword2} twice.""" def build_description(self, keyword1=None, keyword2=None): - """Build the instruction description. - - Args: - keyword1: A string representing a keyword that is expected in the response. - keyword2: A string representing a keyword that is expected in the response. - - Returns: - A string representing the instruction description. - """ if not keyword1: - self._keyword1 = instructions_util.generate_keywords(num_keywords=1) + self._keyword1 = instructions_util.generate_keywords(num_keywords=1)[0] else: self._keyword1 = keyword1.strip() if not keyword2: - self._keyword2 = instructions_util.generate_keywords(num_keywords=1) + self._keyword2 = instructions_util.generate_keywords(num_keywords=1)[0] else: self._keyword2 = keyword2.strip() - self._description_pattern = ( "Include keyword {keyword1} once in your response, keyword {keyword2} twice in your response." ) - return self._description_pattern.format(keyword1=self._keyword1, keyword2=self._keyword2) def get_instruction_args(self): - """Returns the keyward args of `build_description`.""" return {"keyword1": self._keyword1, "keyword2": self._keyword2} def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" return ["keyword1", "keyword2"] def check_following(self, value): - """Checks if the response contains the expected number of keywords. - - Args: - value: A string representing the response. - - Returns: - True if the response contains the expected number of keywords; - otherwise, False. - """ - actual_occurrences1 = len(re.findall(self._keyword1, value, flags=re.IGNORECASE)) - actual_occurrences2 = len(re.findall(self._keyword2, value, flags=re.IGNORECASE)) - - if actual_occurrences1 == 1 and actual_occurrences2 == 2: - return True - else: - return False + occ1 = len(re.findall(r"\b" + re.escape(self._keyword1) + r"\b", value, flags=re.IGNORECASE)) + occ2 = len(re.findall(r"\b" + re.escape(self._keyword2) + r"\b", value, flags=re.IGNORECASE)) + return occ1 == 1 and occ2 == 2 class PalindromeBasicChecker(Instruction): - "Include a palindrome in your response." + """Include a palindrome in your response.""" def build_description(self): - """Build the instruction description.""" self._description_pattern = "Include a palindrome in your response." return self._description_pattern @@ -2564,123 +2262,72 @@ def get_instruction_args(self): return None def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" return [] def check_following(self, value): - """Checks if the response contains a palindrome. - - Args: - value: A string representing the response. - - Returns: - True if the response contains a palindrome; otherwise, False. - """ - palindromes = [word for word in value.split() if word == word[::-1]] - return len(palindromes) > 0 + # simple word-level palindrome check + words = value.split() + return any(w == w[::-1] and len(w) > 1 for w in words) class KeywordSpecificPositionChecker(Instruction): - "Include keyword {keyword1} in the {n}-th sentence, as the {m}-th word of that sentence." + """Include keyword {keyword} in the n-th sentence, as the m-th word of that sentence.""" def build_description(self, keyword=None, n=None, m=None): - """Build the instruction description. - - Args: - keyword: A string representing a keyword that is expected in the response. - n: An integer representing the sentence number. - m: An integer representing the word number. - - Returns: - A string representing the instruction description. - """ if not keyword: self._keyword = instructions_util.generate_keywords(num_keywords=1)[0] else: - if not isinstance(keyword, str): - self._keyword = keyword[0].strip() - else: - self._keyword = keyword.strip() - if not n: - self._n = random.randint(1, 20) - else: - self._n = n - if not m: - self._m = random.randint(1, 30) - else: - self._m = m - - self._description_pattern = ( - "Include keyword {keyword} in the {n}-th sentence, as the {m}-th word of that sentence." - ) - + self._keyword = keyword.strip() + self._n = n if n else random.randint(1, 20) + self._m = m if m else random.randint(1, 30) + self._description_pattern = "Include keyword {keyword} in the {n}-th sentence, as the {m}-th word of that sentence." return self._description_pattern.format(keyword=self._keyword, n=self._n, m=self._m) def get_instruction_args(self): - """Returns the keyward args of `build_description`.""" return {"keyword": self._keyword, "n": self._n, "m": self._m} def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" return ["keyword", "n", "m"] def check_following(self, value): - """Checks if the response contains the expected number of keywords. - - Args: - value: A string representing the response. - - Returns: - True if the response contains the expected number of keywords; - otherwise, False. - """ sentences = instructions_util.split_into_sentences(value) if len(sentences) < self._n: return False - words = instructions_util.nltk.word_tokenize(sentences[self._n - 1]) + sent = sentences[self._n - 1] + words = [w for w in instructions_util.nltk.word_tokenize(sent) if w.strip()] if len(words) < self._m: return False - if words[self._m - 1] == self._keyword: - return True - else: - return False + target = re.sub(r"\W+$", "", words[self._m - 1]).lower() + return target == self._keyword.lower() class StartEndChecker(Instruction): - "Start and end your response with the same word (do not write anything after the last word, not even punctuation)." + """Start and end your response with the same word (no trailing punctuation after the last word).""" def build_description(self): - """Build the instruction description.""" - self._description_pattern = "Start and end your response with the same word (do not write anything after the last word, not even punctuation)." + self._description_pattern = ( + "Start and end your response with the same word (do not write anything after the last word, not even punctuation)." + ) return self._description_pattern def get_instruction_args(self): return None def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" return [] def check_following(self, value): """Checks if the response starts and ends with the same word. - Args: - value: A string representing the response. - - Returns: - True if the response starts and ends with the same word; - otherwise, False. + If hyphen-connected style is used (e.g., a-b-c), treat hyphen-separated tokens as words. + Otherwise, fall back to tokenizer. """ - # 对于连字符格式,使用连字符分割 - if "-" in value: + if "-" in value and " " not in value.strip(): words = value.split("-") else: - # 对于普通格式,使用NLTK分词 words = instructions_util.nltk.word_tokenize(value) - + words = [w for w in words if w.strip()] if len(words) < 2: return False - if words[0].lower() == words[-1].lower(): - return True - else: - return False \ No newline at end of file + return words[0].lower() == words[-1].lower() + diff --git a/verl/utils/reward_score/ifbench/split_fixed_data.py b/verl/utils/reward_score/ifbench/split_fixed_data.py new file mode 100644 index 000000000..31ab4e295 --- /dev/null +++ b/verl/utils/reward_score/ifbench/split_fixed_data.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 +""" +Split the fixed IFBench data into train and test sets. +""" + +import pandas as pd +import os + +def split_fixed_data(): + """Split the fixed IFBench data into 90% train and 10% test.""" + + # Input file (fixed data) + input_file = "/mnt/sharefs/users/jianshu.she/ood__ifbench_95.1k_fixed.parquet" + + # Output directory + output_dir = "/mnt/sharefs/users/jianshu.she/ifbench_split" + os.makedirs(output_dir, exist_ok=True) + + # Output files + train_file = os.path.join(output_dir, "ifbench_train_fixed.parquet") + test_file = os.path.join(output_dir, "ifbench_test_fixed.parquet") + + print(f"Reading fixed data from: {input_file}") + df = pd.read_parquet(input_file) + print(f"Total rows: {len(df)}") + + # Split data: 90% train, 10% test + train_size = int(0.9 * len(df)) + test_size = len(df) - train_size + + print(f"Splitting into {train_size} train samples and {test_size} test samples") + + # Split the dataframe + train_df = df.iloc[:train_size] + test_df = df.iloc[train_size:] + + # Save split data + train_df.to_parquet(train_file, index=False) + test_df.to_parquet(test_file, index=False) + + print(f"Train data saved to: {train_file}") + print(f"Test data saved to: {test_file}") + + # Verify the split + print(f"\nVerification:") + print(f"Train file rows: {len(pd.read_parquet(train_file))}") + print(f"Test file rows: {len(pd.read_parquet(test_file))}") + + # Check extra_info in split files + train_data = pd.read_parquet(train_file) + test_data = pd.read_parquet(test_file) + + print(f"\nTrain file - None values in extra_info: {train_data['extra_info'].isna().sum()}") + print(f"Test file - None values in extra_info: {test_data['extra_info'].isna().sum()}") + +if __name__ == "__main__": + split_fixed_data() \ No newline at end of file diff --git a/verl/utils/reward_score/ifbench/test_ifbench.py b/verl/utils/reward_score/ifbench/test_ifbench.py new file mode 100644 index 000000000..c57586bc9 --- /dev/null +++ b/verl/utils/reward_score/ifbench/test_ifbench.py @@ -0,0 +1,55 @@ +# test_ifbench.py +import sys +import os + +# Add project root directory to Python path +project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../../../')) +sys.path.insert(0, project_root) + +from .instructions_registry import INSTRUCTION_DICT +from . import compute_score + +# Test data +example = { + "data_source": "ood__ifbench", + "prompt": [ + {"role": "user", "content": "If \"A woman in pink is sitting and enjoying the object she is holding.\" does that mean that \"Woman sitting.\"?\nOptions:\n- yes\n- it is not possible to tell\n- no Stream of thoughts: There should be 4 paragraphs. Paragraphs are separated with the markdown divider: *** In your response, the letter s should appear at least 20 times. Include keyword flight in your response."} + ], + "reward_model": { + "style": "rule", + "ground_truth": "[{'instruction_id': ['length_constraints:number_paragraphs', 'letters:letter_counting2', 'keywords:word_once'], 'kwargs': [{'num_paragraphs': 4}, {'letter': 's', 'let_frequency': 20, 'let_relation': 'at least'}, {'keyword': 'flight'}]}]" + } +} + +# Solution that meets the requirements +solution = """yes + +The statement clearly indicates that a woman is sitting, which directly answers the question about whether a woman is sitting. + +*** + +The woman in pink is specifically mentioned as sitting and enjoying an object, which provides sufficient information to conclude that a woman is indeed sitting in this scenario. + +*** + +The keyword flight appears in this context as we discuss the woman's position and activities, demonstrating how various elements can be connected in meaningful ways. + +*** + +The letter s appears multiple times throughout this response, satisfying the requirement for frequent usage while maintaining coherent and logical content structure.""" + +# Compute score +result = compute_score( + solution_str=solution, + ground_truth=example["reward_model"]["ground_truth"], + extra_info=None +) + +print(f"acc: {result['acc']}") +print(f"reward: {result['score']}") + +# Verify requirements +print(f"\nVerification of requirements:") +print(f"Number of paragraphs: {solution.count('***') + 1}") +print(f"Number of occurrences of letter 's': {solution.count('s')}") +print(f"Contains 'flight': {'flight' in solution.lower()}") \ No newline at end of file diff --git a/verl/utils/reward_score/math.py b/verl/utils/reward_score/math.py index 95970492c..3fff7bc04 100644 --- a/verl/utils/reward_score/math.py +++ b/verl/utils/reward_score/math.py @@ -210,7 +210,8 @@ def strip_string(string): # remove spaces string = string.replace(" ", "") - # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} + # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). + # Also does a/b --> \\frac{a}{b} string = fix_fracs(string) # manually change 0.5 --> \frac{1}{2} diff --git a/verl/utils/reward_score/math_bak.py b/verl/utils/reward_score/math_bak.py new file mode 100644 index 000000000..3fff7bc04 --- /dev/null +++ b/verl/utils/reward_score/math_bak.py @@ -0,0 +1,224 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py + + +def compute_score(solution_str, ground_truth) -> float: + retval = 0.0 + try: + string_in_last_boxed = last_boxed_only_string(solution_str) + if string_in_last_boxed is not None: + answer = remove_boxed(string_in_last_boxed) + if is_equiv(answer, ground_truth): + retval = 1.0 + except Exception as e: + print(e) + + return retval + + +# string normalization from https://github.com/EleutherAI/lm-evaluation-harness/blob/master/lm_eval/tasks/hendrycks_math.py +def is_equiv(str1, str2, verbose=False): + if str1 is None and str2 is None: + print("WARNING: Both None") + return True + if str1 is None or str2 is None: + return False + + try: + ss1 = strip_string(str1) + ss2 = strip_string(str2) + if verbose: + print(ss1, ss2) + return ss1 == ss2 + except Exception: + return str1 == str2 + + +def remove_boxed(s): + if "\\boxed " in s: + left = "\\boxed " + assert s[: len(left)] == left + return s[len(left) :] + + left = "\\boxed{" + + assert s[: len(left)] == left + assert s[-1] == "}" + + return s[len(left) : -1] + + +def last_boxed_only_string(string): + idx = string.rfind("\\boxed") + if "\\boxed " in string: + return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0] + if idx < 0: + idx = string.rfind("\\fbox") + if idx < 0: + return None + + i = idx + right_brace_idx = None + num_left_braces_open = 0 + while i < len(string): + if string[i] == "{": + num_left_braces_open += 1 + if string[i] == "}": + num_left_braces_open -= 1 + if num_left_braces_open == 0: + right_brace_idx = i + break + i += 1 + + retval = None if right_brace_idx is None else string[idx : right_brace_idx + 1] + + return retval + + +def fix_fracs(string): + substrs = string.split("\\frac") + new_str = substrs[0] + if len(substrs) > 1: + substrs = substrs[1:] + for substr in substrs: + new_str += "\\frac" + if substr[0] == "{": + new_str += substr + else: + try: + assert len(substr) >= 2 + except: # noqa: E722 + return string + a = substr[0] + b = substr[1] + if b != "{": + if len(substr) > 2: + post_substr = substr[2:] + new_str += "{" + a + "}{" + b + "}" + post_substr + else: + new_str += "{" + a + "}{" + b + "}" + else: + if len(substr) > 2: + post_substr = substr[2:] + new_str += "{" + a + "}" + b + post_substr + else: + new_str += "{" + a + "}" + b + string = new_str + return string + + +def fix_a_slash_b(string): + if len(string.split("/")) != 2: + return string + a = string.split("/")[0] + b = string.split("/")[1] + try: + a = int(a) + b = int(b) + assert string == "{}/{}".format(a, b) + new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" + return new_string + except: # noqa: E722 + return string + + +def remove_right_units(string): + # "\\text{ " only ever occurs (at least in the val set) when describing units + if "\\text{ " in string: + splits = string.split("\\text{ ") + assert len(splits) == 2 + return splits[0] + else: + return string + + +def fix_sqrt(string): + if "\\sqrt" not in string: + return string + splits = string.split("\\sqrt") + new_string = splits[0] + for split in splits[1:]: + if split[0] != "{": + a = split[0] + new_substr = "\\sqrt{" + a + "}" + split[1:] + else: + new_substr = "\\sqrt" + split + new_string += new_substr + return new_string + + +def strip_string(string): + # linebreaks + string = string.replace("\n", "") + + # remove inverse spaces + string = string.replace("\\!", "") + + # replace \\ with \ + string = string.replace("\\\\", "\\") + + # replace tfrac and dfrac with frac + string = string.replace("tfrac", "frac") + string = string.replace("dfrac", "frac") + + # remove \left and \right + string = string.replace("\\left", "") + string = string.replace("\\right", "") + + # Remove circ (degrees) + string = string.replace("^{\\circ}", "") + string = string.replace("^\\circ", "") + + # remove dollar signs + string = string.replace("\\$", "") + + # remove units (on the right) + string = remove_right_units(string) + + # remove percentage + string = string.replace("\\%", "") + string = string.replace("\%", "") # noqa: W605 + + # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string + string = string.replace(" .", " 0.") + string = string.replace("{.", "{0.") + # if empty, return empty string + if len(string) == 0: + return string + if string[0] == ".": + string = "0" + string + + # to consider: get rid of e.g. "k = " or "q = " at beginning + if len(string.split("=")) == 2 and len(string.split("=")[0]) <= 2: + string = string.split("=")[1] + + # fix sqrt3 --> sqrt{3} + string = fix_sqrt(string) + + # remove spaces + string = string.replace(" ", "") + + # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). + # Also does a/b --> \\frac{a}{b} + string = fix_fracs(string) + + # manually change 0.5 --> \frac{1}{2} + if string == "0.5": + string = "\\frac{1}{2}" + + # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y + string = fix_a_slash_b(string) + + return string diff --git a/verl/utils/reward_score/math_batch.py b/verl/utils/reward_score/math_batch.py index 48f73c281..ed080860a 100644 --- a/verl/utils/reward_score/math_batch.py +++ b/verl/utils/reward_score/math_batch.py @@ -20,4 +20,7 @@ def compute_score_batched(data_sources, solution_strs, ground_truths, extra_info This is a demonstration of how the batched reward function should look like. Typically, you want to use batched reward to speed up the process with parallelization """ - return [compute_score(solution_str, ground_truth) for solution_str, ground_truth in zip(solution_strs, ground_truths)] + return [ + compute_score(solution_str, ground_truth) + for solution_str, ground_truth in zip(solution_strs, ground_truths, strict=True) + ] diff --git a/verl/utils/reward_score/math_dapo.py b/verl/utils/reward_score/math_dapo.py index 33a699e56..940500fd5 100644 --- a/verl/utils/reward_score/math_dapo.py +++ b/verl/utils/reward_score/math_dapo.py @@ -62,7 +62,6 @@ def remove_boxed(s: str) -> str: return s[len(left) : -1] - # Constants for normalization SUBSTITUTIONS = [ ("an ", ""), @@ -163,7 +162,9 @@ def normalize_final_answer(final_answer: str) -> str: return final_answer.strip() -def is_correct_minerva(solution_str: str, gt: str, gt_need_extract: bool = False, answer_pattern: str = r"(?i)Answer\s*:\s*([^\n]+)") -> tuple[bool, str]: +def is_correct_minerva( + solution_str: str, gt: str, gt_need_extract: bool = False, answer_pattern: str = r"(?i)Answer\s*:\s*([^\n]+)" +) -> tuple[bool, str]: """Check if the solution is correct according to Minerva criteria. Args: @@ -189,7 +190,9 @@ def is_correct_minerva(solution_str: str, gt: str, gt_need_extract: bool = False return (pred == gt), pred -def is_correct_strict_box(pred: str, gt: str, pause_tokens_index: Optional[list[int]] = None) -> tuple[int, Optional[str]]: +def is_correct_strict_box( + pred: str, gt: str, pause_tokens_index: Optional[list[int]] = None +) -> tuple[int, Optional[str]]: """Check if the prediction is correct using strict boxed answer criteria. Args: @@ -214,7 +217,9 @@ def is_correct_strict_box(pred: str, gt: str, pause_tokens_index: Optional[list[ return 1 if (extracted_pred == gt) else -1, extracted_pred -def verify(solution_str: str, answer: str, strict_box_verify: bool = False, pause_tokens_index: Optional[list[int]] = None) -> bool: +def verify( + solution_str: str, answer: str, strict_box_verify: bool = False, pause_tokens_index: Optional[list[int]] = None +) -> bool: """Verify if the solution is correct. Args: diff --git a/verl/utils/reward_score/nemotron_stem.py b/verl/utils/reward_score/nemotron_stem.py new file mode 100644 index 000000000..f1f982f4d --- /dev/null +++ b/verl/utils/reward_score/nemotron_stem.py @@ -0,0 +1,90 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import re + + +def extract_solution(solution_str, method='strict'): + """ + Extract the final answer choice from an LLM's response to a multiple-choice nemotron_stem question. + + Args: + solution_str (str): The full text response from the LLM + method (str): 'strict' for exact format matching, 'flexible' for more lenient matching + + Returns: + str: The extracted answer choice (A, B, C, or D) or None if not found + """ + assert method in ['strict', 'flexible'] + + if method == 'strict': + # First try to find answer in boxed format + boxed_match = re.search(r"\\boxed\{([A-D])\}", solution_str) + if boxed_match: + return boxed_match.group(1) + + # Then try standard "Answer:" format + answer_match = re.search(r"(?i)Answer[ \t]*:[ \t]*\$?([A-D])\$?", solution_str) + if answer_match: + return answer_match.group(1) + + # Try to find single letter answers at the end + end_match = re.search(r"\b([A-D])\b(?!.*\b[A-D]\b)", solution_str) + if end_match: + return end_match.group(1) + + return None + + elif method == 'flexible': + # Look for answers in parentheses + answer = re.findall(r"\(([A-D])\)", solution_str) + if answer: + return answer[-1] # Return the last found answer + + # Look for boxed answers + boxed_answer = re.findall(r"\\boxed\{([A-D])\}", solution_str) + if boxed_answer: + return boxed_answer[-1] + + # Look for any A, B, C, D pattern + general_answer = re.findall(r"\b([A-D])\b", solution_str) + if general_answer: + return general_answer[-1] + + return None + + +def compute_score(solution_str, ground_truth, method='strict', format_score=0., score=1., extra_info=None): + """The scoring function for nemotron_stem dataset. + + Args: + solution_str: the solution text + ground_truth: the ground truth answer (A, B, C, or D) + method: the method to extract the solution, choices are 'strict' and 'flexible' + format_score: the score for the format when answer is extractable but wrong + score: the score for the correct answer + extra_info: additional information (not used in this implementation) + + Returns: + dict: A dictionary containing 'score' and 'acc' keys + """ + answer = extract_solution(solution_str=solution_str, method=method) + if answer is None: + return {'score': 0, 'acc': 0} + else: + if answer == ground_truth: + return {'score': score, 'acc': 1.} + else: + return {'score': format_score, 'acc': 0.} \ No newline at end of file diff --git a/verl/utils/reward_score/nemotron_stem_test.py b/verl/utils/reward_score/nemotron_stem_test.py new file mode 100644 index 000000000..96c815f5d --- /dev/null +++ b/verl/utils/reward_score/nemotron_stem_test.py @@ -0,0 +1,105 @@ +#!/usr/bin/env python3 + +import pandas as pd +from verl.utils.reward_score import default_compute_score +from verl.utils.reward_score.nemotron_stem import compute_score, extract_solution + + +def test_extract_solution(): + """Test the extract_solution function with various response formats.""" + print("Testing extract_solution function...") + + test_cases = [ + ("The answer is A", "A"), + ("Answer: B", "B"), + ("\\boxed{C}", "C"), + ("After careful analysis, the answer is D.", "D"), + ("I think (C) is correct", "C"), + ("No clear answer", None), + ("The final answer is \\boxed{A}.", "A"), + ("Answer: C", "C"), + ] + + for response, expected in test_cases: + result = extract_solution(response, method='strict') + print(f"Input: '{response}' -> Expected: {expected}, Got: {result}") + assert result == expected, f"Failed for '{response}': expected {expected}, got {result}" + + print("extract_solution tests passed!\n") + + +def test_compute_score(): + """Test the compute_score function.""" + print("Testing compute_score function...") + + # Test correct answer + result = compute_score("Answer: A", "A") + print(f"Correct answer test: {result}") + assert result == {'score': 1.0, 'acc': 1.0} + + # Test incorrect answer + result = compute_score("Answer: B", "A") + print(f"Incorrect answer test: {result}") + assert result == {'score': 0.0, 'acc': 0.0} + + # Test no extractable answer + result = compute_score("I don't know", "A") + print(f"No answer test: {result}") + assert result == {'score': 0, 'acc': 0} + + print("compute_score tests passed!\n") + + +def test_default_compute_score(): + """Test the default_compute_score function with nemotron_stem data source.""" + print("Testing default_compute_score with nemotron_stem...") + + # Test with stem_nemotron data source + result = default_compute_score("stem_nemotron", "Answer: C", "C") + print(f"stem_nemotron correct: {result}") + assert result == {'score': 1.0, 'acc': 1.0} + + result = default_compute_score("stem_nemotron", "Answer: A", "C") + print(f"stem_nemotron incorrect: {result}") + assert result == {'score': 0.0, 'acc': 0.0} + + # Test with nemotron_stem data source + result = default_compute_score("nemotron_stem", "\\boxed{B}", "B") + print(f"nemotron_stem correct: {result}") + assert result == {'score': 1.0, 'acc': 1.0} + + print("default_compute_score tests passed!\n") + + +def test_real_data(): + """Test with real nemotron_stem data.""" + print("Testing with real nemotron_stem data...") + + try: + # Load a sample of the test data + df = pd.read_parquet('/mnt/sharefs/users/jianshu.she/nemotron_stem/test_data_final.parquet') + sample = df.head(5) + + print(f"Testing with {len(sample)} samples from real data...") + + for idx, row in sample.iterrows(): + data_source = row['data_source'] + response = row['response'] + ground_truth = row['reward_model']['ground_truth'] + + # Test with our implementation + result = default_compute_score(data_source, response, ground_truth) + print(f"Sample {idx}: response='{response}', ground_truth='{ground_truth}', score={result}") + + print("Real data test completed!\n") + + except Exception as e: + print(f"Could not test with real data: {e}") + + +if __name__ == "__main__": + test_extract_solution() + test_compute_score() + test_default_compute_score() + test_real_data() + print("All tests passed!") \ No newline at end of file diff --git a/verl/utils/reward_score/prime_code/README.md b/verl/utils/reward_score/prime_code/README.md new file mode 100644 index 000000000..ddd445714 --- /dev/null +++ b/verl/utils/reward_score/prime_code/README.md @@ -0,0 +1,16 @@ +## LiveCodeBench + +### Introduction +[LiveCodeBench](https://github.com/LiveCodeBench/LiveCodeBench) provides holistic and contamination-free evaluation of coding capabilities of LLMs. Particularly, LiveCodeBench continuously collects new problems over time from contests across three competition platforms -- LeetCode, AtCoder, and CodeForces. + +### How to reproduce +Our evaluation is grounded on the version found in LiveCodeBench. +> **Installation** +```bash +# Make sure the CUDA version > 12.0. +pip install -r requirements.txt +pip install flash-attn --no-build-isolation +``` + +### Acknowleage +Thank you to the [LiveCodeBench](https://livecodebench.github.io/leaderboard.html) team for their contributions to the open-source community. \ No newline at end of file diff --git a/verl/utils/reward_score/prime_code/testing_util.py b/verl/utils/reward_score/prime_code/testing_util.py index da57e44de..2f2232518 100644 --- a/verl/utils/reward_score/prime_code/testing_util.py +++ b/verl/utils/reward_score/prime_code/testing_util.py @@ -48,6 +48,7 @@ class CODE_TYPE(Enum): call_based = 0 standard_input = 1 + # used to capture stdout as a list # from https://stackoverflow.com/a/16571630/6416660 # alternative use redirect_stdout() from contextlib @@ -222,7 +223,9 @@ def run_test(in_outs, test=None, debug=False, timeout=15): in_outs["outputs"][index] = json.loads(in_outs["outputs"][index]) truncate_line_size = 300 // (raw_inputs.count("\n") + 1) - raw_inputs = "\n".join([truncatefn(line, truncate_line_size) for line in raw_inputs.strip().split("\n")]) + raw_inputs = "\n".join( + [truncatefn(line, truncate_line_size) for line in raw_inputs.strip().split("\n")] + ) raw_outputs = truncatefn(raw_outputs, 200) else: raw_inputs = truncatefn(raw_inputs) @@ -245,7 +248,10 @@ def run_test(in_outs, test=None, debug=False, timeout=15): pass if debug: - print(f"time: {datetime.now().time()} testing index = {index} inputs = {inputs}, {type(inputs)}. type = {which_type}") + print( + f"time: {datetime.now().time()} testing index = {index} inputs = {inputs}, {type(inputs)}. " + f"type = {which_type}" + ) if which_type == CODE_TYPE.call_based: # Call-based signal.alarm(timeout) faulthandler.enable() @@ -295,7 +301,10 @@ def run_test(in_outs, test=None, debug=False, timeout=15): faulthandler.disable() signal.alarm(0) if debug: - print(f"outputs = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}") + print( + f"outputs = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, " + f"{type(inputs)}, {output == [in_outs['outputs'][index]]}" + ) elif which_type == CODE_TYPE.standard_input: # Standard input faulthandler.enable() passed = False @@ -330,9 +339,16 @@ def run_test(in_outs, test=None, debug=False, timeout=15): if debug: nl = "\n" if not isinstance(inputs, list): - print(f"not passed output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl, ' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}") + print( + f"not passed output = {output}, test outputs = {in_outs['outputs'][index]}, " + f"inputs = {inputs.replace(nl, ' new-line ')}, {type(inputs)}, " + f"{output == [in_outs['outputs'][index]]}" + ) else: - print(f"not passed output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}") + print( + f"not passed output = {output}, test outputs = {in_outs['outputs'][index]}, " + f"inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}" + ) continue if passed and debug: @@ -367,7 +383,9 @@ def run_test(in_outs, test=None, debug=False, timeout=15): if isinstance(in_outs["outputs"][index], list): for tmp_index, i in enumerate(in_outs["outputs"][index]): in_outs["outputs"][index][tmp_index] = i.split("\n") - in_outs["outputs"][index][tmp_index] = [x.strip() for x in in_outs["outputs"][index][tmp_index] if x] + in_outs["outputs"][index][tmp_index] = [ + x.strip() for x in in_outs["outputs"][index][tmp_index] if x + ] else: in_outs["outputs"][index] = in_outs["outputs"][index].split("\n") in_outs["outputs"][index] = list(filter(len, in_outs["outputs"][index])) @@ -393,9 +411,16 @@ def run_test(in_outs, test=None, debug=False, timeout=15): if debug: nl = "\n" if not isinstance(inputs, list): - print(f"@1 output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl, ' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]} {tmp_result=}") + print( + f"@1 output = {output}, test outputs = {in_outs['outputs'][index]}, " + f"inputs = {inputs.replace(nl, ' new-line ')}, {type(inputs)}, " + f"{output == [in_outs['outputs'][index]]} {tmp_result=}" + ) else: - print(f"@1 output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]} {tmp_result=}") + print( + f"@1 output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, " + f"{type(inputs)}, {output == [in_outs['outputs'][index]]} {tmp_result=}" + ) if debug: print(f"{tmp_result=} @a") @@ -413,13 +438,23 @@ def run_test(in_outs, test=None, debug=False, timeout=15): print(f"{tmp_result=} @b") try: - all_ints = all(combined_int_check(e1) and combined_int_check(e2) for e1, e2 in zip(output, in_outs["outputs"][index])) + all_ints = all( + combined_int_check(e1) and combined_int_check(e2) + for e1, e2 in zip(output, in_outs["outputs"][index], strict=True) + ) if not all_ints: if debug: - print([combined_int_check(e1) and combined_int_check(e2) for e1, e2 in zip(output, in_outs["outputs"][index])]) + print( + [ + combined_int_check(e1) and combined_int_check(e2) + for e1, e2 in zip(output, in_outs["outputs"][index], strict=True) + ] + ) output_float = [float(e) for e in output] gt_float = [float(e) for e in in_outs["outputs"][index]] - tmp_result = tmp_result or ((len(output_float) == len(gt_float)) and np.allclose(output_float, gt_float)) + tmp_result = tmp_result or ( + (len(output_float) == len(gt_float)) and np.allclose(output_float, gt_float) + ) except Exception: pass @@ -428,11 +463,16 @@ def run_test(in_outs, test=None, debug=False, timeout=15): try: if isinstance(output[0], list): - all_ints = all(combined_int_check(e1) and combined_int_check(e2) for e1, e2 in zip(output[0], in_outs["outputs"][index])) + all_ints = all( + combined_int_check(e1) and combined_int_check(e2) + for e1, e2 in zip(output[0], in_outs["outputs"][index], strict=True) + ) if not all_ints: output_float = [float(e) for e in output[0]] gt_float = [float(e) for e in in_outs["outputs"][index][0]] - tmp_result = tmp_result or ((len(output_float) == len(gt_float)) and np.allclose(output_float, gt_float)) + tmp_result = tmp_result or ( + (len(output_float) == len(gt_float)) and np.allclose(output_float, gt_float) + ) except Exception: pass @@ -497,9 +537,16 @@ def run_test(in_outs, test=None, debug=False, timeout=15): if debug: nl = "\n" if not isinstance(inputs, list): - print(f"@2 output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl, ' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}") + print( + f"@2 output = {output}, test outputs = {in_outs['outputs'][index]}, " + f"inputs = {inputs.replace(nl, ' new-line ')}, {type(inputs)}, " + f"{output == [in_outs['outputs'][index]]}" + ) else: - print(f"@2 output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}") + print( + f"@2 output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, " + f"{type(inputs)}, {output == [in_outs['outputs'][index]]}" + ) print(f"results = {results}") @@ -630,3 +677,7 @@ def reliability_guard(maximum_memory_bytes=None): sys.modules["resource"] = None sys.modules["psutil"] = None sys.modules["tkinter"] = None + + # Disable some built-in functions that can be destructive + for mod in ["subprocess", "ctypes"]: + sys.modules[mod] = None diff --git a/verl/utils/reward_score/prime_math/__init__.py b/verl/utils/reward_score/prime_math/__init__.py index 270d7ab8e..8d9d273e3 100644 --- a/verl/utils/reward_score/prime_math/__init__.py +++ b/verl/utils/reward_score/prime_math/__init__.py @@ -21,7 +21,6 @@ import contextlib import math -import os import re import sympy @@ -41,6 +40,7 @@ BAD_REGEXES = ["\^[0-9]+\^", "\^[0-9][0-9]+"] TUPLE_CHARS = "()[]" + def _sympy_parse(expr: str): """Parses an expression with sympy.""" py_expr = expr.replace("^", "**") @@ -231,7 +231,12 @@ def split_tuple(expr: str): expr = _strip_properly_formatted_commas(expr) if len(expr) == 0: return [] - if len(expr) > 2 and expr[0] in TUPLE_CHARS and expr[-1] in TUPLE_CHARS and all([ch not in expr[1:-1] for ch in TUPLE_CHARS]): + if ( + len(expr) > 2 + and expr[0] in TUPLE_CHARS + and expr[-1] in TUPLE_CHARS + and all([ch not in expr[1:-1] for ch in TUPLE_CHARS]) + ): elems = [elem.strip() for elem in expr[1:-1].split(",")] else: elems = [expr] @@ -270,16 +275,21 @@ def grade_answer(given_answer: str, ground_truth: str) -> bool: ground_truth_elems = split_tuple(ground_truth_normalized) given_elems = split_tuple(given_normalized) - if len(ground_truth_elems) > 1 and (ground_truth_normalized[0] != given_normalized[0] or ground_truth_normalized[-1] != given_normalized[-1]) or len(ground_truth_elems) != len(given_elems): + if ( + len(ground_truth_elems) > 1 + and (ground_truth_normalized[0] != given_normalized[0] or ground_truth_normalized[-1] != given_normalized[-1]) + or len(ground_truth_elems) != len(given_elems) + ): is_correct = False else: - for ground_truth_elem, given_elem in zip(ground_truth_elems, given_elems): + for ground_truth_elem, given_elem in zip(ground_truth_elems, given_elems, strict=True): if _is_frac(ground_truth_elem) and _is_frac(given_elem): # if fractions aren't reduced, then shouldn't be marked as correct # so, we don't want to allow sympy.simplify in this case is_correct = ground_truth_elem == given_elem elif _str_is_int(ground_truth_elem) != _str_is_int(given_elem): - # if the ground truth answer is an integer, we require the given answer to be a strict match (no sympy.simplify) + # if the ground truth answer is an integer, we require the given answer to be a strict match + # (no sympy.simplify) is_correct = False else: try: diff --git a/verl/utils/reward_score/prime_math/grader.py b/verl/utils/reward_score/prime_math/grader.py index 0f6ccb072..72bb749f2 100644 --- a/verl/utils/reward_score/prime_math/grader.py +++ b/verl/utils/reward_score/prime_math/grader.py @@ -96,7 +96,6 @@ import math import re from math import isclose -from typing import Union # sympy related from sympy import N, simplify @@ -125,7 +124,9 @@ def normalize(answer, pi) -> str: return answer[1:] # checking if answer is % or \\% and removing % - if isinstance(answer, str) and (bool(re.match(r"^\d+(\.\d+)?%$", answer)) or bool(re.match(r"^\d+(\.\d+)?\\%$", answer))): + if isinstance(answer, str) and ( + bool(re.match(r"^\d+(\.\d+)?%$", answer)) or bool(re.match(r"^\d+(\.\d+)?\\%$", answer)) + ): return answer.replace("\\%", "").replace("%", "") # handle base @@ -171,8 +172,8 @@ def handle_pi(string, pi): def math_equal( - prediction: Union[bool, float, str], - reference: Union[float, str], + prediction: bool | float | str, + reference: float | str, include_percentage: bool = True, tolerance: float = 1e-4, timeout: float = 10.0, @@ -224,7 +225,9 @@ def math_equal( prediction = format_intervals(prediction) pred_str, ref_str = prediction, reference - if (prediction.startswith("[") and prediction.endswith("]") and not reference.startswith("(")) or (prediction.startswith("(") and prediction.endswith(")") and not reference.startswith("[")): + if (prediction.startswith("[") and prediction.endswith("]") and not reference.startswith("(")) or ( + prediction.startswith("(") and prediction.endswith(")") and not reference.startswith("[") + ): pred_str = pred_str.strip("[]()") ref_str = ref_str.strip("[]()") for s in ["{", "}", "(", ")"]: @@ -234,10 +237,22 @@ def math_equal( return True ## [a, b] vs. [c, d], return a==c and b==d - if prediction and reference and prediction[0] in "([" and prediction[-1] in ")]" and prediction[0] == reference[0] and prediction[-1] == reference[-1]: + if ( + prediction + and reference + and prediction[0] in "([" + and prediction[-1] in ")]" + and prediction[0] == reference[0] + and prediction[-1] == reference[-1] + ): pred_parts = prediction[1:-1].split(",") ref_parts = reference[1:-1].split(",") - if len(pred_parts) == len(ref_parts) and all([math_equal(pred_pt, ref_pt, include_percentage, tolerance) for pred_pt, ref_pt in zip(pred_parts, ref_parts)]): + if len(pred_parts) == len(ref_parts) and all( + [ + math_equal(pred_pt, ref_pt, include_percentage, tolerance) + for pred_pt, ref_pt in zip(pred_parts, ref_parts, strict=True) + ] + ): return True if "," in prediction and "," in reference: @@ -245,13 +260,25 @@ def math_equal( ref_parts = [item.strip() for item in reference.split(",")] if len(pred_parts) == len(ref_parts): - return bool(all([math_equal(pred_parts[i], ref_parts[i], include_percentage, tolerance) for i in range(len(pred_parts))])) + return bool( + all( + [ + math_equal(pred_parts[i], ref_parts[i], include_percentage, tolerance) + for i in range(len(pred_parts)) + ] + ) + ) # if we have point == tuple of values if prediction.startswith("Point") and reference[0] == "(" and reference[-1] == ")": pred_parts = prediction[prediction.find("(") + 1 : -1].split(",") ref_parts = reference[1:-1].split(",") - if len(pred_parts) == len(ref_parts) and all([math_equal(pred_pt, ref_pt, include_percentage, tolerance) for pred_pt, ref_pt in zip(pred_parts, ref_parts)]): + if len(pred_parts) == len(ref_parts) and all( + [ + math_equal(pred_pt, ref_pt, include_percentage, tolerance) + for pred_pt, ref_pt in zip(pred_parts, ref_parts, strict=False) + ] + ): return True # if reference is a matrix @@ -259,7 +286,12 @@ def math_equal( try: pred_matrix = parse_expr(prediction) ref_matrix_items = reference.split()[1:-1:2] - if len(pred_matrix) == len(ref_matrix_items) and all([math_equal(pred, ref, include_percentage, tolerance) for ref, pred in zip(ref_matrix_items, pred_matrix)]): + if len(pred_matrix) == len(ref_matrix_items) and all( + [ + math_equal(pred, ref, include_percentage, tolerance) + for ref, pred in zip(ref_matrix_items, pred_matrix, strict=False) + ] + ): return True except Exception: pass @@ -268,10 +300,20 @@ def math_equal( try: pred_matrix = eval(prediction) # ref_matrix_items = reference.split()[1:-1:2] - ref_matrix_items = reference.lstrip("\\begin{pmatrix}").lstrip("\begin{pmatrix}").rstrip("\\end{pmatrix}").rstrip("\\end{pmatrix}") # noqa: B005 + ref_matrix_items = ( + reference.lstrip("\\begin{pmatrix}") # noqa: B005 + .lstrip("\begin{pmatrix}") + .rstrip("\\end{pmatrix}") + .rstrip("\\end{pmatrix}") + ) # noqa: B005 ref_matrix_items = ref_matrix_items.split("\\") ref_matrix_items = [row.split("&") if "&" in row else row for row in ref_matrix_items] - if len(pred_matrix) == len(ref_matrix_items) and all([math_equal(pred, ref, include_percentage, tolerance) for ref, pred in zip(ref_matrix_items, pred_matrix)]): + if len(pred_matrix) == len(ref_matrix_items) and all( + [ + math_equal(pred, ref, include_percentage, tolerance) + for ref, pred in zip(ref_matrix_items, pred_matrix, strict=False) + ] + ): return True except Exception: pass @@ -300,7 +342,7 @@ def _parse(s): if simplify(a - b) == 0: return True except TimeoutError: - print(f"Simplification timed out for {a} - {b}") + print(f"Simplification timed out for {a} - {b}") pass except Exception: pass @@ -316,6 +358,7 @@ def _parse(s): pass return False + def format_intervals(prediction): patterns = { "Interval(": r"^Interval\((.*)\)$", diff --git a/verl/utils/reward_score/prime_math/math_normalize.py b/verl/utils/reward_score/prime_math/math_normalize.py index 2036ad142..2ff2cc0a8 100644 --- a/verl/utils/reward_score/prime_math/math_normalize.py +++ b/verl/utils/reward_score/prime_math/math_normalize.py @@ -178,7 +178,8 @@ def _strip_string(string): # remove spaces string = string.replace(" ", "") - # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} + # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). + # Also does a/b --> \\frac{a}{b} string = _fix_fracs(string) # manually change 0.5 --> \frac{1}{2} diff --git a/verl/utils/reward_score/reasoning_gym/__init__.py b/verl/utils/reward_score/reasoning_gym/__init__.py new file mode 100644 index 000000000..c463aa878 --- /dev/null +++ b/verl/utils/reward_score/reasoning_gym/__init__.py @@ -0,0 +1,219 @@ +import reasoning_gym +from reasoning_gym.utils import extract_answer +import json +import re + +def compute_score(solution_str, ground_truth, extra_info=None, item=None): + """ + Compute the reward score for reasoning gym tasks. + + Args: + solution_str (str): The model's response/solution string + ground_truth (str or dict): The ground truth answer or entry dict + extra_info (str or dict, optional): Should contain 'task' (as JSON string or dict) + item (dict, optional): The full data item, for fallback + + Returns: + dict: {"score": float, "acc": float} + """ + task = None + entry = None + + # 1. Parse extra_info + extra_info_dict = {} + metadata = None + + if extra_info: + if isinstance(extra_info, str): + try: + extra_info_dict = json.loads(extra_info) + except Exception: + extra_info_dict = {} + else: + extra_info_dict = extra_info + + # Get task first + task = extra_info_dict.get("task") + entry = extra_info_dict.get("entry") + + # Handle metadata field if present + if "metadata" in extra_info_dict: + if isinstance(extra_info_dict["metadata"], str): + try: + metadata = json.loads(extra_info_dict["metadata"]) + except Exception: + metadata = {} + elif isinstance(extra_info_dict["metadata"], dict): + metadata = extra_info_dict["metadata"] + + # 2. Try to get from item (fallback - this is rarely used in actual training) + if not task and item and isinstance(item, dict): + task = item.get("ability") + + # 3. Try to get from ground_truth + if not task and isinstance(ground_truth, dict): + task = ground_truth.get("task") + entry = ground_truth + + if not task: + raise ValueError("task must be provided in extra_info, item, or ground_truth dict.") + + # 4. Get scoring function + scorer = reasoning_gym.get_score_answer_fn(task) + + # 5. Get entry + if entry is None: + entry = {"answer": ground_truth} + + # Build metadata field, prioritizing extra_info metadata + if isinstance(entry, dict): + if "metadata" not in entry or not isinstance(entry["metadata"], dict): + entry["metadata"] = {} + if metadata is not None: + entry["metadata"].update(metadata) + if task is not None: + entry["metadata"]["task"] = task + entry["metadata"]["solution_str"] = solution_str + entry["metadata"]["ground_truth"] = ground_truth + if extra_info is not None: + entry["metadata"]["extra_info"] = extra_info + if item is not None: + entry["metadata"]["item"] = item + + # 6. Extract clean answer from solution_str + # Step 1: First extract the answer part from our custom format (remove blocks, etc.) + answer_part = extract_answer_from_solution(solution_str) + + # Step 2: Pass the extracted answer part to official reasoning gym function for standardization + official_answer = extract_answer(answer_part) + + # Step 3: Use official answer if available, otherwise use our extracted answer + if official_answer is not None: + clean_answer = official_answer + else: + # Fallback: if official function can't extract, use our extracted answer + clean_answer = answer_part + + # 7. Scoring with task-specific fixes + debug_log_path = "reasoning_gym_debug.log" + try: + with open(debug_log_path, "a", encoding="utf-8") as f: + f.write("[DEBUG] solution_str: {}\n".format(solution_str)) + f.write("[DEBUG] clean_answer: {}\n".format(clean_answer)) + f.write("[DEBUG] ground_truth: {}\n".format(ground_truth)) + f.write("[DEBUG] task: {}\n".format(task)) + f.write("[DEBUG] metadata: {}\n".format(json.dumps(entry.get("metadata", {}), ensure_ascii=False, indent=2))) + + # Get raw score from reasoning_gym using clean answer + raw_score = scorer(answer=clean_answer, entry=entry) + + # Apply task-specific corrections for known issues + corrected_score = apply_task_specific_corrections(task, solution_str, ground_truth, raw_score) + + f.write("[DEBUG] raw_score: {}\n".format(raw_score)) + f.write("[DEBUG] corrected_score: {}\n".format(corrected_score)) + + return {"score": float(corrected_score), "acc": float(corrected_score)} + except Exception as e: + with open(debug_log_path, "a", encoding="utf-8") as f: + f.write(f"Error in reasoning gym scoring: {e}\n") + return {"score": 0.0, "acc": 0.0} + + +def apply_task_specific_corrections(task, solution_str, ground_truth, raw_score): + """ + Apply corrections for known issues in specific reasoning_gym tasks. + + Args: + task (str): The task name + solution_str (str): The model's solution + ground_truth (str): The ground truth answer + raw_score (float): The raw score from reasoning_gym + + Returns: + float: The corrected score + """ + + # Fix for puzzle24: Convert partial credit (0.01) to 0.0 for wrong answers + if task == "puzzle24": + if raw_score == 0.01: + # Only give 0.01 if the solution actually attempts the format but is wrong + # Otherwise give 0.0 for completely invalid answers + if is_valid_puzzle24_format(solution_str): + return 0.01 # Keep partial credit for valid format but wrong calculation + else: + return 0.0 # No credit for invalid format + return raw_score + + # Fix for game_of_life_halting: Implement proper scoring since reasoning_gym seems broken + elif task == "game_of_life_halting": + # The reasoning_gym library appears to have a bug for this task + # Implement simple exact string matching as fallback + if solution_str.strip().lower() == ground_truth.strip().lower(): + return 1.0 + else: + return 0.0 + + # For all other tasks, return the raw score + return raw_score + + +def extract_answer_from_solution(solution_str): + """ + Extract the final answer from a solution string that may contain and tags. + + Args: + solution_str (str): The full solution string from the model + + Returns: + str: The extracted answer, or the original string if no answer tags found + """ + # Try to extract from tags first + # Use a more restrictive pattern that doesn't match across multiple tags + answer_pattern = r'\s*([^<]*?)\s*' + matches = re.findall(answer_pattern, solution_str, re.IGNORECASE) + + if matches: + # Return the last answer if multiple found + return matches[-1].strip() + + # If no tags, try to extract everything after the last tag + if '' in solution_str: + parts = solution_str.split('
') + if len(parts) > 1: + # Get everything after the last + answer = parts[-1].strip() + # Remove any remaining HTML-like tags + answer = re.sub(r'<[^>]+>', '', answer).strip() + if answer: + return answer + + # If no structured format found, return the original string + # This handles cases where the model generates direct answers without tags + return solution_str.strip() + + +def is_valid_puzzle24_format(solution_str): + """ + Check if a solution string follows a valid puzzle24 format. + Valid formats include mathematical expressions with +, -, *, /, (, ) and numbers. + """ + import re + + # Remove whitespace + solution = solution_str.strip() + + # Check if it contains only valid characters for mathematical expressions + valid_chars = re.match(r'^[0-9+\-*/.() ]+$', solution) + if not valid_chars: + return False + + # Check if it contains at least some mathematical operators + has_operators = any(op in solution for op in ['+', '-', '*', '/']) + + # Check if it contains numbers + has_numbers = re.search(r'\d', solution) + + return bool(has_operators and has_numbers) + + diff --git a/verl/utils/reward_score/reasoning_gym/test.py b/verl/utils/reward_score/reasoning_gym/test.py new file mode 100644 index 000000000..b43267707 --- /dev/null +++ b/verl/utils/reward_score/reasoning_gym/test.py @@ -0,0 +1,7 @@ +import reasoning_gym +data = reasoning_gym.create_dataset('kakurasu', size=10, seed=42) +for i, x in enumerate(data): + print(f'{i}: q="{x['question']}", a="{x['answer']}"') + print('metadata:', x['metadata']) + # use the dataset's `score_answer` method for algorithmic verification + assert data.score_answer(answer=x['answer'], entry=x) == 1.0 \ No newline at end of file diff --git a/verl/utils/reward_score/sandbox_fusion/__init__.py b/verl/utils/reward_score/sandbox_fusion/__init__.py index 194a41833..cd18498dd 100644 --- a/verl/utils/reward_score/sandbox_fusion/__init__.py +++ b/verl/utils/reward_score/sandbox_fusion/__init__.py @@ -25,7 +25,9 @@ logger = logging.getLogger(__name__) -def compute_score(sandbox_fusion_url, concurrent_semaphore, completion, test_cases, continuous=False, timeout=10): +def compute_score( + sandbox_fusion_url, concurrent_semaphore, memory_limit_mb, completion, test_cases, continuous=False, timeout=10 +): """ Computes the code score using the remote sandbox API. @@ -74,7 +76,14 @@ def compute_score(sandbox_fusion_url, concurrent_semaphore, completion, test_cas # Note: The return value of check_correctness might need adaptation here # Assume check_correctness returns (results_list, metadata_list) # results_list contains True, False, or error codes (-1, -2, -3, etc.) - res_list, metadata_list = check_correctness(sandbox_fusion_url=sandbox_fusion_url, in_outs=test_cases, generation=solution, timeout=timeout, concurrent_semaphore=concurrent_semaphore) + res_list, metadata_list = check_correctness( + sandbox_fusion_url=sandbox_fusion_url, + in_outs=test_cases, + generation=solution, + timeout=timeout, + concurrent_semaphore=concurrent_semaphore, + memory_limit_mb=memory_limit_mb, + ) # Calculate score if not res_list: # If there are no results (e.g., invalid input) diff --git a/verl/utils/reward_score/sandbox_fusion/utils.py b/verl/utils/reward_score/sandbox_fusion/utils.py index d054b7d27..6d395ce5c 100644 --- a/verl/utils/reward_score/sandbox_fusion/utils.py +++ b/verl/utils/reward_score/sandbox_fusion/utils.py @@ -19,7 +19,7 @@ import time import traceback import uuid -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Optional import requests @@ -31,10 +31,48 @@ logger = logging.getLogger(__name__) # Define supported languages list (optional, for documentation or validation) -SUPPORTED_LANGUAGES = ["python", "cpp", "nodejs", "go", "go_test", "java", "php", "csharp", "bash", "typescript", "sql", "rust", "cuda", "lua", "R", "perl", "D_ut", "ruby", "scala", "julia", "pytest", "junit", "kotlin_script", "jest", "verilog", "python_gpu", "lean", "swift", "racket"] - - -def call_sandbox_api(sandbox_fusion_url: str, code: str, stdin: str, compile_timeout: int, run_timeout: int, language: str = "python") -> Tuple[Optional[Dict[str, Any]], Optional[str]]: # <-- Remove request_id parameter +SUPPORTED_LANGUAGES = [ + "python", + "cpp", + "nodejs", + "go", + "go_test", + "java", + "php", + "csharp", + "bash", + "typescript", + "sql", + "rust", + "cuda", + "lua", + "R", + "perl", + "D_ut", + "ruby", + "scala", + "julia", + "pytest", + "junit", + "kotlin_script", + "jest", + "verilog", + "python_gpu", + "lean", + "swift", + "racket", +] + + +def call_sandbox_api( + sandbox_fusion_url: str, + code: str, + stdin: Optional[str], + compile_timeout: int, + run_timeout: int, + memory_limit_mb: int, + language: str = "python", +) -> tuple[Optional[dict[str, Any]], Optional[str]]: # <-- Remove request_id parameter """ Calls the remote sandbox API to execute code with retry logic for Gateway Timeout, using increasing delay between retries. Logs internal calls with a unique ID. @@ -66,6 +104,7 @@ def call_sandbox_api(sandbox_fusion_url: str, code: str, stdin: str, compile_tim "run_timeout": run_timeout, "code": code, "stdin": stdin, + "memory_limit_MB": memory_limit_mb, "language": language, # Use the passed language parameter "files": {}, "fetch_files": [], @@ -79,7 +118,9 @@ def call_sandbox_api(sandbox_fusion_url: str, code: str, stdin: str, compile_tim for attempt in range(MAX_RETRIES): try: - logger.info(f"{log_prefix}Attempt {attempt + 1}/{MAX_RETRIES}: Calling sandbox API at {sandbox_fusion_url}") # <-- Use internal log_prefix + logger.info( + f"{log_prefix}Attempt {attempt + 1}/{MAX_RETRIES}: Calling sandbox API at {sandbox_fusion_url}" + ) # <-- Use internal log_prefix response = requests.post( sandbox_fusion_url, headers=headers, @@ -89,7 +130,10 @@ def call_sandbox_api(sandbox_fusion_url: str, code: str, stdin: str, compile_tim # Check for Gateway Timeout (504) specifically for retrying if response.status_code == 504: - last_error = f"{log_prefix}API Request Error: Gateway Timeout (504) on attempt {attempt + 1}/{MAX_RETRIES}" # <-- Use internal log_prefix + last_error = ( + f"{log_prefix}API Request Error: Gateway Timeout (504) on attempt " + f"{attempt + 1}/{MAX_RETRIES}" + ) # <-- Use internal log_prefix logger.warning(last_error) if attempt < MAX_RETRIES - 1: # Don't sleep after the last attempt # Calculate increasing delay (e.g., 1s, 2s, 4s, ...) or (1s, 2s, 3s, ...) @@ -104,7 +148,9 @@ def call_sandbox_api(sandbox_fusion_url: str, code: str, stdin: str, compile_tim response.raise_for_status() # If successful (status code 2xx) - logger.info(f"{log_prefix}Sandbox API call successful on attempt {attempt + 1}") # <-- Use internal log_prefix + logger.info( + f"{log_prefix}Sandbox API call successful on attempt {attempt + 1}" + ) # <-- Use internal log_prefix return response.json(), None except requests.exceptions.RequestException as e: @@ -125,7 +171,18 @@ def call_sandbox_api(sandbox_fusion_url: str, code: str, stdin: str, compile_tim return None, last_error.replace(log_prefix, "API Call Failed: ") if last_error else "API Call Failed after retries" -def _process_single_case(case_index: int, stdin_data: Any, expected_output: Any, sandbox_fusion_url: str, generation: str, timeout: int, language: str, concurrent_semaphore: Optional[threading.Semaphore] = None, fn_name: Optional[str] = None) -> Tuple[int, Dict[str, Any]]: +def _process_single_case( + case_index: int, + stdin_data: Any, + expected_output: Any, + sandbox_fusion_url: str, + generation: str, + timeout: int, + memory_limit_mb: int, + language: str, + concurrent_semaphore: Optional[threading.Semaphore] = None, + fn_name: Optional[str] = None, +) -> tuple[int, dict[str, Any]]: """Helper function to process a single test case.""" api_response = None error_msg = None @@ -186,7 +243,8 @@ def _execute_user_function(): try: _args = [json.loads(line) for line in _raw_input_str.split('\\n')] except json.JSONDecodeError as _je: - sys.stderr.write(f"WrapperError: Invalid JSON input for '{{_SANDBOX_FN_NAME}}': {{_je}}\\nInput was: {{_raw_input_str[:200]}}\\n") + sys.stderr.write(f"WrapperError: Invalid JSON input for '{{_SANDBOX_FN_NAME}}': {{_je}}\\nInput was: " + f"{{_raw_input_str[:200]}}\\n") return None, True # result, error_occurred # --- Function Location and Execution --- @@ -201,9 +259,9 @@ def _execute_user_function(): # Attempt to instantiate and get method. # Errors (e.g., Solution not a class, instantiation fails, method missing) # will be caught by the broad except block below. - _solution_instance = _Solution_class() + _solution_instance = _Solution_class() _target_callable = getattr(_solution_instance, _SANDBOX_FN_NAME) - + if not _target_callable: sys.stderr.write(f"WrapperError: Function or method '{{_SANDBOX_FN_NAME}}' not found.\\n") return None, True # result, error_occurred @@ -228,19 +286,36 @@ def _execute_user_function(): print(str(_result)) # Optional: To explicitly exit with an error code if the sandbox relies on it # else: - # sys.exit(1) + # sys.exit(1) """ current_generation_code = wrapper_code + stdin = None if stdin_data is None else str(stdin_data) try: if concurrent_semaphore: # logger.debug(f"Case {case_index + 1}: Attempting to acquire semaphore.") with concurrent_semaphore: # logger.debug(f"Case {case_index + 1}: Semaphore acquired. Calling API.") - api_response, error_msg = call_sandbox_api(sandbox_fusion_url=sandbox_fusion_url, code=current_generation_code, stdin=str(stdin_data), compile_timeout=timeout, run_timeout=timeout, language=language) + api_response, error_msg = call_sandbox_api( + sandbox_fusion_url=sandbox_fusion_url, + code=current_generation_code, + stdin=stdin, + compile_timeout=timeout, + run_timeout=timeout, + memory_limit_mb=memory_limit_mb, + language=language, + ) # logger.debug(f"Case {case_index + 1}: Semaphore released.") else: - api_response, error_msg = call_sandbox_api(sandbox_fusion_url=sandbox_fusion_url, code=current_generation_code, stdin=str(stdin_data), compile_timeout=timeout, run_timeout=timeout, language=language) + api_response, error_msg = call_sandbox_api( + sandbox_fusion_url=sandbox_fusion_url, + code=current_generation_code, + stdin=stdin, + compile_timeout=timeout, + run_timeout=timeout, + memory_limit_mb=memory_limit_mb, + language=language, + ) except Exception as e: error_msg = f"API Request Exception during check_correctness for case {case_index + 1}: {e}" logger.error(f"Case {case_index + 1}: {error_msg}") @@ -248,7 +323,7 @@ def _execute_user_function(): metadata = { "case_index": case_index, - "input": str(stdin_data), + "input": stdin, "expected_output": str(expected_output), "api_request_error": error_msg, "api_response": None, @@ -272,7 +347,7 @@ def _execute_user_function(): # Log code and input only on error for brevity generation_to_log = generation[:200] + "..." if len(generation) > 200 else generation logger.error(f"Case {case_index}: code: {generation_to_log}") - logger.error(f"Case {case_index}: input: {str(stdin_data)}") + logger.error(f"Case {case_index}: input: {stdin}") elif api_response: # --- Add debug logging --- logger.debug(f"Case {case_index}: API Response: {api_response}") @@ -308,7 +383,10 @@ def _execute_user_function(): logger.debug(f"Run Result: {run_result}") # --- Check the logic here --- # Compile failed or timed out - is_compile_error = compile_result and (metadata["compile_status"] in ["Error", "TimeLimitExceeded"] or (metadata["compile_status"] == "Finished" and compile_result.get("return_code") != 0)) + is_compile_error = compile_result and ( + metadata["compile_status"] in ["Error", "TimeLimitExceeded"] + or (metadata["compile_status"] == "Finished" and compile_result.get("return_code") != 0) + ) if is_compile_error: # Differentiate between compile_error and compile_timeout based on specific status if metadata["compile_status"] == "TimeLimitExceeded": @@ -319,7 +397,11 @@ def _execute_user_function(): # Run failed or timed out elif run_result: # Modified condition: Check for TimeLimitExceeded OR (Finished with non-zero exit code) OR Error status - is_runtime_error = metadata["run_status"] == "TimeLimitExceeded" or metadata["run_status"] == "Error" or (metadata["run_status"] == "Finished" and run_result.get("return_code") != 0) + is_runtime_error = ( + metadata["run_status"] == "TimeLimitExceeded" + or metadata["run_status"] == "Error" + or (metadata["run_status"] == "Finished" and run_result.get("return_code") != 0) + ) if is_runtime_error: if metadata["run_status"] == "TimeLimitExceeded": metadata["status"] = "timeout" # Runtime timeout @@ -364,7 +446,15 @@ def _execute_user_function(): return result_status, metadata -def check_correctness(sandbox_fusion_url: str, in_outs: Optional[dict], generation: str, timeout: int = DEFAULT_TIMEOUT, language: str = "python", concurrent_semaphore: Optional[threading.Semaphore] = None) -> Tuple[List[Any], List[Dict[str, Any]]]: +def check_correctness( + sandbox_fusion_url: str, + in_outs: Optional[dict], + generation: str, + timeout: int = DEFAULT_TIMEOUT, + memory_limit_mb: int = 1024, + language: str = "python", + concurrent_semaphore: Optional[threading.Semaphore] = None, +) -> tuple[list[Any], list[dict[str, Any]]]: """ Checks the correctness of code generation using the remote sandbox API, processing test cases concurrently. @@ -411,7 +501,22 @@ def check_correctness(sandbox_fusion_url: str, in_outs: Optional[dict], generati # max_workers is limited by sandbox_fusion_max_concurrent from concurrent_semaphore with concurrent.futures.ThreadPoolExecutor(max_workers=max(32, os.cpu_count() * 5)) as executor: # Submit all tasks, passing the concurrent_semaphore to _process_single_case - future_to_index = {executor.submit(_process_single_case, i, stdin_data, expected_outputs[i], sandbox_fusion_url, generation, timeout, language, concurrent_semaphore, fn_name): i for i, stdin_data in enumerate(inputs)} + future_to_index = { + executor.submit( + _process_single_case, + i, + stdin_data, + expected_outputs[i], + sandbox_fusion_url, + generation, + timeout, + memory_limit_mb, + language, + concurrent_semaphore, + fn_name, + ): i + for i, stdin_data in enumerate(inputs) + } # Process results as they complete for future in concurrent.futures.as_completed(future_to_index): @@ -442,7 +547,9 @@ def check_correctness(sandbox_fusion_url: str, in_outs: Optional[dict], generati # Post-processing for compile errors if first_compile_error_index != -1: - logger.warning(f"Compile error detected in case {first_compile_error_index}. Marking subsequent cases as compile errors.") + logger.warning( + f"Compile error detected in case {first_compile_error_index}. Marking subsequent cases as compile errors." + ) for i in range(first_compile_error_index + 1, num_cases): # Only update if not already processed (though it should be None or have a result) if results[i] != -4: # Avoid overwriting if it somehow already got -4 diff --git a/verl/utils/reward_score/synlogic/__init__.py b/verl/utils/reward_score/synlogic/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/verl/utils/reward_score/synlogic/arrow_maze_verifier.py b/verl/utils/reward_score/synlogic/arrow_maze_verifier.py new file mode 100644 index 000000000..841c70c1c --- /dev/null +++ b/verl/utils/reward_score/synlogic/arrow_maze_verifier.py @@ -0,0 +1,306 @@ +import json +from typing import List, Dict, Tuple +from .verifier import Verifier +from .data import Data +import re + +class ArrowMazeVerifier(Verifier): + """ + 箭头迷宫游戏验证器 + + 验证条件: + 1. 判断answer grid的大小是否和question grid一致 + 2. 判断answer grid中数字格子是否和question grid中数字格子一致 + 3. 判断question grid空格("X")在answer grid中是否被箭头填满 + 4. 判断箭头符号是否合法: + 上(↑)、下(↓)、左(←)、右(→)或对角线方向(↖、↗、↘、↙) + 5. 判断answer grid中非空格("X")和非数字的部分,即预填的箭头,是否和question grid一致 + 6. 迷宫有个隐藏的条件是所有箭头都能被射线箭头串覆盖到 + 7. 每个数字起点发出的射线箭头串总长度等于该数字 + """ + + # 定义合法的箭头符号 + VALID_ARROWS = {"↑", "↓", "←", "→", "↖", "↗", "↘", "↙"} + + # 定义箭头符号和其对应的方向 + ARROWS_DIRECTIONS = { + "↑": (-1, 0), # 上 + "↓": (1, 0), # 下 + "←": (0, -1), # 左 + "→": (0, 1), # 右 + "↖": (-1, -1), # 左上 + "↗": (-1, 1), # 右上 + "↘": (1, 1), # 右下 + "↙": (1, -1) # 左下 + } + + def verify(self, data: Data, test_solution_str: str) -> bool: + + """ + 验证箭头迷宫的答案是否正确 + + @param data: 游戏数据 + @param test_solution_str: 测试答案字符串 (JSON格式的二维数组) + @return: 答案是否正确 + """ + test_answer_str = self.extract_answer(test_solution_str) + if not test_answer_str: + # print("答案为空,验证失败") + return False + + try: + # 解析测试答案 + test_answer = json.loads(test_answer_str) + + # 获取原始迷宫 + question_grid = data.metadata["maze"] + + # 检查答案是否符合要求 + if not self._verify_grid_size(test_answer, question_grid): + # print("答案网格大小与题目不匹配") + return False + + if not self._verify_number_positions(test_answer, question_grid): + # print("答案中数字位置或值与题目不匹配") + return False + + if not self._verify_all_blanks_filled(test_answer, question_grid): + # print("答案中有空格未被填满") + return False + + + if not self._verify_arrow_symbols(test_answer): + # print("答案中包含非法箭头符号") + return False + + + if not self._verify_prefilled_arrows(test_answer, question_grid): + # print("答案中预填箭头与题目不一致") + return False + + if not self._verify_arrow_rays(test_answer): + # print("答案中存在未被射线覆盖的箭头") + return False + + if not self._verify_number_rays(test_answer): + # print("答案中数字的射线箭头串总数不符合要求") + return False + + # 所有验证都通过 + return True + + except Exception as e: + return False + + def _verify_grid_size(self, test_answer: List[List[str]], question_grid: List[List[str]]) -> bool: + """ + 验证答案网格大小是否与题目一致 + + @param test_answer: 测试答案网格 + @param question_grid: 题目网格 + @return: 网格大小是否一致 + """ + if len(test_answer) != len(question_grid): + return False + + for i in range(len(test_answer)): + if len(test_answer[i]) != len(question_grid[i]): + return False + + return True + + def _verify_number_positions(self, test_answer: List[List[str]], question_grid: List[List[str]]) -> bool: + """ + 验证答案中数字位置和值是否与题目一致 + + @param test_answer: 测试答案网格 + @param question_grid: 题目网格 + @return: 数字位置和值是否一致 + """ + for i in range(len(question_grid)): + for j in range(len(question_grid[i])): + if question_grid[i][j].isdigit(): + if test_answer[i][j] != question_grid[i][j]: + return False + return True + + def _verify_all_blanks_filled(self, test_answer: List[List[str]], question_grid: List[List[str]]) -> bool: + """ + 验证所有空格是否都被填满 + + @param test_answer: 测试答案网格 + @param question_grid: 题目网格 + @return: 所有空格是否被填满 + """ + for i in range(len(question_grid)): + for j in range(len(question_grid[i])): + if question_grid[i][j] == "X" and test_answer[i][j] == "X": + return False + return True + + def _verify_arrow_symbols(self, test_answer: List[List[str]]) -> bool: + """ + 验证箭头符号是否合法 + + @param test_answer: 测试答案网格 + @return: 箭头符号是否合法 + """ + for i in range(len(test_answer)): + for j in range(len(test_answer[i])): + cell = test_answer[i][j] + if not cell.isdigit() and cell != "X" and cell not in self.VALID_ARROWS: + return False + return True + + def _verify_prefilled_arrows(self, test_answer: List[List[str]], question_grid: List[List[str]]) -> bool: + """ + 验证预填的箭头是否与题目一致 + + @param test_answer: 测试答案网格 + @param question_grid: 题目网格 + @return: 预填箭头是否一致 + """ + for i in range(len(question_grid)): + for j in range(len(question_grid[i])): + cell = question_grid[i][j] + if not cell.isdigit() and cell != "X": + if test_answer[i][j] != cell: + return False + return True + + def _verify_arrow_rays(self, test_answer: List[List[str]]) -> bool: + """ + 验证所有箭头是否都能被射线箭头串覆盖到 + + @param test_answer: 测试答案网格 + @return: 所有箭头是否都能被射线覆盖 + """ + n = len(test_answer) + m = len(test_answer[0]) if n > 0 else 0 + + # 创建覆盖标记数组 + covered = [[False for _ in range(m)] for _ in range(n)] + + # 标记数字位置为已覆盖 + for i in range(n): + for j in range(m): + if test_answer[i][j].isdigit(): + covered[i][j] = True + + # 从每个数字出发,沿各个方向延伸射线,标记覆盖到的箭头 + for i in range(n): + for j in range(m): + if test_answer[i][j].isdigit(): + # 检查所有方向 + for arrow_symbol, (di, dj) in self.ARROWS_DIRECTIONS.items(): + ni, nj = i + di, j + dj + # 沿该方向延伸,直到边界或非匹配箭头 + while 0 <= ni < n and 0 <= nj < m and test_answer[ni][nj] == arrow_symbol: + covered[ni][nj] = True + ni += di + nj += dj + + # 检查所有箭头是否都被覆盖 + for i in range(n): + for j in range(m): + if test_answer[i][j] in self.VALID_ARROWS and not covered[i][j]: + return False + + return True + + def _verify_number_rays(self, test_answer: List[List[str]]) -> bool: + """ + 验证每个数字起点发出的射线箭头串总长度是否等于该数字 + + @param test_answer: 测试答案网格 + @return: 每个数字的射线箭头串是否符合要求 + """ + n = len(test_answer) + m = len(test_answer[0]) if n > 0 else 0 + + for i in range(n): + for j in range(m): + if test_answer[i][j].isdigit(): + number = int(test_answer[i][j]) + arrow_count = self._count_arrow_rays(test_answer, i, j) + if arrow_count != number: + return False + + return True + + def _count_arrow_rays(self, grid: List[List[str]], i: int, j: int) -> int: + """ + 计算从数字出发的所有射线箭头串中箭头总数 + + @param grid: 网格 + @param i: 数字行索引 + @param j: 数字列索引 + @return: 箭头总数 + """ + n = len(grid) + m = len(grid[0]) if n > 0 else 0 + count = 0 + + # 检查所有方向 + for arrow_symbol, (di, dj) in self.ARROWS_DIRECTIONS.items(): + ni, nj = i + di, j + dj + ray_length = 0 + + # 沿该方向计数连续的相同箭头 + while 0 <= ni < n and 0 <= nj < m and grid[ni][nj] == arrow_symbol: + ray_length += 1 + ni += di + nj += dj + + count += ray_length + + return count + + def extract_answer(self, test_solution: str) -> str: + """ + 从模型的回答中提取答案 + + @param test_solution: 模型的完整回答 + @return: 提取的答案 (JSON格式的二维数组) + """ + if not test_solution: + return "" + # 尝试匹配Python代码块 + import re + code_block_patterns = [ + r'```python\s*\n(.*?\[.*?\].*?)\n```', # 标准Python代码块 + r'```\s*\n(.*?\[.*?\].*?)\n```', # 无语言标记的代码块 + r'```(.*?\[.*?\].*?)```' # 无换行的代码块 + ] + + for pattern in code_block_patterns: + matches = re.findall(pattern, test_solution, re.DOTALL) + if matches: + # 获取最后一个匹配项 + code_block = matches[-1].strip() + try: + # 尝试解析为Python列表 + grid = eval(code_block) + # 验证格式是否为二维数组 + if isinstance(grid, list) and all(isinstance(row, list) for row in grid): + return json.dumps(grid) + except Exception as e: + # print(f"解析代码块失败: {e}") + continue + + # 如果没有找到有效的代码块,尝试直接寻找列表 + list_pattern = r'\[\s*\[.*?\]\s*\]' + matches = re.findall(list_pattern, test_solution, re.DOTALL) + if matches: + try: + # 尝试解析为Python列表 + grid = eval(matches[-1]) + # 验证格式是否为二维数组 + if isinstance(grid, list) and all(isinstance(row, list) for row in grid): + return json.dumps(grid) + except Exception as e: + pass + # print(f"解析列表失败: {e}") + + # 如果上述方法都失败,返回空字符串 + return "" \ No newline at end of file diff --git a/verl/utils/reward_score/synlogic/boolean_expressions_verifier.py b/verl/utils/reward_score/synlogic/boolean_expressions_verifier.py new file mode 100644 index 000000000..8abdb0742 --- /dev/null +++ b/verl/utils/reward_score/synlogic/boolean_expressions_verifier.py @@ -0,0 +1,53 @@ +import re +from .data import Data +from .verifier import Verifier + +class BooleanExpressionsVerifier(Verifier): + """ + 验证器用于布尔表达式游戏的答案是否正确 + """ + def verify(self, data: Data, test_answer: str): + try: + test_answer = self.extract_answer(test_answer) + if test_answer is None: + return False + # 提取所有字母(a-z和A-Z) + test_answer_letters = re.findall(r'[a-zA-Z]', test_answer) + ground_truth_letters = re.findall(r'[a-zA-Z]', data.answer) + test_answer_letters = self.lower(test_answer_letters) + ground_truth_letters = self.lower(ground_truth_letters) + # 转换为集合进行比较 + test_set = set(test_answer_letters) + ground_truth_set = set(ground_truth_letters) + + return test_set == ground_truth_set + except Exception as e: + return False + + def lower(self, answer_list): + return [answer.lower() for answer in answer_list] + def extract_answer(self, answer_str): + # 先找到最后一个\boxed{的位置 + last_box_index = answer_str.rfind("\\boxed{") + + if last_box_index == -1: + return None + + # 从\boxed{开始截取到正确的闭合位置,处理嵌套括号 + start_index = last_box_index + len("\\boxed{") + bracket_stack = 1 # 已经遇到了一个左括号 + end_index = start_index + + while end_index < len(answer_str) and bracket_stack > 0: + if answer_str[end_index] == '{': + bracket_stack += 1 + elif answer_str[end_index] == '}': + bracket_stack -= 1 + end_index += 1 + + if bracket_stack != 0: # 括号不匹配 + return None + + # 提取\boxed{}内的内容 + latex_content = answer_str[start_index:end_index-1].strip() + return latex_content \ No newline at end of file diff --git a/verl/utils/reward_score/synlogic/campsite_verifier.py b/verl/utils/reward_score/synlogic/campsite_verifier.py new file mode 100644 index 000000000..405b5b073 --- /dev/null +++ b/verl/utils/reward_score/synlogic/campsite_verifier.py @@ -0,0 +1,182 @@ +from .data import Data +from .verifier import Verifier +import re +import ast +from typing import List, Set, Tuple, Dict + + +class CampsiteVerifier(Verifier): + """ + Verifier for Campsite game + """ + def verify(self, data: Data, test_solution: str): + try: + test_answer = self.extract_answer(test_solution) + original_grid = data.metadata["grid"] + row_constraints = data.metadata["row_constraints"] + col_constraints = data.metadata["col_constraints"] + n = data.metadata["n"] + m = data.metadata["m"] + + if not test_answer: + return False + + if len(test_answer) != n or any(len(row) != m for row in test_answer): + return False + + if not self._check_trees_unchanged(original_grid, test_answer): + return False + + if not self._check_row_constraints(test_answer, row_constraints): + return False + + if not self._check_col_constraints(test_answer, col_constraints): + return False + + if not self._check_tents_not_adjacent(test_answer): + return False + + if not self._check_tent_tree_matching(test_answer): + return False + + return True + + except Exception as e: + return False + + def _extract_grid(self, test_answer: str) -> List[List[str]]: + """从回答中提取网格""" + grid_pattern = r'\[\s*\[.*?\]\s*\]' + match = re.search(grid_pattern, test_answer, re.DOTALL) + if match: + try: + grid_str = match.group(0) + return ast.literal_eval(grid_str) + except: + pass + + return None + + def _check_trees_unchanged(self, original_grid: List[List[str]], test_answer: List[List[str]]) -> bool: + """检查树木位置是否保持不变""" + for i in range(len(original_grid)): + for j in range(len(original_grid[0])): + if original_grid[i][j] == 'T' and test_answer[i][j] != 'T': + return False + if original_grid[i][j] != 'T' and test_answer[i][j] == 'T': + return False + return True + + def _check_row_constraints(self, grid: List[List[str]], row_constraints: List[int]) -> bool: + """检查行约束条件""" + for i in range(len(grid)): + tent_count = sum(1 for cell in grid[i] if cell == 'C') + if tent_count != row_constraints[i]: + return False + return True + + def _check_col_constraints(self, grid: List[List[str]], col_constraints: List[int]) -> bool: + """检查列约束条件""" + for j in range(len(grid[0])): + tent_count = sum(1 for i in range(len(grid)) if grid[i][j] == 'C') + if tent_count != col_constraints[j]: + return False + return True + + def _check_tents_not_adjacent(self, grid: List[List[str]]) -> bool: + """检查帐篷之间是否相邻(包括对角线)""" + n = len(grid) + m = len(grid[0]) if n > 0 else 0 + + for i in range(n): + for j in range(m): + if grid[i][j] == 'C': + # 检查周围8个方向是否有其他帐篷 + for di in [-1, 0, 1]: + for dj in [-1, 0, 1]: + if di == 0 and dj == 0: + continue + ni, nj = i + di, j + dj + if 0 <= ni < n and 0 <= nj < m and grid[ni][nj] == 'C': + return False + + return True + + def _check_tent_tree_matching(self, grid: List[List[str]]) -> bool: + """ + 检查帐篷与树木的一一匹配关系: + 1. 每个帐篷必须与一棵树正交相邻 + 2. 每棵树只能与一个帐篷匹配 + 3. 每个帐篷只能与一棵树匹配 + 4. 帐篷和树的数量必须相等 + """ + n = len(grid) + m = len(grid[0]) if n > 0 else 0 + + tents = [] + trees = [] + for i in range(n): + for j in range(m): + if grid[i][j] == 'C': + tents.append((i, j)) + elif grid[i][j] == 'T': + trees.append((i, j)) + + if len(tents) != len(trees): + return False + + tent_to_trees = {} + tree_to_tents = {} + + for tent_i, tent_j in tents: + tent_to_trees[(tent_i, tent_j)] = [] + for di, dj in [(0, 1), (1, 0), (0, -1), (-1, 0)]: + tree_i, tree_j = tent_i + di, tent_j + dj + if 0 <= tree_i < n and 0 <= tree_j < m and grid[tree_i][tree_j] == 'T': + tent_to_trees[(tent_i, tent_j)].append((tree_i, tree_j)) + + for tree_i, tree_j in trees: + tree_to_tents[(tree_i, tree_j)] = [] + for di, dj in [(0, 1), (1, 0), (0, -1), (-1, 0)]: + tent_i, tent_j = tree_i + di, tree_j + dj + if 0 <= tent_i < n and 0 <= tent_j < m and grid[tent_i][tent_j] == 'C': + tree_to_tents[(tree_i, tree_j)].append((tent_i, tent_j)) + + for tent in tents: + if not tent_to_trees[tent]: + return False + + tent_matched = {} + tree_matched = {} + + def dfs(tent): + for tree in tent_to_trees[tent]: + if tree in visited: + continue + visited.add(tree) + + if tree not in tree_matched or dfs(tree_matched[tree]): + tent_matched[tent] = tree + tree_matched[tree] = tent + return True + return False + + for tent in tents: + visited = set() + if tent not in tent_matched: + if not dfs(tent): + return False + + return len(tent_matched) == len(tents) and len(tree_matched) == len(trees) + + def extract_answer(self, test_solution: str): + """从模型回答中提取解决方案""" + grid_pattern = r'\[\s*\[.*?\]\s*\]' + match = re.search(grid_pattern, test_solution, re.DOTALL) + if match: + try: + grid_str = match.group(0) + return ast.literal_eval(grid_str) + except: + pass + return "" \ No newline at end of file diff --git a/verl/utils/reward_score/synlogic/data.py b/verl/utils/reward_score/synlogic/data.py new file mode 100644 index 000000000..8f7d2998d --- /dev/null +++ b/verl/utils/reward_score/synlogic/data.py @@ -0,0 +1,51 @@ +import json + +class Data: + """ + Data class for game/corpus + @param question: question of the game/corpus + @param answer: answer of the game/corpus + @param difficulty: difficulty of the game/corpus, from 1 to 10 + """ + def __init__(self, question: str, answer: str, difficulty: int = 1, metadata: dict = None, **kwargs): + self.question = question + self.answer = answer + self.difficulty = difficulty + self.metadata = metadata + self.gpt_response = "" + + def to_json(self): + return { + "question": self.question, + "answer": self.answer, + "difficulty": self.difficulty, + "metadata": self.metadata, + "gpt_response": self.gpt_response + } + + def to_json_str(self): + return json.dumps(self.to_json(), ensure_ascii=False) + + @classmethod + def from_json_str(cls, json_str): + json_data = json.loads(json_str) + return cls(**json_data) + + @classmethod + def from_json_dict(cls, json_dict): + instance = cls(**json_dict) + if 'gpt_response' in json_dict: + instance.gpt_response = json_dict['gpt_response'] + return instance + + @classmethod + def from_jsonl_file(cls, file_path): + data_list = [] + with open(file_path, "r") as f: + for line in f: + json_data = json.loads(line) + instance = cls(**json_data) + if 'gpt_response' in json_data: + instance.gpt_response = json_data['gpt_response'] + data_list.append(instance) + return data_list \ No newline at end of file diff --git a/verl/utils/reward_score/synlogic/dyck_language_errors_verifier.py b/verl/utils/reward_score/synlogic/dyck_language_errors_verifier.py new file mode 100644 index 000000000..a498473e8 --- /dev/null +++ b/verl/utils/reward_score/synlogic/dyck_language_errors_verifier.py @@ -0,0 +1,90 @@ +from .data import Data +from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END +import re + + +class DyckLanguageErrorsVerifier(Verifier): + """ + 验证器用于检查括号闭合错误识别游戏的答案是否正确 + """ + def verify(self, data: Data, test_answer: str): + """ + 验证模型的回答是否正确 + + @param data: 包含问题、元数据等信息的Data对象 + @param test_answer: 模型给出的回答字符串 + @return: 回答是否正确的布尔值 + """ + try: + test_answer = self.extract_answer(test_solution=test_answer) + # 获取正确答案 + if data.metadata["is_valid"]: + correct_answer = "-1" # 合法序列对应-1 + else: + correct_answer = str(data.metadata["first_error_pos"]) + + # print(f"验证: 模型答案='{test_answer}', 正确答案='{correct_answer}'") + + # 清理和标准化答案 + test_answer = test_answer.strip() + + # 检查-1答案(合法序列) + if correct_answer == "-1": + # 如果正确答案是-1(合法序列),只接受-1作为回答 + if test_answer == "-1": + is_correct = True + else: + is_correct = False + else: + # 正确答案是位置数字,需要验证模型回答也是相同数字 + try: + is_correct = (int(test_answer) == int(correct_answer)) + except (ValueError, TypeError): + # 如果模型回答不是有效数字,验证失败 + is_correct = False + + # if is_correct: + # print("验证结果: 正确") + # else: + # print("验证结果: 错误") + + return is_correct + + except Exception as e: + return False + + def extract_answer(self, test_solution: str): + """ + 从模型的回答中提取答案 + + @param test_solution: 模型的完整回答 + @return: 提取的答案 + """ + answer_str = test_solution + if answer_str is None: + import re + # 清理回答文本 + solution = test_solution.strip() if test_solution else "" + + # 提取所有数字(包括负数) + numbers = re.findall(r'-?\d+', solution) + if numbers: + # 优先返回"-1"(如果存在) + if "-1" in numbers: + return "-1" + # 否则返回找到的第一个非负整数 + for num in numbers: + if num.isdigit() and int(num) >= 0: + return num + # 如果只有负数,返回第一个 + return numbers[0] + + # 检查是否表示合法 + + + # 默认返回空字符串 + return "" + elif any(keyword in answer_str.lower() for keyword in ["合法", "valid", "correct"]): + return "-1" + else: + return answer_str \ No newline at end of file diff --git a/verl/utils/reward_score/synlogic/dyck_language_reasoning_errors_verifier.py b/verl/utils/reward_score/synlogic/dyck_language_reasoning_errors_verifier.py new file mode 100644 index 000000000..ace952964 --- /dev/null +++ b/verl/utils/reward_score/synlogic/dyck_language_reasoning_errors_verifier.py @@ -0,0 +1,129 @@ +from .data import Data +from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END +import re + + +class DyckLanguageReasoningErrorsVerifier(Verifier): + """ + Dyck语言推理错误识别验证器 + """ + def verify(self, data: Data, test_answer: str): + """ + 验证模型的回答是否正确 + + @param data: 包含问题、元数据等信息的Data对象 + @param test_answer: 模型给出的答案字符串 + @return: 回答是否正确的布尔值 + """ + try: + test_answer = self.extract_answer(test_solution=test_answer) + # 获取元数据中的正确答案 + correct_indices = data.metadata["error_indices"] + # 格式化为正确的答案字符串格式 + expected_answer = self._format_answer(correct_indices) + + # print(f"验证: 模型答案='{test_answer}', 正确答案='{expected_answer}'") + + # 检查不明确的答案 + if "不确定" in test_answer or "不知道" in test_answer or "unclear" in test_answer.lower(): + # print("验证结果: 错误") + return False + + # 清理模型答案,允许一定的格式变化 + cleaned_test_answer = self._standardize_answer(test_answer) + + if not correct_indices and (cleaned_test_answer == "" or cleaned_test_answer.lower() in ["无问题", "no", "无错误", "no error", "no errors", "no mistakes", "all correct"]): + # 如果没有错误,且模型回答是空字符串或表示无问题,则正确 + is_correct = True + else: + # 将两个答案转换为数字集合进行比较 + test_error_indices = self._extract_error_indices(cleaned_test_answer) + expected_error_indices = set(correct_indices) + + # 检查两个集合是否相同 + is_correct = test_error_indices == expected_error_indices + + # if is_correct: + # print("验证结果: 正确") + # else: + # print("验证结果: 错误") + + return is_correct + + except Exception as e: + return False + + def _standardize_answer(self, answer: str) -> str: + """ + 标准化答案字符串 + + @param answer: 原始答案字符串 + @return: 标准化后的答案字符串 + """ + # 如果答案为空或仅包含空白字符 + if not answer or answer.strip() == "": + return "" + + # 如果答案表示没有错误 + if answer.lower() in ["无问题", "no", "无错误", "no error", "no errors", "no mistakes", "all correct"]: + return "" + + # 替换中文逗号为英文逗号 + answer = answer.replace(',', ',') + # 移除所有非数字和逗号的字符 + answer = re.sub(r'[^0-9,]', '', answer) + + return answer + + def _extract_error_indices(self, answer: str) -> set: + """ + 从答案字符串中提取错误索引集合 + + @param answer: 答案字符串 + @return: 错误索引集合 + """ + if not answer or answer.strip() == "": + return set() + + # 先替换中文逗号为英文逗号 + answer = answer.replace(',', ',') + + # 分割逗号分隔的数字字符串,并转换为整数集合 + try: + return set(int(idx) for idx in answer.split(',') if idx and idx.strip().isdigit()) + except ValueError: + return set() + + def _format_answer(self, error_indices: list) -> str: + """ + 格式化答案字符串 + + @param error_indices: 错误步骤的索引列表 + @return: 格式化的答案字符串 + """ + if not error_indices: + return "" + + # 按照数字大小排序 + sorted_indices = sorted(error_indices) + + # 转换为字符串,使用英文逗号分隔 + return ",".join(map(str, sorted_indices)) + + def extract_answer(self, test_solution: str) -> str: + """ + 从模型的回答中提取错误索引 + + @param test_solution: 模型的完整回答 + """ + answer_str = test_solution + return answer_str + +if __name__ == "__main__": + test_response = '''\nOkay, so I'm trying to figure out if there are any errors in the reasoning process provided for checking the proper closure of brackets in the sequence: {}[][][]{}{}([]). Let me go through each thought step by step and see where things might have gone wrong.\n\nStarting with Thought 2: the initial stack is empty, which makes sense because we haven't processed any characters yet. So that seems correct.\n\nThought 3: The first character is '{', so it's pushed onto the stack. The stack becomes '{'. Hmm, but in the thought, it's written as '{[', which is incorrect. The stack after the first character should only have '{'. So this thought is wrong because it adds an extra '[' that wasn't there yet. Wait, no, actually, the input sequence starts with '{', so maybe there's a mistake in the thought. Let me recheck. Oh, looking again, no, the input sequence is {}[][][]{}{}([)]}, so the first two characters are '{' and '}'. So after processing '{', the stack should be '{', not '{['. So Thought 3 is wrong because it mistakenly pushes '[' instead of just '{'.\n\nMoving to Thought 4: Next character is '}' which correctly matches the top '{' in the stack, so the stack becomes empty. This part is correct.\n\nThought 5: The next character is '[', so stack becomes '['. Correct.\n\nThought 6: Next is ']', which pops the '[', so stack is empty. Correct.\n\nThought 7: Next is '[', stack becomes '['. Correct.\n\nThought 8: Next is ']', so stack is empty. Correct.\n\nThought 9: Next is '[', stack becomes '['. Correct.\n\nThought 10: Next is ']' which pops the '[', so stack should be empty again. However, in the thought, it says the stack becomes ']', which is incorrect. So Thought 10 is wrong because it added a ']' instead of popping it off. Wait no, the character is ']', which should pop the top '[', so the stack should go from '[' to empty. But the thought says the stack becomes ']' which is wrong. So Thought 10 is incorrect.\n\nThought 11: Next is '{', stack is now '{'. Correct.\n\nThought 12: Next is '}', popping the '{', stack is empty. Correct.\n\nThought 13: Next is '{', stack becomes '{'. Correct.\n\nThought 14: Next is '}', stack is empty. Correct.\n\nThought 15: Next is '(', stack becomes '('. Correct.\n\nThought 16: Next is '[', so stack becomes '(['. Correct.\n\nThought 17: Next is ']', which pops the '[', so stack becomes '('. Correct.\n\nThought 18: Next is ')', which pops the '(', so stack is empty. Correct.\n\nThought 19: After processing all characters, stack is empty. So the sequence is balanced.\n\nWait, but the input sequence actually ends with ']})', right? The sequence given is {}[][][]{}{}([]). So after processing '(' and '[', we have '(', then '['. Then when we process ']', it pops the '[', stack becomes '('. Then next is ')', which pops '(', stack is empty. So everything is correct up to that point.\n\nBut in the earlier steps, the issue was in Thought 10. Because when processing the fifth character, which is '[', followed by ']', the stack should go back to empty. So Thought 10 says the stack becomes ']', which is incorrect. It should go back to empty.\n\nWait, let me recount the steps:\n\nThe input is {}[][][]{}{}([]). So let's break it down:\n\n1. '{' : stack becomes '{'\n2. '}' : stack empty\n3. '[' : stack becomes '['\n4. ']' : stack empty\n5. '[' : '['\n6. ']' : stack empty\n7. '[' : '['\n8. ']' : stack empty\n9. '{' : '{'\n10. '}' : stack empty\n11. '{' : '{'\n12. '}' : stack empty\n13. '(' : '('\n14. '[' : '(['\n15. ']' : '('\n16. ')' : stack empty.\n\nWait, but in the original thought process, after step 10, the character is ']' which pops the '[', making the stack empty, but the thought says stack is ']'. So that's definitely incorrect.\n\nAlso, in step 3, the thought says after '{' the stack is '{[', which is incorrect because the next character is '}', not another '['. So step 3 is wrong because it added an extra '[' instead of just '{' and then processing '}' in step 4 correctly.\n\nSo the errors are in Thought 3 and Thought 10.\n\n\nThe errors are in Thought 3 and Thought 10 because Thought 3 incorrectly adds an extra '[' instead of just '{' and Thought 10 mistakenly leaves the stack as ']' instead of empty after popping.\n\n[3,20]''' + metadata = {"trace_id": "77db72eb-a9db-46cd-96ea-5a49eba78792", "dyck_sequence": "{}[][][]{}{}([])", "thoughts": ["Thought 1: 我们应该逐个处理输入并跟踪栈的配置。", "Thought 2: 栈: 空", "Thought 3: { ; 栈: {[", "Thought 4: } ; 栈: 空", "Thought 5: [ ; 栈: [", "Thought 6: ] ; 栈: 空", "Thought 7: [ ; 栈: [", "Thought 8: ] ; 栈: 空", "Thought 9: [ ; 栈: [", "Thought 10: ] ; 栈: ]", "Thought 11: { ; 栈: {", "Thought 12: } ; 栈: 空", "Thought 13: { ; 栈: {", "Thought 14: } ; 栈: 空", "Thought 15: ( ; 栈: (", "Thought 16: [ ; 栈: ([", "Thought 17: ] ; 栈: (", "Thought 18: ) ; 栈: 空", "Thought 19: 现在,我们已经到达结尾。最终栈是空的。"], "error_indices": [3, 10], "n_types": 3, "total_length": 15, "n_errors": 2} + test_data = Data(question="", answer="", metadata=metadata) + test_verifier = DyckLanguageReasoningErrorsVerifier() + extracted_answer = test_verifier.extract_answer(test_response) + print(extracted_answer) + print(test_verifier.verify(data=test_data, test_answer=test_response)) \ No newline at end of file diff --git a/verl/utils/reward_score/synlogic/dyck_language_verifier.py b/verl/utils/reward_score/synlogic/dyck_language_verifier.py new file mode 100644 index 000000000..04f3b2d20 --- /dev/null +++ b/verl/utils/reward_score/synlogic/dyck_language_verifier.py @@ -0,0 +1,81 @@ +from .data import Data +from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END +import re + + +class DyckLanguageVerifier(Verifier): + """ + 验证器用于检查Dyck Language游戏的答案是否正确 + """ + def verify(self, data: Data, test_answer: str) -> bool: + """ + 验证模型的回答是否正确 + + @param data: 包含问题、元数据等信息的Data对象 + @param test_answer: 模型给出的回答字符串 + @return: 回答是否正确的布尔值 + """ + try: + # 获取元数据中的完整序列 + full_sequence = data.metadata["full_sequence"] + + # print(f"验证: 模型答案='{test_answer}', 完整序列='{full_sequence}'") + + # 从模型回答中提取答案 + extracted_answer = self.extract_answer(test_answer) + + # 检查答案是否完全匹配 + is_correct = (extracted_answer == full_sequence) + + # if is_correct: + # print("验证结果: 正确") + # else: + # print("验证结果: 错误") + + return is_correct + + except Exception as e: + return False + + def extract_answer(self, test_solution: str) -> str: + """ + 从模型的回答中提取括号序列答案 + + @param test_solution: 模型的完整回答 + @return: 提取的答案 + """ + if not test_solution: + return "" + + # print(f"原始回答:\n{test_solution}") + + def clean_text(text: str) -> str: + """清理文本,处理转义字符和空白字符""" + # 移除所有空白字符(包括换行符、制表符等) + text = ''.join(text.split()) + + # 处理转义序列 + text = text.replace('\\n', '') + text = text.replace('\\t', '') + text = text.replace('\\r', '') + text = text.replace('\\\\', '\\') + + # 如果文本被引号包围,且引号不是括号序列的一部分,则移除外层引号 + if len(text) >= 2: + if text.startswith('"') and text.endswith('"'): + text = text[1:-1] + elif text.startswith("'") and text.endswith("'"): + text = text[1:-1] + + return text + + return clean_text(test_solution) + +if __name__ == "__main__": + test_response = '''填写后的完整序列应为“([])({})([()])”。\n\n检查一下长度是否正确:\n\n原序列长度为11字符,补充3个字符,总长度14。\n\n这样,整个序列是合法的。\n\n\n([])({})([()])''' + metadata = {"trace_id": "38aeede4-d5d7-4863-91d2-df1fd99f491b", "full_sequence": "([])({})([()])", "question_sequence": "([])({})([(", "n_types": 3, "total_length": 14, "fill_length": 3, "nesting_depth": 0} + test_data = Data(question="", answer="", metadata=metadata) + test_verifier = DyckLanguageVerifier() + extracted_answer = test_verifier.extract_answer(test_response) + print(extracted_answer) + print(test_verifier.verify(data=test_data, test_answer=test_response)) \ No newline at end of file diff --git a/verl/utils/reward_score/synlogic/game_of_buggy_tables_verifier.py b/verl/utils/reward_score/synlogic/game_of_buggy_tables_verifier.py new file mode 100644 index 000000000..b4c8bdb25 --- /dev/null +++ b/verl/utils/reward_score/synlogic/game_of_buggy_tables_verifier.py @@ -0,0 +1,126 @@ +from .data import Data +from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END +import re + +class BuggyTableVerifier(Verifier): + """ + Verifier for the BuggyTable game. + Checks if the submitted answer matches the expected answer. + """ + def extract_answer(self, answer: str) -> str: + """ + Public method to extract and normalize an answer string from LLM output. + Delegates to the private _extract_answer method. + + @param answer: The answer string to normalize + @return: The normalized answer string + """ + return self._extract_answer(answer) + + def verify(self, data: Data, test_answer: str) -> bool: + """ + Verify whether the test answer is consistent with the expected answer + for the buggy table query. + + @param data: Data object containing the expected answer + @param test_answer: The answer provided by the LLM to verify + @return: bool indicating whether the answer is correct + """ + # Extract the expected answer from the Data object + expected_answer = data.answer if data and hasattr(data, 'answer') else "" + + # For empty strings, compare directly + if not expected_answer and not test_answer: + return True + + # Extract and normalize both answers + normalized_expected = self._extract_answer(expected_answer) + normalized_test = self._extract_answer(test_answer) + + # Direct comparison of normalized answers + return normalized_expected == normalized_test + + def _is_raw_numeric_answer(self, value: str) -> bool: + """ + Check if a string represents a plain numeric answer without additional context. + This is used to validate raw input format. + + @param value: The string to check + @return: True if the string is a simple numeric value + """ + # Remove whitespace + value = value.strip() + + # Simple pattern match for a number (optionally with sign and decimal point) + import re + return bool(re.match(r'^-?\d+(\.\d+)?$', value)) + + def _raw_has_exactly_two_decimals(self, value: str) -> bool: + """ + Check if a raw numeric string has exactly 2 decimal places. + This is used to validate the format of the raw answer. + + @param value: The string to check + @return: True if the string has exactly 2 decimal places + """ + # Remove whitespace + value = value.strip() + + # Split on decimal point + parts = value.replace('-', '', 1).split('.') + + # Check if there is exactly one decimal point and two digits after it + return len(parts) == 2 and len(parts[1]) == 2 + + def _is_numeric(self, value: str) -> bool: + """ + Check if a string represents a valid number (including negative numbers and decimals). + + @param value: The string to check + @return: True if the string represents a valid number + """ + # Remove negative sign if present + value = value.strip() + if value.startswith('-'): + value = value[1:] + # Check if remaining string is a valid decimal number + return value.replace('.', '', 1).isdigit() + + def _has_exactly_two_decimals(self, value: str) -> bool: + """ + Check if a number string has exactly 2 decimal places. + + @param value: The number string to check + @return: True if the number has exactly 2 decimal places + """ + # Remove negative sign if present + value = value.strip() + if value.startswith('-'): + value = value[1:] + + # Split into whole and decimal parts + parts = value.split('.') + if len(parts) != 2: + return False + + # Check if decimal part has exactly 2 digits + return len(parts[1]) == 2 + + def _extract_answer(self, answer: str) -> str: + """ + Extract and normalize an answer string from LLM output. + Only finds values with exactly two decimal places. + + @param answer: The answer string to normalize + @return: The normalized answer string + """ + # Convert to string and normalize + normalized = str(answer).strip() if answer is not None else "" + + # Try to find numbers with exactly two decimal places + exact_matches = re.findall(r'-?\d+\.\d{2}\b', normalized) + if exact_matches: + return exact_matches[-1] # Return the last match with exactly two decimals + + # If no exact two-decimal match found, return the original string + return normalized \ No newline at end of file diff --git a/verl/utils/reward_score/synlogic/goods_exchange_verifier.py b/verl/utils/reward_score/synlogic/goods_exchange_verifier.py new file mode 100644 index 000000000..1a57d66b8 --- /dev/null +++ b/verl/utils/reward_score/synlogic/goods_exchange_verifier.py @@ -0,0 +1,216 @@ +import re +from .data import Data +from .verifier import Verifier + +class GoodsExchangeVerifier(Verifier): + """ + 验证器用于检查物品交换游戏的答案是否正确 + """ + def verify(self, data: Data, test_solution: str): + """ + 验证模型的回答是否正确 + + @param data: 包含问题、元数据等信息的Data对象 + @param test_answer: 模型给出的回答字符串 + @return: 回答是否正确的布尔值 + """ + try: + test_answer = self.extract_answer(test_solution) + # 获取元数据中的正确答案 + correct_answer = data.metadata["owns_after"] + + # print(f"验证: 模型答案='{test_answer}', 正确答案='{correct_answer}'") + + # 解析模型答案 + model_ownership = self._parse_answer(test_answer) + # 解析正确答案 + correct_ownership = self._parse_answer(correct_answer) + + # 比较两个答案是否完全一致 + is_correct = self._compare_answers(model_ownership, correct_ownership) + + # if is_correct: + # print("验证结果: 正确") + # else: + # print("验证结果: 错误") + # # 打印详细的不匹配信息 + # self._print_difference(model_ownership, correct_ownership) + + return is_correct + + except Exception as e: + return False + + def _parse_answer(self, answer_str): + """ + 解析答案字符串为物品归属字典 + + @param answer_str: 答案字符串,格式为"(('人1','物品1'),('人2','物品2'),...)"或"(人1,物品1),(人2,物品2),..." + @return: 归属关系字典 {人: 物品} + """ + if not answer_str: + return {} + + result = {} + try: + # 预处理:只处理最外层的空格,保留内部结构 + answer_str = answer_str.strip() + + # 尝试使用 eval 解析 Python tuple 格式 + pairs = eval(answer_str) + if isinstance(pairs, tuple): + for pair in pairs: + if isinstance(pair, tuple) and len(pair) == 2: + person, item = pair + # 处理每个值中的空格:移除两端空格 + result[person.strip()] = item.strip() + return result + except Exception as e: + # 如果 eval 失败,记录错误并尝试解析旧格式 + + # 移除最外层的括号(如果有) + if answer_str.startswith('('): + answer_str = answer_str[1:] + if answer_str.endswith(')'): + answer_str = answer_str[:-1] + + # 更健壮的手动解析逻辑 + person_item_pairs = [] + current_pair = "" + bracket_count = 0 + + # 更智能地分割答案字符串 + for char in answer_str: + if char == '(': + bracket_count += 1 + current_pair += char + elif char == ')': + bracket_count -= 1 + current_pair += char + if bracket_count == 0: + person_item_pairs.append(current_pair) + current_pair = "" + elif char == ',' and bracket_count == 0: + # 跳过顶层逗号 + continue + else: + current_pair += char + + # 处理每一对 + for pair in person_item_pairs: + pair = pair.strip() + # 移除括号 + if pair.startswith('('): + pair = pair[1:] + if pair.endswith(')'): + pair = pair[:-1] + + # 拆分人和物品 + try: + # 使用更健壮的分割方法 + parts = [] + quote_count = 0 + current = "" + + for char in pair: + if char in "\"'" and (len(current) == 0 or current[-1] != '\\'): + quote_count = 1 - quote_count + + if char == ',' and quote_count == 0: + parts.append(current.strip()) + current = "" + else: + current += char + + if current: + parts.append(current.strip()) + + if len(parts) >= 2: + person = parts[0].strip().strip("'\"") + item = parts[1].strip().strip("'\"") + result[person] = item + except Exception as e: + print(f"NOTE!!! parse error!!!! (GoodsExchange 2): {e}") + + return result + + def _compare_answers(self, model_ownership, correct_ownership): + """ + 比较两个归属关系字典是否相同 + + @param model_ownership: 模型回答的归属关系 + @param correct_ownership: 正确的归属关系 + @return: 是否完全一致 + """ + # 检查人数是否相同 + if len(model_ownership) != len(correct_ownership): + return False + + # 创建小写人名到原始人名的映射 + model_lower_to_original = {person.lower(): person for person in model_ownership} + + # 检查每个人的物品是否一致 + for person in correct_ownership: + # 如果模型答案中没有这个人(不区分大小写) + if person.lower() not in model_lower_to_original: + return False + + # 获取模型答案中对应的原始人名 + model_person = model_lower_to_original[person.lower()] + + # 如果人的物品不匹配(不区分大小写) + if model_ownership[model_person].lower() != correct_ownership[person].lower(): + return False + + return True + + def _print_difference(self, model_ownership, correct_ownership): + """ + 打印两个归属关系之间的差异 + + @param model_ownership: 模型回答的归属关系 + @param correct_ownership: 正确的归属关系 + """ + print("\n差异详情:") + + # 创建小写人名到原始人名的映射 + model_lower_to_original = {person.lower(): person for person in model_ownership} + correct_lower_to_original = {person.lower(): person for person in correct_ownership} + + # 检查正确答案中的每个人 + for person in correct_ownership: + person_lower = person.lower() + if person_lower not in model_lower_to_original: + # print(f" - 模型答案中缺少: {person}") + pass + else: + model_person = model_lower_to_original[person_lower] + # if model_ownership[model_person].lower() != correct_ownership[person].lower(): + # print(f" - {person}: 模型答案={model_ownership[model_person]}, 正确答案={correct_ownership[person]}") + + # 检查模型答案中的额外人员 + # for person in model_ownership: + # if person.lower() not in correct_lower_to_original: + # print(f" - 模型答案中多余: {person}") + + def extract_answer(self, text): + """从文本中提取答案。 + + Args: + text (str): 输入文本 + + Returns: + str: 提取的答案,格式为 "(('人1','物品1'),('人2','物品2'),...)" + """ + if not text: + return "" + + # 尝试从 Python markdown 代码块中提取 + code_block_pattern = r'```python\s*\n(.*?)\n```' + code_blocks = re.findall(code_block_pattern, text, re.DOTALL) + if code_blocks: + # 使用最后一个代码块 + last_block = code_blocks[-1].strip() + if last_block.startswith("(") and last_block.endswith(")"): + return last_block + return "" \ No newline at end of file diff --git a/verl/utils/reward_score/synlogic/math_path_verifier.py b/verl/utils/reward_score/synlogic/math_path_verifier.py new file mode 100755 index 000000000..deffc7a28 --- /dev/null +++ b/verl/utils/reward_score/synlogic/math_path_verifier.py @@ -0,0 +1,98 @@ +import re +import json +import numpy as np +from .data import Data +from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END + + +class MathPathVerifier(Verifier): + """ + 验证器用于检查math_path填充游戏的答案是否正确 + """ + def verify(self, data: Data, test_answer: str): + """ + 验证模型的回答是否正确 + + @param data: 包含问题、元数据等信息的Data对象 + @param test_answer: 模型给出的运算表达式 + @return: 回答是否正确的布尔值 + """ + try: + test_answer = self.extract_answer(test_solution=test_answer) + except Exception as e: + return False + + try: + # 解析元数据 + metadata = data.metadata + ref_expr = metadata["ref_expr"] + query_expr = metadata["query_expr"] + + # 验证数字是否被篡改,数字是否在0-9之间。 + test_tmp = test_answer.replace(' ', '').strip() + query_tmp = query_expr.replace(' ', '').strip() + ref_tmp = ref_expr.replace(' ', '').strip() + query_nums = [x for x in query_tmp if '0'<=x<='9' or x=='?'] + test_nums = [x for x in test_tmp if '0'<=x<='9'] + if len(query_nums)!=len(test_nums): + # print(f"所填数字数量不匹配!原始:{ref_tmp},query:{query_tmp},模型:{test_tmp}") + return False + else: + for ind, x in enumerate(query_nums): + if x=='?': + continue + if x!=test_nums[ind]: + # print(f"表达式数字被篡改!原始:{ref_tmp},query:{query_tmp},模型:{test_tmp}") + return False + + query_symbols = [x for x in query_tmp if x in ['+', '-', '*', '/', '%']] + test_symbols = [x for x in test_tmp if x in ['+', '-', '*', '/', '%']] + if len(query_symbols)!=len(test_symbols): + # print(f"表达式运算符号数量不匹配!原始:{ref_tmp},query:{query_tmp},模型:{test_tmp}") + return False + else: + for ind, x in enumerate(query_symbols): + if x!=test_symbols[ind]: + # print(f"表达式运算符号被篡改!原始:{ref_tmp},query:{query_tmp},模型:{test_tmp}") + return False + + # 验证回答中的等式是否成立 + try: + tmp = test_tmp.replace('=', '==') + if not eval(tmp): + # print(f"等式不成立!原始:{ref_tmp},query:{query_tmp},模型:{test_tmp}") + return False + except: + # print(f"运算表达式错误!原始:{ref_tmp},query:{query_tmp},模型:{test_tmp}") + return False + + + # 所有检查都通过 + # print("验证结果: 正确") + return True + + except Exception as e: + return False + + + def extract_answer(self, test_solution: str): + """ + 从模型的回答中提取答案(字符表达式) + + @param test_solution: 模型的完整回答 + @return: 提取的矩阵答案字符串 + """ + if not test_solution: + return "" + # 尝试提取Python代码块中的矩阵 + code_block_pattern = r'\[\[(.*?)\]\]' + code_matches = re.findall(code_block_pattern, test_solution) + + if code_matches: + # 使用最后一个匹配内容 + operation_expression = code_matches[-1].strip() + return operation_expression + + # 如果所有方法都失败,返回空字符串 + return "" + diff --git a/verl/utils/reward_score/synlogic/minesweeper_verifier.py b/verl/utils/reward_score/synlogic/minesweeper_verifier.py new file mode 100644 index 000000000..ca73d4fe3 --- /dev/null +++ b/verl/utils/reward_score/synlogic/minesweeper_verifier.py @@ -0,0 +1,60 @@ +from .data import Data +from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END +import re +import json +from typing import List, Tuple + + +class MinesweeperVerifier(Verifier): + """ + Verifier for Minesweeper puzzle + 扫雷游戏验证器 + """ + def verify(self, data: Data, test_solution: str, **kwargs): + try: + # 从解答中提取地雷坐标 + predicted_mines = self.extract_answer(test_solution) + + # 从metadata中获取确定性地雷坐标 + expected_mines = data.metadata["current_mines"] + + # 验证提取的坐标是否正确 + if set(tuple(mine) for mine in predicted_mines) == set(tuple(mine) for mine in expected_mines): + return True + + return False + + except Exception as e: + # 如果验证过程中发生任何错误,返回False + return False + + def extract_answer(self, response: str) -> List[Tuple[int, int]]: + """从模型的响应中提取地雷坐标 + Extract mine coordinates from the model's response""" + patterns = [ + r'\[\s*\(\s*(\d+)\s*,\s*(\d+)\s*\)(?:\s*,\s*\(\s*\d+\s*,\s*\d+\s*\))*\s*\]', # [(0,1),(2,3)] + r'\[\s*\[\s*(\d+)\s*,\s*(\d+)\s*\](?:\s*,\s*\[\s*\d+\s*,\s*\d+\s*\])*\s*\]', # [[0,1],[2,3]] + r'\(\s*(\d+)\s*,\s*(\d+)\s*\)(?:\s*,\s*\(\s*\d+\s*,\s*\d+\s*\))*', # (0,1),(2,3) + ] + + for pattern in patterns: + coords = [] + for match in re.finditer(pattern, response): + try: + # 提取所有坐标对 + coord_pattern = r'(?:\(|\[)\s*(\d+)\s*,\s*(\d+)\s*(?:\)|\])' + for coord_match in re.finditer(coord_pattern, match.group(0)): + i, j = int(coord_match.group(1)), int(coord_match.group(2)) + coords.append((i, j)) + except Exception: + continue + + if coords: + return coords + + # 如果没有找到坐标,尝试查找可能是坐标的任何数字 + number_pairs = re.findall(r'(\d+)[^\d]+(\d+)', response) + if number_pairs: + return [(int(i), int(j)) for i, j in number_pairs] + + return [] \ No newline at end of file diff --git a/verl/utils/reward_score/synlogic/norinori_verifier.py b/verl/utils/reward_score/synlogic/norinori_verifier.py new file mode 100644 index 000000000..6d27d21a8 --- /dev/null +++ b/verl/utils/reward_score/synlogic/norinori_verifier.py @@ -0,0 +1,186 @@ +from .data import Data +from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END +import re +from collections import defaultdict + +class NorinoriVerifier(Verifier): + """ + Norinori 游戏的验证器 + 检查提交的答案是否符合 Norinori 游戏规则 + """ + + def __init__(self): + super().__init__() + + def verify(self, data: Data, test_solution: str): + """ + 验证 Norinori 游戏的答案 + + 参数: + data -- 游戏数据,包含区域网格等信息 + test_solution -- 用户提交的答案,应为多米诺坐标列表 + + 返回: + bool -- 答案是否正确 + """ + try: + # 从游戏数据中获取区域网格 + region_grid = data.metadata["region_grid"] + n = len(region_grid) + + # 解析答案 + dominoes = self._parse_answer(test_solution) + if dominoes is None: + return False + + # 检查多米诺形状 + if not self._check_domino_shapes(dominoes): + return False + + # 创建覆盖网格 + covered = [[False for _ in range(n)] for _ in range(n)] + for domino in dominoes: + for i, j in domino: + # 转换为0-indexed + i -= 1 + j -= 1 + if i < 0 or i >= n or j < 0 or j >= n: + return False # 坐标超出范围 + if covered[i][j]: + return False # 格子被多次覆盖 + covered[i][j] = True + + # 检查多米诺之间是否相邻 + if not self._check_domino_adjacency(dominoes, n): + return False + + # 检查每个区域是否恰好有两个格子被覆盖 + region_coverage = defaultdict(int) + for i in range(n): + for j in range(n): + if covered[i][j] and region_grid[i][j] != "X": + region_coverage[region_grid[i][j]] += 1 + + for region, count in region_coverage.items(): + if count != 2: + return False + + # 检查所有阴影格子是否被覆盖 + for i in range(n): + for j in range(n): + if region_grid[i][j] == "X" and not covered[i][j]: + return False + + return True + except Exception as e: + return False + + def _parse_answer(self, test_solution: str): + """ + 解析答案字符串,提取多米诺坐标 + + 参数: + test_solution -- 答案字符串 + + 返回: + list -- 多米诺坐标列表,如果格式不正确则返回None + """ + try: + # 使用正则表达式提取坐标对 + pattern = r'\[\((\d+),\s*(\d+)\),\s*\((\d+),\s*(\d+)\)\]' + matches = re.findall(pattern, test_solution) + + if not matches: + # 尝试另一种可能的格式 + pattern = r'\(\s*(\d+)\s*,\s*(\d+)\s*\)\s*,\s*\(\s*(\d+)\s*,\s*(\d+)\s*\)' + matches = re.findall(pattern, test_solution) + + dominoes = [] + for match in matches: + i1, j1, i2, j2 = map(int, match) + dominoes.append([(i1, j1), (i2, j2)]) + + return dominoes + except Exception as e: + return None + + def _check_domino_shapes(self, dominoes): + """ + 检查所有多米诺是否都是1×2或2×1的形状 + + 参数: + dominoes -- 多米诺坐标列表 + + 返回: + bool -- 是否所有多米诺都符合形状要求 + """ + for domino in dominoes: + if len(domino) != 2: + return False + + (i1, j1), (i2, j2) = domino + + # 检查是否为1×2或2×1 + if not ((i1 == i2 and abs(j1 - j2) == 1) or + (j1 == j2 and abs(i1 - i2) == 1)): + return False + + return True + + def _check_domino_adjacency(self, dominoes, n): + """ + 检查多米诺之间是否相邻 + + 参数: + dominoes -- 多米诺坐标列表 + n -- 网格大小 + + 返回: + bool -- 是否所有多米诺都不相邻 + """ + # 创建一个网格来标记每个多米诺的位置 + grid = [[-1 for _ in range(n+2)] for _ in range(n+2)] # 加2是为了处理边界 + + for idx, domino in enumerate(dominoes): + for i, j in domino: + # 转换为0-indexed并考虑边界 + grid[i][j] = idx + + # 检查每个多米诺是否与其他多米诺相邻 + for idx, domino in enumerate(dominoes): + for i, j in domino: + for di, dj in [(0, 1), (1, 0), (0, -1), (-1, 0)]: + ni, nj = i + di, j + dj + if 1 <= ni <= n and 1 <= nj <= n: # 检查是否在网格内 + if grid[ni][nj] != -1 and grid[ni][nj] != idx: + return False # 发现相邻的多米诺 + + return True + + def extract_answer(self, test_solution: str, strict=False): + """ + 从回答中提取答案 + + 参数: + test_solution -- 用户的回答 + strict -- 是否严格模式 + + 返回: + str -- 提取的答案 + """ + # 尝试找到答案部分 + answer_patterns = [ + r'\[\s*\[\s*\(\s*\d+\s*,\s*\d+\s*\)\s*,\s*\(\s*\d+\s*,\s*\d+\s*\)\s*\]', # 寻找格式如 [[(1,2), (1,3)], ...] 的答案 + r'答案是\s*(.*?)\s*$', # 中文格式 + r'answer is\s*(.*?)\s*$', # 英文格式 + r'solution is\s*(.*?)\s*$' # 另一种英文格式 + ] + + for pattern in answer_patterns: + matches = re.findall(pattern, test_solution, re.IGNORECASE | re.DOTALL) + if matches: + # 返回最后一个匹配项,通常是最终答案 + return matches[-1] + + # 如果没有找到明确的答案格式,返回整个解答 + return test_solution \ No newline at end of file diff --git a/verl/utils/reward_score/synlogic/number_wall_verifier.py b/verl/utils/reward_score/synlogic/number_wall_verifier.py new file mode 100644 index 000000000..3336d292a --- /dev/null +++ b/verl/utils/reward_score/synlogic/number_wall_verifier.py @@ -0,0 +1,225 @@ +from .data import Data +from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END +import re +import json +from collections import deque + +class NumberWallVerifier(Verifier): + """ + Verifier for Number Wall puzzle + 数字墙拼图验证器 + """ + def verify(self, data: Data, test_solution: str, **kwargs): + try: + # 提取答案网格 + solution_grid = self.extract_answer(test_solution) + if not solution_grid: + # print("Failed to extract solution grid") + return False + + # 提取元数据 + original_grid = data.metadata["grid"] + n = data.metadata["n"] + + # 检查网格尺寸 + if len(solution_grid) != n: + # print(f"Solution grid has incorrect number of rows: {len(solution_grid)} != {n}") + return False + + for row in solution_grid: + if len(row) != n: + # print(f"Solution grid has incorrect number of columns: {len(row)} != {n}") + return False + + # 检查每个单元格只包含数字、"X"或"A" + for cell in row: + if not (isinstance(cell, int) or cell in ["X", "A"]): + # print(f"Invalid cell content: {cell}") + return False + + # 检查原始数字是否保留 + if not self._check_original_numbers(original_grid, solution_grid): + # print("Original numbers not preserved") + return False + + # 检查墙壁布局是否有效(没有2×2或更大的连续墙块) + if not self._check_wall_layout(solution_grid): + # print("Invalid wall layout (2x2 or larger continuous wall blocks found)") + return False + + # 检查岛屿划分是否有效 + if not self._check_islands(solution_grid): + # print("Invalid island division") + return False + + # 检查是否有斜线边 + if not self._check_diagonal_borders(solution_grid): + # print("Invalid solution: islands have diagonal borders") + return False + + return True + + except Exception as e: + # 如果验证过程中发生任何错误,返回False + return False + + def _check_original_numbers(self, original_grid, solution_grid): + """检查原始数字是否在解决方案中保留""" + for i in range(len(original_grid)): + for j in range(len(original_grid[i])): + if isinstance(original_grid[i][j], int): + if original_grid[i][j] != solution_grid[i][j]: + # print(f"Original number at ({i},{j}) changed: {original_grid[i][j]} -> {solution_grid[i][j]}") + return False + return True + + def _check_wall_layout(self, grid): + """检查墙壁布局是否有效(没有2×2或更大的连续墙块)""" + n = len(grid) + for i in range(n - 1): + for j in range(n - 1): + if (grid[i][j] == "A" and grid[i][j+1] == "A" and + grid[i+1][j] == "A" and grid[i+1][j+1] == "A"): + # print(f"Found 2x2 wall block at ({i},{j})") + return False + return True + + def _check_islands(self, grid): + """检查岛屿划分是否有效""" + n = len(grid) + visited = set() + + for i in range(n): + for j in range(n): + if (i, j) not in visited and grid[i][j] != "A": + # 发现一个新岛屿 + island_cells = [] + island_number = None + queue = deque([(i, j)]) + visited.add((i, j)) + + while queue: + r, c = queue.popleft() + island_cells.append((r, c)) + + if isinstance(grid[r][c], int): + if island_number is not None: + # 岛屿有多个数字 + # print(f"Island contains multiple numbers: {island_number} and {grid[r][c]}") + return False + island_number = grid[r][c] + + for dr, dc in [(0, 1), (1, 0), (0, -1), (-1, 0)]: + nr, nc = r + dr, c + dc + if (0 <= nr < n and 0 <= nc < n and + (nr, nc) not in visited and + grid[nr][nc] != "A"): + queue.append((nr, nc)) + visited.add((nr, nc)) + + if island_number is None: + # 岛屿没有数字 + # print(f"Island at ({i},{j}) has no number") + return False + + if len(island_cells) != island_number: + # 岛屿大小与数字不匹配 + # print(f"Island size ({len(island_cells)}) doesn't match number ({island_number})") + return False + + return True + + def _check_diagonal_borders(self, grid): + """检查是否有斜线边(对角相邻的不同岛屿)""" + n = len(grid) + + # 标记所有岛屿 + island_map = {} # 映射格子坐标到岛屿ID + island_id = 0 + visited = set() + + for i in range(n): + for j in range(n): + if grid[i][j] != "A" and (i, j) not in visited: + # 发现一个新岛屿 + queue = deque([(i, j)]) + visited.add((i, j)) + + while queue: + r, c = queue.popleft() + island_map[(r, c)] = island_id + + for dr, dc in [(0, 1), (1, 0), (0, -1), (-1, 0)]: + nr, nc = r + dr, c + dc + if (0 <= nr < n and 0 <= nc < n and + grid[nr][nc] != "A" and (nr, nc) not in visited): + queue.append((nr, nc)) + visited.add((nr, nc)) + + island_id += 1 + + # 检查斜线边 + for i in range(n - 1): + for j in range(n - 1): + # 检查2x2方格中的对角格子 + if (grid[i][j] != "A" and grid[i+1][j+1] != "A" and + grid[i][j+1] == "A" and grid[i+1][j] == "A"): + # 对角格子属于不同岛屿,存在斜线边 + if island_map.get((i, j)) != island_map.get((i+1, j+1)): + # print(f"Found diagonal border at ({i},{j}) and ({i+1},{j+1})") + return False + + # 检查另一对对角格子 + if (grid[i][j+1] != "A" and grid[i+1][j] != "A" and + grid[i][j] == "A" and grid[i+1][j+1] == "A"): + # 对角格子属于不同岛屿,存在斜线边 + if island_map.get((i, j+1)) != island_map.get((i+1, j)): + # print(f"Found diagonal border at ({i},{j+1}) and ({i+1},{j})") + return False + + return True + + + def extract_answer(self, response: str): + """从模型的响应中提取答案网格""" + # 在响应中寻找网格表示 + # 修改正则表达式以匹配字符串形式的数字 + grid_pattern = r'\[\s*\[(?:\s*(?:"[XA]"|\'[XA]\'|[0-9]+|"[0-9]+"|\'[0-9]+\')\s*,\s*)*\s*(?:"[XA]"|\'[XA]\'|[0-9]+|"[0-9]+"|\'[0-9]+\')\s*\]\s*(?:,\s*\[(?:\s*(?:"[XA]"|\'[XA]\'|[0-9]+|"[0-9]+"|\'[0-9]+\')\s*,\s*)*\s*(?:"[XA]"|\'[XA]\'|[0-9]+|"[0-9]+"|\'[0-9]+\')\s*\]\s*)*\]' + matches = re.findall(grid_pattern, response) + + if matches: + # 尝试解析最后一个匹配项 + grid_str = matches[-1] + + try: + # 尝试清理字符串,替换可能导致问题的字符 + cleaned_grid_str = grid_str.replace('\n', '').replace('\r', '').strip() + grid = json.loads(cleaned_grid_str) + + # 将字符串数字转换为整数 + for i in range(len(grid)): + for j in range(len(grid[i])): + if isinstance(grid[i][j], str) and grid[i][j].isdigit(): + grid[i][j] = int(grid[i][j]) + + return grid + except json.JSONDecodeError as e: + # 尝试使用 ast.literal_eval 作为备选方案 + try: + import ast + grid = ast.literal_eval(cleaned_grid_str) + + # 将字符串数字转换为整数 + for i in range(len(grid)): + for j in range(len(grid[i])): + if isinstance(grid[i][j], str) and grid[i][j].isdigit(): + grid[i][j] = int(grid[i][j]) + + return grid + except Exception as e2: + pass + else: + # print("No grid pattern found in the response") + pass + + return None diff --git a/verl/utils/reward_score/synlogic/numbrix_verifier.py b/verl/utils/reward_score/synlogic/numbrix_verifier.py new file mode 100644 index 000000000..13a29494a --- /dev/null +++ b/verl/utils/reward_score/synlogic/numbrix_verifier.py @@ -0,0 +1,101 @@ +from .data import Data +from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END +import re +import ast +import numpy as np + +class NumbrixVerifier(Verifier): + """ + Numbrix 游戏的验证器 + 验证提交的解答是否符合 Numbrix 游戏规则 + """ + def verify(self, data: Data, test_solution: str): + try: + # 提取答案网格 + test_grid = self.extract_answer(test_solution) + if not test_grid: + return False + + # 获取原始谜题和网格大小 + original_grid = data.metadata["grid"] + n = len(original_grid) + n_squared = n * n + + # 检查网格大小是否正确 + if len(test_grid) != n or any(len(row) != n for row in test_grid): + return False + + # 检查是否包含所有数字 1 到 n² + flattened_grid = [cell for row in test_grid for cell in row] + if sorted(flattened_grid) != list(range(1, n_squared + 1)): + return False + + # 检查是否保留了原始提示数字 + for i in range(n): + for j in range(n): + if original_grid[i][j] != "X" and test_grid[i][j] != original_grid[i][j]: + return False + + # 检查连续数字是否正交相邻 + for num in range(1, n_squared): + # 找到当前数字的位置 + current_pos = None + next_pos = None + for i in range(n): + for j in range(n): + if test_grid[i][j] == num: + current_pos = (i, j) + elif test_grid[i][j] == num + 1: + next_pos = (i, j) + + if current_pos is None or next_pos is None: + return False + + # 检查是否正交相邻(曼哈顿距离为1) + i1, j1 = current_pos + i2, j2 = next_pos + manhattan_distance = abs(i1 - i2) + abs(j1 - j2) + if manhattan_distance != 1: + return False + + return True + except Exception as e: + return False + + def extract_answer(self, test_solution: str, strict=False): + """从模型回答中提取网格""" + try: + import ast + import re + # 尝试找到 Python 列表格式的答案 + # 寻找形如 [[1, 2, 3], [4, 5, 6], [7, 8, 9]] 的模式 + pattern = r'\[\s*\[\s*\d+.*?\]\s*\]' + matches = re.finditer(pattern, test_solution, re.DOTALL) + match = None + + # 获取最后一个匹配项 + for m in matches: + match = m + if not match: + return None + + # 提取匹配的文本并尝试解析为 Python 对象 + grid_text = match.group(0) + + # 清理文本,确保它是有效的 Python 列表 + # 移除可能导致解析错误的字符 + grid_text = grid_text.replace("'", "").replace('"', "") + + # 解析为 Python 对象 + grid = ast.literal_eval(grid_text) + + # 确保是二维列表且所有元素都是整数 + if not isinstance(grid, list) or not all(isinstance(row, list) for row in grid): + return None + + if not all(isinstance(cell, int) for row in grid for cell in row): + return None + + return grid + except Exception as e: + return None \ No newline at end of file diff --git a/verl/utils/reward_score/synlogic/object_counting_verifier.py b/verl/utils/reward_score/synlogic/object_counting_verifier.py new file mode 100644 index 000000000..389cafce2 --- /dev/null +++ b/verl/utils/reward_score/synlogic/object_counting_verifier.py @@ -0,0 +1,44 @@ +import re +from .data import Data +from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END + + +class ObjectCountingVerifier(Verifier): + """ + 验证器用于物品计数游戏的答案是否正确 + """ + def verify(self, data: Data, test_answer: str): + try: + ground_truth = int(data.answer) + parsed_answer = self.extract_answer(test_answer) + with open("solution_str_OC.txt", "a") as f: + f.write("data.answer: " + data.answer + '\n') + f.write("test_answer: " + test_answer + '\n') + f.write("parsed_answer" + parsed_answer + '\n') + f.write('-'*32 + '\n') + + if parsed_answer is None: + return False + return int(parsed_answer) == ground_truth + + except Exception as e: + return False + + def extract_answer(self, answer_str): + # 先找到最后一个\boxed{的位置 + last_box_index = answer_str.rfind("\\boxed{") + + if last_box_index == -1: + return None + + # 从最后一个\Box{开始截取字符串 + last_box_substring = answer_str[last_box_index:] + + # 在截取的子字符串中进行正则匹配 + box_pattern = r'\\boxed\{([^}]*)\}' + match = re.search(box_pattern, last_box_substring) + + if match: + return match.group(1).strip() + return answer_str + \ No newline at end of file diff --git a/verl/utils/reward_score/synlogic/object_properties_verifier.py b/verl/utils/reward_score/synlogic/object_properties_verifier.py new file mode 100644 index 000000000..6640a9baf --- /dev/null +++ b/verl/utils/reward_score/synlogic/object_properties_verifier.py @@ -0,0 +1,39 @@ +import re +from .data import Data +from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END + + +class ObjectPropertiesVerifier(Verifier): + """ + 验证器用于物品拥有游戏的答案是否正确 + """ + def verify(self, data: Data, test_answer: str): + try: + ground_truth = int(data.answer) + parsed_answer = int(self.extract_answer(test_answer)) + + if parsed_answer is None: + return False + return int(parsed_answer) == ground_truth + + except Exception as e: + return False + + def extract_answer(self, answer_str): + # 先找到最后一个\Box{的位置 + last_box_index = answer_str.rfind("\\boxed{") + + if last_box_index == -1: + return None + + # 从最后一个\Box{开始截取字符串 + last_box_substring = answer_str[last_box_index:] + + # 在截取的子字符串中进行正则匹配 + box_pattern = r'\\boxed\{([^}]*)\}' + match = re.search(box_pattern, last_box_substring) + + if match: + return match.group(1).strip() + return None + \ No newline at end of file diff --git a/verl/utils/reward_score/synlogic/operation_verifier.py b/verl/utils/reward_score/synlogic/operation_verifier.py new file mode 100644 index 000000000..25a3000b7 --- /dev/null +++ b/verl/utils/reward_score/synlogic/operation_verifier.py @@ -0,0 +1,46 @@ +import re +from .data import Data +from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END +import math_verify + + +class OperationVerifier(Verifier): + """ + 验证器用于物品计数游戏的答案是否正确 + """ + def verify(self, data: Data, test_answer: str): + try: + ground_truth = math_verify.parse(data.answer) + parsed_answer = math_verify.parse(test_answer) + + if parsed_answer is None: + return False + return math_verify.verify(parsed_answer, ground_truth) + except Exception as e: + return False + + def extract_answer(self, answer_str): + # 先找到最后一个\boxed{的位置 + last_box_index = answer_str.rfind("\\boxed{") + + if last_box_index == -1: + return None + + # 从\boxed{开始截取到正确的闭合位置,处理嵌套括号 + start_index = last_box_index + len("\\boxed{") + bracket_stack = 1 # 已经遇到了一个左括号 + end_index = start_index + + while end_index < len(answer_str) and bracket_stack > 0: + if answer_str[end_index] == '{': + bracket_stack += 1 + elif answer_str[end_index] == '}': + bracket_stack -= 1 + end_index += 1 + + if bracket_stack != 0: # 括号不匹配 + return None + + # 提取\boxed{}内的内容 + latex_content = answer_str[start_index:end_index-1].strip() + return latex_content \ No newline at end of file diff --git a/verl/utils/reward_score/synlogic/skyscraper_puzzle_verifier.py b/verl/utils/reward_score/synlogic/skyscraper_puzzle_verifier.py new file mode 100644 index 000000000..05b010561 --- /dev/null +++ b/verl/utils/reward_score/synlogic/skyscraper_puzzle_verifier.py @@ -0,0 +1,167 @@ +from .data import Data +from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END +import re +import json +import ast + + +class SkyscraperPuzzleVerifier(Verifier): + """ + 摩天楼游戏验证器,用于验证模型提供的解答是否正确 + """ + def verify(self, data: Data, test_solution: str): + """ + 验证模型的回答是否符合摩天楼游戏的规则 + + @param data: 包含游戏信息的Data对象 + @param test_answer: 游戏类提取的网格数据 + @return: 回答是否正确的布尔值 + """ + try: + # 获取游戏元数据 + metadata = data.metadata + n = metadata['n'] + top = metadata['top'] + bottom = metadata['bottom'] + left = metadata['left'] + right = metadata['right'] + + self.n = n + test_answer = self.extract_answer(test_solution) + + # print(f"验证: 游戏规模 {n}×{n}") + # print(f"上方提示: {top}") + # print(f"下方提示: {bottom}") + # print(f"左侧提示: {left}") + # print(f"右侧提示: {right}") + + # 使用提取好的网格数据 + grid = test_answer + + # 检查网格是否是字符串,如果是,说明提取失败 + if isinstance(grid, str): + # print("无法提取有效网格") + return False + + # print("提取的网格:") + # for row in grid: + # print(row) + + # 检查网格规模 + if len(grid) != n or any(len(row) != n for row in grid): + # print(f"网格规模不正确,应为 {n}×{n}") + return False + + # 检查数字范围 (1 到 n) + for i in range(n): + for j in range(n): + if not isinstance(grid[i][j], int) or grid[i][j] < 1 or grid[i][j] > n: + # print(f"位置 ({i+1},{j+1}) 的值 {grid[i][j]} 不在有效范围内 (1-{n})") + return False + + # 检查每行唯一性 + for i in range(n): + if len(set(grid[i])) != n: + # print(f"第 {i+1} 行包含重复数字") + return False + + # 检查每列唯一性 + for j in range(n): + column = [grid[i][j] for i in range(n)] + if len(set(column)) != n: + # print(f"第 {j+1} 列包含重复数字") + return False + + # 检查从上方观察 + for j in range(n): + visible_count = self._count_visible_skyscrapers([grid[i][j] for i in range(n)]) + if visible_count != top[j]: + # print(f"从上方看第 {j+1} 列可见楼数为 {visible_count},应为 {top[j]}") + return False + + # 检查从下方观察 + for j in range(n): + visible_count = self._count_visible_skyscrapers([grid[i][j] for i in range(n-1, -1, -1)]) + if visible_count != bottom[j]: + # print(f"从下方看第 {j+1} 列可见楼数为 {visible_count},应为 {bottom[j]}") + return False + + # 检查从左侧观察 + for i in range(n): + visible_count = self._count_visible_skyscrapers(grid[i]) + if visible_count != left[i]: + # print(f"从左侧看第 {i+1} 行可见楼数为 {visible_count},应为 {left[i]}") + return False + + # 检查从右侧观察 + for i in range(n): + visible_count = self._count_visible_skyscrapers(grid[i][::-1]) + if visible_count != right[i]: + # print(f"从右侧看第 {i+1} 行可见楼数为 {visible_count},应为 {right[i]}") + return False + + # 所有检查通过 + # print("所有验证规则通过!") + return True + + except Exception as e: + return False + + def _count_visible_skyscrapers(self, heights): + """ + 计算从一个方向看过去能看到的摩天楼数量 + + @param heights: 从观察方向依次排列的摩天楼高度列表 + @return: 可见的摩天楼数量 + """ + visible_count = 0 + max_height = 0 + + for height in heights: + if height > max_height: + visible_count += 1 + max_height = height + + return visible_count + + def extract_answer(self, test_solution: str): + """ + 从模型的回答中提取网格数据 + + @param test_solution: 模型的完整回答 + @return: 提取的解答网格数据 + """ + try: + n = self.n + + # 从 ```python 代码块中提取 + code_block_pattern = r"```python\s*\n([\s\S]*?)\n\s*```" + code_blocks = re.findall(code_block_pattern, test_solution) + + if code_blocks: + # 取第一个代码块(通常只有一个) + code_block = code_blocks[0].strip() + try: + # 直接解析代码块 + grid = ast.literal_eval(code_block) + # 验证是否为有效的n×n网格 + if (isinstance(grid, list) and + len(grid) == n and + all(isinstance(row, list) and len(row) == n for row in grid)): + return grid + except Exception: + # 如果直接解析失败,尝试移除注释后再解析 + code_without_comments = re.sub(r'#.*$', '', code_block, flags=re.MULTILINE) + try: + grid = ast.literal_eval(code_without_comments.strip()) + if (isinstance(grid, list) and + len(grid) == n and + all(isinstance(row, list) and len(row) == n for row in grid)): + return grid + except Exception: + pass + + # 如果提取失败,返回原始答案 + return test_solution + except Exception as e: + return test_solution \ No newline at end of file diff --git a/verl/utils/reward_score/synlogic/space_reasoning_tree_verifier.py b/verl/utils/reward_score/synlogic/space_reasoning_tree_verifier.py new file mode 100644 index 000000000..abc165d5c --- /dev/null +++ b/verl/utils/reward_score/synlogic/space_reasoning_tree_verifier.py @@ -0,0 +1,44 @@ +import re +from .data import Data +from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END +import math_verify + +class SpaceReasoningTreeVerifier(Verifier): + """ + 验证器用于空间推理树游戏的答案是否正确 + """ + def verify(self, data: Data, test_answer: str): + test_answer = self.extract_answer(test_answer) + if test_answer is None: + return False + test_answer = test_answer.replace(",", ",").replace(" ", "") + ground_truth = data.answer.replace(",", ",").replace(" ", "") + test_set = set(test_answer.split(",")) + ground_truth_set = set(ground_truth.split(",")) + return test_set == ground_truth_set + + def extract_answer(self, answer_str): + # 先找到最后一个\boxed{的位置 + last_box_index = answer_str.rfind("\\boxed{") + + if last_box_index == -1: + return None + + # 从\boxed{开始截取到正确的闭合位置,处理嵌套括号 + start_index = last_box_index + len("\\boxed{") + bracket_stack = 1 # 已经遇到了一个左括号 + end_index = start_index + + while end_index < len(answer_str) and bracket_stack > 0: + if answer_str[end_index] == '{': + bracket_stack += 1 + elif answer_str[end_index] == '}': + bracket_stack -= 1 + end_index += 1 + + if bracket_stack != 0: # 括号不匹配 + return None + + # 提取\boxed{}内的内容 + latex_content = answer_str[start_index:end_index-1].strip() + return latex_content \ No newline at end of file diff --git a/verl/utils/reward_score/synlogic/space_reasoning_verifier.py b/verl/utils/reward_score/synlogic/space_reasoning_verifier.py new file mode 100644 index 000000000..249f2dc08 --- /dev/null +++ b/verl/utils/reward_score/synlogic/space_reasoning_verifier.py @@ -0,0 +1,41 @@ +import re +from .data import Data +from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END +import math_verify + + +class SpaceReasoningVerifier(Verifier): + """ + 验证器用于空间推理游戏的答案是否正确 + """ + def verify(self, data: Data, test_answer: str): + test_answer = self.extract_answer(test_answer) + if test_answer is None: + return False + return test_answer.lower() == data.answer.lower() + + def extract_answer(self, answer_str): + # 先找到最后一个\boxed{的位置 + last_box_index = answer_str.rfind("\\boxed{") + + if last_box_index == -1: + return None + + # 从\boxed{开始截取到正确的闭合位置,处理嵌套括号 + start_index = last_box_index + len("\\boxed{") + bracket_stack = 1 # 已经遇到了一个左括号 + end_index = start_index + + while end_index < len(answer_str) and bracket_stack > 0: + if answer_str[end_index] == '{': + bracket_stack += 1 + elif answer_str[end_index] == '}': + bracket_stack -= 1 + end_index += 1 + + if bracket_stack != 0: # 括号不匹配 + return None + + # 提取\boxed{}内的内容 + latex_content = answer_str[start_index:end_index-1].strip() + return latex_content \ No newline at end of file diff --git a/verl/utils/reward_score/synlogic/star_placement_puzzle_verifier.py b/verl/utils/reward_score/synlogic/star_placement_puzzle_verifier.py new file mode 100644 index 000000000..2ea42446f --- /dev/null +++ b/verl/utils/reward_score/synlogic/star_placement_puzzle_verifier.py @@ -0,0 +1,158 @@ +from .data import Data +from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END +import re +import json +import ast + +import re + +class StarPlacementPuzzleVerifier(Verifier): + """ + 星星放置游戏验证器,用于验证模型提供的解答是否正确 + """ + def verify(self, data: Data, test_solution: str): + """ + 验证模型的回答是否符合星星放置游戏的规则 + + @param data: 包含游戏信息的Data对象 + @param star_coords: 通过extract_answer提取的星星坐标字典 {区域: [(行,列), ...]} + @return: 回答是否正确的布尔值 + """ + try: + star_coords = self.extract_answer(test_solution) + # 获取游戏元数据 + metadata = data.metadata + n = metadata['n'] + k = metadata['k'] + region_grid = metadata['region_grid'] + + # print(f"验证: 游戏规模 {n}×{n}, 每行/列/区域星星数量: {k}") + + # 检查是否有有效的星星坐标 + if not star_coords: + # print("无法从回答中提取有效的星星坐标") + return False + + # 创建一个表示星星位置的网格 + star_grid = [[0 for _ in range(n)] for _ in range(n)] + for region, coords in star_coords.items(): + for coord in coords: + row, col = coord + if row < 0 or row >= n or col < 0 or col >= n: + # print(f"无效坐标: ({row},{col}) - 超出网格范围") + return False + star_grid[row][col] = 1 + + # 打印星星网格以便调试 + # print("星星网格:") + # for row in star_grid: + # print(''.join(['* ' if cell == 1 else '. ' for cell in row])) + + # 1. 检查每行是否有k颗星星 + for i in range(n): + stars_in_row = sum(star_grid[i]) + if stars_in_row != k: + # print(f"行 {i+1} 有 {stars_in_row} 颗星星,应该有 {k} 颗") + return False + + # 2. 检查每列是否有k颗星星 + for j in range(n): + stars_in_col = sum(star_grid[i][j] for i in range(n)) + if stars_in_col != k: + # print(f"列 {j+1} 有 {stars_in_col} 颗星星,应该有 {k} 颗") + return False + + # 3. 检查每个区域是否有k颗星星 + regions = {} + for i in range(n): + for j in range(n): + region = region_grid[i][j] + if region not in regions: + regions[region] = [] + regions[region].append((i, j)) + + for region, cells in regions.items(): + stars_in_region = sum(star_grid[i][j] for i, j in cells) + if stars_in_region != k: + # print(f"区域 {region} 有 {stars_in_region} 颗星星,应该有 {k} 颗") + return False + + # 4. 检查星星是否互不相邻(水平、垂直、对角线) + for i in range(n): + for j in range(n): + if star_grid[i][j] == 1: + # 检查周围8个方向 + for di in [-1, 0, 1]: + for dj in [-1, 0, 1]: + if di == 0 and dj == 0: + continue # 跳过自身 + ni, nj = i + di, j + dj + if 0 <= ni < n and 0 <= nj < n and star_grid[ni][nj] == 1: + # print(f"星星在 ({i},{j}) 与星星在 ({ni},{nj}) 相邻") + return False + + # 所有检查通过 + # print("所有验证规则通过!") + return True + + except Exception as e: + return False + + def extract_answer(self, test_solution: str): + """ + 从模型的回答中提取星星坐标 + + @param test_solution: 模型的完整回答 + @return: 提取的星星坐标字典 {区域: [(行,列), ...]} + """ + try: + # 从Python代码块中提取 + python_match = re.search(r'```python\s*\n(.*?)\n\s*```', test_solution, re.DOTALL) + if not python_match: + # print("回答中没有找到```python代码块") + return None + + code_content = python_match.group(1) + + # 尝试从Python代码中提取字典 + try: + # 先尝试直接提取字典内容 + dict_match = re.search(r'\{[^{}]*\}', code_content, re.DOTALL) + if dict_match: + dict_str = dict_match.group(0) + try: + # 将字符串转换为字典 + coords_dict = ast.literal_eval(dict_str) + # 如果成功且是字典类型,继续处理 + if isinstance(coords_dict, dict): + # 将坐标减1(因为用户输入的坐标是1-索引) + result = {} + for region, coords in coords_dict.items(): + result[region] = [(row-1, col-1) for row, col in coords] + return result + except (ValueError, SyntaxError) as e: + pass + + # 如果上面的方法失败,尝试解析变量赋值 + assign_match = re.search(r'(\w+)\s*=\s*(\{[^{}]*\})', code_content, re.DOTALL) + if assign_match: + dict_str = assign_match.group(2) + try: + # 将字符串转换为字典 + coords_dict = ast.literal_eval(dict_str) + # 如果成功且是字典类型,继续处理 + if isinstance(coords_dict, dict): + # 将坐标减1(因为用户输入的坐标是1-索引) + result = {} + for region, coords in coords_dict.items(): + result[region] = [(row-1, col-1) for row, col in coords] + return result + except (ValueError, SyntaxError) as e: + pass + except Exception as e: + pass + + return None + + except Exception as e: + return None \ No newline at end of file diff --git a/verl/utils/reward_score/synlogic/synlogic.py b/verl/utils/reward_score/synlogic/synlogic.py new file mode 100644 index 000000000..29b08b122 --- /dev/null +++ b/verl/utils/reward_score/synlogic/synlogic.py @@ -0,0 +1,92 @@ +import os +import sys +sys.path.append(os.path.dirname(os.path.abspath(__file__))) +# from .game_of_24.scripts.game_of_24_verifier import GameOf24Verifier +# from .cryptarithm.scripts.cryptarithm_verifier import CryptarithmVerifier +# from .survo.scripts.survo_verifier import SurvoVerifier +from .campsite_verifier import CampsiteVerifier +from .skyscraper_puzzle_verifier import SkyscraperPuzzleVerifier +from .web_of_lies_verifier import WebOfLiesVerifier +from .goods_exchange_verifier import GoodsExchangeVerifier +# from .sudoku.scripts.sudoku_verifier import SudokuVerifier +# from corpus.misc.tasks.zebra_puzzle.scripts.zebra_puzzle_verifier import ZebraPuzzleVerifier +# from corpus.misc.tasks.bbeh.scripts.bbeh_verifier import BBEHVerifier +# from corpus.misc.tasks.arc_agi.scripts.arc_agi_verifier import ArcAGIVerifier +from .object_properties_verifier import ObjectPropertiesVerifier +from .object_counting_verifier import ObjectCountingVerifier +from .star_placement_puzzle_verifier import StarPlacementPuzzleVerifier +from .arrow_maze_verifier import ArrowMazeVerifier +# from .kukurasu.scripts.kukurasu_verifier import KukurasuVerifier +from .number_wall_verifier import NumberWallVerifier +from .numbrix_verifier import NumbrixVerifier +from .norinori_verifier import NorinoriVerifier +from .minesweeper_verifier import MinesweeperVerifier +from .operation_verifier import OperationVerifier +from .word_sorting_mistake_verifier import WordSortingMistakeVerifier +from .math_path_verifier import MathPathVerifier +from .boolean_expressions_verifier import BooleanExpressionsVerifier +from .space_reasoning_verifier import SpaceReasoningVerifier +from .space_reasoning_tree_verifier import SpaceReasoningTreeVerifier +from .word_sorting_verifier import WordSortingVerifier +# from corpus.misc.tasks.gpqa.scripts.gpqa_verifier import GPQAVerifier +# from .cipher.scripts.cipher_verifier import CipherVerifier +from .time_sequence_verifier import TimeSequenceVerifier +from .wordscapes_verifier import WordscapesVerifier +# from corpus.misc.tasks.bbh.scripts.boolean_expressions_verifier import BBHBooleanExpressionsVerifier +# from corpus.misc.tasks.bbh.scripts.causal_judgement_verifier import BBHCausalJudgementVerifier # yes no +# from corpus.misc.tasks.bbh.scripts.date_understanding_verifier import BBHDateUnderstandingVerifier # multi-choice +# from corpus.misc.tasks.bbh.scripts.dyck_languages_verifier import BBHDyckLanguagesVerifier +# from corpus.misc.tasks.bbh.scripts.formal_fallacies_verifier import BBHFormalFallaciesVerifier +# from corpus.misc.tasks.bbh.scripts.multistep_arithmetic_two_verifier import BBHMultistepArithmeticVerifier # number +# from corpus.misc.tasks.bbh.scripts.sports_understanding_verifier import BBHSportsUnderstandingVerifier +# from corpus.misc.tasks.bbh.scripts.web_of_lies_verifier import BBHWebOfLiesVerifier +# from corpus.misc.tasks.bbh.scripts.word_sorting_verifier import BBHWordSortingVerifier +from .game_of_buggy_tables_verifier import BuggyTableVerifier +# from .calcudoko.scripts.calcudoko_verifier import CalcudokoVerifier +from .dyck_language_verifier import DyckLanguageVerifier +from .dyck_language_errors_verifier import DyckLanguageErrorsVerifier +from .dyck_language_reasoning_errors_verifier import DyckLanguageReasoningErrorsVerifier +# from .futoshiki.scripts.futoshiki_verifier import FutoshikiVerifier + +# NOTE: Add new tasks in alphabetical order +verifier_classes = { + "arrow_maze": ArrowMazeVerifier, + "boolean_expressions": BooleanExpressionsVerifier, + "buggy_tables": BuggyTableVerifier, + # "calcudoko": CalcudokoVerifier, + "campsite": CampsiteVerifier, + # "cipher": CipherVerifier, + # "cryptarithm": CryptarithmVerifier, + "dyck_language": DyckLanguageVerifier, + "dyck_language_errors": DyckLanguageErrorsVerifier, + "dyck_language_reasoning_errors": DyckLanguageReasoningErrorsVerifier, + # "futoshiki": FutoshikiVerifier, + "goods_exchange": GoodsExchangeVerifier, + # "gpqa_diamond": GPQAVerifier, + # "kukurasu": KukurasuVerifier, + "math_path": MathPathVerifier, + # "arc_agi": ArcAGIVerifier, + # "arc_agi_2": ArcAGIVerifier, + # "mathador": GameOf24Verifier, + "minesweeper": MinesweeperVerifier, + "norinori": NorinoriVerifier, + "number_wall": NumberWallVerifier, + "numbrix": NumbrixVerifier, + "object_counting": ObjectCountingVerifier, + "object_properties": ObjectPropertiesVerifier, + "operation": OperationVerifier, + "skyscraper_puzzle": SkyscraperPuzzleVerifier, + "space_reasoning": SpaceReasoningVerifier, + "space_reasoning_tree": SpaceReasoningTreeVerifier, + "star_placement_puzzle": StarPlacementPuzzleVerifier, + # "sudoku": SudokuVerifier, + # "survo": SurvoVerifier, + "time_sequence": TimeSequenceVerifier, + "web_of_lies": WebOfLiesVerifier, + "word_sorting": WordSortingVerifier, + "word_sorting_mistake": WordSortingMistakeVerifier, + "wordscapes": WordscapesVerifier, + # "zebra_puzzle": ZebraPuzzleVerifier, + # ** bbeh_classes, + # ** bbh_classes, +} \ No newline at end of file diff --git a/verl/utils/reward_score/synlogic/time_sequence_verifier.py b/verl/utils/reward_score/synlogic/time_sequence_verifier.py new file mode 100644 index 000000000..8ae6aa161 --- /dev/null +++ b/verl/utils/reward_score/synlogic/time_sequence_verifier.py @@ -0,0 +1,66 @@ +import json +import numpy as np +from .data import Data +from .verifier import Verifier +import re + +class TimeSequenceVerifier(Verifier): + """ + 验证器用于验证 time sequence 的答案是否正确 + """ + def verify(self, data: Data, test_solution: str): + """ + 验证模型的回答是否正确 + + @param data: 包含问题、元数据等信息的Data对象 + @param test_answer: 模型给出的答案,格式为数字列表 + @return: 回答是否正确的布尔值 + """ + try: + test_answer = self.extract_answer(test_solution) + # 解析元数据 + metadata = data.metadata + true_answers = metadata['records']['answers'] + + # 解析模型给出的列表 + try: + test_list = json.loads(test_answer.replace(",", ",")) + except: + return False + + try: + if test_list[0]!=true_answers['answer_maxLen']: + # print(f"最长会议时间不正确。model:{test_answer} *** true:[{true_answers['answer_maxLen']}, {true_answers['answer_nums']}]") + return False + if test_list[1]!=true_answers['answer_nums']: + # print(f"可选会议数量不正确。model:{test_answer} *** true:[{true_answers['answer_maxLen']}, {true_answers['answer_nums']}]") + return False + except: + return False + + # 所有检查都通过 + # print("验证结果: 正确") + return True + except Exception as e: + return False + + def extract_answer(self, test_solution: str): + """ + 从模型的回答中提取答案(矩阵) + + @param test_solution: 模型的完整回答 + @return: 提取答案列表 + """ + if not test_solution: + return "" + + # 尝试提取列表 + matrix_pattern = r'\[.*?\]' + matrix_matches = re.findall(matrix_pattern, test_solution, re.DOTALL) + if matrix_matches: + # 使用最后一个匹配的列表 + # print(matrix_matches) + return matrix_matches[-1].strip() + + # 如果失败,返回空字符串 + return "" \ No newline at end of file diff --git a/verl/utils/reward_score/synlogic/verifier.py b/verl/utils/reward_score/synlogic/verifier.py new file mode 100644 index 000000000..498e87a82 --- /dev/null +++ b/verl/utils/reward_score/synlogic/verifier.py @@ -0,0 +1,83 @@ +from abc import ABC, abstractmethod +from .data import Data + +class Verifier(ABC): + """ + Base class for verifier + """ + def __init__(self): + pass + + @abstractmethod + def verify(self, data: Data, test_answer: str): + """ + Verify whether the test answer is consistent with the gold answer + @param data: Data + @param test_answer: str + @return: bool + """ + raise NotImplementedError("Verifier.verify() is not implemented") + + @abstractmethod + def extract_answer(self, test_solution: str): + """ + Extract the answer from the test solution + @param test_solution: str + @return: str + """ + raise NotImplementedError("Verifier.extract_answer() is not implemented") + +import re + +THOUGHT_DELIMITER_START = "" +THOUGHT_DELIMITER_END = "" + +def _extract_answer(text): + # 定义正则表达式模式,匹配 之间的内容 + pattern = r'(.*?)' + + # 使用 re.search 查找第一个匹配项 + match = re.search(pattern, text, re.DOTALL) + + # 如果找到匹配项,返回匹配的内容 + if match: + return match.group(1).strip() + else: + return None + +def _extract_solution_with_thought(solution_str): + + model_output = solution_str + + if THOUGHT_DELIMITER_END in solution_str: + model_output = solution_str.split(THOUGHT_DELIMITER_END)[1] + + predict_answer = _extract_answer(model_output) + + + if predict_answer is not None: + return predict_answer + else: + return "" + + +class ExactMatchVerifier(Verifier): + """ + Verifier for Exact Match + """ + def verify(self, data: Data, test_solution: str): + try: + test_answer = self.extract_answer(test_solution) + ground_truth = data.answer + correct = test_answer == ground_truth + if correct: + acc_score = 1.0 + else: + acc_score = 0 + + return acc_score + except: + return False + + def extract_answer(self, test_solution: str): + return _extract_solution_with_thought(solution_str=test_solution) \ No newline at end of file diff --git a/verl/utils/reward_score/synlogic/web_of_lies_verifier.py b/verl/utils/reward_score/synlogic/web_of_lies_verifier.py new file mode 100644 index 000000000..e301d2cea --- /dev/null +++ b/verl/utils/reward_score/synlogic/web_of_lies_verifier.py @@ -0,0 +1,134 @@ +import re +from .data import Data +from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END + +class WebOfLiesVerifier(Verifier): + """ + 验证器用于检查谎言之网游戏的答案是否正确 + """ + def verify(self, data: Data, test_solution: str): + """ + 验证模型的回答是否正确 + + @param data: 包含问题、元数据等信息的Data对象 + @param test_answer: 模型给出的回答字符串 + @return: 回答是否正确的布尔值 + """ + try: + test_answer = self.extract_answer(test_solution) + # 获取预期答案和测试答案 + expected_answer = data.answer.lower() + + # 清理测试答案 + test_answer = test_answer.lower() + + # 提取预期答案中的真假值 + expected_truths = self._parse_answer(expected_answer) + + # 提取测试答案中的真假值 + test_truths = self._parse_answer(test_answer) + + # print(f"验证: 预期答案={expected_truths}, 模型答案={test_truths}") + + # 检查答案列表长度是否匹配 + if len(expected_truths) != len(test_truths): + # print(f"验证失败: 答案长度不匹配,预期 {len(expected_truths)},实际 {len(test_truths)}") + return False + + # 检查每个位置的答案是否匹配 + for i, (expected, actual) in enumerate(zip(expected_truths, test_truths)): + if expected != actual: + # print(f"验证失败: 第 {i+1} 个答案不匹配,预期 {expected},实际 {actual}") + return False + + # print("验证成功: 所有答案匹配") + return True + + except Exception as e: + return False + + def _parse_answer(self, answer_str): + """ + 从答案字符串中解析出真假值列表 + + @param answer_str: 答案字符串 + @return: 真假值列表,True表示说真话,False表示说谎话 + """ + # 尝试匹配英文答案格式 (yes/no) + yes_pattern = r'yes|true|truth' + no_pattern = r'no|false|lie' + + # 尝试匹配中文答案格式 (是/否) + cn_yes_pattern = r'是|真话|真' + cn_no_pattern = r'否|假话|假|谎' + + # 组合模式 + yes_patterns = f'({yes_pattern}|{cn_yes_pattern})' + no_patterns = f'({no_pattern}|{cn_no_pattern})' + + # 根据答案字符串中的关键词确定真假值 + truths = [] + + # 寻找所有可能的yes/no或是/否答案 + all_answers = re.findall(rf'{yes_patterns}|{no_patterns}', answer_str) + + for match in all_answers: + # match是一个元组,需要找到非空的元素 + match_str = next((m for m in match if m), '') + + if re.search(yes_pattern, match_str) or re.search(cn_yes_pattern, match_str): + truths.append(True) + elif re.search(no_pattern, match_str) or re.search(cn_no_pattern, match_str): + truths.append(False) + + return truths + + def extract_answer(self, test_solution: str) -> str: + """ + 从模型的回答中提取答案 + + @param test_solution: 模型的完整回答 + @return: 提取的答案 + """ + if not test_solution: + return "" + # 中文模式 + cn_patterns = [ + r'答案是[::]\s*\*\*([^*]+)\*\*[.。]*$', # 匹配"答案是:**是,否,是**"格式 + ] + + # 英文模式 + en_patterns = [ + r'[Tt]he answer is[::=]\s*\*\*([^*]+)\*\*[.。]*$', # 匹配"The answer is: **yes, no, yes**"格式 + ] + + # 尝试匹配所有模式 + patterns = cn_patterns + en_patterns + + for pattern in patterns: + matches = re.findall(pattern, test_solution, re.DOTALL) + if matches: + return matches[-1].strip() + + # 如果上面的模式都没匹配到,尝试更宽松的匹配 + # 查找最后一行中的加粗文本 + lines = test_solution.strip().split('\n') + if lines: + last_line = lines[-1].strip() + bold_match = re.search(r'\*\*([^*]+)\*\*', last_line) + if bold_match: + return bold_match.group(1).strip() + + # 尝试匹配"答案是"或"The answer is"后面的文本 + answer_match = re.search(r'(?:答案是|[Tt]he answer is)[::=]?\s*(.*?)(?:[.。]|$)', last_line) + if answer_match: + return answer_match.group(1).strip() + + # 如果没有找到格式化的答案,尝试直接匹配yes/no或是/否序列 + yes_no_pattern = r'(?:\b(?:yes|no|是|否)\b[,,\s]*)+' + matches = re.findall(yes_no_pattern, test_solution.lower()) + if matches: + return matches[-1].strip() + + # 如果没有匹配到任何模式,返回空字符串 + return "" \ No newline at end of file diff --git a/verl/utils/reward_score/synlogic/word_sorting_mistake_verifier.py b/verl/utils/reward_score/synlogic/word_sorting_mistake_verifier.py new file mode 100644 index 000000000..f2ee2109a --- /dev/null +++ b/verl/utils/reward_score/synlogic/word_sorting_mistake_verifier.py @@ -0,0 +1,44 @@ +import re +from .data import Data +from .verifier import Verifier + +class WordSortingMistakeVerifier(Verifier): + """ + 验证器用于word sorting mistake的答案是否正确 + """ + def verify(self, data: Data, test_answer: str): + try: + ground_truth = data.answer if data.answer is not None else "No" + parsed_answer = self.extract_answer(test_answer) + + if parsed_answer is None: + return False + + if parsed_answer.isdigit(): + try: + return int(parsed_answer) == int(ground_truth) + except Exception as e: + return False + else: + return parsed_answer.lower() == ground_truth.lower() + except Exception as e: + return False + + def extract_answer(self, answer_str): + # 先找到最后一个\boxed{的位置 + last_box_index = answer_str.rfind("\\boxed{") + + if last_box_index == -1: + return None + + # 从最后一个\boxed{开始截取字符串 + last_box_substring = answer_str[last_box_index:] + + # 在截取的子字符串中进行正则匹配 + box_pattern = r'\\boxed\{([^}]*)\}' + match = re.search(box_pattern, last_box_substring) + + if match: + return match.group(1).strip() + return None + \ No newline at end of file diff --git a/verl/utils/reward_score/synlogic/word_sorting_verifier.py b/verl/utils/reward_score/synlogic/word_sorting_verifier.py new file mode 100644 index 000000000..4032ac21c --- /dev/null +++ b/verl/utils/reward_score/synlogic/word_sorting_verifier.py @@ -0,0 +1,42 @@ +import re +from .data import Data +from .verifier import Verifier + +class WordSortingVerifier(Verifier): + """ + 验证器用于单词排序游戏的答案是否正确 + """ + def str2list(self, answer_str): + # 替换中文逗号为英文逗号,并删除所有空格 + answer_str = answer_str.replace(",", ",").replace(" ", "") + return [w.strip() for w in answer_str.split(",")] + + def verify(self, data: Data, test_answer: str): + try: + ground_truth = self.str2list(data.answer) + parsed_answer = self.str2list(self.extract_answer(test_answer)) + + if parsed_answer is None: + return False + return parsed_answer == ground_truth + + except Exception as e: + return False + + def extract_answer(self, answer_str): + # 先找到最后一个\Box{的位置 + last_box_index = answer_str.rfind("\\boxed{") + + if last_box_index == -1: + return None + + # 从最后一个\Box{开始截取字符串 + last_box_substring = answer_str[last_box_index:] + + # 在截取的子字符串中进行正则匹配 + box_pattern = r'\\boxed\{([^}]*)\}' + match = re.search(box_pattern, last_box_substring) + + if match: + return match.group(1).strip() + return None \ No newline at end of file diff --git a/verl/utils/reward_score/synlogic/wordscapes_verifier.py b/verl/utils/reward_score/synlogic/wordscapes_verifier.py new file mode 100644 index 000000000..1efeb5e21 --- /dev/null +++ b/verl/utils/reward_score/synlogic/wordscapes_verifier.py @@ -0,0 +1,157 @@ +""" +Wordscapes verifier module for the reasonreason framework. +""" + +import json +import re +from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END + +debug_mode = False + +class WordscapesVerifier(Verifier): + """ + Verifier for Wordscapes game + """ + def verify(self, data, test_solution: str): + """ + Verify whether the test answer is consistent with the gold answer + + Args: + data: WordscapesData + test_solution: str containing the solution + + Returns: + float: Score between 0 and 1 + """ + try: + extracted_answer = self.extract_answer(test_solution) + if not extracted_answer: + return False + + if debug_mode: + for row in extracted_answer: + print(" ".join(cell if cell != " " else "_" for cell in row)) + + # Get grid, across_words, and down_words from data + grid = data.metadata["grid"] + across_words = data.metadata["across_words"] + down_words = data.metadata["down_words"] + + # Validate grid dimensions + if len(extracted_answer) != len(grid): + # print(f"Grid height mismatch: expected {len(grid)}, got {len(extracted_answer)}") + return False + + for i in range(len(grid)): + if len(extracted_answer[i]) != len(grid[i]): + # print(f"Grid width mismatch at row {i}: expected {len(grid[i])}, got {len(extracted_answer[i])}") + return False + + # Check if the answer respects the grid layout (X for letters, 0 for empty) + for i in range(len(grid)): + for j in range(len(grid[i])): + if grid[i][j] == "0" and extracted_answer[i][j].strip(): + # print(f"Expected empty space at position ({i},{j}), got '{extracted_answer[i][j]}'") + return False + if grid[i][j] == "X" and not extracted_answer[i][j].strip(): + # print(f"Expected letter at position ({i},{j}), got empty space") + return False + + # Verify across words + for word in across_words: + found = False + for i in range(len(extracted_answer)): + row_str = ''.join(extracted_answer[i]).replace(' ', '').lower() + if word.lower() in row_str: + found = True + break + if not found and word: + # print(f"Across word '{word}' not found in the grid") + return 0 + + # Verify down words + for word in down_words: + found = False + for j in range(len(extracted_answer[0])): + col = [] + for i in range(len(extracted_answer)): + if j < len(extracted_answer[i]): + col.append(extracted_answer[i][j]) + col_str = ''.join(col).replace(' ', '').lower() + if word.lower() in col_str: + found = True + break + if not found and word: # Only check if word is not empty + # print(f"Down word '{word}' not found in the grid") + return False + + # All checks passed + return True + except Exception as e: + return False + + def extract_answer(self, test_solution: str): + """ + Extract the answer from the test solution + + Args: + test_solution: str + + Returns: + list: 2D grid of the answer or None if extraction fails + """ + try: + # Remove thoughts if present + if THOUGHT_DELIMITER_START in test_solution and THOUGHT_DELIMITER_END in test_solution: + # Extract only the part after the thoughts + thought_end_pos = test_solution.rfind(THOUGHT_DELIMITER_END) + if thought_end_pos >= 0: + test_solution = test_solution[thought_end_pos + len(THOUGHT_DELIMITER_END):] + + # Clean up the response and find the grid pattern + # Look for a pattern like [[...]] or [[[...]]] + grid_pattern = re.search(r'\[\s*\[(?:\s*\[)?(.+?)(?:\]\s*)?\]\s*\]', test_solution, re.DOTALL) + if not grid_pattern: + return None + + grid_text = grid_pattern.group(1) + + # Handle various formats + rows = [] + + # Check if rows are separated by commas + split_rows = re.split(r'\],\s*\[', grid_text) + + for row_text in split_rows: + # Clean the row text and extract characters + row_text = row_text.strip().strip('[],') + + # Extract quoted characters: "X" or 'X' or just X + chars = [] + + # Look for quoted strings or standalone characters + char_matches = re.findall(r'\"([^\"]*)\"|\'([^\']*)\'|([^,\s]+)', row_text) + + for match in char_matches: + # Take the first non-empty group from each match + char = next((x for x in match if x), "") + + # Handle numeric or empty values (0, "", '') + if char == "0" or char == "": + char = " " + + chars.append(char) + + if chars: # Only add non-empty rows + rows.append(chars) + + # Make sure we have a valid grid + if not rows or not all(rows): + return None + + return rows + + except Exception as e: + print(f"NOTE!!! parse error!!!! (Wordscapes): {e}") + return None + \ No newline at end of file diff --git a/verl/utils/reward_score/tablereason.py b/verl/utils/reward_score/tablereason.py index e71c1a4ee..f9fcabb86 100644 --- a/verl/utils/reward_score/tablereason.py +++ b/verl/utils/reward_score/tablereason.py @@ -13,9 +13,11 @@ def _check_single_answer(answer: str, ground_truth: str) -> bool: return math.is_equiv(answer, ground_truth) def drop_latex_text(answer: str) -> str: - # Remove \\text{} from "20 \\text{to} 39". There could be multiple \\text{} in the answer. + # Remove \text{} from "20 \text{to} 39". There could be multiple \text{} in the answer. # Replace \text{something} with something - answer = re.sub(r'\\\\text\{([^}]*)\}', r'\1', answer) + # Handle both single and double backslash cases + answer = re.sub(r'\\\\text\{([^}]*)\}', r'\1', answer) # Double backslash + answer = re.sub(r'\\text\{([^}]*)\}', r'\1', answer) # Single backslash answer = re.sub(r'\\\\', r'', answer) return answer @@ -32,7 +34,7 @@ def compute_score(model_output: str, ground_truth: str, extra_info: any = None) else: answer = solution_str - # print(f">>> {answer}, {ground_truth}") + # print(f">>> answer: '{answer}', ground_truth: '{ground_truth}'") if "|" not in ground_truth: # Single numeric answer score = _check_single_answer(answer, ground_truth) diff --git a/verl/utils/rollout_trace.py b/verl/utils/rollout_trace.py new file mode 100644 index 000000000..e34e285d0 --- /dev/null +++ b/verl/utils/rollout_trace.py @@ -0,0 +1,224 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import contextlib +import functools +import inspect +import os +from typing import Optional + + +class RolloutTraceConfig: + _instance: Optional["RolloutTraceConfig"] = None + backend: Optional[str] = None + client: Optional[object] = None + token2text: bool = False + _initialized: bool = False + project_name: str = None + experiment_name: str = None + + def __new__(cls, *args, **kwargs): + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._initialized = False + return cls._instance + + @classmethod + def get_instance(cls) -> "RolloutTraceConfig": + if cls._instance is None: + cls._instance = cls() + return cls._instance + + @classmethod + def init(cls, project_name: str, experiment_name: str, backend: str, token2text: bool = False): + config = cls.get_instance() + if config._initialized: + return + + config.backend = backend + config.token2text = token2text + config.project_name = project_name + config.experiment_name = experiment_name + + if backend == "weave": + import weave + + config.client = weave.init(project_name) + elif backend == "mlflow": + import mlflow + + mlflow.config.enable_async_logging() + config.client = mlflow + + MLFLOW_TRACKING_URI = os.environ.get("MLFLOW_TRACKING_URI", "sqlite:////tmp/mlruns.db") + mlflow.set_tracking_uri(MLFLOW_TRACKING_URI) + + mlflow.set_experiment(project_name) + else: + config.client = None + + config._initialized = True + + @classmethod + def get_backend(cls) -> Optional[str]: + return cls.get_instance().backend + + @classmethod + def get_client(cls) -> Optional[object]: + return cls.get_instance().client + + @classmethod + def enable_token2text(cls) -> Optional[bool]: + return cls.get_instance().token2text + + @classmethod + def reset(cls): + cls._instance = None + + +@contextlib.contextmanager +def rollout_trace_attr(sample_index=None, step=None, rollout_n=None, name="rollout_trace", validate=False): + """A context manager to add attributes to a trace for the configured backend.""" + backend = RolloutTraceConfig.get_backend() + attributes = {} + if backend: + if sample_index is not None: + attributes["sample_index"] = sample_index + if step is not None: + attributes["step"] = step + if rollout_n is not None: + attributes["rollout_n"] = rollout_n + attributes["validate"] = validate + attributes["experiment_name"] = RolloutTraceConfig.get_instance().experiment_name + + if not attributes or backend is None: + yield + return + + if backend == "weave": + import weave + + with weave.attributes(attributes): + yield + elif backend == "mlflow": + import mlflow + + with mlflow.start_span(name=name) as span: + trace_id = span.trace_id + for key, value in attributes.items(): + mlflow.set_trace_tag(trace_id, str(key), str(value)) + yield + else: + yield + + +def rollout_trace_op(func): + @functools.wraps(func) + async def async_wrapper(self, *args, **kwargs): + backend = RolloutTraceConfig.get_backend() + enable_token2text = RolloutTraceConfig.enable_token2text() + if backend is None: + return await func(self, *args, **kwargs) + + sig = inspect.signature(func) + bound_args = sig.bind(self, *args, **kwargs) + bound_args.apply_defaults() + inputs = dict(bound_args.arguments) + del inputs["self"] + + async def add_token2text(self, result): + if hasattr(result, "prompt_ids") and hasattr(self, "tokenizer") and hasattr(self.tokenizer, "decode"): + _result = vars(result) + loop = asyncio.get_running_loop() + if hasattr(result, "prompt_ids"): + prompt_text = await loop.run_in_executor(None, self.tokenizer.decode, result.prompt_ids) + _result["prompt_text"] = prompt_text + + if hasattr(result, "response_ids"): + response_text = await loop.run_in_executor(None, self.tokenizer.decode, result.response_ids) + _result["response_text"] = response_text + return _result + return result + + if backend == "weave": + tracer = RolloutTraceConfig.get_client() + from weave.trace.context import call_context + + cur_attributes = {**call_context.call_attributes.get()} + call = tracer.create_call(op=func.__qualname__, inputs=inputs, attributes=cur_attributes) + try: + result = await func(self, *args, **kwargs) + + if enable_token2text: + _result = await add_token2text(self, result) + tracer.finish_call(call, output=_result) + else: + tracer.finish_call(call, output=result) + + return result + + except Exception as e: + tracer.finish_call(call, exception=e) + raise e + elif backend == "mlflow": + import mlflow + + with mlflow.start_span(name=func.__qualname__) as span: + span.set_inputs(inputs) + result = await func(self, *args, **kwargs) + if enable_token2text: + _result = await add_token2text(self, result) + span.set_outputs(_result) + else: + span.set_outputs(result) + + return result + + else: + return await func(self, *args, **kwargs) + + @functools.wraps(func) + def wrapper(self, *args, **kwargs): + backend = RolloutTraceConfig.get_backend() + if backend is None: + return func(self, *args, **kwargs) + + sig = inspect.signature(func) + bound_args = sig.bind(self, *args, **kwargs) + bound_args.apply_defaults() + inputs = dict(bound_args.arguments) + del inputs["self"] + + if backend == "weave": + tracer = RolloutTraceConfig.get_client() + from weave.trace.context import call_context + + cur_attributes = {**call_context.call_attributes.get()} + call = tracer.create_call(op=func.__qualname__, inputs=inputs, attributes=cur_attributes) + try: + result = func(self, *args, **kwargs) + tracer.finish_call(call, output=result) + return result + except Exception as e: + tracer.finish_call(call, exception=e) + raise e + elif backend == "mlflow": + import mlflow + + return mlflow.trace(func)(self, *args, **kwargs) + else: + return func(self, *args, **kwargs) + + return async_wrapper if inspect.iscoroutinefunction(func) else wrapper diff --git a/verl/utils/seqlen_balancing.py b/verl/utils/seqlen_balancing.py index e2e567050..bde116adf 100644 --- a/verl/utils/seqlen_balancing.py +++ b/verl/utils/seqlen_balancing.py @@ -14,13 +14,16 @@ import copy import heapq -from typing import List, Tuple +from itertools import chain import torch from torch import distributed as dist +from verl.protocol import DataProto +from verl.utils.device import get_device_name -def karmarkar_karp(seqlen_list: List[int], k_partitions: int, equal_size: bool): + +def karmarkar_karp(seqlen_list: list[int], k_partitions: int, equal_size: bool): # see: https://en.wikipedia.org/wiki/Largest_differencing_method class Set: def __init__(self) -> None: @@ -44,7 +47,7 @@ def __lt__(self, other): return self.items < other.items class State: - def __init__(self, items: List[Tuple[int, int]], k: int) -> None: + def __init__(self, items: list[tuple[int, int]], k: int) -> None: self.k = k # sets should always be decreasing order self.sets = [Set() for _ in range(k)] @@ -118,11 +121,13 @@ def __repr__(self) -> str: partitions = final_state.get_partitions() if equal_size: for i, partition in enumerate(partitions): - assert len(partition) * k_partitions == len(seqlen_list), f"{len(partition)} * {k_partitions} != {len(seqlen_list)}" + assert len(partition) * k_partitions == len(seqlen_list), ( + f"{len(partition)} * {k_partitions} != {len(seqlen_list)}" + ) return partitions -def greedy_partition(seqlen_list: List[int], k_partitions: int, equal_size: bool): +def greedy_partition(seqlen_list: list[int], k_partitions: int, equal_size: bool): bias = sum(seqlen_list) + 1 if equal_size else 0 sorted_seqlen = [(seqlen + bias, i) for i, seqlen in enumerate(seqlen_list)] partitions = [[] for _ in range(k_partitions)] @@ -136,11 +141,13 @@ def greedy_partition(seqlen_list: List[int], k_partitions: int, equal_size: bool partition_sums[min_idx] += seqlen if equal_size: for i, partition in enumerate(partitions): - assert len(partition) * k_partitions == len(seqlen_list), f"{len(partition)} * {k_partitions} != {len(seqlen_list)}" + assert len(partition) * k_partitions == len(seqlen_list), ( + f"{len(partition)} * {k_partitions} != {len(seqlen_list)}" + ) return partitions -def get_seqlen_balanced_partitions(seqlen_list: List[int], k_partitions: int, equal_size: bool): +def get_seqlen_balanced_partitions(seqlen_list: list[int], k_partitions: int, equal_size: bool): """ Calculates partitions of indices from seqlen_list such that the sum of sequence lengths in each partition is balanced. Uses the Karmarkar-Karp differencing method. @@ -184,14 +191,28 @@ def _check_and_sort_partitions(partitions): return _check_and_sort_partitions(partitions) -def log_seqlen_unbalance(seqlen_list: List[int], partitions: List[List[int]], prefix): - # add some metrics of seqlen sum on dp ranks +def log_seqlen_unbalance(seqlen_list: list[int], partitions: list[list[int]], prefix): + """ + Calculate and log metrics related to sequence length imbalance before and after partitioning. + + Args: + seqlen_list (List[int]): A list of sequence lengths for each item. + partitions (List[List[int]]): A list of partitions, where each inner list contains indices + from seqlen_list assigned to that partition. + prefix (str): A prefix to be added to each metric key in the returned dictionary. + + Returns: + dict: A dictionary containing metrics related to sequence length imbalance. + """ + # Get the number of partitions k_partition = len(partitions) # assert len(seqlen_list) % k_partition == 0 batch_size = len(seqlen_list) // k_partition min_sum_seqlen = None max_sum_seqlen = None total_sum_seqlen = 0 + + # Iterate over each batch of sequence lengths for offset in range(0, len(seqlen_list), batch_size): cur_sum_seqlen = sum(seqlen_list[offset : offset + batch_size]) if min_sum_seqlen is None or cur_sum_seqlen < min_sum_seqlen: @@ -226,7 +247,15 @@ def roundup_divisible(a, b): return ((a + b - 1) // b) * b -def rearrange_micro_batches(batch, max_token_len, dp_group=None, num_batches_divided_by=None, same_micro_num_in_dp=True, min_num_micro_batch=None): +def rearrange_micro_batches( + batch, + max_token_len, + dp_group=None, + num_batches_divided_by=None, + same_micro_num_in_dp=True, + min_num_micro_batch=None, + use_dynamic_bsz_balance=True, +): """ Split a batch into micro-batches by total token count, with optional DP sync and padding. @@ -237,6 +266,7 @@ def rearrange_micro_batches(batch, max_token_len, dp_group=None, num_batches_div num_batches_divided_by (optional): virtual pipeline parallel size, for megatron. same_micro_num_in_dp (bool): if True and dp_group set, pad all ranks to the same count. min_num_micro_batch (int, optional): force at least this many splits (pads empty ones). + use_dynamic_bsz_balance (bool, optional): balance the computational workload between micro-batches Returns: List[TensorDict]: the micro-batches. @@ -244,7 +274,9 @@ def rearrange_micro_batches(batch, max_token_len, dp_group=None, num_batches_div """ # this is per local micro_bsz max_seq_len = batch["attention_mask"].shape[-1] - assert max_token_len >= max_seq_len, f"max_token_len must be greater than the sequence length. Got {max_token_len=} and {max_seq_len=}" + assert max_token_len >= max_seq_len, ( + f"max_token_len must be greater than the sequence length. Got {max_token_len=} and {max_seq_len=}" + ) seq_len_effective: torch.Tensor = batch["attention_mask"].sum(dim=1) total_seqlen = seq_len_effective.sum().item() # NOTE: num_microbatches <= batch_size, so take the min of this two. @@ -253,7 +285,7 @@ def rearrange_micro_batches(batch, max_token_len, dp_group=None, num_batches_div # used to support pp num_micro_batches = max(min_num_micro_batch, num_micro_batches) if dist.is_initialized() and same_micro_num_in_dp: - num_micro_batches = torch.tensor([num_micro_batches], device="cuda") + num_micro_batches = torch.tensor([num_micro_batches], device=get_device_name()) dist.all_reduce(num_micro_batches, op=dist.ReduceOp.MAX, group=dp_group) num_micro_batches = num_micro_batches.cpu().item() if num_batches_divided_by is not None: @@ -264,6 +296,16 @@ def rearrange_micro_batches(batch, max_token_len, dp_group=None, num_batches_div micro_bsz_idx = get_seqlen_balanced_partitions(seq_len_effective, num_micro_batches, equal_size=False) + if use_dynamic_bsz_balance: + # Use the sum of squared sequence lengths to approximate attention computation workload + micro_bsz_idx.sort( + key=lambda partition: ( + sum(seq_len_effective[idx] ** 2 for idx in partition), + min(partition) if partition else 0, + ), + reverse=True, + ) + micro_batches = [] for partition in micro_bsz_idx: @@ -293,3 +335,41 @@ def get_reverse_idx(idx_map): reverse_idx_map[idx] = i return reverse_idx_map + + +def prepare_dynamic_batch(data: DataProto, max_token_len: int) -> tuple[list[DataProto], list[list[int]]]: + """ + Prepare a batch for dynamic batching. + + Args: + data (DataProto): The input data. + max_token_len (int): The maximum token length for dynamic batching. + + Returns: + Tuple[List[DataProto], List[List[int]]]: A tuple containing a list of DataProto objects + and a list of index lists. + """ + batch, batch_idx_list = rearrange_micro_batches(data.batch, max_token_len=max_token_len) + micro_batches = [] + for i, batch_idx in enumerate(batch_idx_list): + tensors = dict(batch[i]) + non_tensors = {key: value[batch_idx] for key, value in data.non_tensor_batch.items()} + micro_batches.append(DataProto.from_dict(tensors, non_tensors)) + + return micro_batches, batch_idx_list + + +def restore_dynamic_batch(data: torch.Tensor, batch_idx_list: list[list[int]]) -> torch.Tensor: + """ + Restore a batch from dynamic batching. + + Args: + data (torch.Tensor): The input data. + batch_idx_list (List[List[int]]): The list of index lists. + + Returns: + torch.Tensor: The restored data. + """ + indices = list(chain.from_iterable(batch_idx_list)) + revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long) + return data[revert_indices] diff --git a/verl/utils/tokenizer.py b/verl/utils/tokenizer.py index 019683f02..668ea3e14 100644 --- a/verl/utils/tokenizer.py +++ b/verl/utils/tokenizer.py @@ -52,7 +52,9 @@ def hf_tokenizer(name_or_path, correct_pad_token=True, correct_gemma2=True, **kw if correct_gemma2 and isinstance(name_or_path, str) and "gemma-2-2b-it" in name_or_path: # the EOS token in gemma2 is ambiguious, which may worsen RL performance. # https://huggingface.co/google/gemma-2-2b-it/commit/17a01657f5c87135bcdd0ec7abb4b2dece04408a - warnings.warn("Found gemma-2-2b-it tokenizer. Set eos_token and eos_token_id to and 107.", stacklevel=1) + warnings.warn( + "Found gemma-2-2b-it tokenizer. Set eos_token and eos_token_id to and 107.", stacklevel=1 + ) kwargs["eos_token"] = "" kwargs["eos_token_id"] = 107 tokenizer = AutoTokenizer.from_pretrained(name_or_path, **kwargs) @@ -74,8 +76,11 @@ def hf_processor(name_or_path, **kwargs): try: processor = AutoProcessor.from_pretrained(name_or_path, **kwargs) - except Exception: + except Exception as e: processor = None + # TODO(haibin.lin): try-catch should be removed after adding transformer version req to setup.py to avoid + # silent failure + warnings.warn(f"Failed to create processor: {e}. This may affect multimodal processing", stacklevel=1) # Avoid load tokenizer, see: # https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/auto/processing_auto.py#L344 if processor is not None and "Processor" not in processor.__class__.__name__: diff --git a/verl/utils/torch_dtypes.py b/verl/utils/torch_dtypes.py index 015dae5a1..f2f445c26 100644 --- a/verl/utils/torch_dtypes.py +++ b/verl/utils/torch_dtypes.py @@ -15,8 +15,6 @@ Adapted from Cruise. """ -from typing import Union - import torch HALF_LIST = [16, "16", "fp16", "float16", torch.float16] @@ -40,7 +38,7 @@ class PrecisionType: MIXED = "mixed" @staticmethod - def supported_type(precision: Union[str, int]) -> bool: + def supported_type(precision: str | int) -> bool: return any(x == precision for x in PrecisionType) @staticmethod diff --git a/verl/utils/torch_functional.py b/verl/utils/torch_functional.py index e728758d4..df91ad778 100644 --- a/verl/utils/torch_functional.py +++ b/verl/utils/torch_functional.py @@ -17,7 +17,7 @@ import math from contextlib import contextmanager -from typing import Dict, List, Optional, Union +from typing import Optional import torch import torch.distributed @@ -28,6 +28,8 @@ from torch.optim.lr_scheduler import LambdaLR from transformers import PreTrainedTokenizer +from verl.utils.device import get_device_name, get_torch_device + try: from flash_attn.ops.triton.cross_entropy import cross_entropy_loss @@ -36,6 +38,14 @@ FLAH_ATTN_CROSS_ENTROPY_LOSS_AVAILABLE = False +try: + import torch_npu + + NPU_CROSS_ENTROPY_LOSS_AVAILABLE = hasattr(torch_npu, "npu_cross_entropy_loss") +except ImportError: + NPU_CROSS_ENTROPY_LOSS_AVAILABLE = False + + def gather_from_labels(data, label): """Gather the label from data. The value in label should be [0, vocab_size) @@ -75,6 +85,8 @@ def logprobs_from_logits(logits, labels, inplace_backward=True): labels = labels.reshape(-1) output = logprobs_from_logits_flash_attn(logits, labels, inplace_backward=inplace_backward) output = output.view(*batch_dim) + elif NPU_CROSS_ENTROPY_LOSS_AVAILABLE: + output = logprobs_from_logits_torch_npu(logits, labels) else: output = logprobs_from_logits_v2(logits, labels) return output @@ -82,10 +94,19 @@ def logprobs_from_logits(logits, labels, inplace_backward=True): def logprobs_from_logits_flash_attn(logits, labels, inplace_backward=True): output = cross_entropy_loss(logits, labels, inplace_backward=inplace_backward) - assert isinstance(output, tuple), "please make sure flash-attn>=2.4.3 where cross_entropy_loss returns Tuple[losses, z_losses]." + assert isinstance(output, tuple), ( + "please make sure flash-attn>=2.4.3 where cross_entropy_loss returns Tuple[losses, z_losses]." + ) return -output[0] +def logprobs_from_logits_torch_npu(logits, labels): + batch_dim = logits.shape[:-1] + logits = logits.reshape(-1, logits.shape[-1]) + loss, _, _, _ = torch_npu.npu_cross_entropy_loss(logits, labels.reshape(-1), reduction="none") + return -loss.view(*batch_dim) + + def logprobs_from_logits_naive(logits, labels): logp = F.log_softmax(logits, dim=-1) logpy = gather_from_labels(logp, labels) @@ -104,7 +125,7 @@ def logprobs_from_logits_v2(logits: torch.FloatTensor, labels): else: # logsumexp approach is unstable with bfloat16, fall back to slightly less efficent approach logprobs_labels = [] - for row_logits, row_labels in zip(logits, labels): # loop to reduce peak mem consumption + for row_logits, row_labels in zip(logits, labels, strict=True): # loop to reduce peak mem consumption row_logprobs = F.log_softmax(row_logits, dim=-1) row_logprobs_labels = row_logprobs.gather(dim=-1, index=row_labels.unsqueeze(-1)).squeeze(-1) logprobs_labels.append(row_logprobs_labels) @@ -128,9 +149,23 @@ def entropy_from_logits(logits: torch.Tensor): return entropy +def entropy_from_logits_with_chunking(logits: torch.Tensor, chunk_size: int = 2048): + """Memory-efficient entropy calculation with chunking.""" + entropy = torch.zeros(logits.shape[0], device=logits.device) + for i in range(0, logits.shape[0], chunk_size): + logits_chunk = logits[i : i + chunk_size].float() + pd_chunk = torch.nn.functional.softmax(logits_chunk, dim=-1) + entropy_chunk = torch.logsumexp(logits_chunk, dim=-1) - torch.sum(pd_chunk * logits_chunk, dim=-1) + entropy[i : i + chunk_size] = entropy_chunk + return entropy + + def masked_sum(values, mask, axis=None): """Compute mean of tensor with a masked values.""" - return (values * mask).sum(axis=axis) + # If NaNs exist out of mask, replace NaNs in values with a value that + # won't affect the sum (e.g., 0 for masked regions) + valid_values = torch.where(mask.bool(), values, 0.0) + return (valid_values * mask).sum(axis=axis) def masked_mean(values, mask, axis=None): @@ -146,7 +181,8 @@ def masked_mean(values, mask, axis=None): Returns: Tensor: Masked mean, with shape equal to `values` reduced over `axis`. """ - return (values * mask).sum(axis=axis) / (mask.sum(axis=axis) + 1e-8) + s = masked_sum(values, mask, axis) + return s / (mask.sum(axis=axis) + 1e-8) def masked_var(values, mask, unbiased=True): @@ -187,7 +223,7 @@ def masked_whiten(values, mask, shift_mean=True): return whitened -def get_response_mask(response_id: torch.Tensor, eos_token: Union[int, List[int]] = 2, dtype=torch.int64): +def get_response_mask(response_id: torch.Tensor, eos_token: int | list[int] = 2, dtype=torch.int64): """ end of sentence token can be int or list: 1 or [1, 2] e.g. @@ -218,7 +254,7 @@ def compute_grad_norm(model: nn.Module): return total_grad_square -def broadcast_dict_tensor(tensors: Union[Dict[str, torch.Tensor], TensorDict], src, group): +def broadcast_dict_tensor(tensors: dict[str, torch.Tensor] | TensorDict, src, group): """ TODO: optimize this. Technically, we only need one broadcast """ @@ -227,7 +263,7 @@ def broadcast_dict_tensor(tensors: Union[Dict[str, torch.Tensor], TensorDict], s torch.distributed.broadcast(tensors[key], src=src, group=group, async_op=False) -def allgather_dict_tensors(tensors: Union[Dict[str, torch.Tensor], TensorDict], size, group, dim=0): +def allgather_dict_tensors(tensors: dict[str, torch.Tensor] | TensorDict, size, group, dim=0): """ TODO: optimize this. - We can use async ops @@ -261,8 +297,10 @@ def allgather_dict_tensors(tensors: Union[Dict[str, torch.Tensor], TensorDict], return output -def split_dict_tensor_into_batches(tensors: TensorDict, batch_size) -> List[TensorDict]: - assert tensors.batch_size[0] % batch_size == 0, f"input data batch size: {tensors.batch_size[0]}, split batch size: {batch_size}" +def split_dict_tensor_into_batches(tensors: TensorDict, batch_size) -> list[TensorDict]: + assert tensors.batch_size[0] % batch_size == 0, ( + f"input data batch size: {tensors.batch_size[0]}, split batch size: {batch_size}" + ) return tensors.split(batch_size) @@ -306,7 +344,7 @@ def postprocess_data( max_length: Target sequence length pad_token_id: Padding token ID left_pad: Pad left if True - truncation: "left", "right" or "error" + truncation: "left", "right", "middle" or "error" Returns: (input_ids, attention_mask) padded/truncated to max_length @@ -316,8 +354,12 @@ def postprocess_data( sequence_length = input_ids.shape[-1] if sequence_length < max_length: - input_ids = pad_sequence_to_length(input_ids, max_seq_len=max_length, pad_token_id=pad_token_id, left_pad=left_pad) - attention_mask = pad_sequence_to_length(attention_mask, max_seq_len=max_length, pad_token_id=0, left_pad=left_pad) + input_ids = pad_sequence_to_length( + input_ids, max_seq_len=max_length, pad_token_id=pad_token_id, left_pad=left_pad + ) + attention_mask = pad_sequence_to_length( + attention_mask, max_seq_len=max_length, pad_token_id=0, left_pad=left_pad + ) elif sequence_length > max_length: if truncation == "left": # actually, left truncation may not be reasonable @@ -339,7 +381,9 @@ def postprocess_data( return input_ids, attention_mask -def tokenize_and_postprocess_data(prompt: str, tokenizer: PreTrainedTokenizer, max_length: int, pad_token_id: int, left_pad=True, truncation="error"): +def tokenize_and_postprocess_data( + prompt: str, tokenizer: PreTrainedTokenizer, max_length: int, pad_token_id: int, left_pad=True, truncation="error" +): """Tokenize text and process outputs to consistent tensor shapes. Args: @@ -370,7 +414,7 @@ def remove_pad_token(input_ids: torch.Tensor, attention_mask: torch.Tensor): no_padding_batch(List[List[int]]): contains the rmpad token ids per query. """ no_padding_batch = [] - for ids, mask in zip(input_ids, attention_mask): + for ids, mask in zip(input_ids, attention_mask, strict=True): no_padding_batch.append((ids[len(ids) - mask.sum() :]).cpu().numpy().tolist()) return no_padding_batch @@ -411,7 +455,9 @@ def log_probs_from_logits_response_rmpad(input_ids, attention_mask, logits_rmpad input_ids_rmpad = input_ids_rmpad.squeeze(-1) input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=0) full_log_probs_rmpad = logprobs_from_logits(logits=logits_rmpad, labels=input_ids_rmpad_rolled) # (total_nnz,) - full_output = pad_input(hidden_states=full_log_probs_rmpad.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen) + full_output = pad_input( + hidden_states=full_log_probs_rmpad.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen + ) output = full_output.squeeze(-1)[:, -response_length - 1 : -1] # [batch_size, response_length] return output @@ -437,7 +483,9 @@ def log_probs_from_logits_all_rmpad(input_ids_rmpad, logits_rmpad, indices, batc input_ids_rmpad = input_ids_rmpad.squeeze(-1) input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=0) full_log_probs_rmpad = logprobs_from_logits(logits=logits_rmpad, labels=input_ids_rmpad_rolled) # (total_nnz,) - full_output = pad_input(hidden_states=full_log_probs_rmpad.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen) + full_output = pad_input( + hidden_states=full_log_probs_rmpad.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen + ) output = full_output.squeeze(-1)[:, -response_length - 1 : -1] # [batch_size, response_length] return output @@ -540,8 +588,12 @@ def prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds): if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(inputs_embeds.device) - combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( + inputs_embeds.device + ) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) return combined_attention_mask @@ -642,14 +694,14 @@ def lr_lambda(current_step): @contextmanager -def check_cuda_is_available(): +def check_device_is_available(): """ Some modules must be imported after CUDA is initialized. Such as sglang's sharding manager. This context manager checks if CUDA is available and raises an error if it is not. """ - if not torch.cuda.is_available(): - raise RuntimeError("CUDA must be initialized before importing this module.") + if not get_torch_device().is_available(): + raise RuntimeError("Device {} must be initialized before importing this module.".format(get_device_name())) yield @@ -668,7 +720,7 @@ def distributed_mean_max_min_std(local_tensor, compute_max=True, compute_min=Tru """ # Sum the local tensor across all processes local_sum = torch.sum(local_tensor) - local_num = torch.tensor(torch.numel(local_tensor), device="cuda") + local_num = torch.tensor(torch.numel(local_tensor), device=get_device_name()) torch.distributed.all_reduce(local_sum, op=torch.distributed.ReduceOp.SUM) torch.distributed.all_reduce(local_num, op=torch.distributed.ReduceOp.SUM) diff --git a/verl/utils/tracking.py b/verl/utils/tracking.py index a2d658bea..ab9906b2f 100644 --- a/verl/utils/tracking.py +++ b/verl/utils/tracking.py @@ -16,11 +16,12 @@ """ import dataclasses +import os from enum import Enum from functools import partial from pathlib import Path -from typing import Any, Dict, List, Union -import os +from typing import Any + class Tracking: """A unified tracking interface for logging experiment data to multiple backends. @@ -35,7 +36,7 @@ class Tracking: supported_backend = ["wandb", "mlflow", "swanlab", "vemlp_wandb", "tensorboard", "console", "clearml"] - def __init__(self, project_name, experiment_name, default_backend: Union[str, List[str]] = "console", config=None): + def __init__(self, project_name, experiment_name, default_backend: str | list[str] = "console", config=None, run_id=None): if isinstance(default_backend, str): default_backend = [default_backend] for backend in default_backend: @@ -52,9 +53,12 @@ def __init__(self, project_name, experiment_name, default_backend: Union[str, Li import wandb settings = None - if config["trainer"].get("wandb_proxy", None): + if config and config["trainer"].get("wandb_proxy", None): settings = wandb.Settings(https_proxy=config["trainer"]["wandb_proxy"]) - wandb.init(project=project_name, name=experiment_name, config=config, settings=settings) + if run_id is None or run_id == "": + wandb.init(project=project_name, name=experiment_name, config=config, settings=settings) + else: + wandb.init(project=project_name, name=experiment_name, config=config, settings=settings, resume="must", id=run_id) self.logger["wandb"] = wandb if "mlflow" in default_backend: @@ -62,9 +66,8 @@ def __init__(self, project_name, experiment_name, default_backend: Union[str, Li import mlflow - MLFLOW_TRACKING_URI = os.environ.get("MLFLOW_TRACKING_URI", None) - if MLFLOW_TRACKING_URI: - mlflow.set_tracking_uri(MLFLOW_TRACKING_URI) + MLFLOW_TRACKING_URI = os.environ.get("MLFLOW_TRACKING_URI", "sqlite:////tmp/mlruns.db") + mlflow.set_tracking_uri(MLFLOW_TRACKING_URI) # Project_name is actually experiment_name in MLFlow # If experiment does not exist, will create a new experiment @@ -116,10 +119,10 @@ def __init__(self, project_name, experiment_name, default_backend: Union[str, Li self.logger["vemlp_wandb"] = vemlp_wandb if "tensorboard" in default_backend: - self.logger["tensorboard"] = _TensorboardAdapter() + self.logger["tensorboard"] = _TensorboardAdapter(project_name, experiment_name) if "console" in default_backend: - from verl.utils.logger.aggregate_logger import LocalLogger + from verl.utils.logger import LocalLogger self.console_logger = LocalLogger(print_to_console=True) self.logger["console"] = self.console_logger @@ -174,7 +177,7 @@ def log(self, data, step): for k, v in data.items(): title, series = k.split("/", 1) - if isinstance(v, (int, float, np.floating, np.integer)): + if isinstance(v, int | float | np.floating | np.integer): logger.report_scalar( title=title, series=series, @@ -189,19 +192,22 @@ def log(self, data, step): iteration=step, ) else: - logger.warning(f'Trainer is attempting to log a value of "{v}" of type {type(v)} for key "{k}". This invocation of ClearML logger\'s function is incorrect so this attribute was dropped. ') + logger.warning( + f'Trainer is attempting to log a value of "{v}" of type {type(v)} for key "{k}". This ' + f"invocation of ClearML logger's function is incorrect so this attribute was dropped. " + ) def finish(self): self._task.mark_completed() class _TensorboardAdapter: - def __init__(self): + def __init__(self, project_name, experiment_name): import os from torch.utils.tensorboard import SummaryWriter - tensorboard_dir = os.environ.get("TENSORBOARD_DIR", "tensorboard_log") + tensorboard_dir = os.environ.get("TENSORBOARD_DIR", f"tensorboard_log/{project_name}/{experiment_name}") os.makedirs(tensorboard_dir, exist_ok=True) print(f"Saving tensorboard log to {tensorboard_dir}.") self.writer = SummaryWriter(tensorboard_dir) @@ -222,7 +228,7 @@ def log(self, data, step): mlflow.log_metrics(metrics=results, step=step) -def _compute_mlflow_params_from_objects(params) -> Dict[str, Any]: +def _compute_mlflow_params_from_objects(params) -> dict[str, Any]: if params is None: return {} @@ -249,7 +255,7 @@ def _transform_params_to_json_serializable(x, convert_list_to_dict: bool): return x -def _flatten_dict(raw: Dict[str, Any], *, sep: str) -> Dict[str, Any]: +def _flatten_dict(raw: dict[str, Any], *, sep: str) -> dict[str, Any]: import pandas as pd ans = pd.json_normalize(raw, sep=sep).to_dict(orient="records")[0] @@ -271,13 +277,27 @@ def log(self, loggers, samples, step): self.log_generations_to_clearml(samples, step) if "tensorboard" in loggers: self.log_generations_to_tensorboard(samples, step) - + + if "vemlp_wandb" in loggers: + self.log_generations_to_vemlp_wandb(samples, step) + + def log_generations_to_vemlp_wandb(self, samples, step): + from volcengine_ml_platform import wandb as vemlp_wandb + + self._log_generations_to_wandb(samples, step, vemlp_wandb) + def log_generations_to_wandb(self, samples, step): - """Log samples to wandb as a table""" import wandb + self._log_generations_to_wandb(samples, step, wandb) + + def _log_generations_to_wandb(self, samples, step, wandb): + """Log samples to wandb as a table""" + # Create column names for all samples - columns = ["step"] + sum([[f"input_{i + 1}", f"output_{i + 1}", f"score_{i + 1}"] for i in range(len(samples))], []) + columns = ["step"] + sum( + [[f"input_{i + 1}", f"output_{i + 1}", f"score_{i + 1}"] for i in range(len(samples))], [] + ) if not hasattr(self, "validation_table"): # Initialize the table on first call @@ -303,23 +323,16 @@ def log_generations_to_swanlab(self, samples, step): """Log samples to swanlab as text""" import swanlab - swanlab_text_list = [] - for i, sample in enumerate(samples): - row_text = f""" - input: {sample[0]} - - --- - - output: {sample[1]} - - --- - - score: {sample[2]} - """ - swanlab_text_list.append(swanlab.Text(row_text, caption=f"sample {i + 1}")) + swanlab_table = swanlab.echarts.Table() + + # Create column names + headers = ["step", "input", "output", "score"] + + swanlab_row_list = [[step, *sample] for sample in samples] + swanlab_table.add(headers=headers, rows=swanlab_row_list) # Log to swanlab - swanlab.log({"val/generations": swanlab_text_list}, step=step) + swanlab.log({"val/generations": swanlab_table}, step=step) def log_generations_to_mlflow(self, samples, step): """Log validation generation to mlflow as artifacts""" @@ -370,36 +383,37 @@ def log_generations_to_clearml(self, samples, step): table_plot=pd.DataFrame.from_records(table), iteration=step, ) - + def log_generations_to_tensorboard(self, samples, step): """Log samples to tensorboard as text""" # Initialize tensorboard writer if not exists if not hasattr(self, "writer"): from torch.utils.tensorboard import SummaryWriter + tensorboard_dir = os.environ.get("TENSORBOARD_DIR", "tensorboard_log") os.makedirs(tensorboard_dir, exist_ok=True) self.writer = SummaryWriter(log_dir=tensorboard_dir) - + # Format the samples data into readable text text_content = f"**Generation Results - Step {step}**\n\n" - + for i, sample in enumerate(samples): text_content += f"### Sample {i + 1}\n" - + # Assuming sample contains [input, output, score] if len(sample) >= 3: input_text, output_text, score = sample[0], sample[1], sample[2] - + text_content += f"**Input:** {input_text}\n\n" text_content += f"**Output:** {output_text}\n\n" text_content += f"**Score:** {score}\n\n" else: # Handle cases where sample format might be different text_content += f"**Data:** {sample}\n\n" - + text_content += "---\n\n" - + # Log to tensorboard as text - self.writer.add_text('val/generations', text_content, step) + self.writer.add_text("val/generations", text_content, step) # Flush to ensure data is written - self.writer.flush() \ No newline at end of file + self.writer.flush() diff --git a/verl/utils/ulysses.py b/verl/utils/ulysses.py index 3670b1d20..1669f6f32 100644 --- a/verl/utils/ulysses.py +++ b/verl/utils/ulysses.py @@ -17,7 +17,7 @@ Inspired from: https://github.com/deepspeedai/DeepSpeed/blob/master/deepspeed/sequence/layer.py """ -from typing import Any, Optional, Tuple +from typing import Any, Optional import torch import torch.distributed as dist @@ -179,7 +179,7 @@ def forward( return all_to_all_tensor(local_input, scatter_dim, gather_dim, group, async_op) @staticmethod - def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]: + def backward(ctx: Any, *grad_output: Tensor) -> tuple[None, Tensor, None, None]: input_t = torch.cat(grad_output[1:], dim=ctx.gather_dim).contiguous() if ctx.async_op else grad_output[0] return ( None, @@ -234,7 +234,13 @@ def backward(ctx: Any, grad_output: Tensor) -> Any: ) -def gather_outpus_and_unpad( +def gather_outpus_and_unpad(*args, **kwargs): + raise RuntimeError( + "please use verl.utils.ulysses.gather_outputs_and_unpad instead of verl.utils.ulysses.gather_outpus_and_unpad" + ) + + +def gather_outputs_and_unpad( x: Tensor, gather_dim: int, unpad_dim: int = None, @@ -268,12 +274,11 @@ def gather_outpus_and_unpad( x = _unpad_tensor(x, unpad_dim, padding_size) return x -def ulysses_pad( - input_ids_rmpad: torch.Tensor, position_ids_rmpad: Optional[torch.Tensor] = None, sp_size: int = 1 -): + +def ulysses_pad(input_ids_rmpad: torch.Tensor, position_ids_rmpad: Optional[torch.Tensor] = None, sp_size: int = 1): if position_ids_rmpad is not None: - assert position_ids_rmpad.size(0) == 1 - assert input_ids_rmpad.size(1) == position_ids_rmpad.size(1) + assert position_ids_rmpad.size(-2) == 1 + assert input_ids_rmpad.size(-1) == position_ids_rmpad.size(-1) if sp_size <= 1: return input_ids_rmpad, position_ids_rmpad, 0 _, total_seq_len = input_ids_rmpad.shape @@ -282,10 +287,15 @@ def ulysses_pad( input_ids_rmpad = torch.nn.functional.pad(input_ids_rmpad, (0, pad_size), value=0) if position_ids_rmpad is not None: pad_pos_ids = torch.arange(pad_size, device=position_ids_rmpad.device).unsqueeze(0) + if position_ids_rmpad.dim() == 3: + pad_pos_ids = pad_pos_ids.unsqueeze(0).repeat(3, 1, 1) position_ids_rmpad = torch.cat((position_ids_rmpad, pad_pos_ids), dim=-1) return input_ids_rmpad, position_ids_rmpad, pad_size -def ulysses_pad_and_slice_inputs(input_ids_rmpad: torch.Tensor, position_ids_rmpad: Optional[torch.Tensor] = None, sp_size: int = 1): + +def ulysses_pad_and_slice_inputs( + input_ids_rmpad: torch.Tensor, position_ids_rmpad: Optional[torch.Tensor] = None, sp_size: int = 1 +): """ Pad and slice input_ids to be divisible by sp_size Pad position_ids to be divisible by sp_size. @@ -304,9 +314,7 @@ def ulysses_pad_and_slice_inputs(input_ids_rmpad: torch.Tensor, position_ids_rmp torch.Tensor: padded and sliced position_ids int: pad size """ - input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad( - input_ids_rmpad, position_ids_rmpad, sp_size - ) + input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad(input_ids_rmpad, position_ids_rmpad, sp_size) input_ids_rmpad = slice_input_tensor(input_ids_rmpad, dim=1, padding=False) if position_ids_rmpad is not None: position_ids_rmpad = slice_input_tensor(position_ids_rmpad, dim=1, padding=False) @@ -315,4 +323,6 @@ def ulysses_pad_and_slice_inputs(input_ids_rmpad: torch.Tensor, position_ids_rmp def validate_ulysses_config(num_heads, ulysses_sequence_size): if ulysses_sequence_size > 1: - assert num_heads % ulysses_sequence_size == 0, f"num_heads ({num_heads}) must be divisible by ulysses sequence size({ulysses_sequence_size})" + assert num_heads % ulysses_sequence_size == 0, ( + f"num_heads ({num_heads}) must be divisible by ulysses sequence size({ulysses_sequence_size})" + ) diff --git a/verl/utils/vllm_utils.py b/verl/utils/vllm_utils.py index c1b71452d..25ee6656d 100644 --- a/verl/utils/vllm_utils.py +++ b/verl/utils/vllm_utils.py @@ -12,11 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. -# To support different vLLM versions, we add the model into SUPPORTED_MOE_MODELS separately to avoid triggering unsupported issues. + +from msgspec import field +from packaging import version as vs +from vllm.lora.models import LoRAModel +from vllm.lora.request import LoRARequest +from vllm.lora.utils import get_adapter_absolute_path +from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager + +from verl.third_party.vllm import get_version + +# To support different vLLM versions, we add the model into SUPPORTED_MOE_MODELS separately to avoid triggering +# unsupported issues. SUPPORTED_MOE_MODELS = [] try: from vllm.model_executor.models.deepseek_v2 import DeepseekV2ForCausalLM, DeepseekV3ForCausalLM + SUPPORTED_MOE_MODELS.append(DeepseekV2ForCausalLM) SUPPORTED_MOE_MODELS.append(DeepseekV3ForCausalLM) except ImportError: @@ -24,39 +36,32 @@ try: from vllm.model_executor.models.mixtral import MixtralForCausalLM + SUPPORTED_MOE_MODELS.append(MixtralForCausalLM) except ImportError: pass try: from vllm.model_executor.models.qwen2_moe import Qwen2MoeForCausalLM + SUPPORTED_MOE_MODELS.append(Qwen2MoeForCausalLM) except ImportError: pass try: from vllm.model_executor.models.qwen3_moe import Qwen3MoeForCausalLM + SUPPORTED_MOE_MODELS.append(Qwen3MoeForCausalLM) except ImportError: pass try: from vllm.model_executor.models.kimi_vl import KimiVLForConditionalGeneration + SUPPORTED_MOE_MODELS.append(KimiVLForConditionalGeneration) except ImportError: pass -from typing import List - -from msgspec import field -from packaging import version as vs -from vllm.lora.models import LoRAModel -from vllm.lora.request import LoRARequest -from vllm.lora.utils import get_adapter_absolute_path -from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager - -from verl.third_party.vllm import get_version - def patch_vllm_moe_model_weight_loader(model): # this is a work around to load the weight of vllm fused moe model @@ -100,8 +105,8 @@ def patch_vllm_moe_model_weight_loader(model): class TensorLoRARequest(LoRARequest): - peft_config:dict = field(default=None) - lora_tensors:dict = field(default=None) + peft_config: dict = field(default=None) + lora_tensors: dict = field(default=None) class VLLMHijack: @@ -113,18 +118,16 @@ def hijack__load_adapter(self, lora_request: TensorLoRARequest) -> LoRAModel: Reason: VLLM does not support adding LoRA from tensors directly. It only supports adding LoRA via file paths. - To synchronize the LoRA tensors of the actor model, we need to find a workaround to enable VLLM to load memory-based LoRA tensors. + To synchronize the LoRA tensors of the actor model, we need to find a workaround to enable VLLM to + load memory-based LoRA tensors. """ try: - supported_lora_modules = ( - self._adapter_manager.supported_lora_modules) - packed_modules_mapping = ( - self._adapter_manager.packed_modules_mapping) - expected_lora_modules: List[str] = [] + supported_lora_modules = self._adapter_manager.supported_lora_modules + packed_modules_mapping = self._adapter_manager.packed_modules_mapping + expected_lora_modules: list[str] = [] for module in supported_lora_modules: if module in packed_modules_mapping: - expected_lora_modules.extend( - packed_modules_mapping[module]) + expected_lora_modules.extend(packed_modules_mapping[module]) else: expected_lora_modules.append(module) @@ -132,6 +135,7 @@ def hijack__load_adapter(self, lora_request: TensorLoRARequest) -> LoRAModel: lora_tensors = None from vllm.lora.peft_helper import PEFTHelper + if isinstance(lora_request, TensorLoRARequest): peft_config = lora_request.peft_config lora_tensors = lora_request.lora_tensors @@ -139,8 +143,7 @@ def hijack__load_adapter(self, lora_request: TensorLoRARequest) -> LoRAModel: else: lora_path = get_adapter_absolute_path(lora_request.lora_path) - peft_helper = PEFTHelper.from_local_dir( - lora_path, self.max_position_embeddings) + peft_helper = PEFTHelper.from_local_dir(lora_path, self.max_position_embeddings) # Validates the LoRA configuration against requirements before # loading weights, throwing an exception if validation fails. @@ -150,8 +153,7 @@ def hijack__load_adapter(self, lora_request: TensorLoRARequest) -> LoRAModel: # to ensure correct loading of lora weights. model = self._adapter_manager.model hf_to_vllm_mapper = None - if (hasattr(model, "hf_to_vllm_mapper") - and model.hf_to_vllm_mapper is not None): + if hasattr(model, "hf_to_vllm_mapper") and model.hf_to_vllm_mapper is not None: hf_to_vllm_mapper = model.hf_to_vllm_mapper if isinstance(lora_request, TensorLoRARequest): @@ -165,7 +167,7 @@ def hijack__load_adapter(self, lora_request: TensorLoRARequest) -> LoRAModel: target_embedding_padding=self.vocab_size + self.lora_config.lora_extra_vocab_size, embedding_modules=self.embedding_modules, embedding_padding_modules=self.embedding_padding_modules, - weights_mapper=hf_to_vllm_mapper + weights_mapper=hf_to_vllm_mapper, ) else: lora = self._lora_model_cls.from_local_checkpoint( @@ -175,18 +177,19 @@ def hijack__load_adapter(self, lora_request: TensorLoRARequest) -> LoRAModel: lora_model_id=lora_request.lora_int_id, device="cpu", dtype=self.lora_config.lora_dtype, - target_embedding_padding=self.vocab_size + - self.lora_config.lora_extra_vocab_size, + target_embedding_padding=self.vocab_size + self.lora_config.lora_extra_vocab_size, embedding_modules=self.embedding_modules, embedding_padding_modules=self.embedding_padding_modules, - weights_mapper=hf_to_vllm_mapper) + weights_mapper=hf_to_vllm_mapper, + ) except Exception as e: raise e if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size: - raise ValueError(f"LoRA added vocab size {lora.extra_vocab_size} " - f"is greater than lora_extra_vocab_size " - f"{self.lora_config.lora_extra_vocab_size}.") + raise ValueError( + f"LoRA added vocab size {lora.extra_vocab_size} is greater than lora_extra_vocab_size " + f"{self.lora_config.lora_extra_vocab_size}." + ) return lora def do_hijack(target_cls, target_method_name, hooking_method): @@ -195,6 +198,6 @@ def do_hijack(target_cls, target_method_name, hooking_method): do_hijack(LRUCacheWorkerLoRAManager, "_load_adapter", hijack__load_adapter) -def is_version_ge(pkg:str='vllm', minver:str="0.7.3"): - """ check if the package version is greater than or equal to the minimum version """ - return vs.parse(get_version(pkg)) >= vs.parse(minver) \ No newline at end of file +def is_version_ge(pkg: str = "vllm", minver: str = "0.7.3"): + """check if the package version is greater than or equal to the minimum version""" + return vs.parse(get_version(pkg)) >= vs.parse(minver) diff --git a/verl/version/version b/verl/version/version index da18d2f65..04c4d903d 100644 --- a/verl/version/version +++ b/verl/version/version @@ -1 +1 @@ -0.3.1.dev +0.4.1.dev diff --git a/verl/workers/actor/base.py b/verl/workers/actor/base.py index 430c21858..2d1ba290d 100644 --- a/verl/workers/actor/base.py +++ b/verl/workers/actor/base.py @@ -16,7 +16,6 @@ """ from abc import ABC, abstractmethod -from typing import Dict import torch @@ -52,7 +51,7 @@ def compute_log_prob(self, data: DataProto) -> torch.Tensor: pass @abstractmethod - def update_policy(self, data: DataProto) -> Dict: + def update_policy(self, data: DataProto) -> dict: """Update the policy with an iterator of DataProto Args: diff --git a/verl/workers/actor/dp_actor.py b/verl/workers/actor/dp_actor.py index cd10da986..d5cea3620 100644 --- a/verl/workers/actor/dp_actor.py +++ b/verl/workers/actor/dp_actor.py @@ -17,10 +17,8 @@ Single Process Actor """ -import itertools import logging import os -from typing import Tuple import torch from torch import nn @@ -28,14 +26,14 @@ import verl.utils.torch_functional as verl_F from verl import DataProto -from verl.trainer.ppo.core_algos import agg_loss, compute_policy_loss, kl_penalty -from verl.utils.debug import GPUMemoryLogger -from verl.utils.device import get_device_name, get_torch_device, is_cuda_available, is_npu_available +from verl.trainer.ppo.core_algos import agg_loss, compute_policy_loss, get_policy_loss_fn, kl_penalty +from verl.utils.device import get_device_name, is_cuda_available, is_npu_available from verl.utils.fsdp_utils import FSDPModule, fsdp2_clip_grad_norm_ +from verl.utils.profiler import GPUMemoryLogger from verl.utils.py_functional import append_to_dict -from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches +from verl.utils.seqlen_balancing import prepare_dynamic_batch, restore_dynamic_batch from verl.utils.torch_functional import logprobs_from_logits -from verl.utils.ulysses import gather_outpus_and_unpad, ulysses_pad, ulysses_pad_and_slice_inputs +from verl.utils.ulysses import gather_outputs_and_unpad, ulysses_pad, ulysses_pad_and_slice_inputs from verl.workers.actor import BasePPOActor if is_cuda_available: @@ -67,14 +65,21 @@ def __init__(self, config, actor_module: nn.Module, actor_optimizer: torch.optim self.ulysses_sequence_parallel_size = self.config.ulysses_sequence_parallel_size self.use_ulysses_sp = self.ulysses_sequence_parallel_size > 1 + if self.config.entropy_from_logits_with_chunking: + entropy_from_logits = verl_F.entropy_from_logits_with_chunking + else: + entropy_from_logits = verl_F.entropy_from_logits + self.compute_entropy_from_logits = ( - torch.compile(verl_F.entropy_from_logits, dynamic=True) + torch.compile(entropy_from_logits, dynamic=True) if self.config.get("use_torch_compile", True) # use torch compile by default - else verl_F.entropy_from_logits + else entropy_from_logits ) self.device_name = get_device_name() - def _forward_micro_batch(self, micro_batch, temperature, calculate_entropy=False) -> Tuple[torch.Tensor, torch.Tensor]: + def _forward_micro_batch( + self, micro_batch, temperature, calculate_entropy=False + ) -> tuple[torch.Tensor, torch.Tensor]: """ Returns: entropy: # (bs, response_len) @@ -82,9 +87,15 @@ def _forward_micro_batch(self, micro_batch, temperature, calculate_entropy=False """ response_length = micro_batch["responses"].size(-1) multi_modal_inputs = {} - if "multi_modal_inputs" in micro_batch: - for key in micro_batch["multi_modal_inputs"][0].keys(): - multi_modal_inputs[key] = torch.cat([inputs[key] for inputs in micro_batch["multi_modal_inputs"]], dim=0) + if "multi_modal_inputs" in micro_batch.keys(): + if "image_bound" in micro_batch["multi_modal_inputs"][0]: # minicpm-o logic + for key in micro_batch["multi_modal_inputs"][0].keys(): + multi_modal_inputs[key] = [inputs[key] for inputs in micro_batch["multi_modal_inputs"]] + else: + for key in micro_batch["multi_modal_inputs"][0].keys(): + multi_modal_inputs[key] = torch.cat( + [inputs[key] for inputs in micro_batch["multi_modal_inputs"]], dim=0 + ) with torch.autocast(device_type=self.device_name, dtype=torch.bfloat16): input_ids = micro_batch["input_ids"] @@ -96,21 +107,36 @@ def _forward_micro_batch(self, micro_batch, temperature, calculate_entropy=False position_ids = position_ids.transpose(0, 1) # (bsz, 3, seqlen) -> (3, bsz, seqlen) if self.use_remove_padding: - input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), attention_mask) # input_ids_rmpad (total_nnz, ...) + input_ids_rmpad, indices, cu_seqlens, *_ = unpad_input( + input_ids.unsqueeze(-1), attention_mask + ) # input_ids_rmpad (total_nnz, ...) input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) # unpad the position_ids to align the rotary if position_ids.dim() == 3: - position_ids_rmpad = index_first_axis(rearrange(position_ids, "c b s ... -> (b s) c ..."), indices).transpose(0, 1).unsqueeze(1) # (3, bsz, seqlen) -> (3, 1, bsz * seqlen) + position_ids_rmpad = ( + index_first_axis(rearrange(position_ids, "c b s ... -> (b s) c ..."), indices) + .transpose(0, 1) + .unsqueeze(1) + ) # (3, bsz, seqlen) -> (3, 1, bsz * seqlen) else: - position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices).transpose(0, 1) + position_ids_rmpad = index_first_axis( + rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices + ).transpose(0, 1) + + if "image_bound" in multi_modal_inputs: + from verl.utils.dataset.vision_utils import process_multi_modal_inputs_for_minicpmo + + multi_modal_inputs = process_multi_modal_inputs_for_minicpmo( + input_ids, attention_mask, position_ids, cu_seqlens, multi_modal_inputs + ) # for compute the log_prob input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=1) # (1, total_nnz) # pad and slice the inputs if sp > 1 if self.use_ulysses_sp: - is_vlm_model = "multi_modal_inputs" in micro_batch + is_vlm_model = "multi_modal_inputs" in micro_batch.keys() if is_vlm_model: # vlm model's inputs will be sliced after embedding input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad( @@ -136,6 +162,7 @@ def _forward_micro_batch(self, micro_batch, temperature, calculate_entropy=False extra_args = {} if self.use_fused_kernels: extra_args["temperature"] = temperature + extra_args["return_dict"] = True output = self.actor_module( input_ids=input_ids_rmpad, @@ -166,19 +193,24 @@ def _forward_micro_batch(self, micro_batch, temperature, calculate_entropy=False # compute entropy if calculate_entropy: - entropy_rmpad = self.compute_entropy_from_logits(logits_rmpad) # ((total_nnz / sp) + pad) + if not self.config.entropy_checkpointing: + entropy_rmpad = self.compute_entropy_from_logits(logits_rmpad) # ((total_nnz / sp) + pad) + else: + entropy_rmpad = torch.utils.checkpoint.checkpoint( + self.compute_entropy_from_logits, logits_rmpad + ) # gather log_prob if sp > 1 if self.use_ulysses_sp: # gather and unpad for the ulysses sp - log_probs = gather_outpus_and_unpad( + log_probs = gather_outputs_and_unpad( log_probs, gather_dim=0, unpad_dim=0, padding_size=pad_size, ) if calculate_entropy: - entropy_rmpad = gather_outpus_and_unpad( + entropy_rmpad = gather_outputs_and_unpad( entropy_rmpad, gather_dim=0, unpad_dim=0, @@ -208,6 +240,8 @@ def _forward_micro_batch(self, micro_batch, temperature, calculate_entropy=False extra_args = {} if self.use_fused_kernels: extra_args["temperature"] = temperature + extra_args["return_dict"] = True + output = self.actor_module( input_ids=input_ids, attention_mask=attention_mask, @@ -228,7 +262,10 @@ def _forward_micro_batch(self, micro_batch, temperature, calculate_entropy=False logits = logits[:, -response_length - 1 : -1, :] # (bsz, response_length, vocab_size) log_probs = logprobs_from_logits(logits, micro_batch["responses"]) if calculate_entropy: - entropy = verl_F.entropy_from_logits(logits) # (bsz, response_length) + if not self.config.entropy_checkpointing: + entropy = verl_F.entropy_from_logits(logits) # (bsz, response_length) + else: + entropy = torch.utils.checkpoint.checkpoint(verl_F.entropy_from_logits, logits) return entropy, log_probs @@ -275,29 +312,26 @@ def compute_log_prob(self, data: DataProto, calculate_entropy=False) -> torch.Te micro_batch_size = data.meta_info["micro_batch_size"] temperature = data.meta_info["temperature"] # temperature must be in the data.meta_info to avoid silent error use_dynamic_bsz = data.meta_info["use_dynamic_bsz"] - - select_keys = ["responses", "input_ids", "attention_mask", "position_ids"] - batch = data.select(batch_keys=select_keys).batch has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() + select_keys = ["responses", "input_ids", "attention_mask", "position_ids"] + non_tensor_select_keys = ["multi_modal_inputs"] if has_multi_modal_inputs else [] - if has_multi_modal_inputs: - num_micro_batches = data.batch.batch_size[0] // micro_batch_size - non_tensor_select_keys = ["multi_modal_inputs"] - micro_batches = data.select(select_keys, non_tensor_select_keys).chunk(num_micro_batches) - elif use_dynamic_bsz: - # split using dynamic bsz + data = data.select(batch_keys=select_keys, non_tensor_batch_keys=non_tensor_select_keys) + + if use_dynamic_bsz: max_token_len = data.meta_info["max_token_len"] * self.ulysses_sequence_parallel_size - micro_batches, indices = rearrange_micro_batches(batch=batch, max_token_len=max_token_len) + micro_batches, batch_idx_list = prepare_dynamic_batch(data, max_token_len=max_token_len) else: - micro_batches = batch.split(micro_batch_size) + micro_batches = data.split(micro_batch_size) log_probs_lst = [] entropy_lst = [] for micro_batch in micro_batches: - if isinstance(micro_batch, DataProto): - micro_batch = {**micro_batch.batch, **micro_batch.non_tensor_batch} + model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch} with torch.no_grad(): - entropy, log_probs = self._forward_micro_batch(micro_batch, temperature=temperature, calculate_entropy=calculate_entropy) + entropy, log_probs = self._forward_micro_batch( + model_inputs, temperature=temperature, calculate_entropy=calculate_entropy + ) log_probs_lst.append(log_probs) if calculate_entropy: entropy_lst.append(entropy) @@ -306,11 +340,11 @@ def compute_log_prob(self, data: DataProto, calculate_entropy=False) -> torch.Te entropys = None if calculate_entropy: entropys = torch.concat(entropy_lst, dim=0) + if use_dynamic_bsz: - indices = list(itertools.chain.from_iterable(indices)) - assert len(indices) == log_probs.size(0), f"{len(indices)} vs. {log_probs.size()}" - revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long) - log_probs = log_probs[revert_indices] + log_probs = restore_dynamic_batch(log_probs, batch_idx_list) + if calculate_entropy: + entropys = restore_dynamic_batch(entropys, batch_idx_list) return log_probs, entropys @@ -320,64 +354,56 @@ def update_policy(self, data: DataProto): self.actor_module.train() temperature = data.meta_info["temperature"] # temperature must be in the data.meta_info to avoid silent error - multi_turn = data.meta_info.get("multi_turn", False) - select_keys = ["responses", "input_ids", "attention_mask", "position_ids", "old_log_probs", "advantages"] - if multi_turn: - select_keys.append("loss_mask") + select_keys = [ + "responses", + "response_mask", + "input_ids", + "attention_mask", + "position_ids", + "old_log_probs", + "advantages", + ] if self.config.use_kl_loss: select_keys.append("ref_log_prob") - batch = data.select(batch_keys=select_keys).batch + has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() + non_tensor_select_keys = ["multi_modal_inputs"] if has_multi_modal_inputs else [] + + data = data.select(batch_keys=select_keys, non_tensor_batch_keys=non_tensor_select_keys) # Split to make minibatch iterator for updating the actor # See PPO paper for details. https://arxiv.org/abs/1707.06347 - if has_multi_modal_inputs: - num_mini_batches = data.batch.batch_size[0] // self.config.ppo_mini_batch_size - non_tensor_select_keys = ["multi_modal_inputs"] - dataloader = data.select(select_keys, non_tensor_select_keys).chunk(num_mini_batches) - else: - dataloader = batch.split(self.config.ppo_mini_batch_size) + mini_batches = data.split(self.config.ppo_mini_batch_size) metrics = {} - for epoch in range(self.config.ppo_epochs): - for batch_idx, data in enumerate(dataloader): - # split batch into micro_batches - mini_batch = data - if has_multi_modal_inputs: - self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu - num_micro_batches = mini_batch.batch.batch_size[0] // self.config.ppo_micro_batch_size_per_gpu - micro_batches = data.select(select_keys, non_tensor_select_keys).chunk(num_micro_batches) - elif self.config.use_dynamic_bsz: + for _ in range(self.config.ppo_epochs): + for batch_idx, mini_batch in enumerate(mini_batches): + if self.config.use_dynamic_bsz: max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size - micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len) + micro_batches, _ = prepare_dynamic_batch(mini_batch, max_token_len=max_token_len) else: - self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu - # split batch into micro_batches + self.gradient_accumulation = ( + self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu + ) micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu) self.actor_optimizer.zero_grad() - for data in micro_batches: - # Support all hardwares - if isinstance(data, DataProto): - data = {**data.batch.to(get_torch_device().current_device()), **data.non_tensor_batch} - else: - data = data.to(get_torch_device().current_device()) # actor device is cpu when using offload - responses = data["responses"] - response_length = responses.size(1) - attention_mask = data["attention_mask"] - if multi_turn: - response_mask = data["loss_mask"][:, -response_length:] - else: - response_mask = attention_mask[:, -response_length:] - - old_log_prob = data["old_log_probs"] - advantages = data["advantages"] + for micro_batch in micro_batches: + micro_batch_metrics = {} + model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch} + response_mask = model_inputs["response_mask"] + old_log_prob = model_inputs["old_log_probs"] + advantages = model_inputs["advantages"] clip_ratio = self.config.clip_ratio - clip_ratio_low = self.config.clip_ratio_low if self.config.clip_ratio_low is not None else clip_ratio - clip_ratio_high = self.config.clip_ratio_high if self.config.clip_ratio_high is not None else clip_ratio + clip_ratio_low = ( + self.config.clip_ratio_low if self.config.clip_ratio_low is not None else clip_ratio + ) + clip_ratio_high = ( + self.config.clip_ratio_high if self.config.clip_ratio_high is not None else clip_ratio + ) clip_ratio_c = self.config.get("clip_ratio_c", 3.0) entropy_coeff = self.config.entropy_coeff loss_agg_mode = self.config.loss_agg_mode @@ -386,20 +412,36 @@ def update_policy(self, data: DataProto): calculate_entropy = False if entropy_coeff != 0: calculate_entropy = True - entropy, log_prob = self._forward_micro_batch(micro_batch=data, temperature=temperature, calculate_entropy=calculate_entropy) - - pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = compute_policy_loss( - old_log_prob=old_log_prob, - log_prob=log_prob, - advantages=advantages, - response_mask=response_mask, - cliprange=clip_ratio, - cliprange_low=clip_ratio_low, - cliprange_high=clip_ratio_high, - clip_ratio_c=clip_ratio_c, - loss_agg_mode=loss_agg_mode, + entropy, log_prob = self._forward_micro_batch( + model_inputs, temperature=temperature, calculate_entropy=calculate_entropy ) + loss_mode = self.config.policy_loss.get("loss_mode", "vanilla") + + if self.config.policy_loss.loss_mode == "vanilla": + pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = compute_policy_loss( + old_log_prob=old_log_prob, + log_prob=log_prob, + advantages=advantages, + response_mask=response_mask, + cliprange=clip_ratio, + cliprange_low=clip_ratio_low, + cliprange_high=clip_ratio_high, + clip_ratio_c=clip_ratio_c, + loss_agg_mode=loss_agg_mode, + ) + + else: + policy_loss_fn = get_policy_loss_fn(loss_mode) + pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = policy_loss_fn( + old_log_prob=old_log_prob, + log_prob=log_prob, + advantages=advantages, + response_mask=response_mask, + loss_agg_mode=loss_agg_mode, + config=self.config, + ) + if entropy_coeff != 0: entropy_loss = agg_loss(loss_mat=entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) @@ -409,32 +451,36 @@ def update_policy(self, data: DataProto): policy_loss = pg_loss if self.config.use_kl_loss: - ref_log_prob = data["ref_log_prob"] + ref_log_prob = model_inputs["ref_log_prob"] # compute kl loss - kld = kl_penalty(logprob=log_prob, ref_logprob=ref_log_prob, kl_penalty=self.config.kl_loss_type) + kld = kl_penalty( + logprob=log_prob, ref_logprob=ref_log_prob, kl_penalty=self.config.kl_loss_type + ) kl_loss = agg_loss(loss_mat=kld, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef - metrics["actor/kl_loss"] = kl_loss.detach().item() - metrics["actor/kl_coef"] = self.config.kl_loss_coef + micro_batch_metrics["actor/kl_loss"] = kl_loss.detach().item() + micro_batch_metrics["actor/kl_coef"] = self.config.kl_loss_coef if self.config.use_dynamic_bsz: # relative to the dynamic bsz - loss = policy_loss * (len(data) / self.config.ppo_mini_batch_size) + loss = policy_loss * (response_mask.shape[0] / self.config.ppo_mini_batch_size) else: loss = policy_loss / self.gradient_accumulation loss.backward() - data = { - "actor/pg_loss": pg_loss.detach().item(), - "actor/pg_clipfrac": pg_clipfrac.detach().item(), - "actor/ppo_kl": ppo_kl.detach().item(), - "actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(), - } - append_to_dict(metrics, data) + micro_batch_metrics.update( + { + "actor/pg_loss": pg_loss.detach().item(), + "actor/pg_clipfrac": pg_clipfrac.detach().item(), + "actor/ppo_kl": ppo_kl.detach().item(), + "actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(), + } + ) + append_to_dict(metrics, micro_batch_metrics) grad_norm = self._optimizer_step() - data = {"actor/grad_norm": grad_norm.detach().item()} - append_to_dict(metrics, data) + mini_batch_metrics = {"actor/grad_norm": grad_norm.detach().item()} + append_to_dict(metrics, mini_batch_metrics) self.actor_optimizer.zero_grad() return metrics diff --git a/verl/workers/actor/megatron_actor.py b/verl/workers/actor/megatron_actor.py index f0338006b..ce52956d0 100644 --- a/verl/workers/actor/megatron_actor.py +++ b/verl/workers/actor/megatron_actor.py @@ -19,12 +19,11 @@ Note that our model doesn't have to be `MegatronModule` because we don't share embedding in the last layer """ -import copy import itertools import logging import os from functools import partial -from typing import Dict, Iterable +from typing import Iterable import torch import torch.distributed @@ -38,12 +37,13 @@ from torch import nn from verl import DataProto -from verl.trainer.ppo.core_algos import agg_loss, compute_policy_loss, kl_penalty -from verl.utils.debug import GPUMemoryLogger -from verl.utils.debug.profile import Profiler +from verl.trainer.ppo.core_algos import agg_loss, compute_policy_loss, get_policy_loss_fn, kl_penalty +from verl.utils.device import get_device_id, get_torch_device from verl.utils.megatron.pipeline_parallel import make_batch_generator from verl.utils.megatron.tensor_parallel import vocab_parallel_entropy, vocab_parallel_log_probs_from_logits from verl.utils.megatron_utils import get_model_config +from verl.utils.profiler import GPUMemoryLogger +from verl.utils.profiler.profile import Profiler from verl.utils.py_functional import append_to_dict from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches from verl.utils.torch_functional import broadcast_dict_tensor @@ -85,18 +85,22 @@ def __init__( ``model_config.hidden_size`` hf_config (PretrainedConfig): huggingface config tf_config (TransformerConfig): mcore transformer config - actor_module (nn.ModuleList): actor module is a ModuleList that contains a list of nn.Module in this pp stage. - each nn.Module in this rank holds a vpp module chunk. See https://arxiv.org/pdf/2104.04473.pdf for more details. + actor_module (nn.ModuleList): actor module is a ModuleList that contains a list of nn.Module in this + pp stage. + each nn.Module in this rank holds a vpp module chunk. See https://arxiv.org/pdf/2104.04473.pdf for + more details. The actor module has some constraints to follow in order to use the updating logics implemented here - 1. It must implement unpad_input before any computation and pad_input after all the computation. Remove padding is an + 1. It must implement unpad_input before any computation and pad_input after all the computation. + Remove padding is an optimization that removes the padding tokens. See unpad_input and pad_input function in flash-attn (https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py). 2. Each pp stage must return the hidden state with the same shape [total_nnz, 1, hidden_size], where total_nnz is the number of valid tokens in this batch. If sequence parallel is enabled, the size of the hidden state is [total_nnz // tp, 1, hidden_size]. - actor_optimizer (DistributedOptimizer): currently, we only support DistributedOptimizer in Megatron. It implements + actor_optimizer (DistributedOptimizer): currently, we only support DistributedOptimizer in Megatron. + It implements zero1 optimizer that shards the optimizer state across dp ranks. >>> from megatron.training import get_model @@ -119,6 +123,13 @@ def __init__( self.actor_module = actor_module self.actor_optimizer: DistributedOptimizer = actor_optimizer self.prof = Profiler(self.config.profile) + self.use_fused_kernels = self.config.get("use_fused_kernels", False) + if self.use_fused_kernels: + from verl.models.mcore.model_forward_fused import patch_fused_forward + + for model in self.actor_module: + patch_fused_forward(model) + self.optimizer_step_args = OmegaConf.create( { "skip_grad": None, @@ -166,7 +177,7 @@ def compute_log_prob(self, data: DataProto, calculate_entropy=False) -> torch.Te Returns: DataProto: torch.Tensor: the log_prob tensor """ - data.to(torch.cuda.current_device()) + data.to(get_device_id()) data.batch = data.batch.contiguous() use_dynamic_bsz = data.meta_info.get("use_dynamic_bsz", False) micro_batch_size = data.meta_info.get("micro_batch_size", None) @@ -183,7 +194,8 @@ def compute_logprobs_fn(output, data, use_dynamic_bsz=False, indices=None): return {"log_probs": log_probs} # We make recompute_old_log_prob by default here. - # TODO (zhangchi.usc1992): actually, this function should only return log_prob and this logic should be handled by user outside + # TODO (zhangchi.usc1992): actually, this function should only return log_prob and this logic should be + # handled by user outside recompute_old_log_prob = self.config.get("recompute_old_log_prob", True) entropys = torch.Tensor() @@ -195,7 +207,15 @@ def compute_logprobs_fn(output, data, use_dynamic_bsz=False, indices=None): response = batch["responses"] response_length = response.size(1) with torch.no_grad(): - output = self.forward_backward_batch(data, forward_only=True, post_process_fn=compute_logprobs_fn, calculate_entropy=calculate_entropy, use_dynamic_bsz=use_dynamic_bsz, micro_batch_size=micro_batch_size, max_token_len=max_token_len) + output = self.forward_backward_batch( + data, + forward_only=True, + post_process_fn=compute_logprobs_fn, + calculate_entropy=calculate_entropy, + use_dynamic_bsz=use_dynamic_bsz, + micro_batch_size=micro_batch_size, + max_token_len=max_token_len, + ) if mpu.is_pipeline_last_stage(ignore_virtual=True): # only on last rank. It should be on every tp rank if calculate_entropy: @@ -210,7 +230,9 @@ def compute_logprobs_fn(output, data, use_dynamic_bsz=False, indices=None): revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long) log_probs = log_probs[revert_indices] else: - log_probs = torch.empty(size=(batch_size, response_length), dtype=torch.float32, device=input_ids.device) + log_probs = torch.empty( + size=(batch_size, response_length), dtype=torch.float32, device=input_ids.device + ) # broadcast across pp ranks torch.distributed.broadcast( @@ -231,7 +253,9 @@ def compute_logprobs_fn(output, data, use_dynamic_bsz=False, indices=None): revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long) entropys = entropys[revert_indices] else: - entropys = torch.empty(size=(batch_size, response_length), dtype=torch.float32, device=input_ids.device) + entropys = torch.empty( + size=(batch_size, response_length), dtype=torch.float32, device=input_ids.device + ) # broadcast across pp ranks torch.distributed.broadcast( tensor=entropys, @@ -241,7 +265,7 @@ def compute_logprobs_fn(output, data, use_dynamic_bsz=False, indices=None): ) # add empty cache after each compute - torch.cuda.empty_cache() + get_torch_device().empty_cache() return log_probs, entropys @@ -251,26 +275,42 @@ def make_minibatch_iterator(self, data: DataProto) -> Iterable[DataProto]: Args: data (DataProto): a DataProto containing keys - ``input_ids``: tensor of shape [batch_size, sequence_length]. torch.int64, where ``sequence_length = prompt_length + response_length`` + ``input_ids``: tensor of shape [batch_size, sequence_length]. torch.int64, where + ``sequence_length = prompt_length + response_length`` ``attention_mask``: tensor of shape [batch_size, sequence_length]. torch.int64 ``position_ids``: tensor of shape [batch_size, sequence_length]. torch.int64 - ``responses``: tensor of shape [batch_size, response_length]. torch.int64. Note that responses = input_ids[:, -response_length:] + ``responses``: tensor of shape [batch_size, response_length]. torch.int64. Note that + responses = input_ids[:, -response_length:] - ``old_log_probs``: tensor of shape [batch_size, response_length]. torch.float32. The log probability of responses. + ``old_log_probs``: tensor of shape [batch_size, response_length]. torch.float32. The log probability + of responses. - ``advantages``: tensor of shape [batch_size, response_length]. torch.float32. The advantages of responses. + ``advantages``: tensor of shape [batch_size, response_length]. torch.float32. The advantages of + responses. See PPO paper for details. https://arxiv.org/abs/1707.06347 Returns: """ - select_keys = ["responses", "input_ids", "attention_mask", "position_ids", "old_log_probs", "advantages"] + select_keys = [ + "responses", + "input_ids", + "attention_mask", + "response_mask", + "position_ids", + "old_log_probs", + "advantages", + ] if self.config.use_kl_loss: select_keys.append("ref_log_prob") - data = data.select(batch_keys=select_keys) + self.has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() + if self.has_multi_modal_inputs: + data = data.select(select_keys, ["multi_modal_inputs"]) + else: + data = data.select(batch_keys=select_keys) return data.make_iterator( mini_batch_size=self.config.ppo_mini_batch_size, epochs=self.config.ppo_epochs, @@ -278,7 +318,17 @@ def make_minibatch_iterator(self, data: DataProto) -> Iterable[DataProto]: dataloader_kwargs={"shuffle": self.config.shuffle}, ) - def forward_backward_batch(self, data: DataProto, forward_only=False, post_process_fn=None, calculate_entropy=False, use_dynamic_bsz=False, micro_batch_size=None, max_token_len=None, mini_batch_size=None): + def forward_backward_batch( + self, + data: DataProto, + forward_only=False, + post_process_fn=None, + calculate_entropy=False, + use_dynamic_bsz=False, + micro_batch_size=None, + max_token_len=None, + mini_batch_size=None, + ): """ We assume: - The model takes input: (input_ids, attention_mask, position_ids). No rmpad for the input @@ -287,9 +337,24 @@ def forward_backward_batch(self, data: DataProto, forward_only=False, post_proce # broadcast from last pp rank to all other pp ranks # TODO: actually, we just need to control the sampling order. mini_batch = data - broadcast_dict_tensor(mini_batch.batch, src=mpu.get_pipeline_model_parallel_last_rank(), group=mpu.get_pipeline_model_parallel_group()) + broadcast_dict_tensor( + mini_batch.batch, + src=mpu.get_pipeline_model_parallel_last_rank(), + group=mpu.get_pipeline_model_parallel_group(), + ) # split into micro-batches mini_batch.batch["attention_mask"] = mini_batch.batch["attention_mask"].to(bool) + self.has_multi_modal_inputs = "multi_modal_inputs" in mini_batch.non_tensor_batch.keys() + if self.has_multi_modal_inputs: + mini_batch.batch["multi_modal_inputs"] = mini_batch.non_tensor_batch["multi_modal_inputs"] + mini_batch.batch["multi_modal_inputs_idx"] = torch.Tensor( + list(range(len(mini_batch.non_tensor_batch["multi_modal_inputs"]))) + ).to(torch.int64) + + if mini_batch.batch["position_ids"].dim() == 3: # qwen2vl mrope [bs, 3, seq_len] + mini_batch.batch["position_ids"] = mini_batch.batch["position_ids"][ + :, 0 + ] # mcore patch recompute qwen2vl's pos ids during forward indices = None if use_dynamic_bsz: @@ -297,13 +362,22 @@ def forward_backward_batch(self, data: DataProto, forward_only=False, post_proce vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size() if vpp_size is not None and vpp_size > 1: microbatch_group_size_per_vp_stage = self.tf_config.microbatch_group_size_per_vp_stage - micro_batches, indices = rearrange_micro_batches(batch=mini_batch.batch, num_batches_divided_by=microbatch_group_size_per_vp_stage, max_token_len=max_token_len) - assert len(micro_batches) % self.tf_config.microbatch_group_size_per_vp_stage == 0, f"micro_batches {micro_batches} must be divisible by microbatch_group_size_per_vp_stage {microbatch_group_size_per_vp_stage} for megatron backend" + micro_batches, indices = rearrange_micro_batches( + batch=mini_batch.batch, + num_batches_divided_by=microbatch_group_size_per_vp_stage, + max_token_len=max_token_len, + ) + assert len(micro_batches) % self.tf_config.microbatch_group_size_per_vp_stage == 0, ( + f"micro_batches {micro_batches} must be divisible by microbatch_group_size_per_vp_stage " + f"{microbatch_group_size_per_vp_stage} for megatron backend" + ) else: micro_batches, indices = rearrange_micro_batches(batch=mini_batch.batch, max_token_len=max_token_len) total_seqlen = max_token_len else: - assert micro_batch_size is not None, "micro_batch_size is needed to be passed in when not using dynamic batch size" + assert micro_batch_size is not None, ( + "micro_batch_size is needed to be passed in when not using dynamic batch size" + ) micro_batches = mini_batch.batch.split(micro_batch_size) seq_len = micro_batches[0]["input_ids"].shape[1] total_seqlen = micro_batch_size * seq_len @@ -329,33 +403,61 @@ def loss_func(output, data, meta_info): responses = data["responses"] response_length = responses.size(1) - attention_mask = data["attention_mask"] - response_mask = attention_mask[:, -response_length:] + response_mask = data["response_mask"].to(bool) loss_agg_mode = self.config.loss_agg_mode # compute policy loss log_prob = output["log_probs"][:, -response_length - 1 : -1].contiguous() ret_entropy = None + stats = {} if not forward_only: old_log_prob = data["old_log_probs"] advantages = data["advantages"] - clip_ratio = meta_info["clip_ratio"] + clip_ratio = self.config.clip_ratio clip_ratio_low = self.config.clip_ratio_low if self.config.clip_ratio_low is not None else clip_ratio clip_ratio_high = self.config.clip_ratio_high if self.config.clip_ratio_high is not None else clip_ratio - clip_ratio_c = meta_info["clip_ratio_c"] - pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = compute_policy_loss( - old_log_prob=old_log_prob, - log_prob=log_prob, - advantages=advantages, - response_mask=response_mask, - cliprange=clip_ratio, - cliprange_low=clip_ratio_low, - cliprange_high=clip_ratio_high, - clip_ratio_c=clip_ratio_c, - loss_agg_mode=loss_agg_mode, + + clip_ratio_c = self.config.get("clip_ratio_c", 3.0) + entropy_coeff = self.config.entropy_coeff + loss_agg_mode = self.config.loss_agg_mode + + loss_mode = self.config.policy_loss.get("loss_mode", "vanilla") + + if self.config.policy_loss.loss_mode == "vanilla": + pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = compute_policy_loss( + old_log_prob=old_log_prob, + log_prob=log_prob, + advantages=advantages, + response_mask=response_mask, + cliprange=clip_ratio, + cliprange_low=clip_ratio_low, + cliprange_high=clip_ratio_high, + clip_ratio_c=clip_ratio_c, + loss_agg_mode=loss_agg_mode, + ) + + else: + policy_loss_fn = get_policy_loss_fn(loss_mode) + pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = policy_loss_fn( + old_log_prob=old_log_prob, + log_prob=log_prob, + advantages=advantages, + response_mask=response_mask, + loss_agg_mode=loss_agg_mode, + config=self.config, + ) + + stats.update( + { + "actor/pg_loss": pg_loss.detach().item(), + "actor/pg_clipfrac": pg_clipfrac.detach().item(), + "actor/ppo_kl": ppo_kl.detach().item(), + "actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(), + } ) policy_loss = pg_loss + if calculate_entropy: entropy = output["entropy"][:, -response_length - 1 : -1].contiguous() if not forward_only: @@ -365,7 +467,6 @@ def loss_func(output, data, meta_info): else: ret_entropy = entropy - stats = {} if forward_only: policy_loss = torch.tensor(1.0, device=device) else: @@ -380,14 +481,6 @@ def loss_func(output, data, meta_info): metrics["actor/kl_coef"] = self.config.kl_loss_coef # return loss and stats - stats.update( - { - "actor/pg_loss": pg_loss.detach().item(), - "actor/pg_clipfrac": pg_clipfrac.detach().item(), - "actor/ppo_kl": ppo_kl.detach().item(), - "actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(), - } - ) append_to_dict(metrics, stats) return policy_loss, [metrics, ret_entropy] @@ -395,39 +488,66 @@ def loss_func(output, data, meta_info): def forward_step(batch_iter, model): batch = next(batch_iter) input_ids = batch["input_ids"] - attention_mask = batch["attention_mask"] + attention_mask = batch["attention_mask"].to(bool) position_ids = batch["position_ids"] + multi_modal_inputs = {} + if "multi_modal_inputs" in batch: + for key in batch["multi_modal_inputs"][0].keys(): + idxs = batch["multi_modal_inputs_idx"] + mmi = batch["multi_modal_inputs"] + multi_modal_inputs[key] = torch.cat( + [mmi[idx].get(key) for idx in idxs if mmi[idx].get(key) is not None], dim=0 + ) responses = batch["responses"] response_length = responses.size(1) - label = copy.deepcopy(position_ids) + label = position_ids.clone() label[:, -response_length - 1 : -1] = responses - label_mask = copy.deepcopy(attention_mask) + label_mask = attention_mask.clone() label_mask[:, : -response_length - 1] = False label_mask[:, -1] = False - def logits_processor(logits, label, label_mask): - assert logits.shape[:2] == label.shape[:2] - assert label.shape == label_mask.shape - - ret = {} - - if calculate_entropy: - entropy = vocab_parallel_entropy(logits) - ret["entropy"] = entropy - - log_probs = vocab_parallel_log_probs_from_logits(logits, label) - log_probs = log_probs.masked_fill(~label_mask, 0.0) - ret["log_probs"] = log_probs - return ret - - logits_processor_args = {"label": label, "label_mask": label_mask} - - from verl.models.mcore import get_mcore_forward_fn - - forward_fn = get_mcore_forward_fn(self.hf_config) + from verl.models.mcore import get_mcore_forward_fn, get_mcore_forward_fused_fn + + if self.use_fused_kernels: + forward_fn = get_mcore_forward_fused_fn(self.hf_config) + # return dict of [logits, entropy] + output = forward_fn( + model, + input_ids, + position_ids, + attention_mask, + sequence_parallel=self.tf_config.sequence_parallel, + multi_modal_inputs=multi_modal_inputs, + labels=label, + labels_mask=label_mask, + ) + else: + forward_fn = get_mcore_forward_fn(self.hf_config) - output = forward_fn(model, input_ids, attention_mask, position_ids, sequence_parallel=self.tf_config.sequence_parallel, logits_processor=logits_processor, logits_processor_args=logits_processor_args) + def logits_processor(logits, label, label_mask): + assert logits.shape[:2] == label.shape[:2] + assert label.shape == label_mask.shape + ret = {} + if calculate_entropy: + entropy = vocab_parallel_entropy(logits) + ret["entropy"] = entropy + log_probs = vocab_parallel_log_probs_from_logits(logits, label) + log_probs = log_probs.masked_fill(~label_mask, 0.0) + ret["log_probs"] = log_probs + return ret + + logits_processor_args = {"label": label, "label_mask": label_mask} + output = forward_fn( + model, + input_ids, + attention_mask, + position_ids, + sequence_parallel=self.tf_config.sequence_parallel, + multi_modal_inputs=multi_modal_inputs, + logits_processor=logits_processor, + logits_processor_args=logits_processor_args, + ) if forward_only: meta_info = None @@ -466,13 +586,19 @@ def logits_processor(logits, label, label_mask): forward_only=forward_only, ) # loss_reduces contains the stats returned from loss_func + + if self.has_multi_modal_inputs: + data.batch.pop("multi_modal_inputs") + data.batch.pop("multi_modal_inputs_idx") + data.non_tensor_batch.pop("multi_modal_inputs") + losses_reduced = {"output": losses_reduced} if use_dynamic_bsz: losses_reduced["indices"] = indices return losses_reduced @GPUMemoryLogger(role="megatron actor", logger=logger) - def update_policy(self, dataloader: Iterable[DataProto]) -> Dict: + def update_policy(self, dataloader: Iterable[DataProto]) -> dict: """Update the policy with an iterator of DataProto Args: @@ -487,7 +613,7 @@ def update_policy(self, dataloader: Iterable[DataProto]) -> Dict: metrics = {} self.prof.start() for data in dataloader: - data.to(torch.cuda.current_device()) + data.to(get_device_id()) self.actor_optimizer.zero_grad() # use use_contiguous_buffers_in_local_ddp and no overlap_dp_param_comm for chunk in self.actor_module: @@ -502,15 +628,21 @@ def update_policy(self, dataloader: Iterable[DataProto]) -> Dict: max_token_len = None if self.config.use_dynamic_bsz: max_token_len = self.config.ppo_max_token_len_per_gpu * self.config.megatron.context_parallel_size - metric_micro_batch = self.forward_backward_batch(data, calculate_entropy=calculate_entropy, use_dynamic_bsz=self.config.use_dynamic_bsz, micro_batch_size=micro_batch_size, max_token_len=max_token_len, mini_batch_size=self.config.ppo_mini_batch_size) + metric_micro_batch = self.forward_backward_batch( + data, + calculate_entropy=calculate_entropy, + use_dynamic_bsz=self.config.use_dynamic_bsz, + micro_batch_size=micro_batch_size, + max_token_len=max_token_len, + mini_batch_size=self.config.ppo_mini_batch_size, + ) metric_micro_batch = metric_micro_batch["output"] for metric in metric_micro_batch: # Note that o[0] is metrics, o[1] is entropy, o[2] is response_mask append_to_dict(metrics, metric[0]) # append the metric from this micro-batch to global metrics. update_successful, grad_norm, num_zeros_in_grad = self.actor_optimizer.step() - learning_rate = self.actor_optimizer.param_groups[-1]["lr"] - data = {"actor/grad_norm": grad_norm, "actor/lr": learning_rate} + data = {"actor/grad_norm": grad_norm} append_to_dict(metrics, data) if update_successful: @@ -522,5 +654,5 @@ def update_policy(self, dataloader: Iterable[DataProto]) -> Dict: # add empty cache after each compute self.prof.stop_and_save() self.prof.stop_trace() - torch.cuda.empty_cache() + get_torch_device().empty_cache() return metrics diff --git a/verl/workers/critic/dp_critic.py b/verl/workers/critic/dp_critic.py index efe5dd3ef..4d7c87ef7 100644 --- a/verl/workers/critic/dp_critic.py +++ b/verl/workers/critic/dp_critic.py @@ -15,37 +15,29 @@ Implement a multiprocess PPOCritic """ -import itertools import logging import os import torch import torch.distributed -from flash_attn.bert_padding import (index_first_axis, pad_input, rearrange, - unpad_input) from torch import nn, optim from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from verl import DataProto from verl.trainer.ppo import core_algos -from verl.utils.debug import GPUMemoryLogger -from verl.utils.device import (get_device_name, get_torch_device, - is_cuda_available, is_npu_available) +from verl.utils.device import get_device_name, is_cuda_available, is_npu_available from verl.utils.fsdp_utils import FSDPModule, fsdp2_clip_grad_norm_ +from verl.utils.profiler import GPUMemoryLogger from verl.utils.py_functional import append_to_dict -from verl.utils.seqlen_balancing import (get_reverse_idx, - rearrange_micro_batches) +from verl.utils.seqlen_balancing import prepare_dynamic_batch, restore_dynamic_batch from verl.utils.torch_functional import masked_mean -from verl.utils.ulysses import (gather_outpus_and_unpad, - ulysses_pad_and_slice_inputs) +from verl.utils.ulysses import gather_outputs_and_unpad, ulysses_pad_and_slice_inputs from verl.workers.critic import BasePPOCritic if is_cuda_available: - from flash_attn.bert_padding import (index_first_axis, pad_input, - rearrange, unpad_input) + from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input elif is_npu_available: - from transformers.integrations.npu_flash_attention import ( - index_first_axis, pad_input, rearrange, unpad_input) + from transformers.integrations.npu_flash_attention import index_first_axis, pad_input, rearrange, unpad_input logger = logging.getLogger(__file__) logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) @@ -65,9 +57,11 @@ def __init__(self, config, critic_module: nn.Module, critic_optimizer: optim.Opt def _forward_micro_batch(self, micro_batch): response_length = micro_batch["responses"].size(-1) multi_modal_inputs = {} - if "multi_modal_inputs" in micro_batch: + if "multi_modal_inputs" in micro_batch.keys(): for key in micro_batch["multi_modal_inputs"][0].keys(): - multi_modal_inputs[key] = torch.cat([inputs[key] for inputs in micro_batch["multi_modal_inputs"]], dim=0) + multi_modal_inputs[key] = torch.cat( + [inputs[key] for inputs in micro_batch["multi_modal_inputs"]], dim=0 + ) with torch.autocast(device_type=self.device_name, dtype=torch.bfloat16): input_ids = micro_batch["input_ids"] @@ -78,18 +72,28 @@ def _forward_micro_batch(self, micro_batch): position_ids = position_ids.transpose(0, 1) if self.use_remove_padding: - input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), attention_mask) # input_ids_rmpad (total_nnz, ...) + input_ids_rmpad, indices, *_ = unpad_input( + input_ids.unsqueeze(-1), attention_mask + ) # input_ids_rmpad (total_nnz, ...) input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) # unpad the position_ids to align the rotary if position_ids.dim() == 3: - position_ids_rmpad = index_first_axis(rearrange(position_ids, "c b s ... -> (b s) c ..."), indices).transpose(0, 1).unsqueeze(1) # (3, bsz, seqlen) -> (3, 1, bsz * seqlen) + position_ids_rmpad = ( + index_first_axis(rearrange(position_ids, "c b s ... -> (b s) c ..."), indices) + .transpose(0, 1) + .unsqueeze(1) + ) # (3, bsz, seqlen) -> (3, 1, bsz * seqlen) else: - position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices).transpose(0, 1) + position_ids_rmpad = index_first_axis( + rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices + ).transpose(0, 1) # pad and slice the inputs if sp > 1 if self.ulysses_sequence_parallel_size > 1: - input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(input_ids_rmpad, position_ids_rmpad, sp_size=self.ulysses_sequence_parallel_size) + input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs( + input_ids_rmpad, position_ids_rmpad, sp_size=self.ulysses_sequence_parallel_size + ) # only pass input_ids and position_ids to enable flash_attn_varlen output = self.critic_module( @@ -99,12 +103,19 @@ def _forward_micro_batch(self, micro_batch): **multi_modal_inputs, use_cache=False, ) # prevent model thinks we are generating - values_rmpad = output.logits - values_rmpad = values_rmpad.squeeze(0) # (total_nnz) + + if hasattr(self.critic_module, "v_head"): + # For trl.AutoModelForCausalLMWithValueHead + values_rmpad = output[2].squeeze(0).unsqueeze(-1) + else: + values_rmpad = output.logits + values_rmpad = values_rmpad.squeeze(0) # (total_nnz) # gather output if sp > 1 if self.ulysses_sequence_parallel_size > 1: - values_rmpad = gather_outpus_and_unpad(values_rmpad, gather_dim=0, unpad_dim=0, padding_size=pad_size) + values_rmpad = gather_outputs_and_unpad( + values_rmpad, gather_dim=0, unpad_dim=0, padding_size=pad_size + ) # pad it back values = pad_input(values_rmpad, indices=indices, batch=batch, seqlen=seqlen).squeeze(-1) @@ -117,7 +128,11 @@ def _forward_micro_batch(self, micro_batch): **multi_modal_inputs, use_cache=False, ) # prevent model thinks we are generating - values = output.logits + if hasattr(self.critic_module, "v_head"): + # For trl.AutoModelForCausalLMWithValueHead + values = output[2] + else: + values = output.logits values = values[:, -response_length - 1 : -1].squeeze(-1) return values @@ -143,43 +158,32 @@ def _optimizer_step(self): def compute_values(self, data: DataProto) -> torch.Tensor: self.critic_module.eval() micro_batch_size = data.meta_info["micro_batch_size"] - select_keys = ["responses", "input_ids", "attention_mask", "position_ids"] - batch = data.select(batch_keys=select_keys).batch use_dynamic_bsz = data.meta_info["use_dynamic_bsz"] has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() + select_keys = ["responses", "input_ids", "response_mask", "attention_mask", "position_ids"] + non_tensor_select_keys = ["multi_modal_inputs"] if has_multi_modal_inputs else [] - if has_multi_modal_inputs: - num_micro_batches = data.batch.batch_size[0] // micro_batch_size - non_tensor_select_keys = ["multi_modal_inputs"] - micro_batches = data.select(select_keys, non_tensor_select_keys).chunk(num_micro_batches) - elif use_dynamic_bsz: - # split using dynamic bsz + data = data.select(batch_keys=select_keys, non_tensor_batch_keys=non_tensor_select_keys) + + if use_dynamic_bsz: max_token_len = data.meta_info["max_token_len"] * self.ulysses_sequence_parallel_size - micro_batches, indices = rearrange_micro_batches(batch=batch, max_token_len=max_token_len) + micro_batches, batch_idx_list = prepare_dynamic_batch(data, max_token_len=max_token_len) else: - micro_batches = batch.split(micro_batch_size) + micro_batches = data.split(micro_batch_size) values_lst = [] for micro_batch in micro_batches: - if isinstance(micro_batch, DataProto): - micro_batch = {**micro_batch.batch, **micro_batch.non_tensor_batch} - + model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch} with torch.no_grad(): - values = self._forward_micro_batch(micro_batch) + values = self._forward_micro_batch(model_inputs) values_lst.append(values) values = torch.concat(values_lst, dim=0) if use_dynamic_bsz: - indices = list(itertools.chain.from_iterable(indices)) - assert len(indices) == values.size(0), f"{len(indices)} vs. {values.size()}" - revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long) - values = values[revert_indices] - - responses = data.batch["responses"] - attention_mask = data.batch["attention_mask"] - response_length = responses.size(1) - response_mask = attention_mask[:, -response_length:] - values = values * response_mask # Only action tokens have values + values = restore_dynamic_batch(values, batch_idx_list) + + response_mask = data.batch["response_mask"] + values = values * response_mask # Only action tokens have values return values @GPUMemoryLogger(role="dp critic", logger=logger) @@ -188,53 +192,37 @@ def update_critic(self, data: DataProto): self.critic_module.train() metrics = {} - select_keys = ["input_ids", "responses", "attention_mask", "position_ids", "values", "returns"] - batch = data.select(batch_keys=select_keys).batch + select_keys = ["input_ids", "responses", "response_mask", "attention_mask", "position_ids", "values", "returns"] has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() + non_tensor_select_keys = ["multi_modal_inputs"] if has_multi_modal_inputs else [] + + data = data.select(batch_keys=select_keys, non_tensor_batch_keys=non_tensor_select_keys) # Split to make minibatch iterator for updating the actor # See PPO paper for details. https://arxiv.org/abs/1707.06347 - if has_multi_modal_inputs: - num_mini_batches = data.batch.batch_size[0] // self.config.ppo_mini_batch_size - non_tensor_select_keys = ["multi_modal_inputs"] - dataloader = data.select(select_keys, non_tensor_select_keys).chunk(num_mini_batches) - else: - dataloader = batch.split(self.config.ppo_mini_batch_size) - - for epoch in range(self.config.ppo_epochs): - for batch_idx, data in enumerate(dataloader): - # split batch into micro_batches - mini_batch = data - if has_multi_modal_inputs: - num_micro_batches = mini_batch.batch.batch_size[0] // self.config.ppo_micro_batch_size_per_gpu - micro_batches = data.select(select_keys, non_tensor_select_keys).chunk(num_micro_batches) - elif self.config.use_dynamic_bsz: + mini_batches = data.split(self.config.ppo_mini_batch_size) + + for _ in range(self.config.ppo_epochs): + for batch_idx, mini_batch in enumerate(mini_batches): + if self.config.use_dynamic_bsz: max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size - micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len) + micro_batches, _ = prepare_dynamic_batch(mini_batch, max_token_len=max_token_len) else: + self.gradient_accumulation = ( + self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu + ) micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu) - self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu self.critic_optimizer.zero_grad() - for data in micro_batches: - # Support all devices - if isinstance(data, DataProto): - data = {**data.batch.to(get_torch_device().current_device()), **data.non_tensor_batch} - else: - data = data.to(get_torch_device().current_device()) # critic device is cpu when using offload - responses = data["responses"] - attention_mask = data["attention_mask"] - values = data["values"] - returns = data["returns"] - response_length = responses.size(1) - - response_mask = attention_mask[:, -response_length:] - - vpreds = self._forward_micro_batch(data) - - # assert not torch.any(torch.isnan(vpreds)).item() + for micro_batch in micro_batches: + micro_batch_metrics = {} + model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch} + response_mask = model_inputs["response_mask"] + values = model_inputs["values"] + returns = model_inputs["returns"] + vpreds = self._forward_micro_batch(model_inputs) vf_loss, vf_clipfrac = core_algos.compute_value_loss( vpreds=vpreds, values=values, @@ -245,22 +233,24 @@ def update_critic(self, data: DataProto): ) if self.config.use_dynamic_bsz: # relative to the dynamic bsz - loss = vf_loss * (len(data) / self.config.ppo_mini_batch_size) + loss = vf_loss * (response_mask.shape[0] / self.config.ppo_mini_batch_size) else: loss = vf_loss / self.gradient_accumulation loss.backward() - data = { - "critic/vf_loss": vf_loss.detach().item(), - "critic/vf_clipfrac": vf_clipfrac.detach().item(), - "critic/vpred_mean": masked_mean(vpreds, response_mask).detach().item(), - } + micro_batch_metrics.update( + { + "critic/vf_loss": vf_loss.detach().item(), + "critic/vf_clipfrac": vf_clipfrac.detach().item(), + "critic/vpred_mean": masked_mean(vpreds, response_mask).detach().item(), + } + ) - append_to_dict(metrics, data) + append_to_dict(metrics, micro_batch_metrics) grad_norm = self._optimizer_step() - data = {"critic/grad_norm": grad_norm.detach().item()} - append_to_dict(metrics, data) + mini_batch_metrics = {"critic/grad_norm": grad_norm.detach().item()} + append_to_dict(metrics, mini_batch_metrics) self.critic_optimizer.zero_grad() return metrics diff --git a/verl/workers/critic/megatron_critic.py b/verl/workers/critic/megatron_critic.py index 7e63f6108..1d44a8876 100644 --- a/verl/workers/critic/megatron_critic.py +++ b/verl/workers/critic/megatron_critic.py @@ -31,11 +31,11 @@ from verl import DataProto from verl.trainer.ppo import core_algos -from verl.utils.debug import GPUMemoryLogger +from verl.utils.device import get_device_id, get_torch_device from verl.utils.megatron.pipeline_parallel import make_batch_generator +from verl.utils.profiler import GPUMemoryLogger from verl.utils.py_functional import append_to_dict -from verl.utils.seqlen_balancing import (get_reverse_idx, - rearrange_micro_batches) +from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches from verl.utils.torch_functional import broadcast_dict_tensor, masked_mean from verl.workers.critic import BasePPOCritic @@ -91,7 +91,7 @@ def _validate_config(self, config) -> None: @GPUMemoryLogger("megatron critic", logger=logger) def compute_values(self, data: DataProto) -> DataProto: - data.to(torch.cuda.current_device()) + data.to(get_device_id()) responses = data.batch["responses"] attention_mask = data.batch["attention_mask"] use_dynamic_bsz = data.meta_info.get("use_dynamic_bsz", False) @@ -103,7 +103,14 @@ def compute_values(self, data: DataProto) -> DataProto: max_token_len = max_token_len * self.config.megatron.context_parallel_size response_length = responses.size(1) with torch.no_grad(): - output = self.forward_backward_batch(data=data, forward_only=True, use_dynamic_bsz=use_dynamic_bsz, micro_batch_size=micro_batch_size, max_token_len=max_token_len, mini_batch_size=None) + output = self.forward_backward_batch( + data=data, + forward_only=True, + use_dynamic_bsz=use_dynamic_bsz, + micro_batch_size=micro_batch_size, + max_token_len=max_token_len, + mini_batch_size=None, + ) if mpu.is_pipeline_last_stage(ignore_virtual=True): # only on last rank. It should be on every tp rank values = [o["vpreds"] for o in output["output"]] # (bs, seq_size, vocal_size) @@ -118,9 +125,11 @@ def compute_values(self, data: DataProto) -> DataProto: values = torch.empty_like(attention_mask, dtype=torch.float32) # each tp ranks should contain the same value - values = values[:, -response_length - 1 : -1] # Values are predicted at the ends of prefixes, e.g., the last prompt token + values = values[ + :, -response_length - 1 : -1 + ] # Values are predicted at the ends of prefixes, e.g., the last prompt token response_mask = attention_mask[:, -response_length:] - values = values * response_mask # Only action tokens have values + values = values * response_mask # Only action tokens have values values = values.contiguous() # sync among pp ranks @@ -131,7 +140,7 @@ def compute_values(self, data: DataProto) -> DataProto: ) # add empty cache after each compute - torch.cuda.empty_cache() + get_torch_device().empty_cache() return values @@ -145,12 +154,24 @@ def make_minibatch_iterator(self, data: DataProto) -> Iterable[DataProto]: dataloader_kwargs={"shuffle": self.config.shuffle}, ) - def forward_backward_batch(self, data: DataProto, forward_only=False, use_dynamic_bsz=False, micro_batch_size=None, max_token_len=None, mini_batch_size=None): + def forward_backward_batch( + self, + data: DataProto, + forward_only=False, + use_dynamic_bsz=False, + micro_batch_size=None, + max_token_len=None, + mini_batch_size=None, + ): # broadcast from last pp rank to all other pp ranks mini_batch = data - mini_batch.to(torch.cuda.current_device()) + mini_batch.to(get_device_id()) mini_batch.batch = mini_batch.batch.contiguous() - broadcast_dict_tensor(mini_batch.batch, src=mpu.get_pipeline_model_parallel_last_rank(), group=mpu.get_pipeline_model_parallel_group()) + broadcast_dict_tensor( + mini_batch.batch, + src=mpu.get_pipeline_model_parallel_last_rank(), + group=mpu.get_pipeline_model_parallel_group(), + ) # split into micro-batches mini_batch.batch["attention_mask"] = mini_batch.batch["attention_mask"].to(bool) @@ -160,13 +181,22 @@ def forward_backward_batch(self, data: DataProto, forward_only=False, use_dynami vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size() if vpp_size is not None and vpp_size > 1: microbatch_group_size_per_vp_stage = self.tf_config.microbatch_group_size_per_vp_stage - micro_batches, indices = rearrange_micro_batches(batch=mini_batch.batch, num_batches_divided_by=microbatch_group_size_per_vp_stage, max_token_len=max_token_len) - assert len(micro_batches) % self.tf_config.microbatch_group_size_per_vp_stage == 0, f"micro_batches {micro_batches} must be divisible by microbatch_group_size_per_vp_stage {microbatch_group_size_per_vp_stage} for megatron backend" + micro_batches, indices = rearrange_micro_batches( + batch=mini_batch.batch, + num_batches_divided_by=microbatch_group_size_per_vp_stage, + max_token_len=max_token_len, + ) + assert len(micro_batches) % self.tf_config.microbatch_group_size_per_vp_stage == 0, ( + f"micro_batches {micro_batches} must be divisible by microbatch_group_size_per_vp_stage " + f"{microbatch_group_size_per_vp_stage} for megatron backend" + ) else: micro_batches, indices = rearrange_micro_batches(batch=mini_batch.batch, max_token_len=max_token_len) total_seqlen = max_token_len else: - assert micro_batch_size is not None, "micro_batch_size is needed to be passed in when not using dynamic batch size" + assert micro_batch_size is not None, ( + "micro_batch_size is needed to be passed in when not using dynamic batch size" + ) micro_batches = mini_batch.batch.split(micro_batch_size) seq_len = micro_batches[0]["input_ids"].shape[1] total_seqlen = micro_batch_size * seq_len @@ -276,7 +306,14 @@ def update_critic(self, dataloader: Iterable[DataProto]): max_token_len = None if self.config.use_dynamic_bsz: max_token_len = self.config.ppo_max_token_len_per_gpu * self.config.megatron.context_parallel_size - metric_micro_batch = self.forward_backward_batch(data, forward_only=False, use_dynamic_bsz=self.config.use_dynamic_bsz, micro_batch_size=micro_batch_size, max_token_len=max_token_len, mini_batch_size=self.config.ppo_mini_batch_size) + metric_micro_batch = self.forward_backward_batch( + data, + forward_only=False, + use_dynamic_bsz=self.config.use_dynamic_bsz, + micro_batch_size=micro_batch_size, + max_token_len=max_token_len, + mini_batch_size=self.config.ppo_mini_batch_size, + ) metric_micro_batch = metric_micro_batch["output"] update_successful, grad_norm, num_zeros_in_grad = self.critic_optimizer.step() learning_rate = self.critic_optimizer.param_groups[-1]["lr"] @@ -293,5 +330,5 @@ def update_critic(self, dataloader: Iterable[DataProto]): append_to_dict(metrics, metric) # append the metric from this micro-batch to global metrics. # add empty cache after each compute - torch.cuda.empty_cache() + get_torch_device().empty_cache() return metrics diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index 1b87d2f5d..30e117000 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -20,14 +20,14 @@ import os import warnings from dataclasses import asdict -from typing import Union +from typing import Any import psutil import torch import torch.distributed import torch.distributed as dist from codetiming import Timer -from omegaconf import DictConfig, open_dict +from omegaconf import DictConfig, OmegaConf, open_dict from peft import LoraConfig, TaskType, get_peft_model from safetensors.torch import save_file from torch.distributed.device_mesh import init_device_mesh @@ -41,8 +41,15 @@ from verl.utils import hf_processor, hf_tokenizer from verl.utils.activation_offload import enable_activation_offloading from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager -from verl.utils.debug import log_gpu_memory_usage -from verl.utils.device import get_device_name, get_torch_device, is_cuda_available, is_npu_available +from verl.utils.config import omega_conf_to_dataclass +from verl.utils.device import ( + get_device_id, + get_device_name, + get_nccl_backend, + get_torch_device, + is_cuda_available, + is_npu_available, +) from verl.utils.flops_counter import FlopsCounter from verl.utils.fs import copy_to_local from verl.utils.fsdp_utils import ( @@ -62,6 +69,8 @@ ) from verl.utils.import_utils import import_external_libs from verl.utils.model import compute_position_id_with_mask +from verl.utils.profiler import DistProfiler, DistProfilerExtension, log_gpu_memory_usage, simple_timer +from verl.utils.profiler.performance import reduce_timing from verl.utils.py_functional import convert_to_regular_types from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager @@ -75,7 +84,9 @@ def create_device_mesh(world_size, fsdp_size): if fsdp_size < 0 or fsdp_size >= world_size: device_mesh = init_device_mesh(device_name, mesh_shape=(world_size,), mesh_dim_names=["fsdp"]) else: - device_mesh = init_device_mesh(device_name, mesh_shape=(world_size // fsdp_size, fsdp_size), mesh_dim_names=["ddp", "fsdp"]) + device_mesh = init_device_mesh( + device_name, mesh_shape=(world_size // fsdp_size, fsdp_size), mesh_dim_names=["ddp", "fsdp"] + ) return device_mesh @@ -91,21 +102,28 @@ def get_sharding_strategy(device_mesh): return sharding_strategy -class ActorRolloutRefWorker(Worker): +class ActorRolloutRefWorker(Worker, DistProfilerExtension): """ This worker can be instantiated as a standalone actor or a standalone rollout or a standalone reference policy or a hybrid engine based on the config.rollout """ - def __init__(self, config: DictConfig, role: str): - super().__init__() + def __init__(self, config: DictConfig, role: str, **kwargs): + Worker.__init__(self) + self.config = config + self.profile_option = kwargs.get("profile_option", None) import torch.distributed if not torch.distributed.is_initialized(): rank = int(os.environ.get("RANK", 0)) world_size = int(os.environ.get("WORLD_SIZE", 1)) - torch.distributed.init_process_group(backend="cpu:gloo,cuda:nccl" if is_cuda_available else "cpu:gloo,npu:hccl", rank=rank, world_size=world_size) + torch.distributed.init_process_group( + backend=f"cpu:gloo,{get_device_name()}:{get_nccl_backend()}", + rank=rank, + world_size=world_size, + init_method=os.environ.get("DIST_INIT_METHOD", None), + ) # build device mesh for FSDP world_size = torch.distributed.get_world_size() @@ -117,7 +135,9 @@ def __init__(self, config: DictConfig, role: str): self.ulysses_sequence_parallel_size = self.config.actor.get("ulysses_sequence_parallel_size", 1) dp = world_size // self.ulysses_sequence_parallel_size if self.ulysses_sequence_parallel_size > 1: - self.ulysses_device_mesh = init_device_mesh(device_name, mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"]) + self.ulysses_device_mesh = init_device_mesh( + device_name, mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"] + ) self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) self._lora_rank = self.config.model.get("lora_rank", 0) @@ -130,6 +150,17 @@ def __init__(self, config: DictConfig, role: str): self._is_rollout = self.role in ["rollout", "actor_rollout", "actor_rollout_ref"] self._is_ref = self.role in ["ref", "actor_rollout_ref"] + # TODO(haibin.lin): + # As of now the type of config is DictConfig, if we assign config.profiler with ProfilerConfig, + # it will actually convert the ProfilerConfig dataclass back to a DictConfig. + # We can still use ProfilerConfig for testing purpose (tests/utils/test_nvtx_profile.py) + # as they provides DictConfig-like interface + # The benefit of creating the dataclass config is to perform validation during __post_init__ + profiler_config = omega_conf_to_dataclass(config.get("profiler")) + DistProfilerExtension.__init__( + self, DistProfiler(rank=self.rank, config=profiler_config, option=self.profile_option) + ) + self._is_offload_param = False self._is_offload_optimizer = False if self._is_actor: @@ -143,19 +174,32 @@ def __init__(self, config: DictConfig, role: str): if self._is_actor: self.config.actor.ppo_mini_batch_size *= self.config.rollout.n self.config.actor.ppo_mini_batch_size //= self.device_mesh.size() // self.ulysses_sequence_parallel_size - assert self.config.actor.ppo_mini_batch_size > 0, f"ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be larger than 0 after normalization" + assert self.config.actor.ppo_mini_batch_size > 0, ( + f"ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be larger than 0 after " + f"normalization" + ) # micro bsz if self.config.actor.ppo_micro_batch_size is not None: - self.config.actor.ppo_micro_batch_size //= self.device_mesh.size() // self.ulysses_sequence_parallel_size + self.config.actor.ppo_micro_batch_size //= ( + self.device_mesh.size() // self.ulysses_sequence_parallel_size + ) self.config.actor.ppo_micro_batch_size_per_gpu = self.config.actor.ppo_micro_batch_size if self.config.actor.ppo_micro_batch_size_per_gpu is not None: - assert self.config.actor.ppo_mini_batch_size % self.config.actor.ppo_micro_batch_size_per_gpu == 0, f"normalized ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be divisible by ppo_micro_batch_size_per_gpu {self.config.actor.ppo_micro_batch_size_per_gpu}" - assert self.config.actor.ppo_mini_batch_size // self.config.actor.ppo_micro_batch_size_per_gpu > 0, f"normalized ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be larger than ppo_micro_batch_size_per_gpu {self.config.actor.ppo_micro_batch_size_per_gpu}" + assert self.config.actor.ppo_mini_batch_size % self.config.actor.ppo_micro_batch_size_per_gpu == 0, ( + f"normalized ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be divisible by " + f"ppo_micro_batch_size_per_gpu {self.config.actor.ppo_micro_batch_size_per_gpu}" + ) + assert self.config.actor.ppo_mini_batch_size // self.config.actor.ppo_micro_batch_size_per_gpu > 0, ( + f"normalized ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be larger than " + f"ppo_micro_batch_size_per_gpu {self.config.actor.ppo_micro_batch_size_per_gpu}" + ) # normalize rollout config if self._is_rollout and self.config.rollout.log_prob_micro_batch_size is not None: - self.config.rollout.log_prob_micro_batch_size //= self.device_mesh.size() // self.ulysses_sequence_parallel_size + self.config.rollout.log_prob_micro_batch_size //= ( + self.device_mesh.size() // self.ulysses_sequence_parallel_size + ) self.config.rollout.log_prob_micro_batch_size_per_gpu = self.config.rollout.log_prob_micro_batch_size # normalize ref config if self._is_ref and self.config.ref.log_prob_micro_batch_size is not None: @@ -193,6 +237,12 @@ def _build_model_optimizer( self.tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) self.processor = hf_processor(local_path, trust_remote_code=trust_remote_code) + if self.config.model.get("custom_chat_template", None) is not None: + if self.processor is not None: + self.processor.chat_template = self.config.model.custom_chat_template + else: + self.tokenizer.chat_template = self.config.model.custom_chat_template + torch_dtype = fsdp_config.get("model_dtype", None) if torch_dtype is None: torch_dtype = torch.float32 if self._is_actor else torch.bfloat16 @@ -200,7 +250,9 @@ def _build_model_optimizer( torch_dtype = PrecisionType.to_dtype(torch_dtype) # override model kwargs - actor_model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code, attn_implementation="flash_attention_2") + actor_model_config = AutoConfig.from_pretrained( + local_path, trust_remote_code=trust_remote_code, attn_implementation="flash_attention_2" + ) # patch for kimi-vl if getattr(actor_model_config, "model_type", None) == "kimi_vl": @@ -219,7 +271,9 @@ def _build_model_optimizer( print(f"Model config after override: {actor_model_config}") # NOTE(fix me): tie_word_embedding causes meta_tensor init to hang - init_context = get_init_weight_context_manager(use_meta_tensor=not actor_model_config.tie_word_embeddings, mesh=self.device_mesh) + init_context = get_init_weight_context_manager( + use_meta_tensor=not actor_model_config.tie_word_embeddings, mesh=self.device_mesh + ) with init_context(), warnings.catch_warnings(): warnings.simplefilter("ignore") @@ -241,11 +295,17 @@ def _build_model_optimizer( _apply_liger_kernel_to_instance(model=actor_module) + fused_kernel_options = self.config.model.get("fused_kernel_options", None) + fused_kernels_backend = ( + fused_kernel_options.get("impl_backend", None) if fused_kernel_options is not None else None + ) + apply_monkey_patch( model=actor_module, use_remove_padding=use_remove_padding, ulysses_sp_size=self.ulysses_sequence_parallel_size, use_fused_kernels=use_fused_kernels, + fused_kernels_backend=fused_kernels_backend, ) # some parameters may not in torch_dtype. TODO(zhangchi.usc1992) remove this after we switch to fsdp2 @@ -257,7 +317,14 @@ def _build_model_optimizer( print("Applying LoRA to actor module") actor_module.enable_input_require_grads() # Convert config to regular Python types before creating PEFT model - lora_config = {"task_type": TaskType.CAUSAL_LM, "r": self.config.model.lora_rank, "lora_alpha": self.config.model.lora_alpha, "target_modules": convert_to_regular_types(self.config.model.target_modules), "bias": "none"} + lora_config = { + "task_type": TaskType.CAUSAL_LM, + "r": self.config.model.lora_rank, + "lora_alpha": self.config.model.lora_alpha, + "target_modules": convert_to_regular_types(self.config.model.target_modules), + "exclude_modules": convert_to_regular_types(self.config.model.exclude_modules), + "bias": "none", + } actor_module = get_peft_model(actor_module, LoraConfig(**lora_config)) torch.distributed.barrier() @@ -279,7 +346,11 @@ def _build_model_optimizer( mixed_precision = MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype) - auto_wrap_policy = get_fsdp_wrap_policy(module=actor_module, config=fsdp_config.get("wrap_policy", None), is_lora=self.config.model.get("lora_rank", 0) > 0) + auto_wrap_policy = get_fsdp_wrap_policy( + module=actor_module, + config=fsdp_config.get("wrap_policy", None), + is_lora=self.config.model.get("lora_rank", 0) > 0, + ) if self._is_rollout and self.config.rollout.name == "hf": # TODO(zhangchi.usc1992, shengguangming) fix me. Current, auto_wrap_policy causes HFRollout to hang in Gemma @@ -301,18 +372,20 @@ def _build_model_optimizer( actor_module, cpu_offload=cpu_offload, param_init_fn=init_fn, - use_orig_params=False, auto_wrap_policy=auto_wrap_policy, - device_id=get_torch_device().current_device(), + device_id=get_device_id(), sharding_strategy=sharding_strategy, # zero3 mixed_precision=mixed_precision, sync_module_states=True, device_mesh=self.device_mesh, - forward_prefetch=False, + use_orig_params=self.config.actor.fsdp_config.get("use_orig_params", False), + forward_prefetch=self.config.actor.fsdp_config.get("forward_prefetch", False), ) elif fsdp_strategy == "fsdp2": assert CPUOffloadPolicy is not None, "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)" - mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype, cast_forward_inputs=True) + mp_policy = MixedPrecisionPolicy( + param_dtype=param_dtype, reduce_dtype=reduce_dtype, cast_forward_inputs=True + ) if role == "actor" and fsdp_config.offload_policy: cpu_offload = CPUOffloadPolicy(pin_memory=True) self._is_offload_param = False @@ -362,9 +435,17 @@ def _build_model_optimizer( print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}") if warmup_style == "constant": - actor_lr_scheduler = get_constant_schedule_with_warmup(optimizer=actor_optimizer, num_warmup_steps=num_warmup_steps) + actor_lr_scheduler = get_constant_schedule_with_warmup( + optimizer=actor_optimizer, num_warmup_steps=num_warmup_steps + ) elif warmup_style == "cosine": - actor_lr_scheduler = get_cosine_schedule_with_warmup(optimizer=actor_optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=total_steps, min_lr_ratio=min_lr_ratio, num_cycles=num_cycles) + actor_lr_scheduler = get_cosine_schedule_with_warmup( + optimizer=actor_optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=total_steps, + min_lr_ratio=min_lr_ratio, + num_cycles=num_cycles, + ) else: raise NotImplementedError(f"Warmup style {warmup_style} is not supported") @@ -381,8 +462,12 @@ def _build_rollout(self, trust_remote_code=False): # TODO(sgm): support FSDP hybrid shard for larger model infer_tp = self.config.rollout.tensor_model_parallel_size dp = self.world_size // infer_tp - assert self.world_size % infer_tp == 0, f"rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}" - rollout_device_mesh = init_device_mesh(device_name, mesh_shape=(dp, infer_tp), mesh_dim_names=["dp", "infer_tp"]) + assert self.world_size % infer_tp == 0, ( + f"rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}" + ) + rollout_device_mesh = init_device_mesh( + device_name, mesh_shape=(dp, infer_tp), mesh_dim_names=["dp", "infer_tp"] + ) rollout_name = self.config.rollout.name if rollout_name == "hf": from verl.workers.rollout import HFRollout @@ -393,22 +478,29 @@ def _build_rollout(self, trust_remote_code=False): # TODO: a sharding manager that do nothing? elif rollout_name == "vllm": - from verl.workers.rollout.vllm_rollout import vllm_mode, vLLMRollout + from verl.workers.rollout.vllm_rollout import vLLMRollout from verl.workers.sharding_manager.fsdp_vllm import FSDPVLLMShardingManager log_gpu_memory_usage(f"Before building {rollout_name} rollout", logger=logger) local_path = copy_to_local(self.config.model.path, use_shm=self.config.model.get("use_shm", False)) - lora_kwargs = {"lora_kwargs": {"enable_lora": True, "max_loras": 1, "max_lora_rank": self._lora_rank}} if self._is_lora else {} + lora_kwargs = ( + {"lora_kwargs": {"enable_lora": True, "max_loras": 1, "max_lora_rank": self._lora_rank}} + if self._is_lora + else {} + ) # lora_kwargs = {} - if vllm_mode == "customized": - rollout = vLLMRollout(actor_module=self.actor_module_fsdp, config=self.config.rollout, tokenizer=self.tokenizer, model_hf_config=self.actor_model_config, trust_remote_code=trust_remote_code, **lora_kwargs) - elif vllm_mode == "spmd": - from verl.workers.rollout.vllm_rollout import vLLMAsyncRollout + from verl.workers.rollout.vllm_rollout import vLLMAsyncRollout - vllm_rollout_cls = vLLMRollout if self.config.rollout.mode == "sync" else vLLMAsyncRollout - rollout = vllm_rollout_cls(model_path=local_path, config=self.config.rollout, tokenizer=self.tokenizer, model_hf_config=self.actor_model_config, device_mesh=rollout_device_mesh, trust_remote_code=trust_remote_code, **lora_kwargs) - else: - raise NotImplementedError("vllm_mode must be 'customized' or 'spmd'") + vllm_rollout_cls = vLLMRollout if self.config.rollout.mode == "sync" else vLLMAsyncRollout + rollout = vllm_rollout_cls( + model_path=local_path, + config=self.config.rollout, + tokenizer=self.tokenizer, + model_hf_config=self.actor_model_config, + device_mesh=rollout_device_mesh, + trust_remote_code=trust_remote_code, + **lora_kwargs, + ) log_gpu_memory_usage(f"After building {rollout_name} rollout", logger=logger) full_params = torch.distributed.get_world_size() == 1 @@ -416,6 +508,7 @@ def _build_rollout(self, trust_remote_code=False): module=self.actor_module_fsdp, inference_engine=rollout.inference_engine, model_config=self.actor_model_config, + rollout_config=self.config.rollout, full_params=full_params, device_mesh=rollout_device_mesh, offload_param=self._is_offload_param, @@ -424,13 +517,7 @@ def _build_rollout(self, trust_remote_code=False): ) log_gpu_memory_usage("After building sharding manager", logger=logger) - elif rollout_name in ["sglang", "sglang_async"]: - if rollout_name == "sglang_async": - warnings.warn( - "'sglang_async' has been deprecated and merged into 'sglang'. Please use 'sglang' going forward.", - DeprecationWarning, - stacklevel=2, - ) + elif rollout_name == "sglang": from verl.workers.rollout.sglang_rollout import SGLangRollout # NOTE(linjunrong): Due to recent fp8 support in SGLang. Now importing any symbol relate to @@ -447,7 +534,7 @@ def _build_rollout(self, trust_remote_code=False): rollout = SGLangRollout( actor_module=local_path, config=self.config.rollout, - tokenizer=self.tokenizer, + processing_class=self.processor if self.processor is not None else self.tokenizer, model_hf_config=self.actor_model_config, trust_remote_code=trust_remote_code, ) @@ -459,9 +546,11 @@ def _build_rollout(self, trust_remote_code=False): module=self.actor_module_fsdp, inference_engine=rollout._engine, model_config=self.actor_model_config, + rollout_config=self.config.rollout, full_params="hf" in self.config.rollout.load_format, device_mesh=rollout_device_mesh, offload_param=self._is_offload_param, + multi_stage_wake_up=self.config.rollout.multi_stage_wake_up, ) log_gpu_memory_usage("After building sharding manager", logger=logger) @@ -477,8 +566,6 @@ def init_model(self): # This is used to import external_lib into the huggingface systems import_external_libs(self.config.model.get("external_lib", None)) - from omegaconf import OmegaConf - override_model_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create())) use_remove_padding = self.config.model.get("use_remove_padding", False) @@ -525,16 +612,20 @@ def init_model(self): if self._is_offload_optimizer: offload_fsdp_optimizer(optimizer=self.actor_optimizer) log_gpu_memory_usage("After offload actor optimizer during init", logger=logger) - # load from checkpoint + if self._is_actor: OmegaConf.set_struct(self.config.actor, True) with open_dict(self.config.actor): self.config.actor.use_remove_padding = use_remove_padding self.config.actor.use_fused_kernels = use_fused_kernels - self.actor = DataParallelPPOActor(config=self.config.actor, actor_module=self.actor_module_fsdp, actor_optimizer=self.actor_optimizer) + self.actor = DataParallelPPOActor( + config=self.config.actor, actor_module=self.actor_module_fsdp, actor_optimizer=self.actor_optimizer + ) if self._is_rollout: - self.rollout, self.rollout_sharding_manager = self._build_rollout(trust_remote_code=self.config.model.get("trust_remote_code", False)) + self.rollout, self.rollout_sharding_manager = self._build_rollout( + trust_remote_code=self.config.model.get("trust_remote_code", False) + ) if self._is_ref: local_path = copy_to_local(self.config.model.path, use_shm=use_shm) @@ -562,19 +653,33 @@ def init_model(self): optimizer=self.actor.actor_optimizer, lr_scheduler=self.actor_lr_scheduler, processing_class=self.processor if self.processor is not None else self.tokenizer, - checkpoint_contents=self.config.actor.checkpoint.contents, + checkpoint_config=self.config.actor.checkpoint, + ) + + if not self._is_actor and self._is_rollout: + # If ActorRolloutRefWorker is initialized as a standalone rollout, + # create a checkpoint manager for FSDP model to allow loading FSDP checkpoints for rollout. + + checkpoint_contents = OmegaConf.create({"load_contents": ["model"], "save_contents": []}) + self.checkpoint_manager = FSDPCheckpointManager( + model=self.actor_module_fsdp, + optimizer=None, + lr_scheduler=None, + processing_class=self.processor if self.processor is not None else self.tokenizer, + checkpoint_config=checkpoint_contents, ) @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) + @DistProfiler.annotate(color="red", role="actor_update") def update_actor(self, data: DataProto): # Support all hardwares - data = data.to(get_torch_device().current_device()) + data = data.to(get_device_id()) assert self._is_actor if self._is_offload_param: load_fsdp_model_to_gpu(self.actor_module_fsdp) if self._is_offload_optimizer: - load_fsdp_optimizer(optimizer=self.actor_optimizer, device_id=get_torch_device().current_device()) + load_fsdp_optimizer(optimizer=self.actor_optimizer, device_id=get_device_id()) with self.ulysses_sharding_manager: data = self.ulysses_sharding_manager.preprocess_data(data=data) @@ -584,7 +689,9 @@ def update_actor(self, data: DataProto): delta_time = timer.last global_num_tokens = data.meta_info["global_token_num"] estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time) - metrics["perf/mfu/actor"] = estimated_flops * self.config.actor.ppo_epochs / promised_flops / self.world_size + metrics["perf/mfu/actor"] = ( + estimated_flops * self.config.actor.ppo_epochs / promised_flops / self.world_size + ) metrics["perf/max_memory_allocated_gb"] = get_torch_device().max_memory_allocated() / (1024**3) metrics["perf/max_memory_reserved_gb"] = get_torch_device().max_memory_reserved() / (1024**3) metrics["perf/cpu_memory_used_gb"] = psutil.virtual_memory().used / (1024**3) @@ -609,27 +716,39 @@ def update_actor(self, data: DataProto): return output @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) + @DistProfiler.annotate(color="red", role="rollout_generate") def generate_sequences(self, prompts: DataProto): # Support all hardwares - prompts = prompts.to(get_torch_device().current_device()) + prompts = prompts.to(get_device_id()) assert self._is_rollout meta_info = { - "eos_token_id": self.generation_config.eos_token_id if self.generation_config is not None else self.tokenizer.eos_token_id, - "pad_token_id": self.generation_config.pad_token_id if self.generation_config is not None else self.tokenizer.pad_token_id, + "eos_token_id": self.generation_config.eos_token_id + if self.generation_config is not None + else self.tokenizer.eos_token_id, + "pad_token_id": self.generation_config.pad_token_id + if self.generation_config is not None + else self.tokenizer.pad_token_id, } prompts.meta_info.update(meta_info) + timing_generate = {} with self.rollout_sharding_manager: log_gpu_memory_usage("After entering rollout sharding manager", logger=logger) prompts = self.rollout_sharding_manager.preprocess_data(prompts) - output = self.rollout.generate_sequences(prompts=prompts) + with simple_timer("generate_sequences", timing_generate): + output = self.rollout.generate_sequences(prompts=prompts) log_gpu_memory_usage("After rollout generation", logger=logger) output = self.rollout_sharding_manager.postprocess_data(output) + timing_generate.update(self.rollout_sharding_manager.timing) + # We calculate the average timing across all ranks + # to make sure meta_info["timing"] is the same + timing_generate = reduce_timing(timing_generate) + output.meta_info["timing"] = timing_generate output = output.to("cpu") # clear kv cache @@ -637,6 +756,7 @@ def generate_sequences(self, prompts: DataProto): return output @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) + @DistProfiler.annotate(color="blue", role="actor_compute_log_prob") def compute_log_prob(self, data: DataProto): # when is_lora is True, we use the actor without lora applied to calculate the log_prob # which is mostly used for ref log_prob calculation @@ -649,7 +769,7 @@ def compute_log_prob(self, data: DataProto): is_lora = data.meta_info.pop("is_lora", False) adapter_ctx = self.actor.actor_module.disable_adapter() if is_lora else nullcontext() - data = data.to(get_torch_device().current_device()) + data = data.to(get_device_id()) # we should always recompute old_log_probs when it is HybridEngine data.meta_info["micro_batch_size"] = self.config.rollout.log_prob_micro_batch_size_per_gpu data.meta_info["max_token_len"] = self.config.rollout.log_prob_max_token_len_per_gpu @@ -680,6 +800,7 @@ def compute_log_prob(self, data: DataProto): return output @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) + @DistProfiler.annotate(color="olive", role="ref_compute_log_prob") def compute_ref_log_prob(self, data: DataProto): if self._is_lora: # if _is_lora, actor without lora applied is the ref @@ -692,7 +813,7 @@ def compute_ref_log_prob(self, data: DataProto): # else: # otherwise, the class have a standalone ref model # Support all hardwares - data = data.to(get_torch_device().current_device()) + data = data.to(get_device_id()) micro_batch_size = self.config.ref.log_prob_micro_batch_size_per_gpu data.meta_info["micro_batch_size"] = micro_batch_size @@ -716,13 +837,17 @@ def compute_ref_log_prob(self, data: DataProto): @register(dispatch_mode=Dispatch.ONE_TO_ALL) def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None): + from verl.utils.logger import log_with_rank + # only support save and load ckpt for actor assert self._is_actor if self._is_offload_param: load_fsdp_model_to_gpu(self.actor_module_fsdp) - self.checkpoint_manager.save_checkpoint(local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep) + self.checkpoint_manager.save_checkpoint( + local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep + ) dist.barrier() if self._is_lora and hasattr(getattr(self, "actor_module", self.actor_module_fsdp), "peft_config"): @@ -737,29 +862,41 @@ def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to peft_config["target_modules"] = list(peft_config["target_modules"]) try: if fsdp_version(self.actor_module_fsdp) > 0: - self.actor_module_fsdp = self.actor_module_fsdp.cuda() + self.actor_module_fsdp = self.actor_module_fsdp.to(get_device_name()) lora_params = layered_summon_lora_params(self.actor_module_fsdp) if dist.get_rank() == 0: save_file(lora_params, os.path.join(lora_save_path, "adapter_model.safetensors")) with open(os.path.join(lora_save_path, "adapter_config.json"), "w", encoding="utf-8") as f: json.dump(peft_config, f, ensure_ascii=False, indent=4) except Exception as e: - if dist.get_rank() == 0: - print(f"[rank-{self.rank}]: Save LoRA Adapter Error ({e})") + log_with_rank( + f"Save LoRA Adapter Error ({e})", rank=dist.get_rank(), logger=logger, log_only_rank_0=True + ) dist.barrier() - if dist.get_rank() == 0: - print(f"[rank-{self.rank}]: Saved LoRA adapter to: {lora_save_path}") + log_with_rank( + f"[rank-{self.rank}]: Saved LoRA adapter to: {lora_save_path}", + rank=dist.get_rank(), + logger=logger, + log_only_rank_0=True, + ) if self._is_offload_param: offload_fsdp_model_to_cpu(self.actor_module_fsdp) @register(dispatch_mode=Dispatch.ONE_TO_ALL) def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=False): + assert self._is_actor or (not self._is_actor and self._is_rollout), ( + f"Checkpoint loading is only supported for Actor or standalone Rollout Workers, but got " + f"{self._is_actor} and {self._is_rollout}" + ) + if self._is_offload_param: load_fsdp_model_to_gpu(self.actor_module_fsdp) - self.checkpoint_manager.load_checkpoint(local_path=local_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load) + self.checkpoint_manager.load_checkpoint( + local_path=local_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load + ) if self._is_offload_param: offload_fsdp_model_to_cpu(self.actor_module_fsdp) @@ -767,14 +904,29 @@ def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=False if self._is_offload_optimizer: offload_fsdp_optimizer(self.actor_optimizer) + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def start_profile(self, **kwargs) -> None: + """Start profiling for the current rank in the current training step.""" + self.profiler.start(**kwargs) -class CriticWorker(Worker): + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def stop_profile(self) -> None: + """Stop profiling for the current rank in the current training step.""" + self.profiler.stop() + + +class CriticWorker(Worker, DistProfilerExtension): def __init__(self, config): - super().__init__() + Worker.__init__(self) + DistProfilerExtension.__init__( + self, DistProfiler(rank=self.rank, config=omega_conf_to_dataclass(config.get("profiler"))) + ) import torch.distributed if not torch.distributed.is_initialized(): - torch.distributed.init_process_group(backend="nccl" if is_cuda_available else "hccl") + torch.distributed.init_process_group( + backend=get_nccl_backend(), init_method=os.environ.get("DIST_INIT_METHOD", None) + ) self.config = config # build device mesh for Ulysses Sequence Parallel @@ -788,7 +940,9 @@ def __init__(self, config): self.ulysses_sequence_parallel_size = self.config.get("ulysses_sequence_parallel_size", 1) dp = world_size // self.ulysses_sequence_parallel_size if self.ulysses_sequence_parallel_size > 1: - self.ulysses_device_mesh = init_device_mesh(device_name, mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"]) + self.ulysses_device_mesh = init_device_mesh( + device_name, mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"] + ) self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) @@ -800,14 +954,24 @@ def __init__(self, config): self.config.ppo_mini_batch_size *= self.config.rollout_n self.config.ppo_mini_batch_size //= torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size if self.config.ppo_micro_batch_size is not None: - self.config.ppo_micro_batch_size //= torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size - self.config.forward_micro_batch_size //= torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size + self.config.ppo_micro_batch_size //= ( + torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size + ) + self.config.forward_micro_batch_size //= ( + torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size + ) self.config.ppo_micro_batch_size_per_gpu = self.config.ppo_micro_batch_size self.config.forward_micro_batch_size_per_gpu = self.config.forward_micro_batch_size if self.config.ppo_micro_batch_size_per_gpu is not None: - assert self.config.ppo_mini_batch_size % self.config.ppo_micro_batch_size_per_gpu == 0, f"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be divisible by ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}" - assert self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu > 0, f"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be larger than ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}" + assert self.config.ppo_mini_batch_size % self.config.ppo_micro_batch_size_per_gpu == 0, ( + f"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be divisible by " + f"ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}" + ) + assert self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu > 0, ( + f"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be larger than " + f"ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}" + ) self._is_lora = self.config.model.get("lora_rank", 0) > 0 def _build_critic_model_optimizer(self, config): @@ -815,7 +979,7 @@ def _build_critic_model_optimizer(self, config): from torch import optim from torch.distributed.fsdp import MixedPrecision - from verl.utils.model import print_model_size + from verl.utils.model import load_valuehead_model, print_model_size from verl.utils.torch_dtypes import PrecisionType use_shm = config.model.get("use_shm", False) @@ -827,7 +991,11 @@ def _build_critic_model_optimizer(self, config): self.tokenizer = hf_tokenizer(tokenizer_path, trust_remote_code=config.model.get("trust_remote_code", False)) self.processor = hf_processor(tokenizer_path, trust_remote_code=config.model.get("trust_remote_code", False)) - from omegaconf import OmegaConf + if self.config.model.get("custom_chat_template", None) is not None: + if self.processor is not None: + self.processor.chat_template = self.config.model.custom_chat_template + else: + self.tokenizer.chat_template = self.config.model.custom_chat_template override_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create())) override_config_kwargs = { @@ -842,25 +1010,33 @@ def _build_critic_model_optimizer(self, config): torch_dtype = self.config.model.fsdp_config.get("model_dtype", "fp32") torch_dtype = PrecisionType.to_dtype(torch_dtype) - from transformers import AutoConfig, AutoModelForTokenClassification + from transformers import AutoConfig - critic_model_config = AutoConfig.from_pretrained(local_path, attn_implementation="flash_attention_2", trust_remote_code=config.model.get("trust_remote_code", False)) + critic_model_config = AutoConfig.from_pretrained( + local_path, + attn_implementation="flash_attention_2", + trust_remote_code=config.model.get("trust_remote_code", False), + ) critic_model_config.num_labels = 1 # patch for kimi-vl if getattr(critic_model_config, "model_type", None) == "kimi_vl": critic_model_config.text_config.topk_method = "greedy" - init_context = get_init_weight_context_manager(use_meta_tensor=not critic_model_config.tie_word_embeddings, mesh=self.device_mesh) + init_context = get_init_weight_context_manager( + use_meta_tensor=not critic_model_config.tie_word_embeddings, mesh=self.device_mesh + ) with init_context(), warnings.catch_warnings(): warnings.simplefilter("ignore") critic_model_config.classifier_dropout = 0.0 critic_model_config.hidden_dropout = "0" - critic_module = AutoModelForTokenClassification.from_pretrained( - pretrained_model_name_or_path=local_path, - torch_dtype=torch_dtype, - config=critic_model_config, - trust_remote_code=config.model.get("trust_remote_code", False), + critic_model_config.summary_dropout_prob = 0.0 + + critic_module = load_valuehead_model( + local_path, + torch_dtype, + critic_model_config, + config.model.get("trust_remote_code", False), ) use_remove_padding = config.model.get("use_remove_padding", False) @@ -908,7 +1084,11 @@ def _build_critic_model_optimizer(self, config): mixed_precision = MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype) - auto_wrap_policy = get_fsdp_wrap_policy(module=critic_module, config=self.config.model.fsdp_config.wrap_policy, is_lora=self.config.model.get("lora_rank", 0) > 0) + auto_wrap_policy = get_fsdp_wrap_policy( + module=critic_module, + config=self.config.model.fsdp_config.wrap_policy, + is_lora=self.config.model.get("lora_rank", 0) > 0, + ) log_gpu_memory_usage("Before critic FSDP", logger=None) @@ -922,17 +1102,19 @@ def _build_critic_model_optimizer(self, config): param_init_fn=init_fn, use_orig_params=False, auto_wrap_policy=auto_wrap_policy, - device_id=get_torch_device().current_device(), + device_id=get_device_id(), sharding_strategy=sharding_strategy, mixed_precision=mixed_precision, sync_module_states=True, - forward_prefetch=False, + forward_prefetch=self.config.model.fsdp_config.forward_prefetch, device_mesh=self.device_mesh, cpu_offload=None, ) elif config.strategy == "fsdp2": assert CPUOffloadPolicy is not None, "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)" - mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype, cast_forward_inputs=True) + mp_policy = MixedPrecisionPolicy( + param_dtype=param_dtype, reduce_dtype=reduce_dtype, cast_forward_inputs=True + ) offload_policy = None if fsdp_config.offload_policy: self._is_offload_param = False @@ -977,9 +1159,13 @@ def _build_critic_model_optimizer(self, config): from verl.utils.torch_functional import get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup if warmup_style == "constant": - critic_lr_scheduler = get_constant_schedule_with_warmup(optimizer=critic_optimizer, num_warmup_steps=num_warmup_steps) + critic_lr_scheduler = get_constant_schedule_with_warmup( + optimizer=critic_optimizer, num_warmup_steps=num_warmup_steps + ) elif warmup_style == "cosine": - critic_lr_scheduler = get_cosine_schedule_with_warmup(optimizer=critic_optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=total_steps) + critic_lr_scheduler = get_cosine_schedule_with_warmup( + optimizer=critic_optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=total_steps + ) else: raise NotImplementedError(f"Warmup style {warmup_style} is not supported") @@ -992,7 +1178,9 @@ def init_model(self): from verl.workers.critic import DataParallelPPOCritic - self.critic_module, self.critic_optimizer, self.critic_lr_scheduler = self._build_critic_model_optimizer(self.config) + self.critic_module, self.critic_optimizer, self.critic_lr_scheduler = self._build_critic_model_optimizer( + self.config + ) if self._is_offload_param: offload_fsdp_model_to_cpu(self.critic_module) @@ -1001,7 +1189,9 @@ def init_model(self): offload_fsdp_optimizer(optimizer=self.critic_optimizer) log_gpu_memory_usage("After offload critic optimizer during init", logger=logger) - self.critic = DataParallelPPOCritic(config=self.config, critic_module=self.critic_module, critic_optimizer=self.critic_optimizer) + self.critic = DataParallelPPOCritic( + config=self.config, critic_module=self.critic_module, critic_optimizer=self.critic_optimizer + ) self.flops_counter = FlopsCounter(self.critic_model_config) self.checkpoint_manager = FSDPCheckpointManager( @@ -1009,13 +1199,14 @@ def init_model(self): optimizer=self.critic_optimizer, lr_scheduler=self.critic_lr_scheduler, processing_class=self.processor if self.processor is not None else self.tokenizer, - checkpoint_contents=self.config.checkpoint.contents, + checkpoint_config=self.config.checkpoint, ) @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) + @DistProfiler.annotate(color="cyan") def compute_values(self, data: DataProto): # Support all hardwares - data = data.to(get_torch_device().current_device()) + data = data.to(get_device_id()) if self._is_offload_param: load_fsdp_model_to_gpu(self.critic_module) @@ -1036,13 +1227,14 @@ def compute_values(self, data: DataProto): return output @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) + @DistProfiler.annotate(color="pink") def update_critic(self, data: DataProto): # Support all hardwares - data = data.to(get_torch_device().current_device()) + data = data.to(get_device_id()) if self._is_offload_param: load_fsdp_model_to_gpu(self.critic_module) if self._is_offload_optimizer: - load_fsdp_optimizer(optimizer=self.critic_optimizer, device_id=get_torch_device().current_device()) + load_fsdp_optimizer(optimizer=self.critic_optimizer, device_id=get_device_id()) # perform forward computation with self.ulysses_sharding_manager: @@ -1056,9 +1248,9 @@ def update_critic(self, data: DataProto): estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time) metrics["perf/mfu/critic"] = estimated_flops * self.config.ppo_epochs / promised_flops / self.world_size - self.critic_lr_scheduler.step() lr = self.critic_lr_scheduler.get_last_lr()[0] metrics["critic/lr"] = lr + self.critic_lr_scheduler.step() output = DataProto(batch=None, meta_info={"metrics": metrics}) output = self.ulysses_sharding_manager.postprocess_data(data=output) @@ -1078,7 +1270,9 @@ def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to if self._is_offload_param: load_fsdp_model_to_gpu(self.critic_module) - self.checkpoint_manager.save_checkpoint(local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep) + self.checkpoint_manager.save_checkpoint( + local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep + ) torch.distributed.barrier() if self._is_offload_param: @@ -1091,7 +1285,9 @@ def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=True) if self._is_offload_param: load_fsdp_model_to_gpu(self.critic_module) - self.checkpoint_manager.load_checkpoint(local_path=local_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load) + self.checkpoint_manager.load_checkpoint( + local_path=local_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load + ) torch.distributed.barrier() if self._is_offload_param: @@ -1102,17 +1298,23 @@ def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=True) # TODO(sgm): we may need to extract it to dp_reward_model.py -class RewardModelWorker(Worker): +class RewardModelWorker(Worker, DistProfilerExtension): """ Note that we only implement the reward model that is subclass of AutoModelForTokenClassification. """ def __init__(self, config): - super().__init__() + Worker.__init__(self) + DistProfilerExtension.__init__( + self, DistProfiler(rank=self.rank, config=omega_conf_to_dataclass(config.get("profiler"))) + ) + import torch.distributed if not torch.distributed.is_initialized(): - torch.distributed.init_process_group(backend="nccl" if is_cuda_available else "hccl") + torch.distributed.init_process_group( + backend=get_nccl_backend(), init_method=os.environ.get("DIST_INIT_METHOD", None) + ) self.config = config # build device mesh for Ulysses Sequence Parallel @@ -1126,7 +1328,9 @@ def __init__(self, config): self.ulysses_sequence_parallel_size = self.config.get("ulysses_sequence_parallel_size", 1) dp = world_size // self.ulysses_sequence_parallel_size if self.ulysses_sequence_parallel_size > 1: - self.ulysses_device_mesh = init_device_mesh(device_name, mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"]) + self.ulysses_device_mesh = init_device_mesh( + device_name, mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"] + ) self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) @@ -1151,7 +1355,9 @@ def _build_model(self, config): else: self._do_switch_chat_template = True input_tokenizer_local_path = copy_to_local(config.model.input_tokenizer, use_shm=use_shm) - self.input_tokenizer = hf_tokenizer(input_tokenizer_local_path, trust_remote_code=config.model.get("trust_remote_code", False)) + self.input_tokenizer = hf_tokenizer( + input_tokenizer_local_path, trust_remote_code=config.model.get("trust_remote_code", False) + ) self.tokenizer = hf_tokenizer(local_path, trust_remote_code=config.model.get("trust_remote_code", False)) trust_remote_code = config.model.get("trust_remote_code", False) @@ -1159,7 +1365,9 @@ def _build_model(self, config): model_config.num_labels = 1 # note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect - init_context = get_init_weight_context_manager(use_meta_tensor=not model_config.tie_word_embeddings, mesh=self.device_mesh) + init_context = get_init_weight_context_manager( + use_meta_tensor=not model_config.tie_word_embeddings, mesh=self.device_mesh + ) with init_context(), warnings.catch_warnings(): warnings.simplefilter("ignore") @@ -1191,11 +1399,11 @@ def _build_model(self, config): param_init_fn=init_fn, use_orig_params=False, auto_wrap_policy=auto_wrap_policy, - device_id=get_torch_device().current_device(), + device_id=get_device_id(), sharding_strategy=sharding_strategy, # zero3 sync_module_states=True, cpu_offload=CPUOffload(offload_params=True), - forward_prefetch=False, + forward_prefetch=self.config.model.fsdp_config.forward_prefetch, device_mesh=self.device_mesh, ) elif config.strategy == "fsdp2": @@ -1223,40 +1431,66 @@ def _forward_micro_batch(self, micro_batch): if is_cuda_available: from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input elif is_npu_available: - from transformers.integrations.npu_flash_attention import index_first_axis, pad_input, rearrange, unpad_input + from transformers.integrations.npu_flash_attention import ( + index_first_axis, + pad_input, + rearrange, + unpad_input, + ) - from verl.utils.ulysses import gather_outpus_and_unpad, ulysses_pad_and_slice_inputs + from verl.utils.ulysses import gather_outputs_and_unpad, ulysses_pad_and_slice_inputs with torch.no_grad(), torch.autocast(device_type=device_name, dtype=torch.bfloat16): input_ids = micro_batch["input_ids"] batch_size, seqlen = input_ids.shape attention_mask = micro_batch["attention_mask"] position_ids = micro_batch["position_ids"] + if position_ids.dim() == 3: # qwen2vl mrope + position_ids = position_ids.transpose(0, 1) # (bsz, 3, seqlen) -> (3, bsz, seqlen) if self.use_remove_padding: - input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), attention_mask) # input_ids_rmpad (total_nnz, ...) + input_ids_rmpad, indices, *_ = unpad_input( + input_ids.unsqueeze(-1), attention_mask + ) # input_ids_rmpad (total_nnz, ...) input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) # unpad the position_ids to align the rotary - position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices).transpose(0, 1) + if position_ids.dim() == 3: + position_ids_rmpad = ( + index_first_axis(rearrange(position_ids, "c b s ... -> (b s) c ..."), indices) + .transpose(0, 1) + .unsqueeze(1) + ) # (3, bsz, seqlen) -> (3, 1, bsz * seqlen) + else: + position_ids_rmpad = index_first_axis( + rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices + ).transpose(0, 1) # pad and slice the inputs if sp > 1 if self.ulysses_sequence_parallel_size > 1: - input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(input_ids_rmpad, position_ids_rmpad, sp_size=self.ulysses_sequence_parallel_size) + input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs( + input_ids_rmpad, position_ids_rmpad, sp_size=self.ulysses_sequence_parallel_size + ) # only pass input_ids and position_ids to enable flash_attn_varlen - output = self.reward_module(input_ids=input_ids_rmpad, attention_mask=None, position_ids=position_ids_rmpad, use_cache=False) # prevent model thinks we are generating + output = self.reward_module( + input_ids=input_ids_rmpad, attention_mask=None, position_ids=position_ids_rmpad, use_cache=False + ) reward_rmpad = output.logits reward_rmpad = reward_rmpad.squeeze(0) # (total_nnz) # gather output if sp > 1 if self.ulysses_sequence_parallel_size > 1: - reward_rmpad = gather_outpus_and_unpad(reward_rmpad, gather_dim=0, unpad_dim=0, padding_size=pad_size) + reward_rmpad = gather_outputs_and_unpad( + reward_rmpad, gather_dim=0, unpad_dim=0, padding_size=pad_size + ) # pad it back rm_score = pad_input(reward_rmpad, indices=indices, batch=batch_size, seqlen=seqlen).squeeze(-1) else: - output = self.reward_module(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False) + output = self.reward_module( + input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False + ) rm_score = output.logits # (batch_size, seq_len, 1) rm_score = rm_score.squeeze(-1) @@ -1271,6 +1505,8 @@ def _expand_to_token_level(self, data: DataProto, scores: torch.Tensor): attention_mask = data.batch["attention_mask"] position_ids = data.batch["position_ids"] response_length = data.batch["responses"].shape[-1] + if position_ids.dim() == 3: # qwen2vl mrope [bs, 3, seq_len] + position_ids = position_ids[:, 0, :] eos_mask_idx = torch.argmax(position_ids * attention_mask, dim=-1) # (bsz,) token_level_scores = torch.zeros_like(attention_mask, dtype=scores.dtype) # (bsz, seqlen) token_level_scores[torch.arange(batch_size), eos_mask_idx] = scores @@ -1318,7 +1554,9 @@ def _switch_chat_template(self, data: DataProto): chat.append({"role": "assistant", "content": response}) - prompt_with_chat_template = target_tokenizer.apply_chat_template(chat, add_generation_prompt=False, tokenize=False) + prompt_with_chat_template = target_tokenizer.apply_chat_template( + chat, add_generation_prompt=False, tokenize=False + ) if self.rank == 0 and i == 0: # for debugging purpose print(f"Switch template. chat: {prompt_with_chat_template}") @@ -1351,13 +1589,14 @@ def _switch_chat_template(self, data: DataProto): return DataProto.from_dict(rm_inputs) @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) + @DistProfiler.annotate(color="brown") def compute_rm_score(self, data: DataProto): import itertools from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches # Support all hardwares - data = data.to(get_torch_device().current_device()) + data = data.to(get_device_id()) if self._do_switch_chat_template: rm_data = self._switch_chat_template(data) else: @@ -1372,7 +1611,7 @@ def compute_rm_score(self, data: DataProto): rm_data = DataProto.from_dict(rm_inputs) # Support all hardwares - rm_data.batch = rm_data.batch.to(get_torch_device().current_device()) + rm_data.batch = rm_data.batch.to(get_device_id()) # perform forward computation with self.ulysses_sharding_manager: @@ -1432,17 +1671,39 @@ def _build_rollout(self, trust_remote_code=False): def generate_sequences(self, prompts: DataProto): raise NotImplementedError("AsyncActorRolloutRefWorker does not support generate_sequences") + # ============================ vLLM related ============================ + @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD) - def execute_method(self, method: Union[str, bytes], *args, **kwargs): + def execute_method(self, method: str | bytes, *args, **kwargs): """Called by ExternalRayDistributedExecutor collective_rpc.""" - if self.vllm_tp_rank == 0 and method != "execute_model": - print(f"[DP={self.vllm_dp_rank},TP={self.vllm_tp_rank}] execute_method: {method if isinstance(method, str) else 'Callable'}") return self.rollout.execute_method(method, *args, **kwargs) @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD) - def resume(self): - return self.rollout.resume() + def get_zeromq_address(self): + return self.rollout.get_zeromq_address() + + # ============================ SGLang related ============================ + + @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD, blocking=False) + async def chat_completion(self, json_request): + ret = await self.rollout.chat_completion(json_request) + return ret + + @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD, blocking=False) + async def generate(self, prompt_ids: list[int], sampling_params: dict[str, Any], request_id: str) -> list[int]: + ret = await self.rollout.generate(prompt_ids, sampling_params, request_id) + return ret + + @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD) + async def wake_up(self): + if self.config.rollout.free_cache_engine: + await self.rollout.wake_up() + # return something to block the caller + return True @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD) - def offload(self): - return self.rollout.offload() + async def sleep(self): + if self.config.rollout.free_cache_engine: + await self.rollout.sleep() + # return something to block the caller + return True diff --git a/verl/workers/megatron_workers.py b/verl/workers/megatron_workers.py index d079e8802..de7267dc9 100644 --- a/verl/workers/megatron_workers.py +++ b/verl/workers/megatron_workers.py @@ -15,23 +15,26 @@ The main entry point to run the PPO algorithm """ +import datetime import logging import os import time -import warnings +from typing import Any +import psutil import torch import torch.distributed from codetiming import Timer from megatron.core import parallel_state as mpu -from omegaconf import DictConfig +from omegaconf import DictConfig, OmegaConf, open_dict from verl import DataProto from verl.single_controller.base.decorator import Dispatch, register from verl.single_controller.base.megatron.worker import MegatronWorker from verl.utils import hf_tokenizer from verl.utils.checkpoint.megatron_checkpoint_manager import MegatronCheckpointManager -from verl.utils.debug import GPUMemoryLogger, log_gpu_memory_usage +from verl.utils.config import omega_conf_to_dataclass +from verl.utils.device import get_device_id, get_device_name, get_nccl_backend, get_torch_device from verl.utils.flops_counter import FlopsCounter from verl.utils.fs import copy_to_local from verl.utils.megatron_utils import ( @@ -40,8 +43,16 @@ offload_megatron_model_to_cpu, offload_megatron_optimizer, ) -from verl.utils.model import load_mcore_dist_weights, load_megatron_gptmodel_weights -from verl.utils.torch_functional import broadcast_dict_tensor +from verl.utils.model import get_hf_model_path, load_mcore_dist_weights, load_megatron_gptmodel_weights +from verl.utils.profiler import ( + DistProfiler, + DistProfilerExtension, + GPUMemoryLogger, + log_gpu_memory_usage, + simple_timer, +) +from verl.utils.profiler.performance import reduce_timing +from verl.utils.torch_functional import broadcast_dict_tensor # NOTE: added by Reasoning360 from verl.workers.actor.megatron_actor import MegatronPPOActor from verl.workers.critic.megatron_critic import MegatronPPOCritic from verl.workers.reward_model.megatron.reward_model import MegatronRewardModel @@ -59,7 +70,7 @@ def set_random_seed(seed): torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) - if torch.cuda.device_count() > 0: + if get_torch_device().device_count() > 0: from megatron.core import tensor_parallel tensor_parallel.model_parallel_cuda_manual_seed(seed) @@ -73,33 +84,37 @@ def megatron_pp_dummy_output(data: DataProto): from verl.single_controller.base.decorator import _make_dummy_data_proto if ( - mpu.get_pipeline_model_parallel_rank() != mpu.get_pipeline_model_parallel_world_size() - 1 # not the last stage - or mpu.get_tensor_model_parallel_rank() != 0 # not the first tensor parallel rank + mpu.get_pipeline_model_parallel_rank() != mpu.get_pipeline_model_parallel_world_size() - 1 # not the last stage + or mpu.get_tensor_model_parallel_rank() != 0 # not the first tensor parallel rank ): return _make_dummy_data_proto(data) return data -class ActorRolloutRefWorker(MegatronWorker): +class ActorRolloutRefWorker(MegatronWorker, DistProfilerExtension): """ This worker can be instantiated as a standalone actor or a standalone rollout or a standalone reference policy or a hybrid engine based on the config.rollout """ - def __init__(self, config: DictConfig, role: str): - super().__init__() + def __init__(self, config: DictConfig, role: str, **kwargs): + MegatronWorker.__init__(self) self.config = config # NOTE(sgm): We utilize colocate WorkerGroup by default. # As a result, Workers for different model share the same process. # Therefore, we only require one distribute initialization. - # To utilize different parallel startegy in different models: + # To utilize different parallel strategy in different models: # 1, users should disable WorkerDict; 2.assign different ResourcePool to different models, # 3. and apply the following patch in ray==2.10, https://github.com/ray-project/ray/pull/44385 if not torch.distributed.is_initialized(): rank = int(os.environ["LOCAL_RANK"]) - torch.distributed.init_process_group(backend="nccl") - torch.cuda.set_device(rank) + torch.distributed.init_process_group( + backend=get_nccl_backend(), + timeout=datetime.timedelta(seconds=self.config.get("nccl_timeout", 600)), + init_method=os.environ.get("DIST_INIT_METHOD", None), + ) + get_torch_device().set_device(rank) if self.config.actor.megatron.sequence_parallel: os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" @@ -124,6 +139,9 @@ def __init__(self, config: DictConfig, role: str): self._is_rollout = self.role in ["rollout", "actor_rollout", "actor_rollout_ref"] self._is_ref = self.role in ["ref", "actor_rollout_ref"] + profiler_config = omega_conf_to_dataclass(config.get("profiler")) + DistProfilerExtension.__init__(self, DistProfiler(rank=self.rank, config=profiler_config)) + # TODO(sgm): Currently, we only support reference model param offload # will support other offload later self._is_offload_param = False @@ -148,86 +166,133 @@ def __init__(self, config: DictConfig, role: str): self.config.ref.log_prob_micro_batch_size //= mpu.get_data_parallel_world_size() self.config.ref.log_prob_micro_batch_size_per_gpu = self.config.ref.log_prob_micro_batch_size else: - assert self.config.ref.get("log_prob_micro_batch_size_per_gpu", None) is not None, "Please note that in the ref policy configuration, `log_prob_micro_batch_size_per_gpu` and `log_prob_micro_batch_size` should not be None at the same time." + assert self.config.ref.get("log_prob_micro_batch_size_per_gpu", None) is not None, ( + "Please note that in the ref policy configuration, `log_prob_micro_batch_size_per_gpu` and " + "`log_prob_micro_batch_size` should not be None at the same time." + ) self._ref_is_offload_param = self.config.ref.megatron.get("param_offload", False) def _build_model_optimizer(self, model_path, optim_config, override_model_config, override_transformer_config): - from megatron.core.models.gpt.gpt_model import ModelType - - from verl.utils.megatron.optimizer import get_megatron_optimizer + from verl.utils.megatron.optimizer import get_megatron_optimizer, get_megatron_optimizer_param_scheduler from verl.utils.megatron_utils import get_model, init_megatron_optim_config from verl.utils.model import get_generation_config, print_model_size - self._init_hf_config_and_tf_config(model_path, model_path, self.dtype, override_model_config, override_transformer_config, self.config.model.get("trust_remote_code", False)) + self._init_hf_config_and_tf_config( + model_path, + model_path, + self.dtype, + override_model_config, + override_transformer_config, + self.config.model.get("trust_remote_code", False), + self.config.actor.megatron.use_mbridge, + ) self.generation_config = get_generation_config(self.local_path) - def megatron_actor_model_provider(pre_process, post_process): - from verl.models.mcore import init_mcore_model + def make_model(wrap_with_ddp=False): + if self.bridge is not None: + from verl.models.mcore.mbridge import freeze_moe_router - parallel_model = init_mcore_model(self.tf_config, self.hf_config, pre_process, post_process, share_embeddings_and_output_weights=self.share_embeddings_and_output_weights, value=False, freeze_moe_router=override_model_config.get("moe_config", {}).get("freeze_moe_router", False)) - parallel_model.cuda() - return parallel_model + post_model_creation_callbacks = [] + if override_model_config.get("moe_config", {}).get("freeze_moe_router", False): + post_model_creation_callbacks.append(freeze_moe_router) + return self.bridge.get_model( + post_model_creation_callbacks=post_model_creation_callbacks, wrap_with_ddp=wrap_with_ddp + ) + else: + + def megatron_actor_model_provider(pre_process, post_process): + from verl.models.mcore import init_mcore_model + + parallel_model = init_mcore_model( + self.tf_config, + self.hf_config, + pre_process, + post_process, + share_embeddings_and_output_weights=self.share_embeddings_and_output_weights, + value=False, + freeze_moe_router=override_model_config.get("moe_config", {}).get("freeze_moe_router", False), + ) + parallel_model.to(get_device_name()) + return parallel_model + + override_ddp_config = OmegaConf.to_container( + self.config.actor.megatron.get("override_ddp_config", OmegaConf.create()), resolve=True + ) + return get_model( + megatron_actor_model_provider, + wrap_with_ddp=wrap_with_ddp, + use_distributed_optimizer=self.config.actor.megatron.use_distributed_optimizer, + override_ddp_config=override_ddp_config, + ) - # Step 3: initialize the megatron model if self._is_actor and self._is_rollout: - actor_module = get_model( - megatron_actor_model_provider, - wrap_with_ddp=True, - use_distributed_optimizer=self.config.actor.megatron.use_distributed_optimizer, - ) + actor_module = make_model(wrap_with_ddp=True) print(f"actor_module: {len(actor_module)}") if self.config.actor.load_weight: if self.config.actor.megatron.use_dist_checkpointing: - load_mcore_dist_weights(actor_module, self.config.actor.megatron.dist_checkpointing_path, is_value_model=False) + load_mcore_dist_weights( + actor_module, self.config.actor.megatron.dist_checkpointing_path, is_value_model=False + ) else: - load_megatron_gptmodel_weights(self.config, self.hf_config, actor_module, params_dtype=self.dtype, is_value_model=False) + if self.bridge is not None: + local_model_path = get_hf_model_path(self.config) + self.bridge.load_weights(actor_module, local_model_path) + else: + load_megatron_gptmodel_weights( + self.config, self.hf_config, actor_module, params_dtype=self.dtype, is_value_model=False + ) if self.rank == 0: print_model_size(actor_module[0]) log_gpu_memory_usage("After MegatronPPOActor init", logger=logger) elif self._is_ref: print(f"self.config.ref.load_weight: {self.config.ref.load_weight}") - ref_module = get_model( - model_provider_func=megatron_actor_model_provider, - model_type=ModelType.encoder_or_decoder, - wrap_with_ddp=False, - use_distributed_optimizer=self.config.ref.megatron.use_distributed_optimizer, - ) - # ref_module = nn.ModuleList(ref_module) - + ref_module = make_model(wrap_with_ddp=False) if self.config.ref.load_weight: # should align with the actor: assert self.config.actor.load_weight == self.config.ref.load_weight print("load ref weight start") if self.config.ref.megatron.use_dist_checkpointing: - load_mcore_dist_weights(ref_module, self.config.ref.megatron.dist_checkpointing_path, is_value_model=False) + load_mcore_dist_weights( + ref_module, self.config.ref.megatron.dist_checkpointing_path, is_value_model=False + ) else: - load_megatron_gptmodel_weights(self.config, self.hf_config, ref_module, params_dtype=self.dtype, is_value_model=False) + if self.bridge is not None: + local_model_path = get_hf_model_path(self.config) + self.bridge.load_weights(ref_module, local_model_path) + else: + load_megatron_gptmodel_weights( + self.config, self.hf_config, ref_module, params_dtype=self.dtype, is_value_model=False + ) log_gpu_memory_usage("After ref module init", logger=logger) return ref_module, self.hf_config # TODO: add more optimizer args into config if self._is_actor: - optim_config = init_megatron_optim_config(optim_config) - actor_optimizer = get_megatron_optimizer(model=actor_module, config=optim_config) + optim_config_megatron = init_megatron_optim_config(optim_config) + actor_optimizer = get_megatron_optimizer(model=actor_module, config=optim_config_megatron) + actor_optimizer_scheduler = get_megatron_optimizer_param_scheduler( + optimizer=actor_optimizer, config=optim_config + ) else: optim_config = None actor_optimizer = None + actor_optimizer_scheduler = None log_gpu_memory_usage("After actor optimizer init", logger=logger) - return actor_module, actor_optimizer, self.hf_config, optim_config + return actor_module, actor_optimizer, actor_optimizer_scheduler, self.hf_config, optim_config def _build_rollout(self, trust_remote_code=False): from torch.distributed.device_mesh import init_device_mesh layer_name_mapping = { "qkv_layer_name": "self_attention.linear_qkv.", - "gate_proj_layer_name": "linear_fc1.weight", + "gate_proj_layer_name": "linear_fc1.", } if self.config.rollout.name == "vllm": from torch.distributed.device_mesh import init_device_mesh - from verl.workers.rollout.vllm_rollout import vllm_mode, vLLMRollout + from verl.workers.rollout.vllm_rollout import vLLMRollout from verl.workers.sharding_manager.megatron_vllm import MegatronVLLMShardingManager # NOTE(sgm): If the QKV and gate_up projection layer are concate together in actor, @@ -235,27 +300,26 @@ def _build_rollout(self, trust_remote_code=False): infer_tp = self.config.rollout.tensor_model_parallel_size dp = self.world_size // infer_tp - assert self.world_size % infer_tp == 0, f"rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}" - rollout_device_mesh = init_device_mesh("cuda", mesh_shape=(dp, infer_tp), mesh_dim_names=["dp", "infer_tp"]) + assert self.world_size % infer_tp == 0, ( + f"rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}" + ) + rollout_device_mesh = init_device_mesh( + get_device_name(), mesh_shape=(dp, infer_tp), mesh_dim_names=["dp", "infer_tp"] + ) log_gpu_memory_usage("Before building vllm rollout", logger=None) local_path = copy_to_local(self.config.model.path, use_shm=self.config.model.get("use_shm", False)) - if vllm_mode == "customized": - rollout = vLLMRollout( - actor_module=self.actor_module, - config=self.config.rollout, - tokenizer=self.tokenizer, - model_hf_config=self.actor_model_config, - ) - elif vllm_mode == "spmd": - rollout = vLLMRollout( - model_path=local_path, - config=self.config.rollout, - tokenizer=self.tokenizer, - model_hf_config=self.actor_model_config, - device_mesh=rollout_device_mesh, - trust_remote_code=trust_remote_code, - ) + from verl.workers.rollout.vllm_rollout import vLLMAsyncRollout + + vllm_rollout_cls = vLLMRollout if self.config.rollout.mode == "sync" else vLLMAsyncRollout + rollout = vllm_rollout_cls( + model_path=local_path, + config=self.config.rollout, + tokenizer=self.tokenizer, + model_hf_config=self.actor_model_config, + device_mesh=rollout_device_mesh, + trust_remote_code=trust_remote_code, + ) log_gpu_memory_usage("After building vllm rollout", logger=logger) # perform weight resharding between actor and rollout @@ -266,39 +330,43 @@ def _build_rollout(self, trust_remote_code=False): inference_engine=rollout.inference_engine, model_config=self.actor_model_config, transformer_config=self.tf_config, + rollout_config=self.config.rollout, layer_name_mapping=layer_name_mapping, actor_module=self.actor.actor_module, weight_converter=weight_converter, + device_mesh=rollout_device_mesh, + offload_param=self._is_offload_param, + bridge=self.bridge, ) log_gpu_memory_usage("After building sharding manager", logger=logger) - elif self.config.rollout.name in ["sglang", "sglang_async"]: - if self.config.rollout.name == "sglang_async": - warnings.warn( - "'sglang_async' has been deprecated and merged into 'sglang'. Please use 'sglang' going forward.", - DeprecationWarning, - stacklevel=2, - ) + elif self.config.rollout.name == "sglang": from verl.workers.rollout.sglang_rollout import SGLangRollout - # NOTE(linjunrong): Due to recent fp8 support in SGLang. Now importing any symbol relate to SGLang's model_runner would check CUDA device capability. - # However, due to verl's setting, the main process of ray can not find any CUDA device, which would potentially lead to: - # "RuntimeError: No CUDA GPUs are available". - # For this reason, sharding_manager.__init__ should not import FSDPSGLangShardingManager and we import it here use the abs path. + # NOTE(linjunrong): Due to recent fp8 support in SGLang. Now importing any symbol relate to SGLang's + # model_runner would check CUDA device capability. + # However, due to verl's setting, the main process of ray can not find any CUDA device, which would + # potentially lead to: "RuntimeError: No CUDA GPUs are available". + # For this reason, sharding_manager.__init__ should not import FSDPSGLangShardingManager and we import it + # here use the abs path. # check: https://github.com/sgl-project/sglang/blob/00f42707eaddfc2c0528e5b1e0094025c640b7a0/python/sglang/srt/layers/quantization/fp8_utils.py#L76 from verl.workers.sharding_manager.megatron_sglang import MegatronSGLangShardingManager infer_tp = self.config.rollout.tensor_model_parallel_size dp = self.world_size // infer_tp - assert self.world_size % infer_tp == 0, f"rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}" - rollout_device_mesh = init_device_mesh("cpu", mesh_shape=(dp, infer_tp, 1), mesh_dim_names=("dp", "tp", "pp")) + assert self.world_size % infer_tp == 0, ( + f"rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}" + ) + rollout_device_mesh = init_device_mesh( + "cpu", mesh_shape=(dp, infer_tp, 1), mesh_dim_names=("dp", "tp", "pp") + ) local_path = copy_to_local(self.config.model.path) log_gpu_memory_usage(f"Before building {self.config.rollout.name} rollout", logger=None) rollout = SGLangRollout( actor_module=local_path, config=self.config.rollout, - tokenizer=self.tokenizer, + processing_class=self.processor if self.processor is not None else self.tokenizer, model_hf_config=self.actor_model_config, trust_remote_code=trust_remote_code, device_mesh=rollout_device_mesh, @@ -312,15 +380,18 @@ def _build_rollout(self, trust_remote_code=False): actor_module=self.actor.actor_module, inference_engine=rollout._engine, model_config=self.actor_model_config, + rollout_config=self.config.rollout, transformer_config=self.tf_config, layer_name_mapping=layer_name_mapping, weight_converter=weight_converter, + bridge=self.bridge, device_mesh=rollout_device_mesh, + offload_param=self._is_offload_param, ) log_gpu_memory_usage("After building sharding manager", logger=logger) else: raise NotImplementedError("Only vllmRollout is supported with Megatron now") - + print(f"rollout and sharding manager init done sharding_manager: {sharding_manager}") return rollout, sharding_manager @register(dispatch_mode=Dispatch.ONE_TO_ALL) @@ -331,15 +402,17 @@ def init_model(self): importlib.import_module(self.config.model.external_lib) - from omegaconf import OmegaConf - from verl.utils.torch_dtypes import PrecisionType override_model_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create())) if self._is_actor: - override_transformer_config = OmegaConf.to_container(self.config.actor.megatron.get("override_transformer_config", OmegaConf.create()), resolve=True) + override_transformer_config = OmegaConf.to_container( + self.config.actor.megatron.get("override_transformer_config", OmegaConf.create()), resolve=True + ) elif self._is_ref: - override_transformer_config = OmegaConf.to_container(self.config.ref.megatron.get("override_transformer_config", OmegaConf.create()), resolve=True) + override_transformer_config = OmegaConf.to_container( + self.config.ref.megatron.get("override_transformer_config", OmegaConf.create()), resolve=True + ) else: override_transformer_config = None self.param_dtype = torch.bfloat16 @@ -348,7 +421,13 @@ def init_model(self): if self._is_actor or self._is_rollout: # we need the model for actor and rollout optim_config = self.config.actor.optim if self._is_actor else None - self.actor_module, self.actor_optimizer, self.actor_model_config, self.actor_optim_config = self._build_model_optimizer( + ( + self.actor_module, + self.actor_optimizer, + self.actor_optimizer_scheduler, + self.actor_model_config, + self.actor_optim_config, + ) = self._build_model_optimizer( model_path=self.config.model.path, optim_config=optim_config, override_model_config=override_model_config, @@ -362,6 +441,10 @@ def init_model(self): log_gpu_memory_usage("After offload actor optimizer during init", logger=logger) if self._is_actor: + OmegaConf.set_struct(self.config.actor, True) + with open_dict(self.config.actor): + use_fused_kernels = self.config.model.get("use_fused_kernels", False) + self.config.actor.use_fused_kernels = use_fused_kernels self.actor = MegatronPPOActor( config=self.config.actor, model_config=self.actor_model_config, @@ -373,7 +456,11 @@ def init_model(self): log_gpu_memory_usage("After MegatronPPOActor init", logger=logger) if self._is_rollout: - self.rollout, self.sharding_manager = self._build_rollout(trust_remote_code=self.config.model.get("trust_remote_code", False)) + self.rollout, self.sharding_manager = self._build_rollout( + trust_remote_code=self.config.model.get("trust_remote_code", False) + ) + # used for sleep/wake_up + self.rollout.sharding_manager = self.sharding_manager log_gpu_memory_usage("After rollout init", logger=logger) if self._is_ref: @@ -400,23 +487,30 @@ def init_model(self): self.flops_counter = FlopsCounter(self.actor_model_config) self.checkpoint_mananager = MegatronCheckpointManager( config=self.config, + checkpoint_config=self.config.actor.checkpoint, model_config=self.actor_model_config, + transformer_config=self.tf_config, role="actor", model=self.actor_module, arch=self.architectures[0], hf_config=self.hf_config, param_dtype=self.param_dtype, share_embeddings_and_output_weights=self.share_embeddings_and_output_weights, - tokenizer=self.tokenizer, + processing_class=self.processor if self.processor is not None else self.tokenizer, optimizer=self.actor_optimizer, + optimizer_scheduler=self.actor_optimizer_scheduler, use_distributed_optimizer=self.config.actor.megatron.use_distributed_optimizer, - checkpoint_contents=self.config.actor.checkpoint.contents, + use_checkpoint_opt_param_scheduler=self.config.actor.optim.use_checkpoint_opt_param_scheduler, + bridge=self.bridge, + use_dist_checkpointing=self.config.actor.megatron.use_dist_checkpointing, ) - torch.cuda.empty_cache() + get_torch_device().empty_cache() log_gpu_memory_usage("After init_model finish", logger=logger) + # Modified by Reasoning360. @register(dispatch_mode=Dispatch.MEGATRON_PP_DUMMY_PROTO) @GPUMemoryLogger(role="update_actor", logger=logger) + @DistProfiler.annotate(color="red") def update_actor(self, data: DataProto): assert self._is_actor if self._is_offload_param: @@ -425,9 +519,12 @@ def update_actor(self, data: DataProto): if self._is_offload_optimizer: load_megatron_optimizer(self.actor_optimizer) log_gpu_memory_usage("After load actor optimizer during update_actor", logger=logger) - data.batch = data.batch.cuda() + data.batch = data.batch.to(get_device_name()) - broadcast_dict_tensor(data.batch, src=mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group()) + # NOTE: added by Reasoning360. + broadcast_dict_tensor( + data.batch, src=mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group() + ) micro_batch_size = self.config.actor.ppo_micro_batch_size_per_gpu data.meta_info["micro_batch_size"] = micro_batch_size @@ -438,6 +535,13 @@ def update_actor(self, data: DataProto): global_num_tokens = data.meta_info["global_token_num"] estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time) metrics["perf/mfu/actor"] = estimated_flops * self.config.actor.ppo_epochs / promised_flops / self.world_size + metrics["perf/max_memory_allocated_gb"] = get_torch_device().max_memory_allocated() / (1024**3) + metrics["perf/max_memory_reserved_gb"] = get_torch_device().max_memory_reserved() / (1024**3) + metrics["perf/cpu_memory_used_gb"] = psutil.virtual_memory().used / (1024**3) + from verl.utils.megatron.optimizer import get_megatron_last_lr + + metrics["actor/lr"] = get_megatron_last_lr(self.actor_optimizer) + self.actor_optimizer_scheduler.step(1) # TODO: here, we should return all metrics output = DataProto(meta_info={"metrics": metrics}) @@ -450,51 +554,50 @@ def update_actor(self, data: DataProto): offload_megatron_optimizer(self.actor_optimizer) log_gpu_memory_usage("After offload actor optimizer during update_actor", logger=logger) - torch.cuda.empty_cache() + get_torch_device().empty_cache() return output @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) @GPUMemoryLogger(role="generate_sequences", logger=logger) + @DistProfiler.annotate(color="red") def generate_sequences(self, prompts: DataProto): assert self._is_rollout - if self._is_offload_param: - load_megatron_model_to_gpu(self.actor_module) - log_gpu_memory_usage("After load actor params during generate_sequences", logger=logger) - prompts.batch = prompts.batch.cuda() + prompts.batch = prompts.batch.to(get_device_name()) meta_info = { - "eos_token_id": self.generation_config.eos_token_id if self.generation_config is not None else self.tokenizer.eos_token_id, - "pad_token_id": self.generation_config.pad_token_id if self.generation_config is not None else self.tokenizer.pad_token_id, + "eos_token_id": self.generation_config.eos_token_id + if self.generation_config is not None + else self.tokenizer.eos_token_id, + "pad_token_id": self.generation_config.pad_token_id + if self.generation_config is not None + else self.tokenizer.pad_token_id, } prompts.meta_info.update(meta_info) if self._is_offload_optimizer: offload_megatron_optimizer(self.actor_optimizer) + timing_generate = {} with self.sharding_manager: - if self._is_offload_param: - offload_megatron_model_to_cpu(self.actor_module) log_gpu_memory_usage("After entering sharding manager", logger=logger) - - # (zhangchi.usc1992) wake up kv cache here. Currently only support vllm. - # Will support sglang once separate wakeup of model weights and kv cache is supported - # This API should be exposed by the rollout. Will rewrite this part when we refactor after v0.4 release. - # Currently, we hack here to support running large models (QWen3-236b and DeepSeek-671b) - if self.config.rollout.name == "vllm": - import inspect - - if "tags" in inspect.signature(self.rollout.inference_engine.wake_up).parameters: - self.rollout.inference_engine.wake_up(tags=["kv_cache"]) - prompts = self.sharding_manager.preprocess_data(prompts) - output = self.rollout.generate_sequences(prompts=prompts) + with simple_timer("generate_sequences", timing_generate): + output = self.rollout.generate_sequences(prompts=prompts) output = self.sharding_manager.postprocess_data(output) + log_gpu_memory_usage("After rollout generation", logger=logger) + timing_generate.update(self.sharding_manager.timing) + # We calculate the average timing across all ranks + # to make sure meta_info["timing"] is the same + timing_generate = reduce_timing(timing_generate) + output.meta_info["timing"] = timing_generate output = output.to("cpu") # clear kv cache - torch.cuda.empty_cache() + get_torch_device().empty_cache() return output + # Modified by Reasoning360. @register(dispatch_mode=Dispatch.MEGATRON_PP_DUMMY_PROTO) @GPUMemoryLogger(role="compute_ref_log_prob", logger=logger) + @DistProfiler.annotate(color="olive") def compute_ref_log_prob(self, data: DataProto): assert self._is_ref if self._ref_is_offload_param: @@ -505,9 +608,11 @@ def compute_ref_log_prob(self, data: DataProto): data.meta_info["max_token_len"] = self.config.ref.log_prob_max_token_len_per_gpu data.meta_info["use_dynamic_bsz"] = self.config.ref.log_prob_use_dynamic_bsz data.meta_info["temperature"] = self.config.rollout.temperature - data = data.to(torch.cuda.current_device()) - - broadcast_dict_tensor(data.batch, src=mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group()) + data = data.to(get_device_id()) + # NOTE: added by Reasoning360. + broadcast_dict_tensor( + data.batch, src=mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group() + ) # NOTE: this function internally broadcasts the last stage's input and output to all ranks. output, _ = self.ref_policy.compute_log_prob(data=data, calculate_entropy=False) @@ -516,11 +621,14 @@ def compute_ref_log_prob(self, data: DataProto): if self._ref_is_offload_param: offload_megatron_model_to_cpu(self.ref_module) log_gpu_memory_usage("After offload ref params and grad during compute_ref_log_prob", logger=logger) - torch.cuda.empty_cache() + get_torch_device().empty_cache() + # NOTE: added by Reasoning360. return megatron_pp_dummy_output(output) + # Modified by Reasoning360. @register(dispatch_mode=Dispatch.MEGATRON_PP_DUMMY_PROTO) @GPUMemoryLogger(role="compute_log_prob", logger=logger) + @DistProfiler.annotate(color="blue") def compute_log_prob(self, data: DataProto): assert self._is_actor if self._is_offload_param: @@ -531,26 +639,35 @@ def compute_log_prob(self, data: DataProto): data.meta_info["max_token_len"] = self.config.rollout.log_prob_max_token_len_per_gpu data.meta_info["use_dynamic_bsz"] = self.config.rollout.log_prob_use_dynamic_bsz data.meta_info["temperature"] = self.config.rollout.temperature - data = data.to(torch.cuda.current_device()) + data = data.to(get_device_id()) - broadcast_dict_tensor(data.batch, src=mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group()) + # NOTE: added by Reasoning360. + broadcast_dict_tensor( + data.batch, src=mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group() + ) - # NOTE: this function internally broadcasts the last stage's input and output to all ranks. output, entropys = self.actor.compute_log_prob(data=data, calculate_entropy=True) - output = DataProto.from_dict(tensors={"old_log_probs": output, "entropys": entropys}, meta_info={"temperature": self.config.rollout.temperature}) + output = DataProto.from_dict( + tensors={"old_log_probs": output, "entropys": entropys}, + meta_info={"temperature": self.config.rollout.temperature}, + ) output = output.to("cpu") # clear kv cache if self._is_offload_param: offload_megatron_model_to_cpu(self.actor_module) log_gpu_memory_usage("After offload actor params and grad during compute_log_prob", logger=logger) - torch.cuda.empty_cache() + get_torch_device().empty_cache() + + # NOTE: modified by Reasoning360. return megatron_pp_dummy_output(output) @register(dispatch_mode=Dispatch.ONE_TO_ALL) def load_checkpoint(self, checkpoint_path, hdfs_path=None, del_local_after_load=True): if self._is_offload_param: load_megatron_model_to_gpu(self.actor_module) - self.checkpoint_mananager.load_checkpoint(local_path=checkpoint_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load) + self.checkpoint_mananager.load_checkpoint( + local_path=checkpoint_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load + ) if self._is_offload_param: offload_megatron_model_to_cpu(self.actor_module) if self._is_offload_optimizer: @@ -564,27 +681,95 @@ def load_pretrained_model(self, checkpoint_path, del_local_after_load=True): def save_checkpoint(self, checkpoint_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None): if self._is_offload_param: load_megatron_model_to_gpu(self.actor_module) - self.checkpoint_mananager.save_checkpoint(local_path=checkpoint_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep) + self.checkpoint_mananager.save_checkpoint( + local_path=checkpoint_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep + ) torch.distributed.barrier() if self._is_offload_param: offload_megatron_model_to_cpu(self.actor_module) -class CriticWorker(MegatronWorker): +class AsyncActorRolloutRefWorker(ActorRolloutRefWorker): + def _build_rollout(self, trust_remote_code=False): + rollout, rollout_sharding_manager = super()._build_rollout(trust_remote_code) + + # NOTE: rollout is not actually initialized here, it's deferred + # to be initialized by AsyncvLLMServer. + + self.vllm_tp_size = self.config.rollout.tensor_model_parallel_size + self.vllm_dp_rank = int(os.environ["RANK"]) // self.vllm_tp_size + self.vllm_tp_rank = int(os.environ["RANK"]) % self.vllm_tp_size + + # used for sleep/wake_up + rollout.sharding_manager = rollout_sharding_manager + + return rollout, rollout_sharding_manager + + # ============================ vLLM related ============================ + + @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD) + def execute_method(self, method: str | bytes, *args, **kwargs): + """Called by ExternalRayDistributedExecutor collective_rpc.""" + if self.vllm_tp_rank == 0 and method != "execute_model": + print( + f"[DP={self.vllm_dp_rank},TP={self.vllm_tp_rank}] execute_method: " + f"{method if isinstance(method, str) else 'Callable'}" + ) + return self.rollout.execute_method(method, *args, **kwargs) + + @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD) + def get_zeromq_address(self): + return self.rollout.get_zeromq_address() + + # ============================ SGLang related ============================ + + @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD, blocking=False) + async def chat_completion(self, json_request): + ret = await self.rollout.chat_completion(json_request) + return ret + + @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD, blocking=False) + async def generate(self, prompt_ids: list[int], sampling_params: dict[str, Any], request_id: str) -> list[int]: + ret = await self.rollout.generate(prompt_ids, sampling_params, request_id) + return ret + + @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD) + async def wake_up(self): + if self.config.rollout.free_cache_engine: + await self.rollout.wake_up() + # return something to block the caller + return True + + @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD) + async def sleep(self): + if self.config.rollout.free_cache_engine: + await self.rollout.sleep() + # return something to block the caller + return True + + +class CriticWorker(MegatronWorker, DistProfilerExtension): def __init__(self, config): - super().__init__() + MegatronWorker.__init__(self) + DistProfilerExtension.__init__( + self, DistProfiler(rank=self.rank, config=omega_conf_to_dataclass(config.get("profiler"))) + ) self.config = config # NOTE(sgm): We utilize colocate WorkerGroup by default. # As a result, Workers for different model share the same process. # Therefore, we only require one distribute initialization. - # To utilize different parallel startegy in different models: + # To utilize different parallel strategy in different models: # 1, users should disable WorkerDict; 2.assign different ResourcePool to different models, # 3. and apply the following patch in ray==2.10, https://github.com/ray-project/ray/pull/44385 if not torch.distributed.is_initialized(): rank = int(os.environ["LOCAL_RANK"]) - torch.distributed.init_process_group(backend="nccl") - torch.cuda.set_device(rank) + torch.distributed.init_process_group( + backend=get_nccl_backend(), + timeout=datetime.timedelta(seconds=self.config.get("nccl_timeout", 600)), + init_method=os.environ.get("DIST_INIT_METHOD", None), + ) + get_torch_device().set_device(rank) if self.config.megatron.sequence_parallel: os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" @@ -615,29 +800,62 @@ def __init__(self, config): # TODO(sgm): support critic model offload - def _build_critic_model_optimizer(self, model_path, optim_config, override_model_config, override_transformer_config): + def _build_critic_model_optimizer( + self, model_path, optim_config, override_model_config, override_transformer_config + ): from megatron.core.models.gpt.gpt_model import ModelType - from verl.utils.megatron.optimizer import get_megatron_optimizer + from verl.utils.megatron.optimizer import get_megatron_optimizer, get_megatron_optimizer_param_scheduler from verl.utils.megatron_utils import get_model, init_megatron_optim_config from verl.utils.model import print_model_size - self._init_hf_config_and_tf_config(model_path, self.config.model.tokenizer_path, self.dtype, override_model_config, override_transformer_config, self.config.model.get("trust_remote_code", False)) + self._init_hf_config_and_tf_config( + model_path, + self.config.model.tokenizer_path, + self.dtype, + override_model_config, + override_transformer_config, + self.config.model.get("trust_remote_code", False), + self.config.megatron.use_mbridge, + ) - def megatron_critic_model_provider(pre_process, post_process): - from verl.models.mcore import init_mcore_model + if self.bridge is not None: + from verl.models.mcore.mbridge import freeze_moe_router, make_value_model - parallel_model = init_mcore_model(self.tf_config, self.hf_config, pre_process, post_process, share_embeddings_and_output_weights=False, value=True, freeze_moe_router=override_model_config.get("moe_config", {}).get("freeze_moe_router", False)) - parallel_model.cuda() - return parallel_model + post_model_creation_callbacks = [make_value_model] + if override_model_config.get("moe_config", {}).get("freeze_moe_router", False): + post_model_creation_callbacks.append(freeze_moe_router) + critic_module = self.bridge.get_model( + post_model_creation_callbacks=post_model_creation_callbacks, wrap_with_ddp=True + ) + else: - # Step 3: initialize the megatron model - critic_module = get_model( - model_provider_func=megatron_critic_model_provider, - model_type=ModelType.encoder_or_decoder, - wrap_with_ddp=True, - use_distributed_optimizer=self.config.megatron.use_distributed_optimizer, - ) + def megatron_critic_model_provider(pre_process, post_process): + from verl.models.mcore import init_mcore_model + + parallel_model = init_mcore_model( + self.tf_config, + self.hf_config, + pre_process, + post_process, + share_embeddings_and_output_weights=False, + value=True, + freeze_moe_router=override_model_config.get("moe_config", {}).get("freeze_moe_router", False), + ) + parallel_model.to(get_device_name()) + return parallel_model + + override_ddp_config = OmegaConf.to_container( + self.config.megatron.get("override_ddp_config", OmegaConf.create()), resolve=True + ) + # Step 3: initialize the megatron model + critic_module = get_model( + model_provider_func=megatron_critic_model_provider, + model_type=ModelType.encoder_or_decoder, + wrap_with_ddp=True, + use_distributed_optimizer=self.config.megatron.use_distributed_optimizer, + override_ddp_config=override_ddp_config, + ) # note that here critic_module will be a list to be compatible with the construction of interleaved pp (vpp). # but here, we do not use pp (vpp) yet. For simplicity, we remove the list # critic_module = nn.ModuleList(critic_module) @@ -645,9 +863,17 @@ def megatron_critic_model_provider(pre_process, post_process): if self.config.load_weight: t0 = time.time() if self.config.megatron.use_dist_checkpointing: - load_mcore_dist_weights(critic_module, self.config.megatron.dist_checkpointing_path, is_value_model=True) + load_mcore_dist_weights( + critic_module, self.config.megatron.dist_checkpointing_path, is_value_model=True + ) else: - load_megatron_gptmodel_weights(self.config, self.hf_config, critic_module, params_dtype=self.dtype, is_value_model=True) + if self.bridge is not None: + local_model_path = get_hf_model_path(self.config) + self.bridge.load_weights(critic_module, local_model_path) + else: + load_megatron_gptmodel_weights( + self.config, self.hf_config, critic_module, params_dtype=self.dtype, is_value_model=True + ) t1 = time.time() if torch.distributed.get_rank() == 0: print(f"critic load_weight time: {t1 - t0}") @@ -655,15 +881,17 @@ def megatron_critic_model_provider(pre_process, post_process): print_model_size(critic_module[0]) # TODO: add more optimizer args into config - optim_config = init_megatron_optim_config(optim_config) - critic_optimizer = get_megatron_optimizer(model=critic_module, config=optim_config) - torch.cuda.empty_cache() - return critic_module, critic_optimizer, self.hf_config, optim_config + optim_config_megatron = init_megatron_optim_config(optim_config) + critic_optimizer = get_megatron_optimizer(model=critic_module, config=optim_config_megatron) + critic_optimizer_scheduler = get_megatron_optimizer_param_scheduler( + optimizer=critic_optimizer, config=optim_config + ) + get_torch_device().empty_cache() + return critic_module, critic_optimizer, critic_optimizer_scheduler, self.hf_config, optim_config @register(dispatch_mode=Dispatch.ONE_TO_ALL) def init_model(self): # create critic - from omegaconf import OmegaConf from verl.utils.torch_dtypes import PrecisionType @@ -673,10 +901,18 @@ def init_model(self): importlib.import_module(self.config.model.external_lib) override_model_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create())) - override_transformer_config = OmegaConf.to_container(self.config.megatron.get("override_transformer_config", OmegaConf.create()), resolve=True) + override_transformer_config = OmegaConf.to_container( + self.config.megatron.get("override_transformer_config", OmegaConf.create()), resolve=True + ) self.param_dtype = torch.bfloat16 self.dtype = PrecisionType.to_dtype(self.param_dtype) - self.critic_module, self.critic_optimizer, self.critic_model_config, critic_optimizer_config = self._build_critic_model_optimizer( + ( + self.critic_module, + self.critic_optimizer, + self.critic_optimizer_scheduler, + self.critic_model_config, + critic_optimizer_config, + ) = self._build_critic_model_optimizer( model_path=self.config.model.path, optim_config=self.config.optim, override_model_config=override_model_config, @@ -699,26 +935,32 @@ def init_model(self): self.flops_counter = FlopsCounter(self.critic_model_config) self.checkpoint_mananager = MegatronCheckpointManager( config=self.config, + checkpoint_config=self.config.checkpoint, model_config=self.critic_model_config, + transformer_config=self.tf_config, role="critic", model=self.critic_module, arch=self.architectures[0], hf_config=self.hf_config, param_dtype=self.param_dtype, share_embeddings_and_output_weights=False, - tokenizer=self.tokenizer, + processing_class=self.processor if self.processor is not None else self.tokenizer, optimizer=self.critic_optimizer, + optimizer_scheduler=self.critic_optimizer_scheduler, use_distributed_optimizer=self.config.megatron.use_distributed_optimizer, - checkpoint_contents=self.config.checkpoint.contents, + use_checkpoint_opt_param_scheduler=self.config.optim.use_checkpoint_opt_param_scheduler, + bridge=self.bridge, + use_dist_checkpointing=self.config.megatron.use_dist_checkpointing, ) @register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO) + @DistProfiler.annotate(color="cyan") def compute_values(self, data: DataProto): micro_batch_size = self.config.ppo_micro_batch_size_per_gpu data.meta_info["micro_batch_size"] = micro_batch_size data.meta_info["max_token_len"] = self.config.forward_max_token_len_per_gpu data.meta_info["use_dynamic_bsz"] = self.config.use_dynamic_bsz - data = data.to(torch.cuda.current_device()) + data = data.to(get_device_id()) if self._is_offload_param: load_megatron_model_to_gpu(self.critic_module) values = self.critic.compute_values(data=data) @@ -729,8 +971,9 @@ def compute_values(self, data: DataProto): return output @register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO) + @DistProfiler.annotate(color="pink") def update_critic(self, data: DataProto): - data = data.to(torch.cuda.current_device()) + data = data.to(get_device_id()) if self._is_offload_param: load_megatron_model_to_gpu(self.critic_module) @@ -744,6 +987,11 @@ def update_critic(self, data: DataProto): global_num_tokens = data.meta_info["global_token_num"] estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time) metrics["perf/mfu/critic"] = estimated_flops * self.config.ppo_epochs / promised_flops / self.world_size + from verl.utils.megatron.optimizer import get_megatron_last_lr + + metrics["critic/lr"] = get_megatron_last_lr(self.critic_optimizer) + self.critic_optimizer_scheduler.step(1) + output = DataProto(batch=None, meta_info={"metrics": metrics}) if self._is_offload_param: @@ -757,7 +1005,9 @@ def update_critic(self, data: DataProto): def load_checkpoint(self, checkpoint_path, hdfs_path=None, del_local_after_load=True): if self._is_offload_param: load_megatron_model_to_gpu(self.critic_module) - self.checkpoint_mananager.load_checkpoint(local_path=checkpoint_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load) + self.checkpoint_mananager.load_checkpoint( + local_path=checkpoint_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load + ) if self._is_offload_param: offload_megatron_model_to_cpu(self.critic_module) if self._is_offload_optimizer: @@ -767,30 +1017,39 @@ def load_checkpoint(self, checkpoint_path, hdfs_path=None, del_local_after_load= def save_checkpoint(self, checkpoint_path, hdfs_path=None, global_steps=0, max_ckpt_to_keep=None): if self._is_offload_param: load_megatron_model_to_gpu(self.critic_module) - self.checkpoint_mananager.save_checkpoint(local_path=checkpoint_path, hdfs_path=hdfs_path, global_step=global_steps, max_ckpt_to_keep=max_ckpt_to_keep) + self.checkpoint_mananager.save_checkpoint( + local_path=checkpoint_path, hdfs_path=hdfs_path, global_step=global_steps, max_ckpt_to_keep=max_ckpt_to_keep + ) if self._is_offload_param: offload_megatron_model_to_cpu(self.critic_module) -class RewardModelWorker(MegatronWorker): +class RewardModelWorker(MegatronWorker, DistProfilerExtension): """ Note that we only implement the reward model that is subclass of AutoModelForSequenceClassification. """ def __init__(self, config): - super().__init__() + MegatronWorker.__init__(self) + DistProfilerExtension.__init__( + self, DistProfiler(rank=self.rank, config=omega_conf_to_dataclass(config.get("profiler"))) + ) self.config = config # NOTE(sgm): We utilize colocate WorkerGroup by default. # As a result, Workers for different model share the same process. # Therefore, we only require one distribute initialization. - # To utilize different parallel startegy in different models: + # To utilize different parallel strategy in different models: # 1, users should disable WorkerDict; 2.assign different ResourcePool to different models, # 3. and apply the following patch in ray==2.10, https://github.com/ray-project/ray/pull/44385 if not torch.distributed.is_initialized(): rank = int(os.environ["LOCAL_RANK"]) - torch.distributed.init_process_group(backend="nccl") - torch.cuda.set_device(rank) + torch.distributed.init_process_group( + backend=get_nccl_backend(), + timeout=datetime.timedelta(seconds=self.config.get("nccl_timeout", 600)), + init_method=os.environ.get("DIST_INIT_METHOD", None), + ) + get_torch_device().set_device(rank) if self.config.megatron.sequence_parallel: os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" @@ -818,47 +1077,70 @@ def _build_rm_model(self, model_path, tokenizer, override_model_config, override from verl.utils.megatron_utils import get_model - self._init_hf_config_and_tf_config(model_path, tokenizer, self.dtype, override_model_config, override_transformer_config, self.config.model.get("trust_remote_code", False)) + self._init_hf_config_and_tf_config( + model_path, + tokenizer, + self.dtype, + override_model_config, + override_transformer_config, + self.config.model.get("trust_remote_code", False), + self.config.megatron.use_mbridge, + ) + if self.bridge is not None: + from verl.models.mcore.mbridge import freeze_moe_router, make_value_model + + post_model_creation_callbacks = [make_value_model] + if override_model_config.get("moe_config", {}).get("freeze_moe_router", False): + post_model_creation_callbacks.append(freeze_moe_router) + reward_model = self.bridge.get_model( + post_model_creation_callbacks=post_model_creation_callbacks, wrap_with_ddp=False + ) + else: - def megatron_rm_model_provider(pre_process, post_process): - from verl.models.mcore import init_mcore_model + def megatron_rm_model_provider(pre_process, post_process): + from verl.models.mcore import init_mcore_model - parallel_model = init_mcore_model( - self.tf_config, - self.hf_config, - pre_process, - post_process, - share_embeddings_and_output_weights=False, - value=True, + parallel_model = init_mcore_model( + self.tf_config, + self.hf_config, + pre_process, + post_process, + share_embeddings_and_output_weights=False, + value=True, + ) + parallel_model.to(get_device_name()) + return parallel_model + + # Step 3: initialize the megatron model + reward_model = get_model( + model_provider_func=megatron_rm_model_provider, + model_type=ModelType.encoder_or_decoder, + wrap_with_ddp=False, + use_distributed_optimizer=self.config.megatron.use_distributed_optimizer, ) - parallel_model.cuda() - return parallel_model - - # Step 3: initialize the megatron model - reward_model = get_model( - model_provider_func=megatron_rm_model_provider, - model_type=ModelType.encoder_or_decoder, - wrap_with_ddp=False, - use_distributed_optimizer=self.config.megatron.use_distributed_optimizer, - ) - # note that here critic_module will be a list to be compatible with the construction of interleaved pp (vpp). - # but here, we do not use pp (vpp) yet. For simplicity, we remove the list - # reward_model = nn.ModuleList(reward_model) + # note that here reward_model will be a list to be compatible with the construction of interleaved pp (vpp) + # but here, we do not use pp (vpp) yet. For simplicity, we remove the list + # reward_model = nn.ModuleList(reward_model) if self.config.load_weight: if self.config.megatron.use_dist_checkpointing: load_mcore_dist_weights(reward_model, self.config.megatron.dist_checkpointing_path, is_value_model=True) else: - load_megatron_gptmodel_weights(self.config, self.hf_config, reward_model, params_dtype=self.dtype, is_value_model=True) + if self.bridge is not None: + local_model_path = get_hf_model_path(self.config) + self.bridge.load_weights(reward_model, local_model_path) + else: + load_megatron_gptmodel_weights( + self.config, self.hf_config, reward_model, params_dtype=self.dtype, is_value_model=True + ) # TODO: add more optimizer args into config - torch.cuda.empty_cache() + get_torch_device().empty_cache() return reward_model, self.hf_config @register(dispatch_mode=Dispatch.ONE_TO_ALL) def init_model(self): # create critic - from omegaconf import OmegaConf from verl.utils.torch_dtypes import PrecisionType @@ -868,7 +1150,9 @@ def init_model(self): importlib.import_module(self.config.model.external_lib) override_model_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create())) - override_transformer_config = OmegaConf.to_container(self.config.megatron.get("override_transformer_config", OmegaConf.create()), resolve=True) + override_transformer_config = OmegaConf.to_container( + self.config.megatron.get("override_transformer_config", OmegaConf.create()), resolve=True + ) use_shm = self.config.model.get("use_shm", False) sft_tokenizer_local_path = copy_to_local(self.config.model.input_tokenizer, use_shm=use_shm) @@ -877,7 +1161,9 @@ def init_model(self): rm_tokenizer = None if rm_tokenizer_path is not None: rm_tokenizer_local_path = copy_to_local(rm_tokenizer_path, use_shm=use_shm) - rm_tokenizer = hf_tokenizer(rm_tokenizer_local_path, trust_remote_code=self.config.model.get("trust_remote_code", False)) + rm_tokenizer = hf_tokenizer( + rm_tokenizer_local_path, trust_remote_code=self.config.model.get("trust_remote_code", False) + ) self.param_dtype = torch.bfloat16 self.dtype = PrecisionType.to_dtype(self.param_dtype) @@ -903,11 +1189,12 @@ def init_model(self): # TODO: reward model use itself tokenizer instead of sft tokenizer # the input_ids, responses, attention_mask and position_ids may be different! @register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO) + @DistProfiler.annotate(color="brown") def compute_rm_score(self, data: DataProto): data.meta_info["micro_batch_size"] = self.config.micro_batch_size_per_gpu data.meta_info["max_token_len"] = self.config.forward_max_token_len_per_gpu data.meta_info["use_dynamic_bsz"] = self.config.use_dynamic_bsz - data = data.to(torch.cuda.current_device()) + data = data.to(get_device_id()) output = self.rm.compute_reward(data) output = output.to("cpu") return output diff --git a/verl/workers/reward_manager/__init__.py b/verl/workers/reward_manager/__init__.py index 79474d642..173cf1bb8 100644 --- a/verl/workers/reward_manager/__init__.py +++ b/verl/workers/reward_manager/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .registry import get_reward_manager_cls, register # noqa: I001 from .batch import BatchRewardManager from .dapo import DAPORewardManager from .naive import NaiveRewardManager @@ -19,10 +20,19 @@ # Added by Reasoning360 from .naive_parallel import NaiveParallelRewardManager -from .async_dapo import AsyncDAPORewardManager +from .async_mp import AsyncMultiProcessRewardManager from .llm_judge import LLMJudgeRewardManager -__all__ = ["BatchRewardManager", "DAPORewardManager", "NaiveRewardManager", "PrimeRewardManager", - # Added by Reasoning360 - "NaiveParallelRewardManager", "AsyncDAPORewardManager", "LLMJudgeRewardManager", - ] +# Note(haibin.lin): no need to include all reward managers here in case of complicated dependencies +__all__ = [ + "BatchRewardManager", + "DAPORewardManager", + "NaiveRewardManager", + "PrimeRewardManager", + "register", + "get_reward_manager_cls", + # Added by Reasoning360 + "NaiveParallelRewardManager", + "AsyncMultiProcessRewardManager", + "LLMJudgeRewardManager", +] diff --git a/verl/workers/reward_manager/async_dapo.py b/verl/workers/reward_manager/async_mp.py similarity index 79% rename from verl/workers/reward_manager/async_dapo.py rename to verl/workers/reward_manager/async_mp.py index a16057da4..09b913c74 100644 --- a/verl/workers/reward_manager/async_dapo.py +++ b/verl/workers/reward_manager/async_mp.py @@ -13,25 +13,34 @@ # limitations under the License. import asyncio -from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor -from functools import partial from collections import defaultdict +from concurrent.futures import ProcessPoolExecutor +from functools import partial + import numpy as np +import torch from verl import DataProto from verl.utils.reward_score import _default_compute_score -import torch +from verl.workers.reward_manager import register -async def single_compute_score(compute_score_fn, data_source, solution_str, ground_truth, extra_info, executor, timeout=300.): +async def single_compute_score( + compute_score_fn, data_source, solution_str, ground_truth, extra_info, executor, timeout=300.0 +): loop = asyncio.get_running_loop() try: tasks = [ asyncio.wait_for( loop.run_in_executor( executor, - partial(compute_score_fn, data_source=data_source, solution_str=solution_str, - ground_truth=ground_truth, extra_info=extra_info) + partial( + compute_score_fn, + data_source=data_source, + solution_str=solution_str, + ground_truth=ground_truth, + extra_info=extra_info, + ), ), timeout=timeout, ) @@ -45,22 +54,30 @@ async def single_compute_score(compute_score_fn, data_source, solution_str, grou return None -async def parallel_compute_score_async(compute_score_fn, data_sources, solutions, ground_truths, - extra_infos, num_processes=64, batch_size=None, shuffle=False): +async def parallel_compute_score_async( + compute_score_fn, + data_sources, + solutions, + ground_truths, + extra_infos, + num_processes=64, + batch_size=None, + shuffle=False, +): # If batch_size is not set, process all items at once if batch_size is None or batch_size <= 0: batch_size = len(data_sources) - + # Create indices for tracking original positions indices = list(range(len(data_sources))) - + # Shuffle data if required if shuffle: # Create a copy of the original indices for restoring order later original_indices = indices.copy() # Create shuffled indices shuffled_indices = np.random.permutation(len(data_sources)) - + # Apply shuffling to all data arrays data_sources = [data_sources[i] for i in shuffled_indices] solutions = [solutions[i] for i in shuffled_indices] @@ -68,45 +85,45 @@ async def parallel_compute_score_async(compute_score_fn, data_sources, solutions extra_infos = [extra_infos[i] for i in shuffled_indices] # Map shuffled positions to original indices indices = [original_indices[i] for i in shuffled_indices] - + results = [None] * len(data_sources) - + with ProcessPoolExecutor(max_workers=num_processes) as executor: # Process data in batches for start_idx in range(0, len(data_sources), batch_size): end_idx = min(start_idx + batch_size, len(data_sources)) - + # Create tasks for current batch tasks_async = [ single_compute_score( - compute_score_fn, - data_sources[i], - solutions[i], - ground_truths[i], - extra_infos[i], - executor, - timeout=300. + compute_score_fn, + data_sources[i], + solutions[i], + ground_truths[i], + extra_infos[i], + executor, + timeout=300.0, ) for i in range(start_idx, end_idx) ] - + # Handle potential exceptions to prevent process starvation try: batch_results = await asyncio.gather(*tasks_async, return_exceptions=False) - + # Store results in their correct positions for i, result in enumerate(batch_results): actual_idx = start_idx + i results[actual_idx] = result - - except Exception as e: + + except Exception: for pid, proc in executor._processes.items(): try: proc.kill() except Exception as kill_err: - print('shut down failed: ' + str(kill_err)) + print("shut down failed: " + str(kill_err)) raise - + # Restore original order if data was shuffled if shuffle: # Create a mapping to restore original order @@ -114,24 +131,26 @@ async def parallel_compute_score_async(compute_score_fn, data_sources, solutions for i, original_idx in enumerate(indices): ordered_results[original_idx] = results[i] results = ordered_results - + return results -class AsyncDAPORewardManager: - """The reward manager. - """ - - def __init__(self, - tokenizer, - num_examine, - compute_score=None, - reward_fn_key='data_source', - max_resp_len=None, - overlong_buffer_cfg=None, - batch_size=2048, - shuffle_batch=True, - **kwargs) -> None: +@register("async_multi_process") +class AsyncMultiProcessRewardManager: + """The reward manager.""" + + def __init__( + self, + tokenizer, + num_examine, + compute_score=None, + reward_fn_key="data_source", + max_resp_len=None, + overlong_buffer_cfg=None, + batch_size=2048, + shuffle_batch=True, + **kwargs, + ) -> None: self.tokenizer = tokenizer self.num_examine = num_examine # the number of batches of decoded responses to print to the console self.compute_score = compute_score or _default_compute_score @@ -142,20 +161,22 @@ def __init__(self, self.shuffle_batch = shuffle_batch if self.overlong_buffer_cfg is not None: - assert self.max_resp_len is not None, f"max_resp_len must be provided if {overlong_buffer_cfg=}, but got None" + assert self.max_resp_len is not None, ( + f"max_resp_len must be provided if {overlong_buffer_cfg=}, but got None" + ) def __call__(self, data: DataProto, return_dict: bool = False): """We will expand this function gradually based on the available datasets""" # If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn - if 'rm_scores' in data.batch.keys(): + if "rm_scores" in data.batch.keys(): if return_dict: - return {"reward_tensor": data.batch['rm_scores']} + return {"reward_tensor": data.batch["rm_scores"]} else: - return data.batch['rm_scores'] + return data.batch["rm_scores"] # print(f"[DEBUG] data.batch['responses'] shape: {data.batch['responses'].shape}") - reward_tensor = torch.zeros_like(data.batch['responses'], dtype=torch.float32) + reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32) # print(f"[DEBUG] reward_tensor initial shape: {reward_tensor.shape}") # Add this to understand DataProto structure @@ -173,7 +194,7 @@ def __call__(self, data: DataProto, return_dict: bool = False): for i in range(len(data)): data_source = data[i].non_tensor_batch[self.reward_fn_key] data_source_counts[data_source] += 1 - + # print(f"[DEBUG] Data source distribution: {dict(data_source_counts)}") # Check if any data is being filtered @@ -182,9 +203,12 @@ def __call__(self, data: DataProto, return_dict: bool = False): if self.reward_fn_key not in data_item.non_tensor_batch: # print(f"[DEBUG] Warning: Item {i} missing reward_fn_key '{self.reward_fn_key}'") pass - + # Check if ground truth exists - if 'reward_model' not in data_item.non_tensor_batch or 'ground_truth' not in data_item.non_tensor_batch['reward_model']: + if ( + "reward_model" not in data_item.non_tensor_batch + or "ground_truth" not in data_item.non_tensor_batch["reward_model"] + ): # print(f"[DEBUG] Warning: Item {i} missing ground_truth") pass @@ -200,30 +224,30 @@ def __call__(self, data: DataProto, return_dict: bool = False): for i in range(len(data)): data_item = data[i] # DataProtoItem - prompt_ids = data_item.batch['prompts'] + prompt_ids = data_item.batch["prompts"] prompt_length = prompt_ids.shape[-1] - valid_prompt_length = data_item.batch['attention_mask'][:prompt_length].sum() + valid_prompt_length = data_item.batch["attention_mask"][:prompt_length].sum() valid_prompt_ids = prompt_ids[-valid_prompt_length:] - response_ids = data_item.batch['responses'] - valid_response_length = data_item.batch['attention_mask'][prompt_length:].sum() + response_ids = data_item.batch["responses"] + valid_response_length = data_item.batch["attention_mask"][prompt_length:].sum() valid_response_lengths.append(valid_response_length) valid_response_ids = response_ids[:valid_response_length] # decode prompt_str = self.tokenizer.decode(valid_prompt_ids, skip_special_tokens=True) prompt_strs.append(prompt_str) - + response_str = self.tokenizer.decode(valid_response_ids, skip_special_tokens=True) eos_token = self.tokenizer.eos_token if response_str.endswith(eos_token): - response_str = response_str[:-len(eos_token)] + response_str = response_str[: -len(eos_token)] response_strs.append(response_str) - ground_truth = data_item.non_tensor_batch['reward_model']['ground_truth'] + ground_truth = data_item.non_tensor_batch["reward_model"]["ground_truth"] data_source = data_item.non_tensor_batch[self.reward_fn_key] - extra_info = data_item.non_tensor_batch.get('extra_info', None) + extra_info = data_item.non_tensor_batch.get("extra_info", None) data_sources.append(data_source) solutions.append(response_str) @@ -242,18 +266,18 @@ def __call__(self, data: DataProto, return_dict: bool = False): extra_infos, num_processes=64, batch_size=self.batch_size, - shuffle=self.shuffle_batch + shuffle=self.shuffle_batch, ) ) # print(f"[DEBUG] Parallel score computation completed") - except Exception as e: + except Exception: # print(f"[DEBUG] Error in parallel score computation: {e}") # Fallback to zeros if computation fails results = [None] * len(solutions) # Process results for i, (result, data_source, response_str, ground_truth, valid_response_length) in enumerate( - zip(results, data_sources, response_strs, ground_truths, valid_response_lengths) + zip(results, data_sources, response_strs, ground_truths, valid_response_lengths, strict=False) ): score = 0.0 if result is None: @@ -263,7 +287,7 @@ def __call__(self, data: DataProto, return_dict: bool = False): if not isinstance(result, dict): # Hack to avoid some rewards don't return a dict result = {"score": result, "acc": result} - + score = result["score"] # Store the information including original reward for key, value in result.items(): @@ -298,7 +322,7 @@ def __call__(self, data: DataProto, return_dict: bool = False): for key, value in result.items(): print(f"[{key}]", value) else: - print(f"[score]", score) + print("[score]", score) # print(f"[DEBUG] Final reward_tensor shape: {reward_tensor.shape}") # print(f"[DEBUG] Non-zero elements in reward_tensor: {(reward_tensor != 0).sum().item()}") @@ -310,4 +334,4 @@ def __call__(self, data: DataProto, return_dict: bool = False): "reward_extra_info": reward_extra_info, } else: - return reward_tensor \ No newline at end of file + return reward_tensor diff --git a/verl/workers/reward_manager/batch.py b/verl/workers/reward_manager/batch.py index 570fdd71d..8d1b11228 100644 --- a/verl/workers/reward_manager/batch.py +++ b/verl/workers/reward_manager/batch.py @@ -17,9 +17,22 @@ import torch from verl import DataProto +from verl.workers.reward_manager import register +@register("batch") class BatchRewardManager: + """ + A batch reward manager that computes rewards for a batch of data. + + Args: + tokenizer (Tokenizer): The tokenizer to use for decoding the responses. + num_examine (int): The number of responses to examine. + compute_score (callable): The function to compute the rewards. + reward_fn_key (str): The key to use for the reward function. + reward_kwargs (dict): The keyword arguments to pass to the reward function. + """ + def __init__(self, tokenizer, num_examine, compute_score, reward_fn_key="data_source", **reward_kwargs): self.tokenizer = tokenizer self.num_examine = num_examine diff --git a/verl/workers/reward_manager/dapo.py b/verl/workers/reward_manager/dapo.py index 399cdf05e..cb8b5cf22 100644 --- a/verl/workers/reward_manager/dapo.py +++ b/verl/workers/reward_manager/dapo.py @@ -18,8 +18,10 @@ from verl import DataProto from verl.utils.reward_score import default_compute_score +from verl.workers.reward_manager import register +@register("dapo") class DAPORewardManager: """The reward manager.""" @@ -40,7 +42,12 @@ def __init__( self.max_resp_len = max_resp_len if self.overlong_buffer_cfg is not None: - assert self.max_resp_len is not None, f"max_resp_len must be provided if {overlong_buffer_cfg=}, but got None" + assert self.max_resp_len is not None, ( + f"max_resp_len must be provided if {overlong_buffer_cfg=}, but got None" + ) + assert self.max_resp_len >= self.overlong_buffer_cfg.len, ( + "max_resp_len must be larger than overlong_buffer.len" + ) def __call__(self, data: DataProto, return_dict: bool = False): """We will expand this function gradually based on the available datasets""" @@ -99,6 +106,7 @@ def __call__(self, data: DataProto, return_dict: bool = False): reward_extra_info[key].append(value) else: score = result + reward_extra_info["acc"].append(score) reward = score diff --git a/verl/workers/reward_manager/naive.py b/verl/workers/reward_manager/naive.py index 59ad618c4..f6f979eef 100644 --- a/verl/workers/reward_manager/naive.py +++ b/verl/workers/reward_manager/naive.py @@ -18,16 +18,28 @@ from verl import DataProto from verl.utils.reward_score import default_compute_score +from verl.workers.reward_manager import register +@register("naive") class NaiveRewardManager: """The reward manager.""" def __init__(self, tokenizer, num_examine, compute_score=None, reward_fn_key="data_source") -> None: - self.tokenizer = tokenizer + """ + Initialize the NaiveRewardManager instance. + + Args: + tokenizer: The tokenizer used to decode token IDs into text. + num_examine: The number of batches of decoded responses to print to the console for debugging purpose. + compute_score: A function to compute the reward score. If None, `default_compute_score` will be used. + reward_fn_key: The key used to access the data source in the non-tensor batch data. Defaults to + "data_source". + """ + self.tokenizer = tokenizer # Store the tokenizer for decoding token IDs self.num_examine = num_examine # the number of batches of decoded responses to print to the console self.compute_score = compute_score or default_compute_score - self.reward_fn_key = reward_fn_key + self.reward_fn_key = reward_fn_key # Store the key for accessing the data source def __call__(self, data: DataProto, return_dict=False): """We will expand this function gradually based on the available datasets""" @@ -63,10 +75,10 @@ def __call__(self, data: DataProto, return_dict=False): response_str = self.tokenizer.decode(valid_response_ids, skip_special_tokens=True) ground_truth = data_item.non_tensor_batch["reward_model"]["ground_truth"] - data_source = data_item.non_tensor_batch[self.reward_fn_key] - - extra_info = data_item.non_tensor_batch.get("extra_info", None) + extra_info = data_item.non_tensor_batch.get("extra_info", {}) + num_turns = data_item.non_tensor_batch.get("__num_turns__", None) + extra_info["num_turns"] = num_turns score = self.compute_score( data_source=data_source, diff --git a/verl/workers/reward_manager/prime.py b/verl/workers/reward_manager/prime.py index d1a68d85f..f2c526b63 100644 --- a/verl/workers/reward_manager/prime.py +++ b/verl/workers/reward_manager/prime.py @@ -23,16 +23,14 @@ from verl import DataProto from verl.utils.reward_score import default_compute_score +from verl.workers.reward_manager import register async def single_compute_score(evaluation_func, completion, reference, task, task_extra_info, executor, timeout=300.0): loop = asyncio.get_running_loop() try: # Ensure process_completion is called properly - future = loop.run_in_executor( - executor, - partial(evaluation_func, task, completion, reference, task_extra_info) - ) + future = loop.run_in_executor(executor, partial(evaluation_func, task, completion, reference, task_extra_info)) return await asyncio.wait_for(future, timeout=timeout) except asyncio.TimeoutError: print(f"[Timeout] Task timeout: {completion}") @@ -42,17 +40,20 @@ async def single_compute_score(evaluation_func, completion, reference, task, tas return None # Default value for failed rows -async def parallel_compute_score_async(evaluation_func, completions, references, tasks, extra_info=None, num_processes=64): +async def parallel_compute_score_async( + evaluation_func, completions, references, tasks, extra_info=None, num_processes=64 +): if extra_info is None: extra_info = [None] * len(tasks) scores = [] with ProcessPoolExecutor(max_workers=num_processes) as executor: - # to prevent very occasional starvation caused by some anomalous programs ( like infinite loop ), the exceptions in async programs will instantly halt the evaluation, and all summoned processes will be killed. + # to prevent very occasional starvation caused by some anomalous programs ( like infinite loop ), the + # exceptions in async programs will instantly halt the evaluation, and all summoned processes will be killed. try: # Create tasks for all rows tasks_async = [ single_compute_score(evaluation_func, c, r, t, ei, executor, timeout=300.0) - for c, r, t, ei in zip(completions, references, tasks, extra_info) + for c, r, t, ei in zip(completions, references, tasks, extra_info, strict=True) ] results = await asyncio.gather(*tasks_async, return_exceptions=False) except Exception as e: @@ -74,27 +75,29 @@ async def parallel_compute_score_async(evaluation_func, completions, references, print(f"[Shutdown] {terminated_count} subprocess(es) terminated.") # Process results - for result, completion, reference, task in zip(results, completions, references, tasks): + for result, completion, reference, task in zip(results, completions, references, tasks, strict=True): if isinstance(result, Exception) or result is None: # Handle failed or timed-out tasks scores.append(0.0) - elif isinstance(result, (int, float, bool)): + elif isinstance(result, int | float | bool): scores.append(float(result)) else: scores.append(float(result[0])) return scores + def run_reward_scoring(evaluation_func, completions, references, tasks, extra_info=None, num_processes=64): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: - return loop.run_until_complete(parallel_compute_score_async( - evaluation_func, completions, references, tasks, extra_info, num_processes - )) + return loop.run_until_complete( + parallel_compute_score_async(evaluation_func, completions, references, tasks, extra_info, num_processes) + ) finally: loop.close() +@register("prime") class PrimeRewardManager: """ The Reward Manager used in https://github.com/PRIME-RL/PRIME diff --git a/verl/workers/reward_manager/registry.py b/verl/workers/reward_manager/registry.py new file mode 100644 index 000000000..3fc34efaa --- /dev/null +++ b/verl/workers/reward_manager/registry.py @@ -0,0 +1,51 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__all__ = ["register", "get_reward_manager_cls"] + +REWARD_MANAGER_REGISTRY = {} + + +def register(name): + """Decorator to register a reward manager class with a given name. + + Args: + name: `(str)` + The name of the reward manager. + """ + + def decorator(cls): + if name in REWARD_MANAGER_REGISTRY and REWARD_MANAGER_REGISTRY[name] != cls: + raise ValueError( + f"Reward manager {name} has already been registered: {REWARD_MANAGER_REGISTRY[name]} vs {cls}" + ) + REWARD_MANAGER_REGISTRY[name] = cls + return cls + + return decorator + + +def get_reward_manager_cls(name): + """Get the reward manager class with a given name. + + Args: + name: `(str)` + The name of the reward manager. + + Returns: + `(type)`: The reward manager class. + """ + if name not in REWARD_MANAGER_REGISTRY: + raise ValueError(f"Unknown reward manager: {name}") + return REWARD_MANAGER_REGISTRY[name] diff --git a/verl/workers/reward_model/megatron/reward_model.py b/verl/workers/reward_model/megatron/reward_model.py index 3f64e5b24..01b132497 100644 --- a/verl/workers/reward_model/megatron/reward_model.py +++ b/verl/workers/reward_model/megatron/reward_model.py @@ -24,6 +24,7 @@ from tensordict import TensorDict from verl import DataProto +from verl.utils.device import get_device_id, get_device_name, get_torch_device from verl.utils.megatron.pipeline_parallel import make_batch_generator from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches from verl.utils.torch_functional import broadcast_dict_tensor, pad_sequence_to_length @@ -73,7 +74,7 @@ def re_encode_by_rm_tokenizer(self, data: DataProto) -> DataProto: position_ids_for_rm = [] print_decode = True ori_seqlen = ori_seqlen + 128 - for id, mask in zip(input_ids, attention_mask): + for id, mask in zip(input_ids, attention_mask, strict=True): # 1. remove pad for each sequence non_zero_indices = torch.nonzero(mask).view(-1) begin_pos, end_pos = non_zero_indices[0].item(), non_zero_indices[-1].item() @@ -81,17 +82,24 @@ def re_encode_by_rm_tokenizer(self, data: DataProto) -> DataProto: # 2. decode by sft_tokenizer, remove sft system prompts decode_result = self.sft_tokenizer.decode(valid_id) # workaround - decode_with_rm_chat = decode_result.replace("<|user|>\n", "[INST] ").replace("\n<|assistant|>\n", " [/INST]").replace(" \n<|assistant|>\n", " [/INST]") + "" + decode_with_rm_chat = ( + decode_result.replace("<|user|>\n", "[INST] ") + .replace("\n<|assistant|>\n", " [/INST]") + .replace(" \n<|assistant|>\n", " [/INST]") + + "" + ) if print_decode and torch.distributed.get_rank() == 0: # only print first decode result print( - f"device {torch.cuda.current_device()}: sft decode result:\n{decode_result}\n \ - \ndevice {torch.cuda.current_device()}: sft decode result with \ + f"device {get_device_id()}: sft decode result:\n{decode_result}\n \ + \ndevice {get_device_id()}: sft decode result with \ rm chat template:\n{decode_with_rm_chat}\n\n" ) print_decode = False # 3. encode by rm_tokenizer - rm_input_ids = self.rm_tokenizer(decode_with_rm_chat, return_tensors="pt")["input_ids"][0].to(input_ids.device) + rm_input_ids = self.rm_tokenizer(decode_with_rm_chat, return_tensors="pt")["input_ids"][0].to( + input_ids.device + ) # 4. generate attention_mask and position_ids rm_attention_mask = torch.ones_like(rm_input_ids, device=input_ids.device) cur_seqlen = rm_input_ids.shape[-1] @@ -144,7 +152,9 @@ def compute_reward(self, data: DataProto) -> DataProto: response_length = responses.size(1) with torch.no_grad(): - output = self.forward_batch(data, use_dynamic_bsz=use_dynamic_bsz, micro_batch_size=micro_batch_size, max_token_len=max_token_len) + output = self.forward_batch( + data, use_dynamic_bsz=use_dynamic_bsz, micro_batch_size=micro_batch_size, max_token_len=max_token_len + ) if mpu.is_pipeline_last_stage(ignore_virtual=True): logits = torch.cat(output["output"], dim=0) if use_dynamic_bsz: @@ -183,6 +193,8 @@ def compute_reward(self, data: DataProto) -> DataProto: token_level_rewards = rewards.expand(attention_mask.shape[0], attention_mask.shape[1]) # (bs, ori_seqlen) # assign last valid token reward to ori position + if position_ids.dim() == 3: # qwen2vl mrope [bs, 3, seq_len] + position_ids = position_ids[:, 0, :] eos_mask_idx = torch.argmax(position_ids * attention_mask, dim=-1) # (bs,) eos_mask = torch.zeros_like(attention_mask) eos_mask[torch.arange(batch_size), eos_mask_idx] = 1.0 @@ -194,7 +206,7 @@ def compute_reward(self, data: DataProto) -> DataProto: self.offload_params_to_cpu() else: # add empty cache after each compute - torch.cuda.empty_cache() + get_torch_device().empty_cache() batch = TensorDict({"rm_scores": token_level_rewards}, batch_size=input_ids.shape[0]) @@ -210,23 +222,43 @@ def forward_batch(self, data: DataProto, use_dynamic_bsz=False, micro_batch_size # TODO: actually, we just need to control the sampling order. mini_batch = data mini_batch.batch = mini_batch.batch.contiguous() - broadcast_dict_tensor(mini_batch.batch, src=mpu.get_pipeline_model_parallel_last_rank(), group=mpu.get_pipeline_model_parallel_group()) + broadcast_dict_tensor( + mini_batch.batch, + src=mpu.get_pipeline_model_parallel_last_rank(), + group=mpu.get_pipeline_model_parallel_group(), + ) mini_batch.batch["attention_mask"] = mini_batch.batch["attention_mask"].to(bool) + self.has_multi_modal_inputs = "multi_modal_inputs" in mini_batch.non_tensor_batch.keys() + if self.has_multi_modal_inputs: + mini_batch.batch["multi_modal_inputs"] = mini_batch.non_tensor_batch["multi_modal_inputs"] + mini_batch.batch["multi_modal_inputs_idx"] = torch.Tensor( + list(range(len(mini_batch.non_tensor_batch["multi_modal_inputs"]))) + ).to(torch.int64) + indices = None if use_dynamic_bsz: assert max_token_len is not None, "max_token_len must be set when use_dynamic_bsz is True" vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size() if vpp_size is not None and vpp_size > 1: microbatch_group_size_per_vp_stage = self.tf_config.microbatch_group_size_per_vp_stage - micro_batches, indices = rearrange_micro_batches(batch=mini_batch.batch, num_batches_divided_by=microbatch_group_size_per_vp_stage, max_token_len=max_token_len) - assert len(micro_batches) % self.tf_config.microbatch_group_size_per_vp_stage == 0, f"micro_batches {micro_batches} must be divisible by microbatch_group_size_per_vp_stage {microbatch_group_size_per_vp_stage} for megatron backend" + micro_batches, indices = rearrange_micro_batches( + batch=mini_batch.batch, + num_batches_divided_by=microbatch_group_size_per_vp_stage, + max_token_len=max_token_len, + ) + assert len(micro_batches) % self.tf_config.microbatch_group_size_per_vp_stage == 0, ( + f"micro_batches {micro_batches} must be divisible by microbatch_group_size_per_vp_stage " + f"{microbatch_group_size_per_vp_stage} for megatron backend" + ) else: micro_batches, indices = rearrange_micro_batches(batch=mini_batch.batch, max_token_len=max_token_len) total_seqlen = max_token_len else: - assert micro_batch_size is not None, "micro_batch_size is needed to be passed in when not using dynamic batch size" + assert micro_batch_size is not None, ( + "micro_batch_size is needed to be passed in when not using dynamic batch size" + ) micro_batches = mini_batch.batch.split(micro_batch_size) seq_len = micro_batches[0]["input_ids"].shape[1] total_seqlen = micro_batch_size * seq_len @@ -247,6 +279,13 @@ def forward_step(batch_iter, model): forward_fn = get_mcore_forward_fn(self.hf_config) + multi_modal_inputs = {} + if "multi_modal_inputs" in batch: + for key in batch["multi_modal_inputs"][0].keys(): + multi_modal_inputs[key] = torch.cat( + [batch["multi_modal_inputs"][i][key] for i in batch["multi_modal_inputs_idx"]], dim=0 + ) + output = forward_fn( model, input_ids, @@ -254,6 +293,7 @@ def forward_step(batch_iter, model): position_ids, sequence_parallel=self.tf_config.sequence_parallel, value_model=True, + multi_modal_inputs=multi_modal_inputs, ) return output, loss_func @@ -283,6 +323,11 @@ def forward_step(batch_iter, model): micro_batch_size=1, # in use for pp = 1 forward_only=True, ) + + if self.has_multi_modal_inputs: + data.batch.pop("multi_modal_inputs") + data.batch.pop("multi_modal_inputs_idx") + data.non_tensor_batch.pop("multi_modal_inputs") # loss_reduces contains the stats returned from loss_func losses_reduced = {"output": losses_reduced} if use_dynamic_bsz: @@ -290,16 +335,16 @@ def forward_step(batch_iter, model): return losses_reduced def offload_params_to_cpu(self): - if self.device == "cuda": + if self.device in ["cuda", "npu"]: for reward_model_module in self.reward_model_module: for name, param in reward_model_module.named_parameters(): param.data = param.data.to("cpu", non_blocking=True) self.device = "cpu" - torch.cuda.empty_cache() + get_torch_device().empty_cache() def load_params_to_cuda(self): if self.device == "cpu": for reward_model_module in self.reward_model_module: for name, param in reward_model_module.named_parameters(): - param.data = param.data.to(torch.cuda.current_device(), non_blocking=True) - self.device = "cuda" + param.data = param.data.to(get_device_id(), non_blocking=True) + self.device = get_device_name() diff --git a/verl/workers/rollout/async_server.py b/verl/workers/rollout/async_server.py index 160cdbaf7..8b09c0b59 100644 --- a/verl/workers/rollout/async_server.py +++ b/verl/workers/rollout/async_server.py @@ -12,31 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. import asyncio -import heapq -import importlib import logging import os import socket import threading from abc import ABC, abstractmethod from contextlib import asynccontextmanager -from typing import Any, Callable, Dict, List, Tuple, Type -from uuid import uuid4 +from typing import Any, Optional -import aiohttp import fastapi import ray import uvicorn -from cachetools import LRUCache from omegaconf import DictConfig -from openai import AsyncOpenAI -from openai.types.chat.chat_completion import ChatCompletion from starlette.requests import Request from verl.protocol import DataProto from verl.single_controller.ray.base import RayWorkerGroup -from verl.utils import hf_tokenizer -from verl.utils.fs import copy_to_local +from verl.workers.rollout.chat_scheduler import ChatCompletionScheduler logger = logging.getLogger(__file__) @@ -51,7 +43,7 @@ class AsyncServerBase(ABC): """Base class for AsyncServer.""" def __init__(self): - self.address = ray._private.services.get_node_ip_address() + self.address = ray.util.get_node_ip_address() self.port = None self.server_ready = asyncio.Event() asyncio.create_task(self._start_fastapi_server()) @@ -59,7 +51,7 @@ def __init__(self): async def _start_fastapi_server(self): @asynccontextmanager async def lifespan(app: fastapi.FastAPI): - print("FastAPI startup") + print(f"FastAPI listen on {self.address}:{self.port}") self.server_ready.set() yield @@ -76,7 +68,7 @@ async def lifespan(app: fastapi.FastAPI): server = uvicorn.Server(config) await server.serve() - async def get_server_address(self) -> Tuple[str, int]: + async def get_server_address(self) -> tuple[str, int]: """Get FastAPI server address.""" await self.server_ready.wait() return f"{self.address}:{self.port}" @@ -89,6 +81,20 @@ async def chat_completion(self, raw_request: Request): """ raise NotImplementedError + @abstractmethod + async def generate(self, prompt_ids: list[int], sampling_params: dict[str, Any], request_id: str) -> list[int]: + """Generate response ids given prompt ids. + + Args: + prompt_ids (List[int]): prompt ids + sampling_params (Dict[str, Any]): sampling params + request_id (str): request id + + Returns: + List[int]: response ids + """ + raise NotImplementedError + @abstractmethod async def init_engine(self): """Init async LLM engine.""" @@ -105,130 +111,19 @@ async def sleep(self): raise NotImplementedError -class ChatCompletionScheduler: - def __init__( - self, - config: DictConfig, - model_path: str, - server_addresses: List[str], - max_cache_size: int = 10000, - ): - """ - Args: - config: DictConfig, rollout config. - model_path: str, model path. - server_addresses: List[str], server addresses. - max_cache_size: int, max cache size of request_id to address mapping. - """ - self.config = config - self.model_name = "/".join(model_path.split("/")[-2:]) - local_path = copy_to_local(model_path) - self.tokenizer = hf_tokenizer(local_path, trust_remote_code=True) - - # Least requests load balancing - self.weighted_addresses = [[0, address] for address in server_addresses] - heapq.heapify(self.weighted_addresses) - - # LRU cache to map request_id to address - self.request_id_to_address = LRUCache(maxsize=max_cache_size) - - async def submit_chat_completions( - self, - callback: Callable[[ChatCompletion, Dict[str, Any], Exception], None], - callback_additional_info: Dict[str, Any], - **chat_complete_request, - ): - """ - Submit a chat completion request to the server with the least number of requests. - - Args: - callback: Callable[[ChatCompletion, Dict[str, Any], Exception], None], async callback function - to handle the response. The callback function should have the following signature: - - ```python - async def callback(completions: ChatCompletion, info: Dict[str, Any], exception: Exception): - ... - ``` - - completions: chat completion response from server. - - info: user provided `callback_additional_info`. - - exception: exception raise from OpenAI client if request failed, otherwise None. - - **CAUTION**: the callback function must be async and non-blocking, if you have any blocking operation, - please move to seperate thread or process pool to avoid blocking the event loop. - - callback_additional_info: Dict[str, Any], additional info to pass to the callback function. - - **chat_complete_request: dict, request parameters same as OpenAI AsyncCompletions.create. - OpenAI API reference: https://platform.openai.com/docs/api-reference/chat/create - """ - if "extra_headers" not in chat_complete_request: - chat_complete_request["extra_headers"] = {} - - extra_headers = chat_complete_request["extra_headers"] - request_id = extra_headers.get("x-request-id", None) - if request_id: - if request_id.startswith("chatcmpl-"): - request_id = request_id[len("chatcmpl-") :] - extra_headers["x-request-id"] = request_id - - address = self.request_id_to_address.pop(request_id) - else: - address = self.weighted_addresses[0][1] - self.weighted_addresses[0][0] += 1 - heapq.heapreplace(self.weighted_addresses, self.weighted_addresses[0]) - - # use new request_id to avoid duplicate request_id problem - request_id = uuid4().hex - self.request_id_to_address[request_id] = address - chat_complete_request["extra_headers"]["x-request-id"] = request_id - - completions, exception = None, None - try: - # NOTE: OpenAI client uses httpx, seems to have performance issue in high concurrency requests. - completions = await self._chat_completions_aiohttp(address, **chat_complete_request) - except Exception as e: - # Let user handle the exception - exception = e - - await callback(completions, callback_additional_info, exception) - - async def _chat_completions_openai(self, address: str, **chat_complete_request) -> ChatCompletion: - client = AsyncOpenAI(base_url=f"http://{address}/v1", api_key="token-abc123", timeout=None, max_retries=0) - return await client.chat.completions.create(**chat_complete_request) - - async def _chat_completions_aiohttp(self, address: str, **chat_complete_request) -> ChatCompletion: - try: - extra_headers = chat_complete_request.pop("extra_headers") - timeout = aiohttp.ClientTimeout(total=None) - session = aiohttp.ClientSession(timeout=timeout) - async with session.post( - url=f"http://{address}/v1/chat/completions", - headers={"Authorization": "Bearer token-abc123", **extra_headers}, - json=chat_complete_request, - ) as resp: - data = await resp.json() - return ChatCompletion(**data) - finally: - await session.close() - - async def generate_sequences(self, prompts: DataProto, **sampling_params) -> DataProto: - raise NotImplementedError - - class AsyncLLMServerManager: """AsyncLLMServerManager manage a group of vllm instances, i.e AsyncvLLMServer.""" - def __init__(self, config: DictConfig, worker_group: RayWorkerGroup, *, scheduler_kwargs: Dict[str, Any] = None): + def __init__(self, config: DictConfig, worker_group: RayWorkerGroup): """Initialize AsyncLLMServerManager. Args: config: DictConfig, actor_rollout_ref config. worker_group: RayWorkerGroup, worker group of AsyncActorRolloutRefWorker. - scheduler_kwargs: Dict[str, Any], kwargs for chat scheduler. """ - self.config = config + self.full_config = config + self.config = config.actor_rollout_ref self.worker_group = worker_group - self.scheduler_kwargs = scheduler_kwargs if scheduler_kwargs else {} self.rollout_tp_size = self.config.rollout.tensor_model_parallel_size self.rollout_dp_size = self.worker_group.world_size // self.rollout_tp_size @@ -240,9 +135,14 @@ def __init__(self, config: DictConfig, worker_group: RayWorkerGroup, *, schedule self.async_llm_servers = [None] * self.rollout_dp_size self.server_addresses = [None] * self.rollout_dp_size - server_class = async_server_class( - rollout_backend=self.config.rollout.name, - ) + if self.config.rollout.agent.custom_async_server: + server_class = async_server_class( + rollout_backend=self.config.rollout.name, + rollout_backend_module=self.config.rollout.agent.custom_async_server.path, + rollout_backend_class=self.config.rollout.agent.custom_async_server.name, + ) + else: + server_class = async_server_class(rollout_backend=self.config.rollout.name) # Start all server instances, restart if address already in use. unready_dp_ranks = set(range(self.rollout_dp_size)) @@ -274,6 +174,7 @@ def __init__(self, config: DictConfig, worker_group: RayWorkerGroup, *, schedule # Init user provided chat scheduler in sperate thread. self.chat_scheduler: ChatCompletionScheduler = None + self.chat_scheduler_exception: Exception = None self.chat_scheduler_loop = None self.chat_scheduler_ready = threading.Event() self.chat_scheduler_thread = threading.Thread(target=self._init_chat_scheduler, daemon=True) @@ -284,32 +185,32 @@ def _init_chat_scheduler(self): self.chat_scheduler_loop = asyncio.new_event_loop() asyncio.set_event_loop(self.chat_scheduler_loop) - module_path, class_name = self.config.rollout.chat_scheduler.rsplit(".", 1) - module = importlib.import_module(module_path) - scheduler_cls = getattr(module, class_name) - self.chat_scheduler = scheduler_cls( - config=self.config.rollout, - model_path=self.config.model.path, - server_addresses=self.server_addresses, - **self.scheduler_kwargs, - ) - - self.chat_scheduler_ready.set() + try: + self.chat_scheduler = ChatCompletionScheduler( + config=self.full_config, + server_addresses=self.server_addresses, + ) + except Exception as e: + logger.exception(f"chat_scheduler init error: {e}") + self.chat_scheduler_exception = e + finally: + self.chat_scheduler_ready.set() self.chat_scheduler_loop.run_forever() def wake_up(self): """Wake up all vllm instances.""" - ray.get([server.wake_up.remote() for server in self.async_llm_servers]) + if self.config.rollout.free_cache_engine: + ray.get([server.wake_up.remote() for server in self.async_llm_servers]) def sleep(self): """Sleep all vllm instances.""" - ray.get([server.sleep.remote() for server in self.async_llm_servers]) + if self.config.rollout.free_cache_engine: + ray.get([server.sleep.remote() for server in self.async_llm_servers]) def submit_chat_completions( self, - callback: Callable[[ChatCompletion, Dict[str, Any], Exception], None], - callback_additional_info: Dict[str, Any], - **chat_complete_request, + messages: list[dict[str, str]], + sampling_params: dict[str, Any], ): """Submit a chat completion request to chat scheduler and wait until it is done. To submit multiple requests in parallel, please use `generate_sequences` instead. @@ -318,10 +219,10 @@ def submit_chat_completions( """ assert self.chat_scheduler is not None, "chat scheduler is not initialized." future = asyncio.run_coroutine_threadsafe( - self.chat_scheduler.submit_chat_completions( - callback=callback, - callback_additional_info=callback_additional_info, - **chat_complete_request, + self.chat_scheduler._submit_chat_completions_semaphore( + messages=messages, + request_id=None, + sampling_params=sampling_params, ), self.chat_scheduler_loop, ) @@ -331,26 +232,44 @@ def generate_sequences(self, prompts: DataProto, **sampling_params) -> DataProto """Generate multiple sequences in parallel via chat scheduler.""" assert self.chat_scheduler is not None, "chat scheduler is not initialized." - future = asyncio.run_coroutine_threadsafe(self.chat_scheduler.generate_sequences(prompts, **sampling_params), self.chat_scheduler_loop) + future = asyncio.run_coroutine_threadsafe( + self.chat_scheduler.generate_sequences(prompts, **sampling_params), self.chat_scheduler_loop + ) return future.result() -def async_server_class(rollout_backend: str) -> Type[AsyncServerBase]: +def async_server_class( + rollout_backend: str, rollout_backend_module: Optional[str] = None, rollout_backend_class: Optional[str] = None +) -> type[AsyncServerBase]: """Get async server class. Args: - rollout_backend: str, rollout backend, should be "vllm" or "sglang". + rollout_backend: str, rollout backend type (alias), should be "vllm" or "sglang". + rollout_backend_module: Optional[str], import path of the rollout backend. + rollout_backend_class: Optional[str], class name of the rollout backend. Returns: Type[AsyncServerBase]: async server class. """ - if rollout_backend == "vllm": - from verl.workers.rollout.vllm_rollout.vllm_async_server import AsyncvLLMServer + if rollout_backend_class is None and rollout_backend_module is None: + # If both are None, use the default backend class + # Do not change the original import behavior + # importlib.import_module and from ... import ... have subtle differences in ray - return AsyncvLLMServer - elif rollout_backend == "sglang": - from verl.workers.rollout.sglang_rollout.async_sglang_server import AsyncSglangServer + if rollout_backend == "vllm": + from verl.workers.rollout.vllm_rollout.vllm_async_server import AsyncvLLMServer - return AsyncSglangServer - else: - raise NotImplementedError + return AsyncvLLMServer + elif rollout_backend == "sglang": + from verl.workers.rollout.sglang_rollout.async_sglang_server import AsyncSglangServer + + return AsyncSglangServer + else: + raise NotImplementedError(f"rollout backend {rollout_backend} is not supported") + + if rollout_backend_module is None or rollout_backend_class is None: + raise ValueError("rollout_backend_module and rollout_backend_class must be both provided for customization") + + from verl.utils.import_utils import load_extern_type + + return load_extern_type(rollout_backend_module, rollout_backend_class) diff --git a/verl/workers/rollout/chat_scheduler.py b/verl/workers/rollout/chat_scheduler.py new file mode 100644 index 000000000..268c82d02 --- /dev/null +++ b/verl/workers/rollout/chat_scheduler.py @@ -0,0 +1,444 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import heapq +import importlib +import itertools +import json +import logging +import time +from abc import ABC, abstractmethod +from typing import Any +from uuid import uuid4 + +import aiohttp +import numpy as np +import torch +from cachetools import LRUCache +from omegaconf import DictConfig +from openai import AsyncOpenAI +from openai.types.chat.chat_completion import ChatCompletion +from tensordict import TensorDict + +from verl.protocol import DataProto +from verl.tools.utils.tool_registry import initialize_tools_from_config +from verl.utils import hf_tokenizer +from verl.utils.fs import copy_to_local +from verl.utils.import_utils import deprecated + +logger = logging.getLogger(__file__) + + +class CompletionCallback(ABC): + def __init__(self, config: DictConfig, scheduler: "ChatCompletionScheduler"): + self.config = config + self.scheduler = scheduler + + # Initialize tools from config file + self.max_assistant_turns = config.actor_rollout_ref.rollout.multi_turn.max_assistant_turns + tool_config_path = config.actor_rollout_ref.rollout.multi_turn.tool_config_path + tool_list = initialize_tools_from_config(tool_config_path) if tool_config_path else [] + self.tools = {tool.name: tool for tool in tool_list} + self._tool_schemas = [tool.tool_schema.model_dump(exclude_unset=True, exclude_none=True) for tool in tool_list] + print(f"Initialized tools: {self.tools}", flush=True) + + local_path = copy_to_local(config.actor_rollout_ref.model.path) + self.tokenizer = hf_tokenizer(local_path, trust_remote_code=True) + + @property + def tool_schemas(self): + """OpenAI JSON tool schemas.""" + return self._tool_schemas + + @property + def extra_body(self) -> dict[str, Any]: + """Extra body pass to OpenAI API.""" + return None + + @abstractmethod + async def __call__(self, messages: list[dict[str, str]], completions: ChatCompletion, info: dict[str, Any]): + """Call back function to process completions. + + Args: + messages: List of messages including raw prompt and assistant, tool response generated so far. + completions: Chat completions from OpenAI compatible server. + info: Any other auxiliary information pass across multi-turn. + """ + raise NotImplementedError + + @abstractmethod + def postprocess(self, batch: DataProto, batch_conversations: list[list[dict[str, str]]], n: int) -> DataProto: + """Post process batch data. + + Args: + batch: Batch input messages from RLHFDataset. + batch_conversations: List of messages including raw prompt, assistant response, tool response. + Note that `len(batch_conversations) == len(batch) * n`, e.g n=2, + batch_conversations=[messages_0_0, messages_0_1, messages_1_0, messages_1_1, ...] + n: How many chat completion choices to generate for each input message. + + Returns: + Batch data, should include ["prompts", "responses", "response_mask", "input_ids", "attention_mask", + "position_ids"]. + """ + raise NotImplementedError + + +class ToolCompletionCallback(CompletionCallback): + def __init__(self, config: DictConfig, scheduler: "ChatCompletionScheduler"): + super().__init__(config, scheduler) + + # TODO: add reward manager to calculate reward score once a sample finish + + async def __call__(self, messages: list[dict[str, str]], completions: ChatCompletion, info: dict[str, Any]): + message = completions.choices[0].message.model_dump(exclude_unset=True, exclude_none=True) + if "content" not in message: + message["content"] = "" + messages.append(message) + finish_reason = completions.choices[0].finish_reason + + # STEP 0: check if we reach max turns + if self.max_assistant_turns and len(messages) >= self.max_assistant_turns: + print(f"[id={completions.id},turn={len(messages)},finish_reason={finish_reason}] Reach max turns, done!") + return + + # STEP 1: check if the model called tools + if finish_reason != "tool_calls": + print(f"[id={completions.id},turn={len(messages)},finish_reason={finish_reason}] No tool called, done!") + return + + # STEP 2: call tools + tool_calls = completions.choices[0].message.tool_calls + print(f"[id={completions.id},turn={len(messages)},finish_reason={finish_reason}] Call {len(tool_calls)} tools") + tasks = [] + for tool_call in tool_calls: + tasks.append(self._call_tool(tool_call)) + tool_responses = await asyncio.gather(*tasks) + if any(isinstance(item, Exception) for item in tool_responses): + print( + f"[id={completions.id},turn={len(messages)},finish_reason={finish_reason}] Error when calling tools, " + f"done!" + ) + return + messages.extend(tool_responses) + + # STEP 3: resubmit completion request with tool responses + self.scheduler.submit_chat_completions(messages=messages, request_id=completions.id, info=info) + + async def _call_tool(self, tool_call) -> dict[str, str]: + """Call tool and return tool response.""" + tool_name = tool_call.function.name + tool_args = json.loads(tool_call.function.arguments) + tool = self.tools[tool_name] + + instance_id = await tool.create() + try: + tool_response, tool_reward_score, tool_metrics = await tool.execute(instance_id, tool_args) + except Exception as e: + logger.exception(f"Error when executing tool: {e}") + return e + finally: + await tool.release(instance_id) + + return { + "role": "tool", + "content": tool_response, + "tool_call_id": tool_call.id, + } + + def postprocess(self, batch: DataProto, batch_conversations: list[list[dict[str, str]]], n: int) -> DataProto: + # NOTE: consistent with batch version of generate_sequences in vllm_rollout_spmd.py + # prompts: left pad + # responses: right pad + # input_ids: prompt + response + # attention_mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0] + # position_ids: [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11] + + # prompts: [prompt] from input dataset + prompts = [ + self.tokenizer.apply_chat_template( + prompt, tools=self.tool_schemas, add_generation_prompt=True, tokenize=False + ) + for prompt in batch.non_tensor_batch["raw_prompt"] + ] + assert len(batch_conversations) == len(prompts) * n + + # sequences: [prompt + response] + sequences = [ + self.tokenizer.apply_chat_template( + conversation, tools=self.tool_schemas, add_generation_prompt=False, tokenize=False + ) + for conversation in batch_conversations + ] + + # responses: [response] + responses = [sequence[len(prompts[i // n]) :] for i, sequence in enumerate(sequences)] + + prompts = self.tokenizer(prompts, return_tensors="pt", padding="longest", padding_side="left") + responses = self.tokenizer(responses, return_tensors="pt", padding="longest", padding_side="right") + if n > 1: + prompts["input_ids"] = prompts["input_ids"].repeat_interleave(n, dim=0) + prompts["attention_mask"] = prompts["attention_mask"].repeat_interleave(n, dim=0) + + # response_mask: response mask with tools calling masked out + response_mask = self._mask_out_tools_calling_tokens( + batch.non_tensor_batch["raw_prompt"].repeat(n, axis=0), + batch_conversations, + responses["input_ids"], + responses["attention_mask"], + ) + + input_ids = torch.cat([prompts["input_ids"], responses["input_ids"]], dim=1) + attention_mask = torch.cat([prompts["attention_mask"], responses["attention_mask"]], dim=1) + position_ids = (attention_mask.cumsum(dim=1) - 1) * attention_mask + + batch = TensorDict( + { + "prompts": prompts["input_ids"], # [bsz, prompt_length] + "responses": responses["input_ids"], # [bsz, response_length] + "response_mask": response_mask, # [bsz, response_length] + "input_ids": input_ids, # [bsz, prompt_length + response_length] + "attention_mask": attention_mask, # [bsz, prompt_length + response_length] + "position_ids": position_ids, # [bsz, prompt_length + response_length] + }, + batch_size=len(input_ids), + ) + + num_turns = np.array([len(conversation) for conversation in batch_conversations], dtype=np.int32) + return DataProto(batch=batch, non_tensor_batch={"__num_turns__": num_turns}) + + def _mask_out_tools_calling_tokens( + self, + raw_prompts: list[list[dict[str, str]]], + batch_conversations: list[list[dict[str, str]]], + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + ) -> torch.Tensor: + """Mask out tools calling tokens in the responses. + + Args: + raw_prompts: [prompt] from input dataset + batch_conversations: [prompt + response] + input_ids: responses tokens + attention_mask: responses attention mask + + Returns: + mask: (batch_size, response_length) + """ + batch_size = input_ids.size(0) + assert len(raw_prompts) == batch_size, f"{len(raw_prompts)} != {batch_size}" + assert len(batch_conversations) == batch_size, f"{len(batch_conversations)} != {batch_size}" + + # Deduplicate adjacent tool calls, since they're merged into one turn. + # [user, assistant, tool, tool, assistant] -> [user, assistant, tool, assistant] + # TODO: it's chat_template specific, find a more generic way to do this. + def deduplicate_adjacent_tool_calls(roles): + result = [] + for role, group in itertools.groupby(roles): + if role == "tool": + result.append(role) + else: + result.extend(group) + return result + + loss_mask = attention_mask.clone() + for i in range(batch_size): + responses = batch_conversations[i][len(raw_prompts[i]) :] + assert len(responses) > 0, f"responses is empty: {responses}" + + roles = deduplicate_adjacent_tool_calls([response["role"] for response in responses]) + # Each turn should be: [BOS]...[EOS] + eos_indices = input_ids[i].eq(self.tokenizer.eos_token_id).nonzero().squeeze(1)[: len(roles)] + for j in range(len(roles)): + if roles[j] == "tool": + bos = eos_indices[j - 1] + 1 if j > 0 else 0 + eos = eos_indices[j] + loss_mask[i, bos : eos + 1] = 0 + + return loss_mask + + +@deprecated("verl.experimental.agent_loop.AgentLoopManager") +class ChatCompletionScheduler: + def __init__( + self, + config: DictConfig, + server_addresses: list[str], + max_cache_size: int = 10000, + ): + """ + Args: + config: DictConfig. + server_addresses: List[str], OpenAI compatible server addresses. + max_cache_size: int, max cache size of request_id to address mapping. + """ + self.config = config.actor_rollout_ref.rollout + model_path = config.actor_rollout_ref.model.path + self.model_name = "/".join(model_path.split("/")[-2:]) + + # Least requests load balancing + self.weighted_addresses = [[0, address] for address in server_addresses] + heapq.heapify(self.weighted_addresses) + + # LRU cache to map request_id to address + self.request_id_to_address = LRUCache(maxsize=max_cache_size) + + self.background_tasks = set() + if self.config.multi_turn.completion_callback is None: + self.completion_callback = ToolCompletionCallback(config, self) + logger.warning("completion_callback is None, use ToolCompletionCallback") + else: + module_path, class_name = self.config.multi_turn.completion_callback.rsplit(".", 1) + module = importlib.import_module(module_path) + self.completion_callback = getattr(module, class_name)(config, self) + + def submit_chat_completions(self, *, messages: list[dict[str, str]], request_id: str, info: dict[str, Any]): + """Submit chat completion request without wait, completion_callback will be called when the request is done. + + Args: + messages: List of messages. + request_id: Request id. + info: Any other auxiliary information pass across multi-turn. + """ + info["__depth__"] += 1 + task = asyncio.create_task(self._submit_chat_completions_and_callback(messages, request_id, info)) + + # “fire-and-forget” background tasks + self.background_tasks.add(task) + task.add_done_callback(self.background_tasks.discard) + + async def _submit_chat_completions_and_callback( + self, + messages: list[dict[str, str]], + request_id: str, + info: dict[str, Any], + ): + """Submit chat completion request, wait request finish and do callback.""" + if request_id: + request_id = request_id.removeprefix("chatcmpl-") + assert request_id in self.request_id_to_address + address = self.request_id_to_address.pop(request_id) + else: + address = self.weighted_addresses[0][1] + self.weighted_addresses[0][0] += 1 + heapq.heapreplace(self.weighted_addresses, self.weighted_addresses[0]) + + # use new request_id to avoid duplicate request_id problem + request_id = uuid4().hex + self.request_id_to_address[request_id] = address + + completions, exception = None, None + try: + # NOTE: OpenAI client uses httpx, seems to have performance issue in high concurrency requests. + completions = await self._chat_completions_aiohttp( + address, + messages=messages, + tools=self.completion_callback.tool_schemas, + extra_body=self.completion_callback.extra_body, + extra_headers={"x-request-id": request_id}, + **info["__sampling_params__"], + ) + except Exception as e: + # Let user handle the exception + exception = e + + info["__depth__"] -= 1 + + if exception is not None: + logger.exception(f"chat completion failed with exception: {exception}") + else: + try: + await self.completion_callback(messages, completions, info) + except Exception as e: + logger.exception(f"completion callback failed with exception: {e}") + + # No more ongoing completion requests + if info["__depth__"] == 0: + info["__done__"].set() + + async def _chat_completions_openai(self, address: str, **chat_complete_request) -> ChatCompletion: + client = AsyncOpenAI(base_url=f"http://{address}/v1", api_key="token-abc123", timeout=None, max_retries=0) + return await client.chat.completions.create(**chat_complete_request) + + async def _chat_completions_aiohttp(self, address: str, **chat_complete_request) -> ChatCompletion: + try: + extra_body = chat_complete_request.pop("extra_body", {}) + chat_complete_request.update(extra_body or {}) + extra_headers = chat_complete_request.pop("extra_headers") + timeout = aiohttp.ClientTimeout(total=None) + session = aiohttp.ClientSession(timeout=timeout) + async with session.post( + url=f"http://{address}/v1/chat/completions", + headers={"Authorization": "Bearer token-abc123", **extra_headers}, + json=chat_complete_request, + ) as resp: + data = await resp.json() + return ChatCompletion(**data) + finally: + await session.close() + + async def generate_sequences(self, batch: DataProto) -> DataProto: + t_start = time.time() + kwargs = dict( + model=self.model_name, + temperature=self.config.temperature, + top_p=self.config.top_p, + ) + + # override sampling params for validation + if batch.meta_info.get("validate", False): + kwargs["top_p"] = self.config.val_kwargs.top_p + kwargs["temperature"] = self.config.val_kwargs.temperature + + print(f"[ChatCompletionScheduler] generate_sequences sampling params: {kwargs}") + + # NOTE: For multi-turn rollout, repeat raw_prompt n times and process each prompt independently, + # validation dataset has already been repeated in `PPOTrainer._validate`. + n = 1 if batch.meta_info.get("validate", False) else self.config.n + tasks, batch_conversations = [], [None] * len(batch) * n + for batch_index, conversation in enumerate(batch.non_tensor_batch["raw_prompt"].repeat(n, axis=0)): + # raw_prompt: [{"role": "user", "content": ""}, ["role": "assistant", "content"], ...] + batch_conversations[batch_index] = conversation.tolist() + + tasks.append( + asyncio.create_task( + self._submit_chat_completions_semaphore( + messages=batch_conversations[batch_index], + request_id=None, + sampling_params=kwargs, + ) + ) + ) + + await asyncio.gather(*tasks) + output_batch = self.completion_callback.postprocess(batch, batch_conversations, n=n) + output_batch.meta_info["timing"] = {"generate_sequences": time.time() - t_start} + print("[ChatCompletionScheduler] generate_sequences done") + return output_batch + + async def _submit_chat_completions_semaphore( + self, messages: list[dict[str, str]], request_id: str, sampling_params: dict[str, Any] + ): + done = asyncio.Event() + + info = { + "__done__": done, + "__depth__": 0, # indicate how many ongoing completion requests + "__sampling_params__": sampling_params, + } + + self.submit_chat_completions(messages=messages, request_id=request_id, info=info) + + # Wait until all completion requests are done + await done.wait() diff --git a/verl/workers/rollout/hf_rollout.py b/verl/workers/rollout/hf_rollout.py index 9fe3d649e..32d0bc8a5 100644 --- a/verl/workers/rollout/hf_rollout.py +++ b/verl/workers/rollout/hf_rollout.py @@ -28,7 +28,7 @@ from transformers import GenerationConfig from verl import DataProto -from verl.utils.device import get_torch_device +from verl.utils.device import get_device_name, get_torch_device from verl.utils.torch_functional import get_response_mask from .base import BaseRollout @@ -52,7 +52,7 @@ def generate_sequences(self, prompts: DataProto) -> DataProto: @torch.no_grad() def _generate_minibatch(self, prompts: DataProto) -> DataProto: - # make sampling args can be overriden by inputs + # make sampling args can be overridden by inputs do_sample = prompts.meta_info.get("do_sample", self.config.do_sample) is_validate = prompts.meta_info.get("validate", False) @@ -106,7 +106,7 @@ def _generate_minibatch(self, prompts: DataProto) -> DataProto: if isinstance(self.module, FSDP): # recurse need to set to False according to https://github.com/pytorch/pytorch/issues/100069 param_ctx = FSDP.summon_full_params(self.module, writeback=False, recurse=False) - with param_ctx, torch.autocast(device_type="cuda", dtype=torch.bfloat16): + with param_ctx, torch.autocast(device_type=get_device_name(), dtype=torch.bfloat16): output = self.module.generate( input_ids=idx, attention_mask=attention_mask, @@ -152,7 +152,9 @@ def _generate_minibatch(self, prompts: DataProto) -> DataProto: response_position_ids = position_ids[:, -1:] + delta_position_id position_ids = torch.cat([position_ids, response_position_ids], dim=-1) - response_attention_mask = get_response_mask(response_id=response, eos_token=eos_token_id, dtype=attention_mask.dtype) + response_attention_mask = get_response_mask( + response_id=response, eos_token=eos_token_id, dtype=attention_mask.dtype + ) attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1) batch = TensorDict( diff --git a/verl/workers/rollout/schemas.py b/verl/workers/rollout/schemas.py index 1c5df09fa..99f860acd 100644 --- a/verl/workers/rollout/schemas.py +++ b/verl/workers/rollout/schemas.py @@ -12,17 +12,27 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import difflib +import logging +import os from enum import Enum -from typing import Any, Dict, List, Literal, Optional +from typing import Any, Optional import torch -from pydantic import BaseModel -from transformers import PreTrainedTokenizer +from pydantic import BaseModel, ConfigDict, model_validator +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, ProcessorMixin from verl.tools.schemas import OpenAIFunctionToolCall, OpenAIFunctionToolSchema from verl.utils.model import compute_position_id_with_mask +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + +BASE_CHAT_HISTORY = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "I am a user."}, +] + class FinishReasonTypeEnum(str, Enum): """The enum for finish reason type.""" @@ -45,8 +55,8 @@ def from_str(cls, value: str) -> "FinishReasonTypeEnum": class Message(BaseModel): role: str - content: str - tool_calls: Optional[List[OpenAIFunctionToolCall]] = None + content: str | dict[str, Any] | list[dict[str, Any]] + tool_calls: Optional[list[OpenAIFunctionToolCall]] = None class AsyncRolloutRequestStateEnum(str, Enum): @@ -57,153 +67,420 @@ class AsyncRolloutRequestStateEnum(str, Enum): COMPLETED = "completed" FAILED = "failed" TOOL_CALLING = "tool_calling" + INTERACTING = "interacting" + + +class TokenizationSanityCheckModeEnum(str, Enum): + """The enum for tokenization sanity check mode.""" + + DISABLE = "disable" + STRICT = "strict" + IGNORE_STRIPPABLE = "ignore_strippable" class AsyncRolloutRequest(BaseModel): """The data model for async rollout.""" + model_config = ConfigDict(arbitrary_types_allowed=True) + batch_data_id: int = 0 rollout_offset: int = 0 request_id: str state: AsyncRolloutRequestStateEnum - messages: List[Message] - tools: Optional[List[OpenAIFunctionToolSchema]] = None - tools_kwargs: Dict[str, Any] = {} - input_ids: List[int] - prompt_ids: List[int] - response_ids: List[int] - attention_mask: List[int] - prompt_attention_mask: List[int] - response_attention_mask: List[int] - position_ids: List[int] - prompt_position_ids: List[int] - response_position_ids: List[int] - loss_mask: List[int] - prompt_loss_mask: List[int] - response_loss_mask: List[int] - reward_scores: Dict[str, float] + messages: list[Message] + multi_modal_keys: Optional[list[str]] = None + multi_modal_data: Optional[dict[str, Any]] = None + multi_modal_inputs: Optional[dict[str, torch.Tensor]] = None + tool_schemas: Optional[list[OpenAIFunctionToolSchema]] = None + tools_kwargs: dict[str, Any] = {} + interaction_kwargs: dict[str, Any] = {} + input_ids: Optional[torch.Tensor] = None + prompt_ids: Optional[torch.Tensor] = None + response_ids: Optional[torch.Tensor] = None + attention_mask: Optional[torch.Tensor] = None + prompt_attention_mask: Optional[torch.Tensor] = None + response_attention_mask: Optional[torch.Tensor] = None + position_ids: Optional[torch.Tensor] = None + prompt_position_ids: Optional[torch.Tensor] = None + response_position_ids: Optional[torch.Tensor] = None + loss_mask: Optional[torch.Tensor] = None + prompt_loss_mask: Optional[torch.Tensor] = None + response_loss_mask: Optional[torch.Tensor] = None + reward_scores: dict[str, float] + max_prompt_len: int max_response_len: int = 8192 max_model_len: int = 32768 - metrics: Dict[str, List[Any]] = {} - - format_config: dict = { - "chatml": { - "assistant_prefix_msg": "\n<|im_start|>assistant\n", - "assistant_suffix_msg": "<|im_end|>", - "tool_prefix_msg": "\n<|im_start|>tool\n", - "tool_suffix_msg": "<|im_end|>", - }, - "qwen": { - "assistant_prefix_msg": "\n<|im_start|>assistant\n", - "assistant_suffix_msg": "<|im_end|>", - "merge_tool_response": True, - "tool_prefix_msg": "\n<|im_start|>user", - "tool_suffix_msg": "<|im_end|>", - "tool_response_prefix_msg": "\n\n", - "tool_response_suffix_msg": "\n", - }, - } - - def get_generation_prompt(self, tokenizer: PreTrainedTokenizer) -> list[int]: - return tokenizer.apply_chat_template( # type: ignore - conversation=[msg.model_dump() for msg in self.messages], - tools=[tool.model_dump() for tool in self.tools] if self.tools else None, + metrics: dict[str, list[Any]] = {} + + use_inference_chat_template: bool + tokenization_sanity_check_mode: TokenizationSanityCheckModeEnum + generation_prompt_ids: Optional[torch.Tensor] = None + base_conv_wo_gen_prompt_end_pos: int + base_conv_with_gen_prompt_end_pos: int + + @model_validator(mode="before") + @classmethod + def initialize_request(cls, values): + if not (messages := values.get("messages")): + raise ValueError("messages is required for AsyncRolloutRequest initialization") + if not (max_prompt_len := values.get("max_prompt_len")): + raise ValueError("max_prompt_len is required for AsyncRolloutRequest initialization") + if not (processing_class := values.pop("processing_class", None)): + raise ValueError("processing_class is required for AsyncRolloutRequest initialization") + + values["messages"] = [Message.model_validate(msg) for msg in messages] + + # If there is no multi_modal_keys, we assume the multi-modal data is image and video. + if not values.get("multi_modal_keys"): + values["multi_modal_keys"] = ["image", "video"] + if not values.get("multi_modal_data"): + values["multi_modal_data"] = {key: [] for key in values["multi_modal_keys"]} + else: + # check if all multi_modal_keys are in multi_modal_data + for key in values["multi_modal_keys"]: + if key not in values["multi_modal_data"]: + values["multi_modal_data"][key] = [] + if not values.get("multi_modal_inputs"): + values["multi_modal_inputs"] = {} + + tools = ( + [tool.model_dump() for tool in tool_schemas] if (tool_schemas := values.get("tool_schemas", [])) else None + ) + + multi_modal_data = values["multi_modal_data"] + tokens_without_prompt = cls._handle_apply_chat_template( + processing_class, + messages, + multi_modal_data=multi_modal_data, + tools=tools, + add_generation_prompt=False, + tokenize=True, + ) + if ( + values.get("input_ids") is None + or values.get("attention_mask") is None + or values.get("position_ids") is None + ): + tokenization_dict_with_prompt = cls._handle_apply_chat_template( + processing_class, + messages, + multi_modal_data=multi_modal_data, + tools=tools, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + ) + + values["input_ids"], values["attention_mask"] = ( + tokenization_dict_with_prompt["input_ids"], + tokenization_dict_with_prompt["attention_mask"], + ) + if values["input_ids"].shape[-1] > max_prompt_len: + # Only log the warning to avoid truncating in the middle of generation prompt. Consider raising an + # error for this case in the future. + logger.warning( + f"Prompt {values['batch_data_id']} has length {values['input_ids'].shape[-1]} " + f"which is greater than max_prompt_len {max_prompt_len} after applied chat template with tools." + ) + + # Process multi_modal_inputs + multi_modal_inputs = tokenization_dict_with_prompt.copy() + multi_modal_inputs.pop("input_ids", None) + multi_modal_inputs.pop("attention_mask", None) + values["multi_modal_inputs"] = multi_modal_inputs + + values["position_ids"] = values["prompt_position_ids"] = cls._get_position_ids( + processing_class, values["input_ids"], values["attention_mask"], multi_modal_inputs + ) + + values["prompt_ids"], values["prompt_attention_mask"] = values["input_ids"], values["attention_mask"] + values["loss_mask"] = values["prompt_loss_mask"] = torch.zeros_like(values["input_ids"], dtype=torch.bool) + values["generation_prompt_ids"] = values["input_ids"][..., tokens_without_prompt.shape[-1] :] + values["base_conv_wo_gen_prompt_end_pos"] = cls._handle_apply_chat_template( + processing_class, + BASE_CHAT_HISTORY, + multi_modal_data=multi_modal_data, + tools=tools, + add_generation_prompt=False, + tokenize=True, + ).shape[-1] + + values["base_conv_with_gen_prompt_end_pos"] = cls._handle_apply_chat_template( + processing_class, + BASE_CHAT_HISTORY, + multi_modal_data=multi_modal_data, + tools=tools, add_generation_prompt=True, tokenize=True, + ).shape[-1] + + return values + + @staticmethod + def _handle_apply_chat_template( + processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin, + messages: list[Message], + multi_modal_data: dict[str, Any], + tools: Optional[list[OpenAIFunctionToolSchema]] = None, + add_generation_prompt: bool = False, + tokenize: bool = False, + return_dict: bool = False, + ): + raw_prompt = processing_class.apply_chat_template( + messages, tools=tools, add_generation_prompt=add_generation_prompt, tokenize=False ) + if not tokenize: + return raw_prompt - def add_assistant_message( + if isinstance(processing_class, PreTrainedTokenizer) or isinstance(processing_class, PreTrainedTokenizerFast): + if any(len(values) > 0 for values in multi_modal_data.values()): + logger.warning( + "There is multi_modal_data but you are not using a processor. Multi-modal data will be ignored." + ) + model_inputs = processing_class(text=[raw_prompt], return_tensors="pt") + elif isinstance(processing_class, ProcessorMixin): + # When we update multi_model_keys, we also need to update this logic + images = images if len(images := multi_modal_data.get("image", [])) > 0 else None + videos = videos if len(videos := multi_modal_data.get("video", [])) > 0 else None + model_inputs = processing_class(text=[raw_prompt], images=images, videos=videos, return_tensors="pt") + else: + raise ValueError(f"Unsupported processing class type: {type(processing_class)}") + + model_inputs = dict(model_inputs) + if return_dict: + return model_inputs + else: + return model_inputs["input_ids"] + + @staticmethod + def _get_position_ids( + processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + multi_modal_inputs: Optional[dict[str, torch.Tensor]] = None, + ) -> torch.Tensor: + # special case for qwen2vl + is_qwen2vl = ( + hasattr(processing_class, "image_processor") + and "Qwen2VLImageProcessor" in processing_class.image_processor.__class__.__name__ + ) + if is_qwen2vl: + from verl.models.transformers.qwen2_vl import get_rope_index + + image_grid_thw = video_grid_thw = second_per_grid_ts = None + if multi_modal_inputs: + image_grid_thw = multi_modal_inputs.get("image_grid_thw") + video_grid_thw = multi_modal_inputs.get("video_grid_thw") + second_per_grid_ts = multi_modal_inputs.get("second_per_grid_ts") + + assert input_ids.dim() == 2 and input_ids.shape[0] == 1, ( + f"input_ids should be 2D with batch size 1, but got shape {input_ids.shape}" + ) + assert attention_mask.dim() == 2 and attention_mask.shape[0] == 1, ( + f"attention_mask should be 2D with batch size 1, but got shape {attention_mask.shape}" + ) + new_position_ids = get_rope_index( + processing_class, + input_ids=input_ids.squeeze(0), + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + attention_mask=attention_mask.squeeze(0), + ) + return new_position_ids # (3, seq_len) + else: + return compute_position_id_with_mask(attention_mask) # (1, seq_len) + + def _update_input_ids( self, - tokenizer: PreTrainedTokenizer, - content: str, - tool_calls: Optional[List[OpenAIFunctionToolCall]] = None, - format: Literal["chatml", "qwen"] = "chatml", - already_over_long: bool = False, + processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin, + new_input_ids: torch.Tensor, + attention_mask: bool, + loss_mask: bool, + new_multi_modal_inputs: Optional[dict[str, torch.Tensor]] = None, ) -> None: - """Currently, we only support chatml format.""" - msg = Message(role="assistant", content=content, tool_calls=tool_calls) - self.messages.append(msg) - if tool_calls is not None: - content_with_tool_calls: str = tokenizer.apply_chat_template( # type: ignore - conversation=[msg.model_dump()], add_generation_prompt=False, tokenize=False + """ + Update the input_ids, attention_mask, position_ids, and loss_mask of the request in additive manner. + """ + self.input_ids = torch.cat([self.input_ids, new_input_ids], dim=-1) + attention_mask = torch.ones_like(new_input_ids) * int(attention_mask) + self.attention_mask = torch.cat([self.attention_mask, attention_mask], dim=-1) + loss_mask = torch.ones_like(new_input_ids) * int(loss_mask) + self.loss_mask = torch.cat([self.loss_mask, loss_mask], dim=-1) + + if new_multi_modal_inputs: + self._update_multi_modal_inputs(new_multi_modal_inputs) + + new_position_ids = self._get_position_ids( + processing_class, new_input_ids, attention_mask, new_multi_modal_inputs + ) + + last_pos = self.position_ids[..., -1:] + new_position_ids = new_position_ids + (last_pos + 1) + + self.position_ids = torch.cat([self.position_ids, new_position_ids], dim=-1) + + assert ( + self.input_ids.shape[-1] + == self.attention_mask.shape[-1] + == self.position_ids.shape[-1] + == self.loss_mask.shape[-1] + ), f"""Request {self.request_id} has different length of {self.input_ids.shape[-1]=}, + {self.attention_mask.shape[-1]=}, {self.position_ids.shape[-1]=}, {self.loss_mask.shape[-1]=}""" + + def _update_multi_modal_inputs(self, new_multi_modal_inputs: dict[str, torch.Tensor]) -> None: + """ + Update the multi_modal_inputs of the request in additive manner. + """ + for key in new_multi_modal_inputs: + input_tensor = new_multi_modal_inputs[key] + self.multi_modal_inputs[key] = ( + torch.cat([self.multi_modal_inputs[key], input_tensor], dim=0) + if key in self.multi_modal_inputs + else input_tensor ) + + def get_generation_prompt_ids( + self, processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin + ) -> list[int]: + """ + Get the generation prompt ids for rollout engine. + + Because rollout engine(SGLang) requires the ids to be a list, we need to convert the tensor to a list. + """ + generation_prompt_ids = ( + None + if self.input_ids[..., -self.generation_prompt_ids.shape[-1] :].eq(self.generation_prompt_ids).all() + else self.generation_prompt_ids + ) + if generation_prompt_ids is not None: + self._update_input_ids(processing_class, generation_prompt_ids, attention_mask=True, loss_mask=False) + + if self.use_inference_chat_template: + messages = [msg.model_dump() for msg in self.messages] + tools = [tool.model_dump() for tool in self.tool_schemas] if self.tool_schemas else None + generation_prompt_ids = self._handle_apply_chat_template( + processing_class, + messages, + multi_modal_data=self.multi_modal_data, + tools=tools, + add_generation_prompt=True, + tokenize=True, + ) + return generation_prompt_ids.squeeze(0).tolist() else: - content_with_tool_calls = content - # TODO: support other formats - if format in self.format_config: - prefix_msg = self.format_config[format]["assistant_prefix_msg"] - prefix_token_ids = tokenizer.encode(prefix_msg, add_special_tokens=False) - suffix_msg = self.format_config[format]["assistant_suffix_msg"] - suffix_token_ids = tokenizer.encode(suffix_msg, add_special_tokens=False) - if tool_calls is not None: - content = content_with_tool_calls.split(f"{prefix_msg}")[-1].split(f"{suffix_msg}")[0] - content_token_ids = tokenizer.encode(content, add_special_tokens=False) - if self.input_ids[-len(prefix_token_ids) :] == prefix_token_ids: - append_token_ids = content_token_ids - _loss_mask = [1] * len(content_token_ids) - elif self.input_ids[-len(suffix_token_ids) :] == suffix_token_ids: - append_token_ids = prefix_token_ids + content_token_ids - _loss_mask = [0] * len(prefix_token_ids) + [1] * len(content_token_ids) - else: - max_len = max(len(prefix_token_ids), len(suffix_token_ids)) - raise ValueError( - f"""Unsupported end of message format: - {tokenizer.decode(self.input_ids[-max_len:])}, - {tokenizer.decode(self.input_ids)=}, {self.messages=}""" - ) - if not already_over_long: - append_token_ids += suffix_token_ids - _loss_mask += [1] * len(suffix_token_ids) - self.input_ids += append_token_ids - _attention_mask = [1] * len(append_token_ids) - self.attention_mask += _attention_mask - _delta_position_ids = compute_position_id_with_mask(torch.tensor(_attention_mask)).tolist() - last_position_id = self.position_ids[-1] - _position_ids = [pos_id + last_position_id for pos_id in _delta_position_ids] - self.loss_mask += _loss_mask - self.position_ids += _position_ids - else: - raise ValueError(f"Unsupported format: {format}") - assert len(self.input_ids) == len(self.attention_mask) == len(self.position_ids) == len(self.loss_mask), f"""Request {self.request_id} has different length of {len(self.input_ids)=}, - {len(self.attention_mask)=}, {len(self.position_ids)=}, {len(self.loss_mask)=}""" - - def add_tool_response_message(self, tokenizer: PreTrainedTokenizer, content: str, last_tool: bool, format: Literal["chatml", "qwen"] = "chatml") -> None: - """Currently, we only support chatml format.""" - msg = Message(role="tool", content=content) - self.messages.append(msg) - # TODO: support other formats - if format in self.format_config: - merge_tool_responses = self.format_config[format].get("merge_tool_response", False) - prefix_msg = self.format_config[format]["tool_prefix_msg"] - prefix_token_ids = tokenizer.encode(prefix_msg, add_special_tokens=False) - suffix_msg = self.format_config[format]["tool_suffix_msg"] - suffix_token_ids = tokenizer.encode(suffix_msg, add_special_tokens=False) - prefix_resp = self.format_config[format].get("tool_response_prefix_msg", "") - prefix_resp_token_ids = tokenizer.encode(prefix_resp, add_special_tokens=False) - suffix_resp = self.format_config[format].get("tool_response_suffix_msg", "") - suffix_resp_token_ids = tokenizer.encode(suffix_resp, add_special_tokens=False) - full_suffix_token_ids = suffix_resp_token_ids + (suffix_token_ids if last_tool or not merge_tool_responses else []) - content_token_ids = tokenizer.encode(content, add_special_tokens=False) - if self.input_ids[-len(prefix_token_ids) :] == prefix_token_ids or self.input_ids[-len(suffix_resp_token_ids) :] == suffix_resp_token_ids: - append_token_ids = prefix_resp_token_ids + content_token_ids + full_suffix_token_ids - elif self.input_ids[-len(prefix_resp_token_ids) :] == prefix_resp_token_ids: - append_token_ids = content_token_ids + full_suffix_token_ids - elif self.input_ids[-len(suffix_token_ids) :] == suffix_token_ids: - append_token_ids = prefix_token_ids + prefix_resp_token_ids + content_token_ids + full_suffix_token_ids + return self.input_ids.squeeze(0).tolist() + + def add_user_message( + self, + processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin, + content: str, + ) -> None: + self.messages.append(Message(role="user", content=content)) + messages = [*BASE_CHAT_HISTORY, self.messages[-1]] + tools = [tool.model_dump() for tool in self.tool_schemas] if self.tool_schemas else None + + # We don't need to pass multi_modal_data here because we don't have any multi-modal data from Engine + # Inference, it is pure text. + content_ids = self._handle_apply_chat_template( + processing_class, messages, multi_modal_data={}, tools=tools, add_generation_prompt=False, tokenize=True + )[..., self.base_conv_wo_gen_prompt_end_pos :] + self._update_input_ids(processing_class, content_ids, attention_mask=True, loss_mask=False) + + def add_assistant_message( + self, + processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin, + content: str, + tool_calls: Optional[list[OpenAIFunctionToolCall]] = None, + ) -> None: + self.messages.append(Message(role="assistant", content=content, tool_calls=tool_calls)) + + messages = [*BASE_CHAT_HISTORY, self.messages[-1]] + tools = [tool.model_dump() for tool in self.tool_schemas] if self.tool_schemas else None + + # We don't need to pass multi_modal_data here because we don't have any multi-modal data from Engine + # Inference, it is pure text. + content_ids = self._handle_apply_chat_template( + processing_class, messages, multi_modal_data={}, tools=tools, add_generation_prompt=False, tokenize=True + )[..., self.base_conv_with_gen_prompt_end_pos :] + self._update_input_ids(processing_class, content_ids, attention_mask=True, loss_mask=True) + + def add_tool_response_messages( + self, + processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin, + contents: list[str | dict[str, Any]], + ) -> None: + if not contents: + return + # We also handle the case when tool returns image + # We require the processing of the image and video to be done at tool.execute() level + delta_multi_modal_data = {key: [] for key in self.multi_modal_keys} + for content in contents: + if isinstance(content, dict): + content_list = [] + # When we update multi_model_keys, we also need to update this logic + if "image" in content: + if not isinstance(content["image"], list): + raise ValueError( + f"Image must be a list, but got {type(content['image'])}. Please check the tool.execute(). " + f"For single images, wrap in a list: [image]. " + f"Example: {{'image': [img1]}} or {{'image': [img1, img2, ...]}}." + ) + + content_list.extend([{"type": "image"} for _ in content["image"]]) + delta_multi_modal_data["image"].extend(content["image"]) + if "video" in content: + if not isinstance(content["video"], list): + raise ValueError( + f"Video must be a list, but got {type(content['video'])}. Please check the tool.execute(). " + f"For single videos, wrap in a list: [video]. " + f"Example: {{'video': [video1]}} or {{'video': [video1, video2, ...]}}." + ) + + content_list.extend([{"type": "video"} for _ in content["video"]]) + delta_multi_modal_data["video"].extend(content["video"]) + if "text" in content: + content_list.append({"type": "text", "text": content["text"]}) + for key in content: + if key not in ["image", "video", "text"]: + logger.warning( + f"Tool response message contains unexpected key: {key} " + f"while we only support `image`, `video`, and `text`." + ) + self.messages.append(Message(role="tool", content=content_list)) else: - raise ValueError(f"Unsupported end of message format: {tokenizer.decode(self.input_ids[-len(prefix_token_ids) :])}") - self.input_ids += append_token_ids - _attention_mask = [1] * len(append_token_ids) - self.attention_mask += _attention_mask - _delta_position_ids = compute_position_id_with_mask(torch.tensor(_attention_mask)).tolist() - last_position_id = self.position_ids[-1] - _position_ids = [pos_id + last_position_id for pos_id in _delta_position_ids] - self.loss_mask += [0] * len(append_token_ids) - self.position_ids += _position_ids - else: - raise ValueError(f"Unsupported format: {format}") - assert len(self.input_ids) == len(self.attention_mask) == len(self.position_ids) == len(self.loss_mask), f"""Request {self.request_id} has different length of {len(self.input_ids)=}, - {len(self.attention_mask)=}, {len(self.position_ids)=}, {len(self.loss_mask)=}""" + self.messages.append(Message(role="tool", content=content)) + + messages = [*BASE_CHAT_HISTORY, *self.messages[-len(contents) :]] + tools = [tool.model_dump() for tool in self.tool_schemas] if self.tool_schemas else None + + for key in self.multi_modal_keys: + if len(delta_multi_modal_data[key]) > 0: + self.multi_modal_data[key].extend(delta_multi_modal_data[key]) + + # We just passed the new multi-modal data to the chat template to update the input_ids. + content_info = self._handle_apply_chat_template( + processing_class, + messages, + multi_modal_data=delta_multi_modal_data, + tools=tools, + add_generation_prompt=False, + tokenize=True, + return_dict=True, + ) + content_ids = content_info["input_ids"][..., self.base_conv_wo_gen_prompt_end_pos :] + + # process multi_modal_inputs + multi_modal_inputs = content_info.copy() + multi_modal_inputs.pop("input_ids", None) + multi_modal_inputs.pop("attention_mask", None) + self._update_input_ids( + processing_class, + content_ids, + attention_mask=True, + loss_mask=False, + new_multi_modal_inputs=multi_modal_inputs, + ) def update_metrics(self, metrics: Any, tool_id: str) -> None: """ @@ -213,31 +490,186 @@ def update_metrics(self, metrics: Any, tool_id: str) -> None: self.metrics[tool_id] = [] self.metrics[tool_id].append(metrics) + def _get_prompt_diffs( + self, + processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin, + full_prompt_ids: torch.Tensor, + current_prompt_ids: torch.Tensor, + diff_surrounding_chars: int = 10, + ) -> list[dict[str, Any]]: + """Get differences between full prompt and current prompt with surrounding context. + + This function helps debug tokenization mismatches by showing the differences between + full prompt and current prompt with surrounding context. Instead of just showing + the exact diff, it includes additional tokens before and after to help locate + the issue in the chat template. + + For example, if the actual diff is a newline change from "\n\n" to "\n", with + diff_surrounding_chars the output might look like: + + full_prompt_chunk: "<|im_start|>assistant\n\nI think..." + current_prompt_chunk: "<|im_start|>assistant\nI think..." + + This context makes it much easier to identify where in the chat template the + mismatch occurs. + + Args: + processing_class: The processing class to use for decoding the token IDs + full_prompt_ids: Token IDs from applying chat template to all messages at once + current_prompt_ids: Token IDs from incremental chat template application + diff_surrounding_chars: Number of surrounding characters to include for context (default: 10) + + Returns: + List of dicts containing the differing chunks with context and their indices + """ + full_prompt_ids = full_prompt_ids.squeeze(0) + current_prompt_ids = current_prompt_ids.squeeze(0) + full_prompt = processing_class.decode(full_prompt_ids, skip_special_tokens=False) + current_prompt = processing_class.decode(current_prompt_ids, skip_special_tokens=False) + s = difflib.SequenceMatcher(None, full_prompt, current_prompt, autojunk=False) + diffs = [] + for tag, i1, i2, j1, j2 in s.get_opcodes(): + if tag == "equal": + continue + + # Get the surrounding context for better readability + start_i = max(0, i1 - diff_surrounding_chars) + end_i = min(len(full_prompt), i2 + diff_surrounding_chars) + start_j = max(0, j1 - diff_surrounding_chars) + end_j = min(len(current_prompt), j2 + diff_surrounding_chars) + + diffs.append( + { + "full_prompt_chunk": full_prompt[start_i:end_i], + "current_prompt_chunk": current_prompt[start_j:end_j], + "indices": (start_i, end_i, start_j, end_j), + } + ) + return diffs + def finalize( self, - tokenizer: PreTrainedTokenizer, - reward_scores: Dict[str, float], + processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin, + reward_scores: dict[str, list[float]], finish_reason_type: FinishReasonTypeEnum = FinishReasonTypeEnum.STOP, ) -> None: self.state = AsyncRolloutRequestStateEnum.COMPLETED self.reward_scores = reward_scores - self.response_ids = self.input_ids[len(self.prompt_ids) :] + + # In case we failed to generate the assistant message and the generation prompt ids were already added to + # input_ids, remove them from the end of input_ids + if self.input_ids[..., -self.generation_prompt_ids.shape[-1] :].eq(self.generation_prompt_ids).all(): + self.input_ids = self.input_ids[..., : -self.generation_prompt_ids.shape[-1]] + self.attention_mask = self.attention_mask[..., : -self.generation_prompt_ids.shape[-1]] + self.position_ids = self.position_ids[..., : -self.generation_prompt_ids.shape[-1]] + self.loss_mask = self.loss_mask[..., : -self.generation_prompt_ids.shape[-1]] + + self.response_ids = self.input_ids[..., self.prompt_ids.shape[-1] :] + + if self.tokenization_sanity_check_mode != TokenizationSanityCheckModeEnum.DISABLE: + # When there is a diff, we log the diffs with diff_surrounding_chars context + diff_surrounding_chars = 10 + + messages = [msg.model_dump() for msg in self.messages] + tools = [tool.model_dump() for tool in self.tool_schemas] if self.tool_schemas else None + full_prompt_info = self._handle_apply_chat_template( + processing_class, + messages, + multi_modal_data=self.multi_modal_data, + tools=tools, + add_generation_prompt=False, + tokenize=True, + return_dict=True, + ) + full_prompt_ids = full_prompt_info["input_ids"] + + # We must use dict(full_prompt_info) to convert BatchFeature values to a new dict + # because np.array() only keeps the keys for BatchFeature. + full_prompt_multi_modal_inputs = full_prompt_info.copy() + full_prompt_multi_modal_inputs.pop("input_ids", None) + full_prompt_multi_modal_inputs.pop("attention_mask", None) + + for multi_modal_inputs_key in self.multi_modal_inputs: + if multi_modal_inputs_key in full_prompt_multi_modal_inputs: + if ( + not self.multi_modal_inputs[multi_modal_inputs_key] + .eq(full_prompt_multi_modal_inputs[multi_modal_inputs_key]) + .all() + ): + logger.warning( + f"Multi-modal data {multi_modal_inputs_key} is not consistent. " + f"This may lead to unexpected behavior during training. " + f"Please review your multi_modal_inputs logic." + ) + else: + logger.warning( + f"Multi-modal inputs key {multi_modal_inputs_key} is not found in the multi_modal_inputs. " + f"This may lead to unexpected behavior during training." + f"Please review your multi_modal_inputs logic." + ) + + if diffs := self._get_prompt_diffs( + processing_class, full_prompt_ids, self.input_ids, diff_surrounding_chars=diff_surrounding_chars + ): + log_warning = False + if self.tokenization_sanity_check_mode == TokenizationSanityCheckModeEnum.STRICT: + log_warning = True + elif self.tokenization_sanity_check_mode == TokenizationSanityCheckModeEnum.IGNORE_STRIPPABLE: + non_strippable_diffs_exist = any( + d["full_prompt_chunk"].strip() or d["current_prompt_chunk"].strip() for d in diffs + ) + if non_strippable_diffs_exist: + log_warning = True + + if log_warning: + mode_str = f" ({self.tokenization_sanity_check_mode.value})" + logger.warning( + f"Inconsistent training and inference tokenization detected{mode_str}. This may lead to " + f"unexpected behavior during training. Please review your chat template to determine if this " + f"is intentional. For more information, refer to the multiturn README.md." + ) + logger.warning( + f"Showing {diff_surrounding_chars} characters before and after the diffs for context and " + f"better readability." + ) + diff_details_list = [] + for d in diffs: + i1, i2, j1, j2 = d["indices"] + diff_details_list.append( + f"idx {i1}:{i2} -> {j1}:{j2} | full_prompt_chunk: {repr(d['full_prompt_chunk'])} | " + f"current_prompt_chunk: {repr(d['current_prompt_chunk'])}" + ) + diff_details = "\n".join(diff_details_list) + logger.warning(f"Found differences:\n{diff_details}") + if finish_reason_type == FinishReasonTypeEnum.STOP: pass elif finish_reason_type == FinishReasonTypeEnum.LENGTH: pass else: raise ValueError(f"Unsupported finalize finish reason type: {finish_reason_type}") - self.truncate_output_ids(tokenizer) - assert len(self.input_ids) == len(self.attention_mask) == len(self.position_ids) == len(self.loss_mask), f"""Request {self.request_id} has different length of {len(self.input_ids)=}, - {len(self.attention_mask)=}, {len(self.position_ids)=}, {len(self.loss_mask)=}""" - - def truncate_output_ids(self, tokenizer: PreTrainedTokenizer) -> None: - self.input_ids = self.input_ids[: self.max_model_len] - self.attention_mask = self.attention_mask[: self.max_model_len] - self.position_ids = self.position_ids[: self.max_model_len] - self.loss_mask = self.loss_mask[: self.max_model_len] - self.response_ids = self.input_ids[len(self.prompt_ids) :][: self.max_response_len] - self.response_attention_mask = self.attention_mask[len(self.prompt_attention_mask) :][: self.max_response_len] - self.response_position_ids = self.position_ids[len(self.prompt_position_ids) :][: self.max_response_len] - self.response_loss_mask = self.loss_mask[len(self.prompt_loss_mask) :][: self.max_response_len] + self.truncate_output_ids(processing_class) + + assert ( + self.input_ids.shape[-1] + == self.attention_mask.shape[-1] + == self.position_ids.shape[-1] + == self.loss_mask.shape[-1] + ), f"""Request {self.request_id} has different length of {self.input_ids.shape[-1]=}, + {self.attention_mask.shape[-1]=}, {self.position_ids.shape[-1]=}, {self.loss_mask.shape[-1]=}""" + + def truncate_output_ids( + self, processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin + ) -> None: + self.input_ids = self.input_ids[..., : self.max_model_len] + self.attention_mask = self.attention_mask[..., : self.max_model_len] + self.position_ids = self.position_ids[..., : self.max_model_len] + self.loss_mask = self.loss_mask[..., : self.max_model_len] + self.response_ids = self.input_ids[..., self.prompt_ids.shape[-1] :][..., : self.max_response_len] + self.response_attention_mask = self.attention_mask[..., self.prompt_attention_mask.shape[-1] :][ + ..., : self.max_response_len + ] + self.response_position_ids = self.position_ids[..., self.prompt_position_ids.shape[-1] :][ + ..., : self.max_response_len + ] + self.response_loss_mask = self.loss_mask[..., self.prompt_loss_mask.shape[-1] :][..., : self.max_response_len] diff --git a/verl/workers/rollout/sglang_rollout/async_sglang_server.py b/verl/workers/rollout/sglang_rollout/async_sglang_server.py index b3a368188..df26765c2 100644 --- a/verl/workers/rollout/sglang_rollout/async_sglang_server.py +++ b/verl/workers/rollout/sglang_rollout/async_sglang_server.py @@ -14,6 +14,7 @@ # limitations under the License. import asyncio import logging +from typing import Any import ray from omegaconf import DictConfig @@ -29,44 +30,57 @@ class AsyncSglangServer(AsyncServerBase): def __init__(self, config: DictConfig, dp_size: int, dp_rank: int, wg_prefix: str): super().__init__() - self.config = config - rollout_config = config.get("rollout", {}) - self._tp_size = rollout_config.get("tensor_model_parallel_size", 1) + self.config = config.actor_rollout_ref + self._tp_size = self.config.rollout.get("tensor_model_parallel_size", 1) self._dp_size = dp_size self._dp_rank = dp_rank self.wg_prefix = wg_prefix self.workers = [] + self.master_worker = None async def init_engine(self): + if self.workers: + # avoid init twice + return all_actors = ray.util.list_named_actors(all_namespaces=True) - matched_actors = [actor for actor in all_actors if actor.get("name", None).startswith(self.wg_prefix + "WorkerDict_")] + matched_actors = [ + actor for actor in all_actors if actor.get("name", None).startswith(self.wg_prefix + "WorkerDict_") + ] - # TODO support multi node for matched_actor in matched_actors: - current_rank = int(matched_actor["name"].split(":")[-1]) + fields = matched_actor["name"].split(":") + assert len(fields) == 2, f"invalid actor name: {matched_actor['name']}" + pg_index, local_rank = int(fields[0].split("_")[-1]), int(fields[1]) - # send to all works in this tp group, because sglang is SPMD - if current_rank >= self._dp_rank * self._tp_size and current_rank < (self._dp_rank + 1) * self._tp_size: - self.workers.append(ray.get_actor(**matched_actor)) + if (self._dp_size * pg_index + local_rank) // self._tp_size == self._dp_rank: + worker = ray.get_actor(**matched_actor) + self.workers.append(worker) + if (self._dp_size * pg_index + local_rank) / self._tp_size == self._dp_rank: + self.master_worker = worker async def chat_completion(self, raw_request: Request): request = await raw_request.json() - output_dp_lst = [] - for worker in self.workers: - output_future = worker.execute_method.remote("chat_completion", request) - output_dp_lst.append(output_future) - outputs = await asyncio.gather(*output_dp_lst) + # only send request to master worker in tp rank 0 + output_future = self.master_worker.chat_completion.remote(request) + [outputs] = await asyncio.gather(output_future) + return JSONResponse(outputs) - for output in outputs: - if output is not None: - return JSONResponse(output) - raise RuntimeError("AsyncSglangServer No output from workers self._dp_rank: {self._dp_rank}, self._tp_size: {self._tp_size}, self.workers: {self.workers}") + async def generate(self, prompt_ids: list[int], sampling_params: dict[str, Any], request_id: str) -> list[int]: + return await self.master_worker.generate.remote(prompt_ids, sampling_params, request_id) async def wake_up(self): - for worker in self.workers: - worker.resume.remote() + if not self.config.rollout.free_cache_engine: + return + + tasks = [worker.wake_up.remote() for worker in self.workers] + if tasks: + await asyncio.gather(*tasks) async def sleep(self): - for worker in self.workers: - worker.offload.remote() + if not self.config.rollout.free_cache_engine: + return + + tasks = [worker.sleep.remote() for worker in self.workers] + if tasks: + await asyncio.gather(*tasks) diff --git a/verl/workers/rollout/sglang_rollout/sglang_rollout.py b/verl/workers/rollout/sglang_rollout/sglang_rollout.py index cf4e20580..3c6694325 100644 --- a/verl/workers/rollout/sglang_rollout/sglang_rollout.py +++ b/verl/workers/rollout/sglang_rollout/sglang_rollout.py @@ -17,42 +17,51 @@ import asyncio import logging +import multiprocessing as mp import os import time -from contextlib import contextmanager from copy import deepcopy from json import JSONDecodeError -from typing import Union +from typing import Any, List, Optional, Tuple from uuid import uuid4 import numpy as np +import sglang.srt.entrypoints.engine import torch import torch.distributed as dist from omegaconf import DictConfig -from sglang.srt.entrypoints.engine import Engine -from sglang.srt.openai_api.protocol import Tool +from sglang.srt.managers.tokenizer_manager import ( + ReleaseMemoryOccupationReqInput, + ResumeMemoryOccupationReqInput, + UpdateWeightsFromTensorReqInput, +) from sglang.srt.sampling.sampling_params import SamplingParams -from sglang.srt.utils import get_ip, get_open_port +from sglang.srt.server_args import ServerArgs +from sglang.srt.utils import ( + MultiprocessingSerializer, + assert_pkg_version, + get_ip, + get_open_port, + is_cuda, + maybe_set_triton_cache_manager, + set_prometheus_multiproc_dir, + set_ulimit, +) from tensordict import TensorDict from torch.distributed.device_mesh import DeviceMesh, init_device_mesh from torch.nn.utils.rnn import pad_sequence -from transformers import PreTrainedTokenizer +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, ProcessorMixin from verl import DataProto +from verl.interactions.base import BaseInteraction +from verl.interactions.utils.interaction_registry import initialize_interactions_from_config from verl.third_party.sglang import parallel_state as sglang_ps from verl.tools.base_tool import BaseTool -from verl.tools.schemas import ( - OpenAIFunctionCallSchema, - OpenAIFunctionParsedSchema, - OpenAIFunctionToolCall, -) -from verl.utils.debug import GPUMemoryLogger -from verl.utils.model import compute_position_id_with_mask +from verl.tools.schemas import OpenAIFunctionCallSchema, OpenAIFunctionParsedSchema, OpenAIFunctionToolCall +from verl.tools.utils.tool_registry import initialize_tools_from_config from verl.utils.net_utils import is_ipv6 -from verl.utils.torch_functional import ( - get_response_mask, - pad_sequence_to_length, -) +from verl.utils.profiler import GPUMemoryLogger +from verl.utils.torch_functional import get_response_mask, pad_sequence_to_length from verl.workers.rollout.base import BaseRollout from verl.workers.rollout.schemas import ( AsyncRolloutRequest, @@ -67,28 +76,139 @@ except ImportError: from sglang.srt.function_call_parser import FunctionCallParser +try: + from sglang.srt.entrypoints.openai.protocol import Tool +except ImportError: + from sglang.srt.openai_api.protocol import Tool + logger = logging.getLogger(__file__) logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) +# patch to avoid issue https://github.com/sgl-project/sglang/issues/6723 +def _set_envs_and_config(server_args: ServerArgs): + # Set global environments + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" + os.environ["NCCL_CUMEM_ENABLE"] = "0" + os.environ["NCCL_NVLS_ENABLE"] = str(int(server_args.enable_nccl_nvls)) + os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" + os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4" + os.environ["CUDA_MODULE_LOADING"] = "AUTO" + + # Set prometheus env vars + if server_args.enable_metrics: + set_prometheus_multiproc_dir() + + # Set ulimit + set_ulimit() + + # Fix triton bugs + if server_args.tp_size * server_args.dp_size > 1: + # FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency. + maybe_set_triton_cache_manager() + + # Check flashinfer version + if server_args.attention_backend == "flashinfer": + assert_pkg_version( + "flashinfer_python", + "0.2.5", + "Please uninstall the old version and reinstall the latest version by following the instructions at https://docs.flashinfer.ai/installation.html.", + ) + if is_cuda(): + assert_pkg_version( + "sgl-kernel", + "0.1.1", + "Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`", + ) + + # Set mp start method + mp.set_start_method("spawn", force=True) + + +sglang.srt.entrypoints.engine._set_envs_and_config = _set_envs_and_config + + +# because chatCompletion is an async method, it makes the whole ray actor be an async actor +# which can not call loop.run_until_complete. So we need to make the engine to be an async class +class AsyncEngine(sglang.srt.entrypoints.engine.Engine): + def __init__(self, **kwargs): + super().__init__(**kwargs) + # default to use dummy load format, which need to reload weights in first time + self._need_reload = True + + async def release_memory_occupation(self, tags: Optional[list[str]] = None): + """Release GPU occupation temporarily.""" + if tags is None: + obj = ReleaseMemoryOccupationReqInput() + else: + obj = ReleaseMemoryOccupationReqInput(tags=tags) + return await self.tokenizer_manager.release_memory_occupation(obj, None) + + async def resume_memory_occupation(self, tags: Optional[list[str]] = None): + """Resume GPU occupation.""" + # because __init__ is a sync method, it can not call the async release_memory_occupation + # have to move release_memory_occupation from __init__ to here + # For multi-stage awake, we run release weight and kv_cache when we resume weights for the first time. + if self._need_reload: + await self.release_memory_occupation() + self._need_reload = False + + if tags is None: + obj = ResumeMemoryOccupationReqInput() + else: + obj = ResumeMemoryOccupationReqInput(tags=tags) + return await self.tokenizer_manager.resume_memory_occupation(obj, None) + + async def update_weights_from_tensor( + self, + named_tensors: List[Tuple[str, torch.Tensor]], # noqa: UP006 + load_format: Optional[str] = None, + flush_cache: bool = True, + ): + """Update weights from distributed source. If there are going to be more updates, set `flush_cache` to be false + to avoid duplicated cache cleaning operation.""" + obj = UpdateWeightsFromTensorReqInput( + serialized_named_tensors=[ + MultiprocessingSerializer.serialize(named_tensors) for _ in range(self.server_args.tp_size) + ], + load_format=load_format, + flush_cache=flush_cache, + ) + return await self.tokenizer_manager.update_weights_from_tensor(obj, None) + + async def flush_cache(self): + return await self.tokenizer_manager.flush_cache() + + # NOTE(sgm): add for verl. We can optimize it by making # the dataloader yield List[int] without padding. def _pre_process_inputs( pad_token_id, prompt_token_ids: torch.Tensor, -) -> list[int]: +) -> torch.Tensor: # remove the left padding in the prompt token_id non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0] - token_ids = prompt_token_ids[non_pad_index:].tolist() - return token_ids + return prompt_token_ids[non_pad_index:] # NOTE(linjunrong): adhoc -def _post_process_outputs(tokenizer, output): +def _post_process_outputs(processing_class, output): + try: + # This is when processing_class is a processor + tokenizer = processing_class.tokenizer + except AttributeError: + try: + # This is when processing_class is a tokenizer + tokenizer = processing_class + except AttributeError as e: + raise ValueError(f"Cannot get tokenizer from processing_class {processing_class}") from e + def _map_each_response(resp): output_token_logprobs = resp["meta_info"]["output_token_logprobs"] - log_probs, output_token_ids = zip(*[(log_prob, token_ids) for log_prob, token_ids, _ in output_token_logprobs]) + log_probs, output_token_ids = zip( + *[(log_prob, token_ids) for log_prob, token_ids, _ in output_token_logprobs], strict=True + ) return torch.tensor(output_token_ids), torch.tensor(log_probs) out_map = map(lambda x: _map_each_response(x), output) @@ -104,14 +224,28 @@ def _map_each_response(resp): return batched_output_token_ids, batched_logprobs -def get_tool_call_parser_type(tokenizer: PreTrainedTokenizer) -> str: +def get_tool_call_parser_type( + processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin, +) -> str: items = FunctionCallParser.ToolCallParserEnum.items() for parser_type, parser_cls in items: parser = parser_cls() - if parser.bot_token in tokenizer.get_vocab() and (parser.eot_token == "" or parser.eot_token in tokenizer.get_vocab()): + try: + # This is when processing_class is a tokenizer + tokenizer_vocab = processing_class.get_vocab() + except AttributeError: + try: + # This is when processing_class is a processor + tokenizer_vocab = processing_class.tokenizer.get_vocab() + except AttributeError as e: + raise ValueError(f"Cannot get vocab from processing_class {processing_class}") from e + + if parser.bot_token.strip() in tokenizer_vocab and ( + parser.eot_token == "" or parser.eot_token.strip() in tokenizer_vocab + ): return parser_type else: - raise ValueError(f"No tool call parser found for tokenizer {tokenizer}") + raise ValueError(f"No tool call parser found for processing_class {processing_class}") class SGLangRollout(BaseRollout): @@ -119,7 +253,7 @@ def __init__( self, actor_module: str, config: DictConfig, - tokenizer, + processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin, model_hf_config, port=None, trust_remote_code: bool = False, @@ -134,7 +268,7 @@ def __init__( config: A DictConfig object containing SGLang-specific operational parameters and rollout settings. Refer to https://docs.sglang.ai/backend/server_arguments.html - tokenizer: The tokenizer instance compatible with the actor_module. + processing_class: The tokenizer or processor instance compatible with the actor_module. model_hf_config: The Hugging Face model's configuration (e.g., `transformers.PretrainedConfig`). It provides architectural details and hyperparameters like `max_position_embeddings`, @@ -159,12 +293,15 @@ def __init__( self._tool_call_parser_type, self._sgl_tools, self._function_call_parser, - ) = self._initialize_tools(config, tokenizer) + ) = self._initialize_tools(config, processing_class) + self.interaction_map: dict[str, BaseInteraction] = self._initialize_interactions(config) # If turn on `free_cache_engine`, SGLang engine's KV cache # will be freed after each `generate_sequences` call. - assert not (not config.enforce_eager and config.free_cache_engine), "disable CUDA graph (enforce_eager = False) if free cache engine" - - logger.info(f"tool_schemas: {self._tool_schemas}, tool_map: {self._tool_map}, tool_call_parser_type: {self._tool_call_parser_type}, sgl_tools: {self._sgl_tools}, function_call_parser: {self._function_call_parser}") + logger.info( + f"tool_schemas: {self._tool_schemas}, tool_map: {self._tool_map}, tool_call_parser_type: " + f"{self._tool_call_parser_type}, sgl_tools: {self._sgl_tools}, function_call_parser: " + f"{self._function_call_parser}" + ) self._init_distributed_env(device_mesh_cpu=device_mesh, **kwargs) @@ -174,14 +311,25 @@ def __init__( self._init_sampling_params(**kwargs) - self.tokenizer = tokenizer - self.pad_token_id = tokenizer.pad_token_id + self.processing_class = processing_class + + try: + # This is when processing_class is a tokenizer + self.pad_token_id = self.processing_class.pad_token_id + except AttributeError: + try: + # This is when processing_class is a processor + self.pad_token_id = self.processing_class.tokenizer.pad_token_id + except AttributeError as e: + raise ValueError(f"Cannot get pad_token_id from processing_class {self.processing_class}") from e def _init_distributed_env(self, device_mesh_cpu, **kwargs): self._device_mesh_cpu = device_mesh_cpu os.environ.setdefault("SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK", "true") self.tensor_parallel_size = self.config.get("tensor_model_parallel_size", 1) - assert self.tensor_parallel_size <= dist.get_world_size(), "tensor parallel size should be less than or equal to the world size" + assert self.tensor_parallel_size <= dist.get_world_size(), ( + "tensor parallel size should be less than or equal to the world size" + ) self.train_tp = kwargs.get("train_tp", None) if self.train_tp is not None: # deployed with megatron @@ -214,19 +362,55 @@ def _init_distributed_env(self, device_mesh_cpu, **kwargs): # get tp_rank of this process in this tp group visible_devices = [None] * self._device_mesh_cpu.size(1) - torch.distributed.all_gather_object(visible_devices, os.environ["CUDA_VISIBLE_DEVICES"], self._device_mesh_cpu.get_group("tp")) + torch.distributed.all_gather_object( + visible_devices, os.environ["CUDA_VISIBLE_DEVICES"], self._device_mesh_cpu.get_group("tp") + ) self.visible_devices_set = set(",".join(visible_devices).split(",")) os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(sorted(list(self.visible_devices_set))) def _verify_config(self, model_hf_config): if not self.config.get("max_model_len", None): self.config.max_model_len = self.config.prompt_length + self.config.response_length - assert self.config.max_model_len >= self.config.prompt_length + self.config.response_length, f"""max_model_len should be greater than total sequence length (prompt_length + response_length): + assert ( + self.config.max_model_len >= self.config.prompt_length + self.config.response_length + ), f"""max_model_len should be greater than total sequence length (prompt_length + response_length): {self.config.max_model_len} >= {self.config.prompt_length} + {self.config.response_length}""" - assert model_hf_config.max_position_embeddings >= self.config.max_model_len, "model context length should be greater than total sequence length" - # currently max_turns stand for max number of tool calls - if self.config.multi_turn.max_turns is None: - self.config.multi_turn.max_turns = self.config.max_model_len // 3 + max_position_embeddings = None + if hasattr(model_hf_config, "max_position_embeddings"): + max_position_embeddings = model_hf_config.max_position_embeddings + elif hasattr(model_hf_config, "llm_config") and hasattr(model_hf_config.llm_config, "max_position_embeddings"): + max_position_embeddings = model_hf_config.llm_config.max_position_embeddings + elif hasattr(model_hf_config, "text_config") and hasattr( + model_hf_config.text_config, "max_position_embeddings" + ): + max_position_embeddings = model_hf_config.text_config.max_position_embeddings + if max_position_embeddings is None: + raise ValueError("max_position_embeddings not found in model_hf_config") + rope_scaling_config = getattr(model_hf_config, "rope_scaling", None) + if not rope_scaling_config: + assert max_position_embeddings >= self.config.prompt_length + self.config.response_length, ( + "model context length should be greater than total sequence length" + ) + else: + # handle type where there's a length extend factor + # see https://qwen.readthedocs.io/en/latest/deployment/vllm.html#extended-context-support + # for using yarn as an example + rope_scaling_factor = rope_scaling_config.get("factor", 1.0) + + assert ( + model_hf_config.max_position_embeddings * rope_scaling_factor + >= self.config.prompt_length + self.config.response_length + ), ( + f"model context length should be greater than total sequence length, " + f"got rope_scaling_factor={rope_scaling_factor} and " + f"max_position_embeddings={model_hf_config.max_position_embeddings}" + ) + + # currently max_assistant_turns stand for max number of tool calls + if self.config.multi_turn.max_assistant_turns is None: + self.config.multi_turn.max_assistant_turns = self.config.max_model_len // 3 + if self.config.multi_turn.max_user_turns is None: + self.config.multi_turn.max_user_turns = self.config.max_model_len // 3 def _init_inference_engine(self, trust_remote_code, actor_module, port): # initialize the inference engine @@ -253,7 +437,7 @@ def _init_inference_engine(self, trust_remote_code, actor_module, port): if first_rank_in_node: rank = dist.get_rank() os.environ["SGLANG_BLOCK_NONZERO_RANK_CHILDREN"] = "0" - self._engine = Engine( + self._engine = AsyncEngine( model_path=actor_module, dtype=self.config.dtype, mem_fraction_static=self.config.gpu_memory_utilization, @@ -276,14 +460,15 @@ def _init_inference_engine(self, trust_remote_code, actor_module, port): # log_requests=True, # log_requests_level=2, # max_running_requests=1, + mm_attention_backend="fa3", + attention_backend="fa3", + # In async mode, we want token in token out. + skip_tokenizer_init=self.config.mode == "async", ) else: self._engine = None self.sharding_manager = None - # offload - if self._tp_rank == 0: - self._engine.release_memory_occupation() self.is_sleep = True def _init_sampling_params(self, **kwargs): @@ -296,11 +481,12 @@ def _init_sampling_params(self, **kwargs): ) # supporting adding any sampling params from the config file for k in self.config.keys(): - if hasattr(SamplingParams(), str(k)): + if hasattr(SamplingParams(), str(k)) or "stop" in str(k): kwargs[k] = self.config.get(k) + kwargs["n"] = 1 # already repeat in ray_trainer self.sampling_params = kwargs - def _initialize_tools(self, config, tokenizer): + def _initialize_tools(self, config, processing_class): """Initialize tools from configuration. Args: @@ -327,48 +513,13 @@ def _initialize_tools(self, config, tokenizer): if config.multi_turn.tool_config_path is None: return [], {}, None, [], None - import importlib.util - import sys - - from omegaconf import OmegaConf - - from verl.tools.schemas import OpenAIFunctionToolSchema - - def initialize_tools_from_config(tools_config) -> list: - tool_list = [] - - for tool_config in tools_config.tools: - cls_name = tool_config.class_name - module_name, class_name = cls_name.rsplit(".", 1) - - if module_name not in sys.modules: - spec = importlib.util.find_spec(module_name) - module = importlib.util.module_from_spec(spec) - sys.modules[module_name] = module - spec.loader.exec_module(module) - else: - module = sys.modules[module_name] - - tool_cls = getattr(module, class_name) - - tool_schema_dict = OmegaConf.to_container(tool_config.tool_schema, resolve=True) - tool_schema = OpenAIFunctionToolSchema.model_validate(tool_schema_dict) - - tool = tool_cls( - config=OmegaConf.to_container(tool_config.config, resolve=True), - tool_schema=tool_schema, - ) - tool_list.append(tool) - - return tool_list - tools_config_file = config.multi_turn.tool_config_path - tools_config = OmegaConf.load(tools_config_file) - tool_list = initialize_tools_from_config(tools_config) + tool_list = initialize_tools_from_config(tools_config_file) + logger.info(f"Initialize tools from configuration.: tool_list: {tool_list}") tool_schemas = [tool.get_openai_tool_schema().model_dump() for tool in tool_list] tool_map = {tool.name: tool for tool in tool_list} - tool_call_parser_type = get_tool_call_parser_type(tokenizer) + tool_call_parser_type = get_tool_call_parser_type(processing_class) sgl_tools = [Tool.model_validate(tool_schema) for tool_schema in tool_schemas] function_call_parser = FunctionCallParser( sgl_tools, @@ -383,38 +534,44 @@ def initialize_tools_from_config(tools_config) -> list: function_call_parser, ) - @contextmanager - def update_sampling_params(self, **kwargs): - """ - Temporarily updates the model's sampling parameters for the - duration of a `with` block. Parameters are automatically fall - back to their original values upon exiting the block. + def _initialize_interactions(self, config): + """Initialize interactions from configuration. - Args: - **kwargs: Keyword arguments representing sampling parameters - to be updated. Only parameters that already exist in - `self.sampling_params` will be updated. + Returns: + dict[str, BaseInteraction]: A dictionary mapping interaction names to interaction instances. """ - # Store original values of parameters that will be updated - old_sampling_params_args = {key: self.sampling_params[key] for key in kwargs if key in self.sampling_params} + if config.multi_turn.interaction_config_path is None: + return {} - # Update sampling parameters with new values - for key, value in kwargs.items(): - if key in self.sampling_params: - self.sampling_params[key] = value + interaction_config_file = config.multi_turn.interaction_config_path + interaction_map = initialize_interactions_from_config(interaction_config_file) - try: - yield - # Yield and execute the code within the 'with' block - finally: - # Always restore original values, even if an error - # occurred in the `with` block - for key, value in old_sampling_params_args.items(): - self.sampling_params[key] = value + logger.info(f"Initialize interactions from configuration: interaction_map: {list(interaction_map.keys())}") + return interaction_map @GPUMemoryLogger(role="sglang rollout", logger=logger) @torch.no_grad() def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: + """Generate sequences for a batch of prompts. + + Args: + batch (DataProto): Input batch. + + Returns: + DataProto: Output batch. + - prompts: [bsz, prompt_length], prompt token ids from dataset. + - responses: [bsz, response_length], output token ids include response tokens + from LLM generation and observation tokens from tool_calls. + - response_mask: [bsz, response_length], 1 for LLM generated tokens, 0 for observation/padding tokens. + - input_ids: [bsz, prompt_length + response_length], whole sequence token ids, including prompt tokens + and response tokens. + - attention_mask: [bsz, prompt_length + response_length], 0 for padding tokens, 1 for other tokens. + - position_ids: [bsz, prompt_length + response_length], incremental position ids. + + For multi-turn conversations: + responses: |<- LLM generation ->|<- tool_calls ->|<- LLM generation ->|<- padding ->| + response_mask: | 1, 1, 1, ..., 1, 1 | 0, 0, .., 0, 0 | 1, 1, 1, ..., 1, 1 | 0, 0, ..., 0| + """ if self.config.multi_turn.enable: return self._req_level_generate_sequences(prompts, **kwargs) return self._batch_level_generate_sequences(prompts, **kwargs) @@ -422,11 +579,9 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: @GPUMemoryLogger(role="sglang rollout", logger=logger) @torch.no_grad() def _batch_level_generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: - """Generates sequences for a batch of prompts. + """Generates single-turn sequences for a batch of prompts. For single-turn generation, all prompts are processed in one request. - For multi-turn generation, each prompt is processed separately via - `_generate_req_level_sequences` for better tool calling control. - `_generate_batch_level_sequences` involves: + `_batch_level_generate_sequences` involves: 1. Extracting and pre-processing prompt token IDs from the input `prompts`. This includes handling padding and preparing raw token ID lists. @@ -461,10 +616,8 @@ def _batch_level_generate_sequences(self, prompts: DataProto, **kwargs) -> DataP `input_ids` (concatenated prompt and response), `attention_mask`, and `position_ids` for the full sequences. - Note that when `n > 1`, each prompt generates multiple sequences, - so we need to replicate its non-tensor data (i.e. raw prompts, - messages, reward scores, etc.) n times to match the expanded - tensor data. This is done in the `_non_tensor_batch` dictionary. + Note that in GRPO, if the prompts are validated, we repeat the prompts for rollout.n times in ray_trainer. + Thus we do not need to repeat the prompts here and set the sampling parameter n to 1. """ # input ids: (bs, prompt_length), left-padded idx = prompts.batch["input_ids"] @@ -482,7 +635,7 @@ def _batch_level_generate_sequences(self, prompts: DataProto, **kwargs) -> DataP non_tensor_batch = prompts.non_tensor_batch if "raw_prompt_ids" not in non_tensor_batch: non_tensor_batch["raw_prompt_ids"] = np.array( - [_pre_process_inputs(self.pad_token_id, idx[i]) for i in range(batch_size)], + [_pre_process_inputs(self.pad_token_id, idx[i]).tolist() for i in range(batch_size)], dtype=object, ) @@ -491,23 +644,30 @@ def _batch_level_generate_sequences(self, prompts: DataProto, **kwargs) -> DataP for raw_prompt_ids, multi_modal_data in zip( non_tensor_batch.pop("raw_prompt_ids"), non_tensor_batch.pop("multi_modal_data"), + strict=True, ): sglang_inputs.append( { "prompt_token_ids": raw_prompt_ids, "multi_modal_data": multi_modal_data, - "image_data": (multi_modal_data.get("image", None) if isinstance(multi_modal_data, dict) else None), + "image_data": ( + multi_modal_data.get("image", None) if isinstance(multi_modal_data, dict) else None + ), } ) else: - sglang_inputs = [{"prompt_token_ids": raw_prompt_ids} for raw_prompt_ids in non_tensor_batch.pop("raw_prompt_ids")] + sglang_inputs = [ + {"prompt_token_ids": raw_prompt_ids} for raw_prompt_ids in non_tensor_batch.pop("raw_prompt_ids") + ] # Ensure token IDs are lists or numpy arrays for input_data in sglang_inputs: if isinstance(input_data["prompt_token_ids"], np.ndarray): input_data["prompt_token_ids"] = input_data["prompt_token_ids"].tolist() elif not isinstance(input_data["prompt_token_ids"], list): - raise TypeError(f"prompt_token_ids must be a list or numpy array, got {type(input_data['prompt_token_ids'])}") + raise TypeError( + f"prompt_token_ids must be a list or numpy array, got {type(input_data['prompt_token_ids'])}" + ) # Extract token IDs and image data for SGLang Engine idx_list = [input_data["prompt_token_ids"] for input_data in sglang_inputs] @@ -515,88 +675,93 @@ def _batch_level_generate_sequences(self, prompts: DataProto, **kwargs) -> DataP do_sample = prompts.meta_info.get("do_sample", True) is_validate = prompts.meta_info.get("validate", False) + + # Create request-level sampling parameters + request_sampling_params = self.sampling_params.copy() if not do_sample: - kwargs = dict( - n=1, - presence_penalty=0.0, - frequency_penalty=0.0, - repetition_penalty=1.0, - temperature=0, - top_p=1, - top_k=-1, - ignore_eos=False, - min_new_tokens=0, - max_new_tokens=self.config.response_length, - skip_special_tokens=True, - spaces_between_special_tokens=True, + request_sampling_params.update( + { + "n": 1, + "presence_penalty": 0.0, + "frequency_penalty": 0.0, + "repetition_penalty": 1.0, + "temperature": 0, + "top_p": 1, + "top_k": -1, + "ignore_eos": False, + "min_new_tokens": 0, + "max_new_tokens": self.config.response_length, + "skip_special_tokens": True, + "spaces_between_special_tokens": True, + } ) elif is_validate: - kwargs = dict( - top_k=self.config.val_kwargs.top_k, - top_p=self.config.val_kwargs.top_p, - temperature=self.config.val_kwargs.temperature, - n=1, # if validate, already repeat in ray_trainer + request_sampling_params.update( + { + "top_k": self.config.val_kwargs.top_k, + "top_p": self.config.val_kwargs.top_p, + "temperature": self.config.val_kwargs.temperature, + "n": 1, # if validate, already repeat in ray_trainer + } ) - # users can customize different sampling_params at different run - with self.update_sampling_params(**kwargs): - # print(f"{self.sampling_params=}") - if self._tp_rank == 0: - loop = asyncio.get_event_loop() - output = loop.run_until_complete( - self._engine.async_generate( - prompt=None, # because we have already convert it to prompt token id - sampling_params=self.sampling_params, - return_logprob=True, - input_ids=idx_list, - image_data=image_list, - ) - ) - else: - output = None + # Update with any additional kwargs + request_sampling_params.update(kwargs) - # Most naive implementation, can extract tensor and send via gloo if too slow - dist.barrier() - [output] = broadcast_pyobj( - data=[output], - rank=self._rank, - dist_group=self._device_mesh_cpu["tp"].get_group(), - src=self._device_mesh_cpu["tp"].mesh[0].item(), - force_cpu_device=False, + if self._tp_rank == 0: + loop = asyncio.get_event_loop() + output = loop.run_until_complete( + self._engine.async_generate( + prompt=None, # because we have already convert it to prompt token id + sampling_params=request_sampling_params, + return_logprob=True, + input_ids=idx_list, + image_data=image_list, + ) ) - out = _post_process_outputs(self.tokenizer, output) + else: + output = None + + # Most naive implementation, can extract tensor and send via gloo if too slow + dist.barrier() + [output] = broadcast_pyobj( + data=[output], + rank=self._rank, + dist_group=self._device_mesh_cpu["tp"].get_group(), + src=self._device_mesh_cpu["tp"].mesh[0].item(), + force_cpu_device=False, + ) + out = _post_process_outputs(self.processing_class, output) - response = out[0].to(idx.device) + response = out[0].to(idx.device) + rollout_log_probs = None + if self.config.calculate_log_probs: rollout_log_probs = out[1].to(idx.device) - if response.shape[1] < self.config.response_length: - response = pad_sequence_to_length(response, self.config.response_length, self.pad_token_id) - rollout_log_probs = pad_sequence_to_length(rollout_log_probs, self.config.response_length, self.pad_token_id) - - # utilize current sampling params - if self.sampling_params.get("n", 1) > 1 and do_sample: - idx = idx.repeat_interleave(self.sampling_params["n"], dim=0) - attention_mask = attention_mask.repeat_interleave(self.sampling_params["n"], dim=0) - position_ids = position_ids.repeat_interleave(self.sampling_params["n"], dim=0) - batch_size = batch_size * self.sampling_params["n"] - _non_tensor_batch = {} - for key, val in non_tensor_batch.items(): - _non_tensor_batch[key] = np.repeat(val, self.sampling_params["n"], axis=0) - else: - _non_tensor_batch = non_tensor_batch - seq = torch.cat([idx, response], dim=-1) + if response.shape[1] < self.config.response_length: + response = pad_sequence_to_length(response, self.config.response_length, self.pad_token_id) + if self.config.calculate_log_probs: + rollout_log_probs = pad_sequence_to_length( + rollout_log_probs, self.config.response_length, self.pad_token_id + ) + + seq = torch.cat([idx, response], dim=-1) response_length = response.size(1) delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device) delta_position_id = delta_position_id.unsqueeze(0).repeat(batch_size, 1) + if position_ids.dim() == 3: # qwen2vl mrope + delta_position_id = delta_position_id.view(batch_size, 1, -1).expand(batch_size, 3, -1) # TODO(sgm): fix position_ids on right_pad # prompt: left pad + response: right pad # attention_mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0] # position_ids: [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11] - response_position_ids = position_ids[:, -1:] + delta_position_id + response_position_ids = position_ids[..., -1:] + delta_position_id position_ids = torch.cat([position_ids, response_position_ids], dim=-1) - response_attention_mask = get_response_mask(response_id=response, eos_token=eos_token_id, dtype=attention_mask.dtype) + response_attention_mask = get_response_mask( + response_id=response, eos_token=eos_token_id, dtype=attention_mask.dtype + ) attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1) # all the tp ranks should contain the same data here. data in all ranks are valid @@ -605,18 +770,21 @@ def _batch_level_generate_sequences(self, prompts: DataProto, **kwargs) -> DataP "prompts": idx, "responses": response, "input_ids": seq, # here input_ids become the whole sentences - "rollout_log_probs": rollout_log_probs, # we will recompute old log prob with actor "attention_mask": attention_mask, "position_ids": position_ids, }, batch_size=batch_size, ) + if self.config.calculate_log_probs: + # we will recompute old log prob with actor + batch["rollout_log_probs"] = rollout_log_probs # free cache engine - if self.config.free_cache_engine and self._engine is not None: - self._engine.flush_cache() + if self._engine is not None and self._tp_rank == 0: + loop = asyncio.get_event_loop() + loop.run_until_complete(self._engine.flush_cache()) - return DataProto(batch=batch, non_tensor_batch=_non_tensor_batch) + return DataProto(batch=batch, non_tensor_batch=non_tensor_batch) async def _async_rollout_a_request( self, @@ -631,7 +799,42 @@ async def _async_rollout_a_request( output = None current_turns = 0 - while current_turns < self.config.multi_turn.max_turns: + user_turns = 0 + user_turn_rewards = [] + + # Create request-level sampling parameters + request_sampling_params = self.sampling_params.copy() + if not do_sample: + request_sampling_params.update( + { + "n": 1, + "presence_penalty": 0.0, + "frequency_penalty": 0.0, + "repetition_penalty": 1.0, + "temperature": 0, + "top_p": 1, + "top_k": -1, + "ignore_eos": False, + "min_new_tokens": 0, + "max_new_tokens": self.config.response_length, + "skip_special_tokens": True, + "spaces_between_special_tokens": True, + } + ) + elif is_validate: + request_sampling_params.update( + { + "top_k": self.config.val_kwargs.top_k, + "top_p": self.config.val_kwargs.top_p, + "temperature": self.config.val_kwargs.temperature, + "n": 1, # if validate, already repeat in ray_trainer + } + ) + + # Update with any additional kwargs + request_sampling_params.update(kwargs) + + while current_turns < self.config.multi_turn.max_assistant_turns: if _req.state == AsyncRolloutRequestStateEnum.PENDING: await self._handle_pending_state(_req) _req.state = AsyncRolloutRequestStateEnum.RUNNING @@ -648,16 +851,9 @@ async def _async_rollout_a_request( for tool_call in parsed_tool_calls ] ) - for i, (tool_call, (resp, reward, metrics)) in enumerate(zip(parsed_tool_calls, tool_call_results)): - _req.add_tool_response_message( - self.tokenizer, - resp, - (i == len(parsed_tool_calls) - 1), - format=self.config.multi_turn.format, - ) + _req.add_tool_response_messages(self.processing_class, [resp for resp, _, _ in tool_call_results]) + for tool_call, (resp, reward, metrics) in zip(parsed_tool_calls, tool_call_results, strict=True): _req.update_metrics(metrics, tool_call.function.name) - if len(_req.input_ids) >= self.config.max_model_len: - break if len(_req.input_ids) >= self.config.max_model_len: finish_reason_type = FinishReasonTypeEnum.STOP break @@ -665,17 +861,35 @@ async def _async_rollout_a_request( else: raise ValueError(f"Unexpected tool calling last message state: {_req.messages[-1]}") elif _req.state == AsyncRolloutRequestStateEnum.RUNNING: - output = await self._handle_engine_call(_req, do_sample, is_validate, **kwargs) + # Only continue the conversation if the prompt length is not greater than max_model_len - 1, + # since SGLang raises an error when max_new_tokens + 1 is greater to max_model_len (the extra + # token accounts for the EOS token). + if len(_req.get_generation_prompt_ids(self.processing_class)) + 1 >= self.config.max_model_len: + finish_reason_type = FinishReasonTypeEnum.LENGTH + break + + # Video support is not implemented yet + image_data = ( + _req.multi_modal_data["image"] + if _req.multi_modal_data and "image" in _req.multi_modal_data + else None + ) + video_data = ( + _req.multi_modal_data["video"] + if _req.multi_modal_data and "video" in _req.multi_modal_data + else None + ) + if video_data: + logger.warning( + "video support is not implemented yet, current length of video data is %d", len(video_data) + ) + + output = await self._handle_engine_call(_req, request_sampling_params, image_data=image_data) content = output["text"] finish_reason_type = FinishReasonTypeEnum.from_str(output["meta_info"]["finish_reason"]["type"]) current_turns += 1 if finish_reason_type == FinishReasonTypeEnum.LENGTH: - _req.add_assistant_message( - self.tokenizer, - content, - already_over_long=True, - format=self.config.multi_turn.format, - ) + _req.add_assistant_message(self.processing_class, content) break else: if self._function_call_parser and self._function_call_parser.has_tool_call(content): @@ -708,29 +922,59 @@ async def _async_rollout_a_request( ) if len(parsed_tool_calls) > 0: _req.add_assistant_message( - self.tokenizer, - normed_content, - tool_calls=parsed_tool_calls, - format=self.config.multi_turn.format, + self.processing_class, normed_content, tool_calls=parsed_tool_calls ) else: - _req.add_assistant_message( - self.tokenizer, - content, - format=self.config.multi_turn.format, - ) + _req.add_assistant_message(self.processing_class, content) finish_reason_type = FinishReasonTypeEnum.STOP _req.state = AsyncRolloutRequestStateEnum.COMPLETED break else: _req.add_assistant_message( - self.tokenizer, + self.processing_class, content, - format=self.config.multi_turn.format, ) + if ( + _req.interaction_kwargs + and self.interaction_map + and user_turns < self.config.multi_turn.max_user_turns + and current_turns < self.config.multi_turn.max_assistant_turns + ): + _req.state = AsyncRolloutRequestStateEnum.INTERACTING + else: + break + elif _req.state == AsyncRolloutRequestStateEnum.INTERACTING: + user_turns += 1 + messages = [{"role": x.role, "content": x.content} for x in _req.messages] + + # Get interaction by name from interaction_kwargs + interaction_name = _req.interaction_kwargs.get( + "name", "gsm8k" + ) # Default to gsm8k for backward compatibility + if interaction_name not in self.interaction_map: + raise ValueError( + f"Interaction '{interaction_name}' not found in interaction_map. Available interactions: " + f"{list(self.interaction_map.keys())}" + ) + + interaction = self.interaction_map[interaction_name] + should_terminate_sequence, content, reward, metrics = await interaction.generate_response( + _req.request_id, messages, **_req.interaction_kwargs + ) + user_turn_rewards.append(reward) + if should_terminate_sequence: + finish_reason_type = FinishReasonTypeEnum.STOP + _req.state = AsyncRolloutRequestStateEnum.COMPLETED + break + else: + _req.add_user_message(self.processing_class, content) + if len(_req.input_ids) >= self.config.max_model_len: + finish_reason_type = FinishReasonTypeEnum.STOP break + else: + _req.state = AsyncRolloutRequestStateEnum.RUNNING - if current_turns >= self.config.multi_turn.max_turns: + if current_turns >= self.config.multi_turn.max_assistant_turns: finish_reason_type = FinishReasonTypeEnum.STOP # Calculate the reward for each tool @@ -745,56 +989,52 @@ async def calc_reward_and_release_fn(name: str, tool: BaseTool): tool_reward_tasks.append(calc_reward_and_release_fn(name, tool)) tool_reward_scores = await asyncio.gather(*tool_reward_tasks) tool_reward_scores = dict(tool_reward_scores) - _req.finalize(self.tokenizer, tool_reward_scores, finish_reason_type) + all_rewards = {**tool_reward_scores, **{"user_turn_rewards": user_turn_rewards}} + _req.finalize(self.processing_class, all_rewards, finish_reason_type) return _req - async def _handle_engine_call(self, _req: AsyncRolloutRequest, do_sample: bool, is_validate: bool, **kwargs) -> dict: - generation_prompt_ids = _req.get_generation_prompt(self.tokenizer) + async def _handle_engine_call( + self, _req: AsyncRolloutRequest, sampling_params: dict, image_data: Optional[list[Any]] = None + ) -> dict: + generation_prompt_ids = _req.get_generation_prompt_ids(self.processing_class) + return await self._handle_engine_generate(generation_prompt_ids, sampling_params, image_data) + + async def _handle_engine_generate( + self, generation_prompt_ids: list[int], sampling_params: dict, image_data: Optional[list[Any]] = None + ) -> dict: max_new_tokens = min(self.config.response_length, self.config.max_model_len - len(generation_prompt_ids) - 1) - if not do_sample: - kwargs = dict( - n=1, - presence_penalty=0.0, - frequency_penalty=0.0, - repetition_penalty=1.0, - temperature=0, - top_p=1, - top_k=-1, - ignore_eos=False, - min_new_tokens=0, - max_new_tokens=self.config.response_length, - skip_special_tokens=True, - spaces_between_special_tokens=True, - ) - elif is_validate: - # TODO: try ** - kwargs = { - "top_k": self.config.val_kwargs.top_k, - "top_p": self.config.val_kwargs.top_p, - "temperature": self.config.val_kwargs.temperature, - "n": 1, # if validate, already repeat in ray_trainer - } + kwargs = sampling_params.copy() kwargs["max_new_tokens"] = max_new_tokens - if "n" not in kwargs or kwargs["n"] > 1: # group size is supported in preprocess - kwargs["n"] = 1 - # users can customize different sampling_params at different run - with self.update_sampling_params(**kwargs): - output = await self._engine.async_generate( - input_ids=generation_prompt_ids, - sampling_params=self.sampling_params, - return_logprob=False, - ) + kwargs["n"] = 1 # group size is supported in preprocess + output = await self._engine.async_generate( + input_ids=generation_prompt_ids, + sampling_params=kwargs, + return_logprob=False, + image_data=image_data, + ) return output async def _handle_pending_state(self, _req: AsyncRolloutRequest) -> AsyncRolloutRequest: - if _req.tools is not None: + if _req.tool_schemas is not None: tool_creation_coroutines = [] - for tool_schema in _req.tools: + for tool_schema in _req.tool_schemas: tool = self._tool_map[tool_schema.function.name] create_kwargs = _req.tools_kwargs[tool.name].get("create_kwargs", {}) tool_creation_coroutines.append(tool.create(_req.request_id, **create_kwargs)) await asyncio.gather(*tool_creation_coroutines) + if _req.interaction_kwargs and self.interaction_map: + interaction_kwargs = _req.interaction_kwargs + # Get interaction by name from interaction_kwargs + interaction_name = interaction_kwargs.get("name", "gsm8k") # Default to gsm8k for backward compatibility + if interaction_name not in self.interaction_map: + raise ValueError( + f"Interaction '{interaction_name}' not found in interaction_map. Available interactions: " + f"{list(self.interaction_map.keys())}" + ) + + interaction = self.interaction_map[interaction_name] + await interaction.start_interaction(_req.request_id, **interaction_kwargs) @GPUMemoryLogger(role="sglang rollout", logger=logger) @torch.no_grad() @@ -809,6 +1049,12 @@ def generate_sequences_with_tools(self, prompts: DataProto, **kwargs) -> DataPro @GPUMemoryLogger(role="sglang rollout", logger=logger) @torch.no_grad() def _req_level_generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: + """Generates multi-turn sequences for a batch of prompts. + For multi-turn generation, each prompt is processed separately via + `_req_level_generate_sequences` for better tool calling control. + Note that in multi-turn generation, we repeat the prompts for rollout.n times in ray_trainer. + Thus we do not need to repeat the prompts here and set the sampling parameter n to 1. + """ # Async rollout with tools support do_sample = prompts.meta_info.get("do_sample", True) is_validate = prompts.meta_info.get("validate", False) @@ -816,7 +1062,6 @@ def _req_level_generate_sequences(self, prompts: DataProto, **kwargs) -> DataPro if self._tp_rank == 0: req_list = self._preprocess_prompt_to_async_rollout_requests( prompts, - n=1 if is_validate else self.config.n, ) loop = asyncio.get_event_loop() output_req_list = loop.run_until_complete( @@ -843,37 +1088,46 @@ def _req_level_generate_sequences(self, prompts: DataProto, **kwargs) -> DataPro prompt_loss_mask, response_loss_mask = [], [] messages = [] reward_scores = [] + multi_modal_inputs = [] + for req in sorted_output_req_list: assert req.state == AsyncRolloutRequestStateEnum.COMPLETED, f"Request {req.request_id} is not completed" - assert len(req.input_ids) == len(req.attention_mask) == len(req.position_ids) == len(req.loss_mask), f"""Request {req.request_id} has different length of - {len(req.input_ids)=}, {len(req.attention_mask)=}, {len(req.position_ids)=}, {len(req.loss_mask)=}""" + assert ( + req.input_ids.shape[-1] + == req.attention_mask.shape[-1] + == req.position_ids.shape[-1] + == req.loss_mask.shape[-1] + ), f"""Request {req.request_id} has different length of + {req.input_ids.shape[-1]=}, {req.attention_mask.shape[-1]=}, + {req.position_ids.shape[-1]=}, {req.loss_mask.shape[-1]=}""" error_message_lines = [ - f"""Request {req.request_id} has input_ids length {len(req.input_ids)} + f"""Request {req.request_id} has input_ids length {req.input_ids.shape[-1]} greater than max_model_len {self.config.max_model_len}""", - f"Decoded input_ids: {self.tokenizer.decode(req.input_ids)}", - f"Decoded prompt_ids: {self.tokenizer.decode(req.prompt_ids)}", - f"Decoded response_ids: {self.tokenizer.decode(req.response_ids)}", + f"Decoded input_ids: {self.processing_class.decode(req.input_ids.squeeze(0))}", + f"Decoded prompt_ids: {self.processing_class.decode(req.prompt_ids.squeeze(0))}", + f"Decoded response_ids: {self.processing_class.decode(req.response_ids.squeeze(0))}", f"Messages: {req.messages}", f"Max model length: {req.max_model_len}", ] error_message = "\n".join(error_message_lines) - assert len(req.input_ids) <= self.config.max_model_len, error_message + assert req.input_ids.shape[-1] <= self.config.max_model_len, error_message - prompt_ids.append(torch.tensor(req.prompt_ids, dtype=torch.int, device=tgt_device)) - response_ids.append(torch.tensor(req.response_ids, dtype=torch.int, device=tgt_device)) - if len(req.response_ids) > self.config.response_length: + prompt_ids.append(req.prompt_ids.to(tgt_device).squeeze(0)) + response_ids.append(req.response_ids.to(tgt_device).squeeze(0)) + if req.response_ids.shape[-1] > self.config.response_length: logger.warning( - f"""{req.request_id=} has response_ids length {len(req.response_ids)} + f"""{req.request_id=} has response_ids length {req.response_ids.shape[-1]} greater than max_response_len {self.config.response_length},\n{req=}""" ) - prompt_attention_mask.append(torch.tensor(req.prompt_attention_mask, dtype=torch.int, device=tgt_device)) - response_attention_mask.append(torch.tensor(req.response_attention_mask, dtype=torch.int, device=tgt_device)) - prompt_position_ids.append(torch.tensor(req.prompt_position_ids, dtype=torch.int, device=tgt_device)) - response_position_ids.append(torch.tensor(req.response_position_ids, dtype=torch.int, device=tgt_device)) - prompt_loss_mask.append(torch.tensor(req.prompt_loss_mask, dtype=torch.int, device=tgt_device)) - response_loss_mask.append(torch.tensor(req.response_loss_mask, dtype=torch.int, device=tgt_device)) + prompt_attention_mask.append(req.prompt_attention_mask.to(tgt_device).squeeze(0)) + response_attention_mask.append(req.response_attention_mask.to(tgt_device).squeeze(0)) + prompt_position_ids.append(req.prompt_position_ids.to(tgt_device).squeeze(0)) + response_position_ids.append(req.response_position_ids.to(tgt_device).squeeze(0)) + prompt_loss_mask.append(req.prompt_loss_mask.to(tgt_device).squeeze(0)) + response_loss_mask.append(req.response_loss_mask.to(tgt_device).squeeze(0)) messages.append({"messages": req.messages}) reward_scores.append(req.reward_scores) + multi_modal_inputs.append(req.multi_modal_inputs) prompt_ids = pad_sequence( prompt_ids, @@ -881,10 +1135,10 @@ def _req_level_generate_sequences(self, prompts: DataProto, **kwargs) -> DataPro padding_value=self.pad_token_id, padding_side="left", ) - if prompt_ids.shape[1] < self.config.prompt_length: + if prompt_ids.shape[-1] < self.config.prompt_length: prompt_ids = pad_sequence_to_length(prompt_ids, self.config.prompt_length, self.pad_token_id, left_pad=True) response_ids = pad_sequence(response_ids, batch_first=True, padding_value=self.pad_token_id) - if response_ids.shape[1] < self.config.response_length: + if response_ids.shape[-1] < self.config.response_length: response_ids = pad_sequence_to_length(response_ids, self.config.response_length, self.pad_token_id) prompt_attention_mask = pad_sequence( prompt_attention_mask, @@ -892,18 +1146,46 @@ def _req_level_generate_sequences(self, prompts: DataProto, **kwargs) -> DataPro padding_value=0, padding_side="left", ) - if prompt_attention_mask.shape[1] < self.config.prompt_length: - prompt_attention_mask = pad_sequence_to_length(prompt_attention_mask, self.config.prompt_length, 0, left_pad=True) + if prompt_attention_mask.shape[-1] < self.config.prompt_length: + prompt_attention_mask = pad_sequence_to_length( + prompt_attention_mask, self.config.prompt_length, 0, left_pad=True + ) response_attention_mask = pad_sequence(response_attention_mask, batch_first=True, padding_value=0) - if response_attention_mask.shape[1] < self.config.response_length: + if response_attention_mask.shape[-1] < self.config.response_length: response_attention_mask = pad_sequence_to_length(response_attention_mask, self.config.response_length, 0) - prompt_position_ids = pad_sequence(prompt_position_ids, batch_first=True, padding_value=0, padding_side="left") - if prompt_position_ids.shape[1] < self.config.prompt_length: - prompt_position_ids = pad_sequence_to_length(prompt_position_ids, self.config.prompt_length, 0, left_pad=True) - response_length = response_ids.size(1) - delta_position_id = torch.arange(1, response_length + 1, device=response_ids.device) - delta_position_id = delta_position_id.unsqueeze(0).repeat(len(sorted_output_req_list), 1) - response_position_ids = prompt_position_ids[:, -1:] + delta_position_id + + # padding prompt_position_ids + if prompt_position_ids[0].dim() == 2: + # if prompt_position_ids is a 2D tensor + # e.g. from qwen2vl, prompt_position_ids.shape = (3, seq_len) + transposed_prompt_position_ids = [p.transpose(0, 1) for p in prompt_position_ids] + prompt_position_ids = pad_sequence( + transposed_prompt_position_ids, batch_first=True, padding_value=0, padding_side="left" + ) + prompt_position_ids = prompt_position_ids.transpose(1, 2) + else: + prompt_position_ids = pad_sequence( + prompt_position_ids, batch_first=True, padding_value=0, padding_side="left" + ) + if prompt_position_ids.shape[-1] < self.config.prompt_length: + prompt_position_ids = pad_sequence_to_length( + prompt_position_ids, self.config.prompt_length, 0, left_pad=True + ) + + # padding response_position_ids + if response_position_ids[0].dim() == 2: + # if response_position_ids is a 2D tensor + # e.g. from qwen2vl, response_position_ids.shape = (3, seq_len) + transposed_response_position_ids = [p.transpose(0, 1) for p in response_position_ids] + response_position_ids = pad_sequence( + transposed_response_position_ids, batch_first=True, padding_value=0, padding_side="left" + ) + response_position_ids = response_position_ids.transpose(1, 2) + else: + response_position_ids = pad_sequence(response_position_ids, batch_first=True, padding_value=0) + if response_position_ids.shape[-1] < self.config.response_length: + response_position_ids = pad_sequence_to_length(response_position_ids, self.config.response_length, 0) + prompt_loss_mask = pad_sequence(prompt_loss_mask, batch_first=True, padding_value=0, padding_side="left") if prompt_loss_mask.shape[1] < self.config.prompt_length: prompt_loss_mask = pad_sequence_to_length(prompt_loss_mask, self.config.prompt_length, 0, left_pad=True) @@ -914,189 +1196,196 @@ def _req_level_generate_sequences(self, prompts: DataProto, **kwargs) -> DataPro input_ids = torch.cat((prompt_ids, response_ids), dim=-1) attention_mask = torch.cat((prompt_attention_mask, response_attention_mask), dim=-1) position_ids = torch.cat((prompt_position_ids, response_position_ids), dim=-1) - loss_mask = torch.cat((prompt_loss_mask, response_loss_mask), dim=-1) # Construct the batch data batch = TensorDict( { "prompts": prompt_ids, "responses": response_ids, + "response_mask": response_loss_mask, "input_ids": input_ids, # here input_ids become the whole sentences "attention_mask": attention_mask, "position_ids": position_ids, - "loss_mask": loss_mask, }, batch_size=len(sorted_output_req_list), ) # free cache engine - if self.config.free_cache_engine and self._engine is not None and self._tp_rank == 0: - self._engine.flush_cache() + if self._engine is not None and self._tp_rank == 0: + loop = asyncio.get_event_loop() + loop.run_until_complete(self._engine.flush_cache()) return DataProto( batch=batch, non_tensor_batch={ "messages": np.array(messages), "reward_scores": np.array(reward_scores), + "multi_modal_inputs": np.array(multi_modal_inputs, dtype=object), }, ) - def _preprocess_prompt_to_async_rollout_requests(self, prompts: DataProto, n: int) -> list[AsyncRolloutRequest]: - assert "raw_prompt" in prompts.non_tensor_batch, "need data.return_raw_chat=True, due to no official way do parse_messages" + def _preprocess_prompt_to_async_rollout_requests(self, prompts: DataProto, n: int = 1) -> list[AsyncRolloutRequest]: + assert "raw_prompt" in prompts.non_tensor_batch, ( + "need data.return_raw_chat=True, due to no official way do parse_messages" + ) + logger.info( + "n is deprecated for SGLang rollout since ray ppo trainer will repeat the prompts for rollout.n times" + ) req_list = [] - for data_idx, raw_prompt in enumerate(prompts.non_tensor_batch["raw_prompt"]): - for rollout_offset in range(n): - if self._tool_schemas: - _tools_kwargs = prompts.non_tensor_batch["tools_kwargs"][data_idx] - _tool_schemas = [] - for k in _tools_kwargs.keys(): - _tool_schemas.append(self._tool_map[k].get_openai_tool_schema()) - prompt_with_chat_template = self.tokenizer.apply_chat_template( - conversation=raw_prompt, - tools=[tool.model_dump() for tool in _tool_schemas], - add_generation_prompt=True, - tokenize=False, - return_tensors="pt", - ) - input_data = self.tokenizer( - prompt_with_chat_template, - return_tensors="pt", - add_special_tokens=False, - ) - _input_ids = input_data["input_ids"][0].tolist() - _attention_mask = input_data["attention_mask"][0].tolist() - _position_ids = compute_position_id_with_mask(input_data["attention_mask"][0]).tolist() - if len(_input_ids) > self.config.prompt_length: - logger.warning( - "Prompt {} has length {} greater than max_prompt_len {}", - data_idx, - len(_input_ids), - self.config.prompt_length, - ) - _input_ids = _input_ids[: self.config.prompt_length] - _attention_mask = _attention_mask[: self.config.prompt_length] - _position_ids = _position_ids[: self.config.prompt_length] - else: - _input_ids = _pre_process_inputs(self.pad_token_id, prompts.batch["input_ids"][data_idx]) - _attention_mask = _pre_process_inputs(0, prompts.batch["attention_mask"][data_idx]) - _position_ids = compute_position_id_with_mask(torch.tensor(_attention_mask)).tolist() - _tool_schemas = [] - _tools_kwargs = {} - - req = AsyncRolloutRequest( - batch_data_id=data_idx, - rollout_offset=rollout_offset, - request_id=str(uuid4()), - state=AsyncRolloutRequestStateEnum.PENDING, - messages=[Message.model_validate(msg) for msg in raw_prompt], - tools=_tool_schemas, - tools_kwargs=_tools_kwargs, - input_ids=_input_ids, - prompt_ids=_input_ids, - response_ids=[], - attention_mask=_attention_mask, - prompt_attention_mask=_attention_mask, - response_attention_mask=[], - position_ids=_position_ids, - prompt_position_ids=_position_ids, - response_position_ids=[], - loss_mask=[0] * len(_input_ids), - prompt_loss_mask=[0] * len(_input_ids), - response_loss_mask=[], - reward_scores={}, - max_response_len=self.config.response_length, - max_model_len=min( - self.config.max_model_len, - self.config.prompt_length + self.config.response_length, - ), - ) + multi_modal_data_list = prompts.non_tensor_batch.get( + "multi_modal_data", [None] * len(prompts.non_tensor_batch["raw_prompt"]) + ) - error_message = f"Request {req.request_id} has mismatched lengths: input_ids={len(req.input_ids)}, attention_mask={len(req.attention_mask)}, position_ids={len(req.position_ids)}, loss_mask={len(req.loss_mask)}" - assert len(req.input_ids) == len(req.attention_mask) == len(req.position_ids) == len(req.loss_mask), error_message + for data_idx, (raw_prompt, multi_modal_data) in enumerate( + zip(prompts.non_tensor_batch["raw_prompt"], multi_modal_data_list, strict=True) + ): + if self._tool_schemas: + _tools_kwargs = prompts.non_tensor_batch["tools_kwargs"][data_idx] + _tool_schemas = [self._tool_map[k].get_openai_tool_schema() for k in _tools_kwargs.keys()] + _input_ids = None + _attention_mask = None + else: + _input_ids = _pre_process_inputs(self.pad_token_id, prompts.batch["input_ids"][data_idx]) + _attention_mask = _pre_process_inputs(0, prompts.batch["attention_mask"][data_idx]) + _tools_kwargs = {} + _tool_schemas = None - req_list.append(req) + if self.interaction_map: + _interaction_kwargs = prompts.non_tensor_batch["interaction_kwargs"][data_idx] + else: + _interaction_kwargs = {} + + req = AsyncRolloutRequest( + batch_data_id=data_idx, + rollout_offset=0, + request_id=str(uuid4()), + state=AsyncRolloutRequestStateEnum.PENDING, + messages=raw_prompt.tolist(), + multi_modal_data=multi_modal_data, + tool_schemas=_tool_schemas, + tools_kwargs=_tools_kwargs, + interaction_kwargs=_interaction_kwargs, + input_ids=_input_ids, + response_ids=None, + attention_mask=_attention_mask, + response_attention_mask=None, + response_position_ids=None, + response_loss_mask=None, + reward_scores={}, + max_prompt_len=self.config.prompt_length, + max_response_len=self.config.response_length, + max_model_len=min(self.config.max_model_len, self.config.prompt_length + self.config.response_length), + use_inference_chat_template=self.config.multi_turn.use_inference_chat_template, + tokenization_sanity_check_mode=self.config.multi_turn.tokenization_sanity_check_mode, + processing_class=self.processing_class, + ) + error_message = f"""Request {req.request_id} has mismatched lengths: + input_ids={req.input_ids.shape[-1]}, + attention_mask={req.attention_mask.shape[-1]}, + position_ids={req.position_ids.shape[-1]}, + loss_mask={req.loss_mask.shape[-1]}""" + assert ( + req.input_ids.shape[-1] + == req.attention_mask.shape[-1] + == req.position_ids.shape[-1] + == req.loss_mask.shape[-1] + ), error_message + req_list.append(req) return req_list - def execute_method(self, method: Union[str, bytes], *args, **kwargs): - if method == "chat_completion": - json_request = args[0] - - formatted_messages = [] - for msg in json_request["messages"]: - role = msg.get("role", "user") - content = msg.get("content", "") - formatted_messages.append(f"{role}: {content}") - prompt_str = "\n".join(formatted_messages) - - sampling_params_dict = { - "n": json_request.get("n", 1), - "max_new_tokens": json_request.get("max_completion_tokens", self.config.response_length), - "temperature": json_request.get("temperature", 1.0), - "top_p": json_request.get("top_p", 1.0), - } - output = None - if self._tp_rank == 0: - loop = asyncio.get_event_loop() - output = loop.run_until_complete( - self._engine.async_generate( - prompt=prompt_str, - sampling_params=sampling_params_dict, - return_logprob=True, - ) - ) + async def chat_completion(self, json_request): + assert self._tp_rank == 0, "only called in tp rank 0" + _input_ids = None + _attention_mask = None + _position_ids = None + _tool_schemas = [] + _tools_kwargs = {} + + req = AsyncRolloutRequest( + request_id=str(uuid4()), + state=AsyncRolloutRequestStateEnum.PENDING, + messages=[Message.model_validate(msg) for msg in json_request["messages"]], + tool_schemas=_tool_schemas, + tools_kwargs=_tools_kwargs, + input_ids=_input_ids, + prompt_ids=_input_ids, + response_ids=None, + attention_mask=_attention_mask, + prompt_attention_mask=_attention_mask, + response_attention_mask=None, + position_ids=_position_ids, + prompt_position_ids=_position_ids, + response_position_ids=None, + loss_mask=None, + prompt_loss_mask=None, + response_loss_mask=None, + reward_scores={}, + max_prompt_len=self.config.prompt_length, + max_response_len=self.config.response_length, + max_model_len=min(self.config.max_model_len, self.config.prompt_length + self.config.response_length), + use_inference_chat_template=self.config.multi_turn.use_inference_chat_template, + tokenization_sanity_check_mode=self.config.multi_turn.tokenization_sanity_check_mode, + processing_class=self.processing_class, + ) - dist.barrier() - output = broadcast_pyobj( - data=[output], - rank=self._rank, - dist_group=self._device_mesh_cpu["tp"].get_group(), - src=self._device_mesh_cpu["tp"].mesh[0].item(), - force_cpu_device=False, + # json_request already contains sampling_params + # Filter only valid SamplingParams arguments + valid_sampling_params = {} + temp_sampling_params = SamplingParams() # Create temporary instance to check valid attributes + for k, v in json_request.items(): + if k not in ["messages", "model", "tools"] and hasattr(temp_sampling_params, k): + valid_sampling_params[k] = v + output = await self._handle_engine_call(req, valid_sampling_params) + # it can be Dict or AsyncIterator[Dict] + if isinstance(output, dict): + outputs = [output] + else: + outputs = output + + # build openai chat completion format + choices = [] + id = None + for i, content in enumerate(outputs): + choices.append( + { + "index": i, + "message": { + "role": "assistant", + "content": content["text"], + }, + "finish_reason": content["meta_info"]["finish_reason"]["type"], + } ) + id = content["meta_info"]["id"] - # only return value from master rank - if self._tp_rank != 0: - return None - # build openai chat completion format - choices = [] - id = None - for i, content in enumerate(output): - choices.append( - { - "index": i, - "message": { - "role": "assistant", - "content": content["text"], - }, - "finish_reason": content["meta_info"]["finish_reason"]["type"], - } - ) - id = content["meta_info"]["id"] - - return { - "id": "chatcmpl-" + id, - "object": "chat.completion", - "created": int(time.time()), - "model": json_request.get("model", "sglang_model"), - "choices": choices, - } - else: - raise ValueError(f"not supported method : {method}") + return { + "id": "chatcmpl-" + id, + "object": "chat.completion", + "created": int(time.time()), + "model": json_request.get("model", "sglang_model"), + "choices": choices, + } # this function is left for uniform train-inference resharding - def resume(self): + async def generate( + self, prompt_ids: torch.Tensor, sampling_params: dict[str, Any], request_id: str + ) -> torch.Tensor: + request_sampling_params = self.sampling_params.copy() + request_sampling_params.update(sampling_params) + output = await self._handle_engine_generate(prompt_ids, request_sampling_params) + return output["output_ids"] + + async def wake_up(self): if not self.is_sleep: return - self.sharding_manager.__enter__() # pylint: disable=C2801 - + await self.sharding_manager.wake_up() # pylint: disable=C2801 self.is_sleep = False # this function is left for uniform train-inference resharding - def offload(self): + async def sleep(self): if self.is_sleep: return - - self.sharding_manager.__exit__(None, None, None) + await self.sharding_manager.sleep() self.is_sleep = True diff --git a/verl/workers/rollout/sglang_rollout/utils.py b/verl/workers/rollout/sglang_rollout/utils.py index 438facd9e..f64bf63b8 100644 --- a/verl/workers/rollout/sglang_rollout/utils.py +++ b/verl/workers/rollout/sglang_rollout/utils.py @@ -14,15 +14,17 @@ # limitations under the License. import pickle -from typing import Any, List, Optional +from typing import Any, Iterator, Optional import numpy as np import torch import torch.distributed as dist +from verl.utils.device import get_device_name + def broadcast_pyobj( - data: List[Any], + data: list[Any], rank: int, dist_group: Optional[torch.distributed.ProcessGroup] = None, src: int = 0, @@ -34,9 +36,7 @@ def broadcast_pyobj( The `rank` here refer to the source rank on global process group (regardless of dist_group argument). """ - device = torch.device( - "cuda" if torch.cuda.is_available() and not force_cpu_device else "cpu" - ) + device = torch.device(get_device_name() if not force_cpu_device else "cpu") if rank == src: if len(data) == 0: @@ -46,9 +46,7 @@ def broadcast_pyobj( serialized_data = pickle.dumps(data) size = len(serialized_data) - tensor_data = torch.ByteTensor( - np.frombuffer(serialized_data, dtype=np.uint8) - ).to(device) + tensor_data = torch.ByteTensor(np.frombuffer(serialized_data, dtype=np.uint8)).to(device) tensor_size = torch.tensor([size], dtype=torch.long, device=device) dist.broadcast(tensor_size, src=src, group=dist_group) @@ -68,3 +66,43 @@ def broadcast_pyobj( serialized_data = bytes(tensor_data.cpu().numpy()) data = pickle.loads(serialized_data) return data + + +def get_named_tensor_buckets( + iterable: Iterator[tuple[str, torch.Tensor]], bucket_bytes: int +) -> Iterator[list[tuple[str, torch.Tensor]]]: + """ + Group tensors into buckets based on a specified size in megabytes. + + Args: + iterable: An iterator of tuples containing tensor names and tensors. + bucket_bytes: The maximum size of each bucket in bytes. + + Yields: + Lists of tuples, where each tuple contains a tensor name and its corresponding tensor. + + Example: + >>> tensors = [('tensor1', torch.randn(1000, 1000)), ('tensor2', torch.randn(2000, 2000))] + >>> for bucket in get_named_tensor_buckets(tensors, bucket_size_mb=10): + ... print(bucket) + [('tensor1', tensor(...)), ('tensor2', tensor(...))] + + """ + if bucket_bytes <= 0: + raise ValueError(f"bucket_bytes must be greater than 0, got {bucket_bytes}") + + current_bucket = [] + current_size = 0 + for name, tensor in iterable: + tensor_size = tensor.element_size() * tensor.numel() + if current_size + tensor_size > bucket_bytes: + if current_bucket: + yield current_bucket + current_bucket = [(name, tensor)] + current_size = tensor_size + else: + current_bucket.append((name, tensor)) + current_size += tensor_size + + if current_bucket: + yield current_bucket diff --git a/verl/workers/rollout/tokenizer.py b/verl/workers/rollout/tokenizer.py index 435fa5bfe..1e1212e50 100644 --- a/verl/workers/rollout/tokenizer.py +++ b/verl/workers/rollout/tokenizer.py @@ -16,7 +16,6 @@ """ from abc import ABC, abstractmethod -from typing import Dict, List, Union import numpy as np import torch @@ -54,7 +53,7 @@ def eos_token_id(self): @property @abstractmethod - def all_special_ids(self) -> List[int]: + def all_special_ids(self) -> list[int]: """ `List[int]`: List the ids of the special tokens(`''`, `''`, etc.) mapped to class attributes. """ @@ -62,7 +61,7 @@ def all_special_ids(self) -> List[int]: @property @abstractmethod - def all_special_tokens(self) -> List[str]: + def all_special_tokens(self) -> list[str]: """ `List[str]`: A list of the unique special tokens (`''`, `''`, ..., etc.). @@ -89,7 +88,7 @@ def encode(self, text): @abstractmethod def decode( self, - token_ids: Union[int, List[int], np.ndarray, torch.Tensor], + token_ids: int | list[int] | np.ndarray | torch.Tensor, skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = None, **kwargs, @@ -117,7 +116,7 @@ def decode( pass @abstractmethod - def convert_ids_to_tokens(self, ids: Union[int, List[int]], skip_special_tokens: bool = False) -> Union[str, List[str]]: + def convert_ids_to_tokens(self, ids: int | list[int], skip_special_tokens: bool = False) -> str | list[str]: """ Converts a single index or a sequence of indices in a token or a sequence of tokens, using the vocabulary and added tokens. @@ -134,7 +133,7 @@ def convert_ids_to_tokens(self, ids: Union[int, List[int]], skip_special_tokens: pass @abstractmethod - def get_added_vocab(self) -> Dict[str, int]: + def get_added_vocab(self) -> dict[str, int]: """ Returns the added tokens in the vocabulary as a dictionary of token to index. Results might be different from the fast call because for now we always add the tokens even if they are already in the vocabulary. This is @@ -146,7 +145,7 @@ def get_added_vocab(self) -> Dict[str, int]: pass @abstractmethod - def convert_tokens_to_string(self, tokens: List[str]) -> str: + def convert_tokens_to_string(self, tokens: list[str]) -> str: """ Converts a sequence of tokens in a single string. The most simple way to do it is `" ".join(tokens)` but we often want to remove sub-word tokenization artifacts at the same time. diff --git a/verl/workers/rollout/vllm_rollout/__init__.py b/verl/workers/rollout/vllm_rollout/__init__.py index 450da77d5..767858fe3 100644 --- a/verl/workers/rollout/vllm_rollout/__init__.py +++ b/verl/workers/rollout/vllm_rollout/__init__.py @@ -14,7 +14,7 @@ import os from importlib.metadata import PackageNotFoundError, version -from packaging.version import Version +from .vllm_rollout_spmd import vLLMAsyncRollout, vLLMRollout # noqa: F401 def get_version(pkg): @@ -27,12 +27,11 @@ def get_version(pkg): vllm_package_name = "vllm" vllm_package_version = get_version(vllm_package_name) if vllm_package_version is None: - raise PackageNotFoundError("To use vllm rollout, please ensure the 'vllm' package is properly installed. See https://verl.readthedocs.io/en/latest/start/install.html for more details") + raise PackageNotFoundError( + "To use vllm rollout, please ensure the 'vllm' package is properly installed. See " + "https://verl.readthedocs.io/en/latest/start/install.html for more details" + ) -### -# package_version = get_version(package_name) -# [SUPPORT AMD:] -# Do not call any torch.cuda* API here, or ray actor creation import class will fail. if "ROCM_PATH" in os.environ: import re @@ -41,12 +40,3 @@ def get_version(pkg): vllm_package_version = match.group(1) else: raise ValueError(f"Warning: Could not parse version format: {vllm_package_version}") -### - -if Version(vllm_package_version) <= Version("0.6.3"): - vllm_mode = "customized" - from .fire_vllm_rollout import FIREvLLMRollout # noqa: F401 - from .vllm_rollout import vLLMRollout # noqa: F401 -else: - vllm_mode = "spmd" - from .vllm_rollout_spmd import vLLMAsyncRollout, vLLMRollout # noqa: F401 diff --git a/verl/workers/rollout/vllm_rollout/fire_vllm_rollout.py b/verl/workers/rollout/vllm_rollout/fire_vllm_rollout.py deleted file mode 100644 index 5fced091b..000000000 --- a/verl/workers/rollout/vllm_rollout/fire_vllm_rollout.py +++ /dev/null @@ -1,215 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -The vllm_rollout that can be applied in different backend -When working with FSDP: -- Use DTensor weight loader (recommended) or HF weight loader -- Utilize state_dict from the FSDP to synchronize the weights among tp ranks in vLLM -When working with Megatron: -- Use Megatron weight loader -- During training, only the current pp stage holds the parameters -- Before inference, broadcast the parameters of the current pp rank to all other pp ranks (all pp ranks holds all the parameters) -- Bind the parameters to the inference engine -- Do inference in tp. pp is treated as additional dp -- After inference, all the parameters that doesn't belong to this pp rank is freed. -""" - -from contextlib import contextmanager -from typing import List - -import torch -import torch.distributed -from omegaconf import DictConfig -from tensordict import TensorDict -from torch import nn -from vllm import SamplingParams - -from verl import DataProto -from verl.utils.torch_functional import get_response_mask, pad_sequence_to_length -from verl.workers.rollout.vllm_rollout.vllm_rollout import vLLMRollout - -# TODO -# 1. support pp in vllm -# 2. passing tokenizer is not necessary? no encoding/decoding is happening here -# 3. simplify init logics - - -# NOTE(sgm): add for verl. We can optimize it by making the dataloader yield List[int] without padding. -def _pre_process_inputs(pad_token_id, prompt_token_ids: torch.Tensor) -> List[int]: - # remove the left padding in the prompt token_id - # pad_token_id = self.llm_engine.tokenizer.pad_token_id if self.llm_engine.tokenizer.pad_token_id is not None else self.llm_engine.tokenizer.eos_token_id - non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0] - token_ids = prompt_token_ids[non_pad_index:].tolist() - return token_ids - - -class FIREvLLMRollout(vLLMRollout): - def __init__(self, actor_module: nn.Module, config: DictConfig, tokenizer, model_hf_config, **kwargs): - """A vLLM rollout. It requires the module is supported by the vllm. - - Args: - module: module here follows huggingface APIs - config: DictConfig - tokenizer: the task/model tokenizer - model_hf_config: the huggingface config to initialize the generating model in vllm - **kwargs: train_tp, for Megatron Backend to initialize hybrid engine (zero redundancy) process group - """ - super().__init__(actor_module, config, tokenizer, model_hf_config, **kwargs) - - self.use_fire_sampling = config.get("use_fire_sampling", False) - if self.use_fire_sampling: - kwargs_0 = kwargs.copy() - kwargs_0["temperature"] = 30 - kwargs_0["max_tokens"] = 1 - if "top_k" not in kwargs_0 or kwargs_0["top_k"] <= 0: - kwargs_0["top_k"] = 16 - self.sampling_params.max_tokens = config.response_length - 1 - for k in config.keys(): - if hasattr(SamplingParams(), str(k)): - kwargs_0[k] = config.get(k) - self.sampling_params_0 = SamplingParams(**kwargs_0) - - @contextmanager - def update_sampling_params(self, **kwargs): - # update sampling params - old_sampling_params_args = {} - if kwargs: - for key, value in kwargs.items(): - if hasattr(self.sampling_params, key): - old_value = getattr(self.sampling_params, key) - old_sampling_params_args[key] = old_value - setattr(self.sampling_params, key, value) - if self.use_fire_sampling: - old_sampling_params_args_0 = {} - if kwargs: - for key, value in kwargs.items(): - if hasattr(self.sampling_params_0, key): - old_value = getattr(self.sampling_params_0, key) - old_sampling_params_args_0[key] = old_value - setattr(self.sampling_params_0, key, value) - yield - # roll back to previous sampling params - # if len(old_sampling_params_args): - for key, value in old_sampling_params_args.items(): - setattr(self.sampling_params, key, value) - if self.use_fire_sampling: - for key, value in old_sampling_params_args_0.items(): - setattr(self.sampling_params_0, key, value) - - @torch.no_grad() - def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: - # rebuild vllm cache engine - if self.config.free_cache_engine: - self.inference_engine.init_cache_engine() - - idx = prompts.batch["input_ids"] # (bs, prompt_length) - # left-padded attention_mask - attention_mask = prompts.batch["attention_mask"] - position_ids = prompts.batch["position_ids"] - - # used to construct attention_mask - eos_token_id = prompts.meta_info["eos_token_id"] - - batch_size = idx.size(0) - - idx_list = [] - # parse idx from torch.Tensor to List[List[str]] - for i in range(batch_size): - idx_list.append(_pre_process_inputs(self.pad_token_id, idx[i])) - - do_sample = prompts.meta_info.get("do_sample", True) - if not do_sample: - kwargs = { - "best_of": 1, - "top_p": 1.0, - "top_k": -1, - "min_p": 0.0, - "temperature": 0, - "n": 1, # if greedy, only 1 response - } - - if not self.use_fire_sampling: - # users can customize different sampling_params at different run - with self.update_sampling_params(**kwargs): - output = self.inference_engine.generate( - prompts=None, # because we have already convert it to prompt token id - sampling_params=self.sampling_params, - prompt_token_ids=idx_list, - use_tqdm=False, - ) - - response = output[0].to(idx.device) # (bs, response_length) - else: - with self.update_sampling_params(**kwargs): - output_0 = self.inference_engine.generate( - prompts=None, # because we have already convert it to prompt token id - sampling_params=self.sampling_params_0, - prompt_token_ids=idx_list, - use_tqdm=False, - ) - new_idx_list = [] - for i in range(batch_size): - new_idx_list.append(idx_list[i] + output_0[0][i].tolist()) - output = self.inference_engine.generate( - prompts=None, # because we have already convert it to prompt token id - sampling_params=self.sampling_params, - prompt_token_ids=new_idx_list, - use_tqdm=False, - ) - - response = torch.cat([output_0[0], output[0]], dim=1).to(idx.device) # (bs, response_length) - # log_probs = torch.cat([output_0[1], output[1]], dim=1).to(idx.device) # (bs, response_length) - - if response.shape[1] < self.config.response_length: - response = pad_sequence_to_length(response, self.config.response_length, self.pad_token_id) - # log_probs = pad_sequence_to_length(log_probs, self.config.response_length, self.pad_token_id) - - if self.config.n > 1 and do_sample: - idx = idx.repeat_interleave(self.config.n, dim=0) - attention_mask = attention_mask.repeat_interleave(self.config.n, dim=0) - position_ids = position_ids.repeat_interleave(self.config.n, dim=0) - batch_size = batch_size * self.config.n - seq = torch.cat([idx, response], dim=-1) - - response_length = response.size(1) - delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device) - delta_position_id = delta_position_id.unsqueeze(0).repeat(batch_size, 1) - - # TODO(sgm): fix position_ids on right_pad - # prompt: left pad + response: right pad - # attention_mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0] - # position_ids: [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11] - response_position_ids = position_ids[:, -1:] + delta_position_id - position_ids = torch.cat([position_ids, response_position_ids], dim=-1) - response_attention_mask = get_response_mask(response_id=response, eos_token=eos_token_id, dtype=attention_mask.dtype) - attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1) - - # all the tp ranks should contain the same data here. data in all ranks are valid - batch = TensorDict( - { - "prompts": idx, - "responses": response, - "input_ids": seq, # here input_ids become the whole sentences - # 'old_log_probs': log_probs, # we will recompute old log prob with actor - "attention_mask": attention_mask, - "position_ids": position_ids, - }, - batch_size=batch_size, - ) - - # free vllm cache engine - if self.config.free_cache_engine: - self.inference_engine.free_cache_engine() - - return DataProto(batch=batch) diff --git a/verl/workers/rollout/vllm_rollout/vllm_async_server.py b/verl/workers/rollout/vllm_rollout/vllm_async_server.py index 4f8109e3c..988dac407 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_async_server.py +++ b/verl/workers/rollout/vllm_rollout/vllm_async_server.py @@ -12,11 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from collections.abc import AsyncGenerator -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import os +import pickle +from typing import Any, Callable, Optional -import cloudpickle import ray +import zmq from omegaconf import DictConfig from starlette.requests import Request from starlette.responses import JSONResponse, StreamingResponse @@ -26,6 +27,8 @@ from vllm.entrypoints.openai.protocol import ChatCompletionRequest, ChatCompletionResponse, ErrorResponse from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels +from vllm.inputs import TokensPrompt +from vllm.outputs import RequestOutput from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.executor.abstract import Executor from vllm.worker.worker_base import WorkerWrapperBase @@ -36,37 +39,102 @@ logger = logging.getLogger(__file__) +def _get_model_runner_workers(vllm_config, init_ray: bool = True): + assert vllm_config.instance_id is not None, "instance_id must be set for external ray actors." + + fields = vllm_config.instance_id.split(":") + assert len(fields) == 4, ( + f"instance_id: {vllm_config.instance_id} must be in the format of " + f":::." + ) + namespace, wg_prefix, vllm_dp_size, vllm_dp_rank = fields[0], fields[1], int(fields[2]), int(fields[3]) + + # Make sure subprocess in same namespace as parent actor. + # actor name format: {name_prefix}WorkerDict_{pg_idx}:{local_rank} + if init_ray: + ray.init(namespace=namespace) + actor_names = [ + actor_name for actor_name in ray.util.list_named_actors() if actor_name.startswith(f"{wg_prefix}WorkerDict") + ] + + vllm_tp_size = vllm_config.parallel_config.tensor_parallel_size + assert len(actor_names) == vllm_dp_size * vllm_tp_size, ( + f"instance_id: {vllm_config.instance_id} has {len(actor_names)} actors, but vllm_dp_size: " + f"{vllm_dp_size} * vllm_tp_size: {vllm_tp_size} = {vllm_dp_size * vllm_tp_size} is expected." + ) + + def get_pg_index_and_local_rank(actor_name) -> tuple[int, int]: + fields = actor_name.split(":") + assert len(fields) == 2, f"invalid actor name: {actor_name}" + pg_index, local_rank = int(fields[0].split("_")[-1]), int(fields[1]) + return pg_index, local_rank + + # sort actor names by pg_index and local_rank + actor_names = sorted(actor_names, key=get_pg_index_and_local_rank) + actor_names = actor_names[vllm_dp_rank * vllm_tp_size : (vllm_dp_rank + 1) * vllm_tp_size] + workers: list[WorkerWrapperBase] = [ray.get_actor(actor_name) for actor_name in actor_names] + print(f"instance_id: {vllm_config.instance_id} initializes with external actors: {actor_names}") + + return workers + + class ExternalRayDistributedExecutor(Executor): """An executor that engines are launched by external ray actors.""" uses_ray: bool = False def _init_executor(self) -> None: - assert self.vllm_config.instance_id is not None, "instance_id must be set for external ray actors." + self.workers = _get_model_runner_workers(vllm_config=self.vllm_config, init_ray=True) + + kwargs = dict( + vllm_config=self.vllm_config, + local_rank=None, + rank=None, + distributed_init_method="env://", + is_driver_worker=True, + ) + self.collective_rpc("init_worker", args=([kwargs],)) + self.collective_rpc("init_device") + self.collective_rpc("load_model") + print(f"instance_id: {self.vllm_config.instance_id} initializes finished.") - fields = self.vllm_config.instance_id.split(":") - assert len(fields) == 4, f"instance_id: {self.vllm_config.instance_id} must be in the format of :::." - namespace, wg_prefix, vllm_dp_size, vllm_dp_rank = fields[0], fields[1], int(fields[2]), int(fields[3]) + def collective_rpc( + self, + method: str | Callable, + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict[str, Any]] = None, + ) -> list[Any]: + # TODO(wuxibin): support ray compiled graph + if isinstance(method, str): + sent_method = method + else: + sent_method = pickle.dumps(method) + del method - # Make sure subprocess in same namespace as parent actor. - # actor name format: {name_prefix}WorkerDict_{pg_idx}:{local_rank} - ray.init(namespace=namespace) - actor_names = [actor_name for actor_name in ray.util.list_named_actors() if actor_name.startswith(f"{wg_prefix}WorkerDict")] + # ~3ms overhead per schedule step due to SchedulerOutput/ModelRunnerOutput serialization/deserialization. + outputs = ray.get( + [worker.execute_method.remote(sent_method, *args, **(kwargs or {})) for worker in self.workers] + ) + return outputs + + def check_health(self): + return - vllm_tp_size = self.vllm_config.parallel_config.tensor_parallel_size - assert len(actor_names) == vllm_dp_size * vllm_tp_size, f"instance_id: {self.vllm_config.instance_id} has {len(actor_names)} actors, but vllm_dp_size: {vllm_dp_size} * vllm_tp_size: {vllm_tp_size} = {vllm_dp_size * vllm_tp_size} is expected." - def get_pg_index_and_local_rank(actor_name) -> Tuple[int, int]: - fields = actor_name.split(":") - assert len(fields) == 2, f"invalid actor name: {actor_name}" - pg_index, local_rank = int(fields[0].split("_")[-1]), int(fields[1]) - return pg_index, local_rank +class ExternalZeroMQDistributedExecutor(Executor): + """An executor that engines are launched by external ray actors.""" - # sort actor names by pg_index and local_rank - actor_names = sorted(actor_names, key=get_pg_index_and_local_rank) - actor_names = actor_names[vllm_dp_rank * vllm_tp_size : (vllm_dp_rank + 1) * vllm_tp_size] - self.workers: List[WorkerWrapperBase] = [ray.get_actor(actor_name) for actor_name in actor_names] - print(f"instance_id: {self.vllm_config.instance_id} intializes with external actors: {actor_names}") + uses_ray: bool = False + + def _init_executor(self) -> None: + addresses = os.environ["VERL_VLLM_ZMQ_ADDRESSES"].split(",") + self.context = zmq.Context() + self.sockets = [] + for address in addresses: + socket = self.context.socket(zmq.REQ) + socket.connect(address) + self.sockets.append(socket) kwargs = dict( vllm_config=self.vllm_config, @@ -78,24 +146,27 @@ def get_pg_index_and_local_rank(actor_name) -> Tuple[int, int]: self.collective_rpc("init_worker", args=([kwargs],)) self.collective_rpc("init_device") self.collective_rpc("load_model") - print(f"instance_id: {self.vllm_config.instance_id} intializes finished.") def collective_rpc( self, - method: Union[str, Callable], + method: str | Callable, timeout: Optional[float] = None, - args: Tuple = (), - kwargs: Optional[Dict[str, Any]] = None, - ) -> List[Any]: - # TODO(wuxibin): support ray compiled graph + args: tuple = (), + kwargs: Optional[dict[str, Any]] = None, + ) -> list[Any]: if isinstance(method, str): sent_method = method else: - sent_method = cloudpickle.dumps(method) + sent_method = pickle.dumps(method) del method - # ~3ms overhead per schedule step due to SchedulerOutput/ModelRunnerOutput serialization/deserialization. - outputs = ray.get([worker.execute_method.remote(sent_method, *args, **(kwargs or {})) for worker in self.workers]) + message = pickle.dumps((sent_method, args, kwargs or {})) + for socket in self.sockets: + socket.send(message, zmq.DONTWAIT) + + outputs = [] + for socket in self.sockets: + outputs.append(pickle.loads(socket.recv())) return outputs def check_health(self): @@ -122,14 +193,14 @@ class AsyncvLLMServer(AsyncServerBase): def __init__(self, config: DictConfig, vllm_dp_size: int, vllm_dp_rank: int, wg_prefix: str): """ Args: - config: DictConfig, actor_rollout_ref config. + config: DictConfig. vllm_dp_size: int, vllm data parallel size. vllm_dp_rank: int, vllm data parallel rank. wg_prefix: str, worker group prefix, used to lookup actors. """ super().__init__() - self.config = config + self.config = config.actor_rollout_ref self.vllm_dp_size = vllm_dp_size self.vllm_dp_rank = vllm_dp_rank self.wg_prefix = wg_prefix @@ -147,46 +218,52 @@ async def init_engine(self): tensor_parallel_size = config.get("tensor_model_parallel_size", 1) max_num_batched_tokens = config.get("max_num_batched_tokens", 8192) max_model_len = config.max_model_len if config.max_model_len else config.prompt_length + config.response_length - max_model_len = int(max_model_len) + self.max_model_len = int(max_model_len) # Override default generation config from hugging face model config, # user can still override them by passing kwargs in each request. kwargs = dict( n=1, logprobs=0, - max_tokens=config.response_length, + repetition_penalty=1.0, + max_new_tokens=config.response_length, ) for k in config.keys(): if hasattr(SamplingParams(), str(k)): kwargs[k] = config.get(k) print(f"override_generation_config: {kwargs}") + backend = os.environ.get("VERL_VLLM_DISTRIBUTED_BACKEND", "zeromq") + if backend == "zeromq": + distributed_executor_backend = ExternalZeroMQDistributedExecutor + elif backend == "ray": + distributed_executor_backend = ExternalRayDistributedExecutor + else: + distributed_executor_backend = None + engine_args = AsyncEngineArgs( model=local_path, - enable_sleep_mode=True, + enable_sleep_mode=config.free_cache_engine, override_generation_config=kwargs, tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=ExternalRayDistributedExecutor, + distributed_executor_backend=distributed_executor_backend, dtype=config.dtype, enforce_eager=config.enforce_eager, gpu_memory_utilization=config.gpu_memory_utilization, disable_custom_all_reduce=True, - disable_mm_preprocessor_cache=True, skip_tokenizer_init=False, - max_model_len=max_model_len, + max_model_len=self.max_model_len, load_format="auto", disable_log_stats=config.disable_log_stats, max_num_batched_tokens=max_num_batched_tokens, enable_chunked_prefill=config.enable_chunked_prefill, enable_prefix_caching=True, trust_remote_code=trust_remote_code, - seed=self.vllm_dp_rank, + seed=config.get("seed", 0), ) # init async llm engine - vllm_config = engine_args.create_engine_config() - namespace = ray.get_runtime_context().namespace - vllm_config.instance_id = f"{namespace}:{self.wg_prefix}:{self.vllm_dp_size}:{self.vllm_dp_rank}" + vllm_config = self._create_engine_config(engine_args) self.engine = AsyncLLM.from_vllm_config(vllm_config) # build serving chat @@ -201,8 +278,24 @@ async def init_engine(self): request_logger=RequestLogger(max_log_len=4096), chat_template=None, chat_template_content_format="auto", + enable_auto_tools=config.multi_turn.tool_config_path is not None, + tool_parser=config.multi_turn.format, # hermes, llama3_json, ... ) + def _create_engine_config(self, engine_args: AsyncEngineArgs): + vllm_config = engine_args.create_engine_config() + namespace = ray.get_runtime_context().namespace + vllm_config.instance_id = f"{namespace}:{self.wg_prefix}:{self.vllm_dp_size}:{self.vllm_dp_rank}" + + # VERL_VLLM_ZMQ_ADDRESSES + if engine_args.distributed_executor_backend == ExternalZeroMQDistributedExecutor: + workers = _get_model_runner_workers(vllm_config=vllm_config, init_ray=False) + zmq_addresses = ray.get([worker.get_zeromq_address.remote() for worker in workers]) + print(f"VERL_VLLM_ZMQ_ADDRESSES: {zmq_addresses}") + os.environ["VERL_VLLM_ZMQ_ADDRESSES"] = ",".join(zmq_addresses) + + return vllm_config + async def chat_completion(self, raw_request: Request): """OpenAI-compatible HTTP endpoint. @@ -220,32 +313,26 @@ async def chat_completion(self, raw_request: Request): assert isinstance(generator, ChatCompletionResponse) return JSONResponse(content=generator.model_dump()) - async def chat_completion_generator(self, request: ChatCompletionRequest) -> AsyncGenerator[Tuple[int, str]]: - """Direct chat completion without FastAPI. + async def generate(self, prompt_ids: list[int], sampling_params: dict[str, Any], request_id: str) -> list[int]: + max_tokens = self.max_model_len - len(prompt_ids) + sampling_params = SamplingParams(max_tokens=max_tokens, **sampling_params) + prompt = TokensPrompt(prompt_token_ids=prompt_ids) + generator = self.engine.generate(prompt=prompt, sampling_params=sampling_params, request_id=request_id) - Args: - request: ChatCompletionRequest, request object. - - Returns: - AsyncGenerator[Tuple[int, str]]: async generator of (status_code, data) pairs. - """ - generator = await self.openai_serving_chat.create_chat_completion(request) - if isinstance(generator, ErrorResponse): - data = generator.model_dump_json(exclude_unset=True) - yield generator.code, f"data: {data}\n\n" + # Get final response + final_res: Optional[RequestOutput] = None + async for output in generator: + final_res = output + assert final_res is not None - if request.stream: - async for chunk in generator: - yield 200, chunk - else: - assert isinstance(generator, ChatCompletionResponse) - data = generator.model_dump_json(exclude_unset=True) - yield 200, f"data: {data}\n\n" + return final_res.outputs[0].token_ids async def wake_up(self): - await self.engine.wake_up() + if self.config.rollout.free_cache_engine: + await self.engine.wake_up() async def sleep(self): # TODO: https://github.com/vllm-project/vllm/issues/17103 await self.engine.reset_prefix_cache() - await self.engine.sleep() + if self.config.rollout.free_cache_engine: + await self.engine.sleep() diff --git a/verl/workers/rollout/vllm_rollout/vllm_rollout.py b/verl/workers/rollout/vllm_rollout/vllm_rollout.py deleted file mode 100644 index a47d5f632..000000000 --- a/verl/workers/rollout/vllm_rollout/vllm_rollout.py +++ /dev/null @@ -1,325 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -The vllm_rollout that can be applied in different backend -When working with FSDP: -- Use DTensor weight loader (recommended) or HF weight loader -- Utilize state_dict from the FSDP to synchronize the weights among tp ranks in vLLM -When working with Megatron: -- Use Megatron weight loader -- During training, only the current pp stage holds the parameters -- Before inference, broadcast the parameters of the current pp rank to all other pp ranks (all pp ranks holds all the parameters) -- Bind the parameters to the inference engine -- Do inference in tp. pp is treated as additional dp -- After inference, all the parameters that doesn't belong to this pp rank is freed. -""" - -import logging -import os -from contextlib import contextmanager -from copy import deepcopy -from typing import List - -import torch -import torch.distributed -from omegaconf import DictConfig, OmegaConf -from tensordict import TensorDict -from torch import nn -from vllm import SamplingParams - -from verl import DataProto -from verl.third_party.vllm import LLM, vllm_version -from verl.third_party.vllm import parallel_state as vllm_ps -from verl.utils.debug import GPUMemoryLogger -from verl.utils.torch_functional import get_response_mask, pad_sequence_to_length -from verl.workers.rollout.base import BaseRollout - -logger = logging.getLogger(__file__) -logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) -from vllm.lora.request import LoRARequest - -# TODO -# 1. support pp in vllm -# 2. passing tokenizer is not necessary? no encoding/decoding is happending here -# 3. simplify init logics - - -# NOTE(sgm): add for verl. We can optimize it by making the dataloader yield List[int] without padding. -def _pre_process_inputs(pad_token_id, prompt_token_ids: torch.Tensor) -> List[int]: - # remove the left padding in the prompt token_id - # pad_token_id = self.llm_engine.tokenizer.pad_token_id if self.llm_engine.tokenizer.pad_token_id is not None else self.llm_engine.tokenizer.eos_token_id - non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0] - token_ids = prompt_token_ids[non_pad_index:].tolist() - return token_ids - - -class vLLMRollout(BaseRollout): - def __init__(self, actor_module: nn.Module, config: DictConfig, tokenizer, model_hf_config, **kwargs): - """A vLLM rollout. It requires the module is supported by the vllm. - - Args: - module: module here follows huggingface APIs - config: DictConfig - tokenizer: the task/model tokenizer - model_hf_config: the huggingface config to initiallize the generating model in vllm - **kwargs: train_tp, for Megatron Backend to initialize hybrid engine (zero redundancy) process group - """ - super().__init__() - self.config = config - assert not (not config.enforce_eager and config.free_cache_engine), "disable CUDA graph (enforce_eager = False) if free cache engine" - - tensor_parallel_size = self.config.get("tensor_model_parallel_size", 1) - assert tensor_parallel_size <= torch.distributed.get_world_size(), "tensor parallel size should be less than or equal to the world size" - max_num_batched_tokens = int(self.config.get("max_num_batched_tokens", 8192)) - - if kwargs.get("train_tp") is not None: - # deployed with megatron - import os - - os.environ["CUDA_TIMER_STREAM_KAFKA_ENABLE"] = "0" - os.environ["MEGATRON_IMPORT_TIMERS"] = "0" - train_tp = kwargs.get("train_tp") - num_tp_per_train_tp = train_tp // tensor_parallel_size - if vllm_version in ( - "0.5.4", - "0.6.3", - ): - vllm_ps.initialize_parallel_state(tensor_model_parallel_size=tensor_parallel_size, num_tp_per_train_tp=num_tp_per_train_tp) - - rope_scaling_config = getattr(model_hf_config, "rope_scaling", None) - if not rope_scaling_config: - assert model_hf_config.max_position_embeddings >= config.prompt_length + config.response_length, "model context length should be greater than total sequence length" - - max_model_len = self.config.max_model_len if self.config.max_model_len else config.prompt_length + config.response_length - max_model_len = int(max_model_len) - - if max_num_batched_tokens < max_model_len and self.config.enable_chunked_prefill: - raise ValueError( - "Enable chunked prefill, max_num_batched_tokens is smaller than max_model_len, \ - please increase max_num_batched_tokens or disable chunked prefill" - ) - - # copy it to avoid secretly modifying the engine config - engine_kwargs = {} if "engine_kwargs" not in config or "vllm" not in config.engine_kwargs else OmegaConf.to_container(deepcopy(config.engine_kwargs.vllm)) - # For each vLLM engine parameter, - # - `None` means not setting it, so we pop it, and leave it to vLLM default value - # (which can vary across different vLLM versions); - # - Otherwise it's the desired value we want to explicitly set. - engine_kwargs = {key: val for key, val in engine_kwargs.items() if val is not None} - lora_kwargs = kwargs.pop('lora_kwargs', {}) - self.lora_kwargs = lora_kwargs - self.inference_engine = LLM( - actor_module, - tokenizer=tokenizer, - model_hf_config=model_hf_config, - tensor_parallel_size=tensor_parallel_size, - dtype=config.dtype, - enforce_eager=config.enforce_eager, - gpu_memory_utilization=config.gpu_memory_utilization, - skip_tokenizer_init=False, - max_model_len=max_model_len, - load_format=config.load_format, - disable_log_stats=config.disable_log_stats, - max_num_batched_tokens=max_num_batched_tokens, - enable_chunked_prefill=config.enable_chunked_prefill, - **lora_kwargs, - **engine_kwargs, - ) - - # Offload vllm model to reduce peak memory usage - self.inference_engine.offload_model_weights() - - kwargs = dict( - n=1, - logprobs=0, # can be set to 0 and let actor to recompute - max_tokens=config.response_length, - ) - - # we may detokenize the result all together later - if vllm_version in ( - "0.5.4", - "0.6.3", - ): - kwargs["detokenize"] = False - - # supporting adding any sampling params from the config file - for k in config.keys(): - if hasattr(SamplingParams(), str(k)): - kwargs[k] = config.get(k) - - print(f"kwargs: {kwargs}") - self.sampling_params = SamplingParams(**kwargs) - - self.pad_token_id = tokenizer.pad_token_id - - @contextmanager - def update_sampling_params(self, **kwargs): - # update sampling params - old_sampling_params_args = {} - if kwargs: - for key, value in kwargs.items(): - if hasattr(self.sampling_params, key): - old_value = getattr(self.sampling_params, key) - old_sampling_params_args[key] = old_value - setattr(self.sampling_params, key, value) - yield - # roll back to previous sampling params - # if len(old_sampling_params_args): - for key, value in old_sampling_params_args.items(): - setattr(self.sampling_params, key, value) - - # NOTE: added by Reasoning360. timer for precise logging - @staticmethod - @contextmanager - def timer(): - import time - - start = end = time.perf_counter() - yield lambda: end - start - end = time.perf_counter() - - @GPUMemoryLogger(role="vllm rollout spmd", logger=logger) - @torch.no_grad() - def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: - # rebuild vllm cache engine - if self.config.free_cache_engine: - self.inference_engine.init_cache_engine() - - idx = prompts.batch["input_ids"] # (bs, prompt_length) - # left-padded attention_mask - attention_mask = prompts.batch["attention_mask"] - position_ids = prompts.batch["position_ids"] - - # used to construct attention_mask - eos_token_id = prompts.meta_info["eos_token_id"] - - batch_size = idx.size(0) - - idx_list = [] - # parse idx from torch.Tensor to List[List[str]] - for i in range(batch_size): - idx_list.append(_pre_process_inputs(self.pad_token_id, idx[i])) - - do_sample = prompts.meta_info.get("do_sample", True) - is_validate = prompts.meta_info.get("validate", False) - if not do_sample: - kwargs = { - "best_of": 1, - "top_p": 1.0, - "top_k": -1, - "min_p": 0.0, - "temperature": 0, - "n": 1, # if greedy, only 1 response - } - elif is_validate: - # TODO: try ** - kwargs = { - "top_k": self.config.val_kwargs.top_k, - "top_p": self.config.val_kwargs.top_p, - "temperature": self.config.val_kwargs.temperature, - "n": 1, # if validate, already repeat in ray_trainer - } - - # NOTE: added by Reasoning360, modify sampling params for batches too small - if "num_samples" in prompts.meta_info: - kwargs["n"] = prompts.meta_info["num_samples"] - - lora_requests = None - if self.lora_kwargs: - # self.inference_engine.llm_engine.list_loras - lora_int_ids = list(self.inference_engine.llm_engine.list_loras()) - if len(lora_int_ids) > 0: - lora_int_id=lora_int_ids[0] - lora_requests = [LoRARequest(lora_name=f"{lora_int_id}",lora_int_id=lora_int_id,lora_path="/simon-stub-path")] * batch_size - # users can customize different sampling_params at different run - with self.update_sampling_params(**kwargs), self.timer() as t: - output = self.inference_engine.generate( - prompts=None, # because we have already convert it to prompt token id - sampling_params=self.sampling_params, - prompt_token_ids=idx_list, - lora_request=lora_requests, - use_tqdm=False, - ) - - # TODO(sgm): disable logprob when recompute_log_prob is enable - # if n = 1: (bs, response_length) ; if n > 1: (bs * n, response_length) - response = output[0].to(idx.device) - log_probs = output[1].to(idx.device) - - if response.shape[1] < self.config.response_length: - response = pad_sequence_to_length(response, self.config.response_length, self.pad_token_id) - log_probs = pad_sequence_to_length(log_probs, self.config.response_length, self.pad_token_id) - - # utilize current sampling params - if self.sampling_params.n > 1 and do_sample: - idx = idx.repeat_interleave(self.sampling_params.n, dim=0) - attention_mask = attention_mask.repeat_interleave(self.sampling_params.n, dim=0) - position_ids = position_ids.repeat_interleave(self.sampling_params.n, dim=0) - batch_size = batch_size * self.sampling_params.n - seq = torch.cat([idx, response], dim=-1) - - response_length = response.size(1) - delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device) - delta_position_id = delta_position_id.unsqueeze(0).repeat(batch_size, 1) - - # TODO(sgm): fix position_ids on right_pad - # prompt: left pad + response: right pad - # attention_mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0] - # position_ids: [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11] - response_position_ids = position_ids[:, -1:] + delta_position_id - position_ids = torch.cat([position_ids, response_position_ids], dim=-1) - response_attention_mask = get_response_mask(response_id=response, eos_token=eos_token_id, dtype=attention_mask.dtype) - attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1) - - tokens_per_second = torch.sum(response_attention_mask).item() / t() - import os - - print( - f'Tokens per second: {tokens_per_second} t/s on device {os.environ["CUDA_VISIBLE_DEVICES"]} on host {os.uname().nodename}', - flush=True, - ) - - # all the tp ranks should contain the same data here. data in all ranks are valid - batch = TensorDict( - { - "prompts": idx, - "responses": response, - "input_ids": seq, # here input_ids become the whole sentences - 'rollout_log_probs': log_probs, # we will recompute old log prob with actor - "attention_mask": attention_mask, - "position_ids": position_ids, - }, - batch_size=batch_size, - ) - - # free vllm cache engine - if self.config.free_cache_engine: - self.inference_engine.free_cache_engine() - - # NOTE: added by Reasoning360 - # metrics = self.report_memory_usage(reset=True) - # # NOTE: we do not use meta_info because dp collect fn only picks - # # meta_info of the first data. - # non_tensor_batch = { - # 'metrics_' + k: np.asarray([v] * seq.size(0), dtype=object) for k, v in metrics.items() - # } or None - - return DataProto(batch=batch) - - def report_memory_usage(self, reset: bool=False): - # NOTE: added by Reasoning360 - method = getattr(self.inference_engine.llm_engine, 'report_page_usage_history', None) - if method is not None: - return method(reset=reset) - return {} diff --git a/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py b/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py index b4ea8fb3f..92006fe1a 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py +++ b/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py @@ -28,23 +28,31 @@ import logging import os +import pickle +import socket +import threading from contextlib import contextmanager from copy import deepcopy -from typing import Any, Dict, List, Union +from types import MethodType +from typing import Any import numpy as np +import ray import torch import torch.distributed +import zmq +from typing import Union, List, Any +from filelock import FileLock from omegaconf import DictConfig, OmegaConf from tensordict import TensorDict from vllm import LLM, SamplingParams from vllm.distributed import parallel_state as vllm_ps from vllm.lora.request import LoRARequest +from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.worker.worker_base import WorkerWrapperBase from verl import DataProto -from verl.third_party.vllm import vllm_version -from verl.utils.debug import GPUMemoryLogger +from verl.utils.profiler import GPUMemoryLogger from verl.utils.torch_functional import get_response_mask, pad_2d_list_to_length from verl.workers.rollout.base import BaseRollout @@ -66,14 +74,12 @@ def _pre_process_inputs(pad_token_id, prompt_token_ids: torch.Tensor) -> List[in token_ids = prompt_token_ids[non_pad_index:].tolist() return token_ids - def _repeat_interleave(value: Union[torch.Tensor, np.ndarray], repeats: int) -> Union[torch.Tensor, List[Any]]: if isinstance(value, torch.Tensor): return value.repeat_interleave(repeats, dim=0) else: return np.repeat(value, repeats, axis=0) - class vLLMRollout(BaseRollout): def __init__(self, model_path: str, config: DictConfig, tokenizer, model_hf_config, **kwargs): """A vLLM rollout. It requires the module is supported by the vllm. @@ -87,10 +93,11 @@ def __init__(self, model_path: str, config: DictConfig, tokenizer, model_hf_conf """ super().__init__() self.config = config - assert not (not config.enforce_eager and config.free_cache_engine), "disable CUDA graph (enforce_eager = False) if free cache engine" tensor_parallel_size = self.config.get("tensor_model_parallel_size", 1) - assert tensor_parallel_size <= torch.distributed.get_world_size(), "tensor parallel size should be less than or equal to the world size" + assert tensor_parallel_size <= torch.distributed.get_world_size(), ( + "tensor parallel size should be less than or equal to the world size" + ) max_num_batched_tokens = self.config.get("max_num_batched_tokens", 8192) if kwargs.get("train_tp") is not None: @@ -98,31 +105,43 @@ def __init__(self, model_path: str, config: DictConfig, tokenizer, model_hf_conf # NOTE: import os removed by Reasoning360. Definitely a bug of the official code. os.environ["CUDA_TIMER_STREAM_KAFKA_ENABLE"] = "0" os.environ["MEGATRON_IMPORT_TIMERS"] = "0" - if vllm_version in ( - "0.5.4", - "0.6.3", - ): - train_tp = kwargs.get("train_tp") - num_tp_per_train_tp = train_tp // tensor_parallel_size - vllm_ps.initialize_parallel_state(tensor_model_parallel_size=tensor_parallel_size, num_tp_per_train_tp=num_tp_per_train_tp) - else: - vllm_ps.initialize_model_parallel(tensor_model_parallel_size=tensor_parallel_size) + vllm_ps.initialize_model_parallel(tensor_model_parallel_size=tensor_parallel_size) rope_scaling_config = getattr(model_hf_config, "rope_scaling", None) if not rope_scaling_config: max_position_embeddings = None if hasattr(model_hf_config, "max_position_embeddings"): max_position_embeddings = model_hf_config.max_position_embeddings - elif hasattr(model_hf_config, "llm_config") and hasattr(model_hf_config.llm_config, "max_position_embeddings"): + elif hasattr(model_hf_config, "llm_config") and hasattr( + model_hf_config.llm_config, "max_position_embeddings" + ): max_position_embeddings = model_hf_config.llm_config.max_position_embeddings - elif hasattr(model_hf_config, "text_config") and hasattr(model_hf_config.text_config, "max_position_embeddings"): + elif hasattr(model_hf_config, "text_config") and hasattr( + model_hf_config.text_config, "max_position_embeddings" + ): max_position_embeddings = model_hf_config.text_config.max_position_embeddings if max_position_embeddings is None: raise ValueError("max_position_embeddings not found in model_hf_config") - - assert max_position_embeddings >= config.prompt_length + config.response_length, "model context length should be greater than total sequence length" + assert max_position_embeddings >= config.prompt_length + config.response_length, ( + "model context length should be greater than total sequence length" + ) + else: + # handle type where there's a length extend factor + # see https://qwen.readthedocs.io/en/latest/deployment/vllm.html#extended-context-support + # for using yarn as an example + rope_scaling_factor = rope_scaling_config.get("factor", 1.0) + + assert ( + model_hf_config.max_position_embeddings * rope_scaling_factor + >= config.prompt_length + config.response_length + ), ( + "model context length should be greater than total sequence length, " + + f"got rope_scaling_factor={rope_scaling_factor} and " + + f"max_position_embeddings={model_hf_config.max_position_embeddings}" + ) max_model_len = int(config.max_model_len or config.prompt_length + config.response_length) + #max_model_len = 1024 * (32 + 4) if max_num_batched_tokens < max_model_len and self.config.enable_chunked_prefill: raise ValueError( @@ -133,30 +152,31 @@ def __init__(self, model_path: str, config: DictConfig, tokenizer, model_hf_conf trust_remote_code = kwargs.get("trust_remote_code", False) load_format = "dummy" if config.load_format.startswith("dummy") else config.load_format - limit_mm_per_prompt = None - if config.get("limit_images", None): # support for multi-image data - limit_mm_per_prompt = {"image": config.get("limit_images")} - - lora_kwargs = kwargs.pop('lora_kwargs', {}) + lora_kwargs = kwargs.pop("lora_kwargs", {}) self.lora_kwargs = lora_kwargs # copy it to avoid secretly modifying the engine config - engine_kwargs = {} if "engine_kwargs" not in config or "vllm" not in config.engine_kwargs else OmegaConf.to_container(deepcopy(config.engine_kwargs.vllm)) + engine_kwargs = ( + {} + if "engine_kwargs" not in config or "vllm" not in config.engine_kwargs + else OmegaConf.to_container(deepcopy(config.engine_kwargs.vllm)) + ) # For each vLLM engine parameter, # - `None` means not setting it, so we pop it, and leave it to vLLM default value # (which can vary across different vLLM versions); # - Otherwise it's the desired value we want to explicitly set. engine_kwargs = {key: val for key, val in engine_kwargs.items() if val is not None} + if config.get("limit_images", None): # support for multi-image data + engine_kwargs["limit_mm_per_prompt"] = {"image": config.get("limit_images")} + self.inference_engine = LLM( model=model_path, - enable_sleep_mode=True, + enable_sleep_mode=config.free_cache_engine, tensor_parallel_size=tensor_parallel_size, distributed_executor_backend="external_launcher", dtype=config.dtype, enforce_eager=config.enforce_eager, gpu_memory_utilization=config.gpu_memory_utilization, disable_custom_all_reduce=True, - disable_mm_preprocessor_cache=True, - limit_mm_per_prompt=limit_mm_per_prompt, skip_tokenizer_init=False, max_model_len=max_model_len, load_format=load_format, @@ -165,7 +185,8 @@ def __init__(self, model_path: str, config: DictConfig, tokenizer, model_hf_conf enable_chunked_prefill=config.enable_chunked_prefill, enable_prefix_caching=True, trust_remote_code=trust_remote_code, - seed=int(os.getenv("RANK", "0")) // tensor_parallel_size, # NOTE: modified by Reasoning360. Originally config.get("seed", 0) + seed=int(os.getenv("RANK", "0")) + // tensor_parallel_size, # NOTE: modified by Reasoning360. Originally config.get("seed", 0) **lora_kwargs, **engine_kwargs, ) @@ -173,7 +194,8 @@ def __init__(self, model_path: str, config: DictConfig, tokenizer, model_hf_conf # self._monkey_patch_vllm_engine_v0() # Offload vllm model to reduce peak memory usage - self.inference_engine.sleep(level=1) + if config.free_cache_engine: + self.inference_engine.sleep(level=1) kwargs = dict( n=1, @@ -181,16 +203,19 @@ def __init__(self, model_path: str, config: DictConfig, tokenizer, model_hf_conf max_tokens=config.response_length, ) - # # we may detokenize the result all together later - if vllm_version != "0.3.1": - kwargs["detokenize"] = False + kwargs["detokenize"] = False # supporting adding any sampling params from the config file for k in config.keys(): if hasattr(SamplingParams(), str(k)): kwargs[k] = config.get(k) + kwargs['n'] = 1 + print(f"kwargs by 360 (SPMD): {kwargs}") + + # Preserve the n value for deterministic expansion (needed for both paths) + n_samples = kwargs.get('n', 1) - print(f"kwargs: {kwargs}") + # users can customize different sampling_params at different run self.sampling_params = SamplingParams(**kwargs) self.pad_token_id = tokenizer.pad_token_id @@ -224,17 +249,26 @@ def timer(): @GPUMemoryLogger(role="vllm rollout spmd", logger=logger) @torch.no_grad() def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: - # rebuild vllm cache engine - if ( - vllm_version - in ( - "0.5.4", - "0.6.3", - ) - and self.config.free_cache_engine - ): - self.inference_engine.init_cache_engine() + """Generate sequences for a batch of prompts. + Args: + batch (DataProto): Input batch. + + Returns: + DataProto: Output batch. + - prompts: [bsz, prompt_length], prompt token ids from dataset. + - responses: [bsz, response_length], output token ids include response tokens + from LLM generation and observation tokens from tool_calls. + - response_mask: [bsz, response_length], 1 for LLM generated tokens, 0 for observation/padding tokens. + - input_ids: [bsz, prompt_length + response_length], whole sequence token ids, including prompt tokens + and response tokens. + - attention_mask: [bsz, prompt_length + response_length], 0 for padding tokens, 1 for other tokens. + - position_ids: [bsz, prompt_length + response_length], incremental position ids. + + For multi-turn conversations: + responses: |<- LLM generation ->|<- tool_calls ->|<- LLM generation ->|<- padding ->| + response_mask: | 1, 1, 1, ..., 1, 1 | 0, 0, .., 0, 0 | 1, 1, 1, ..., 1, 1 | 0, 0, ..., 0| + """ idx = prompts.batch["input_ids"] # (bs, prompt_length) # left-padded attention_mask attention_mask = prompts.batch["attention_mask"] @@ -247,17 +281,23 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: non_tensor_batch = prompts.non_tensor_batch if "raw_prompt_ids" not in non_tensor_batch: - non_tensor_batch["raw_prompt_ids"] = np.array([_pre_process_inputs(self.pad_token_id, idx[i]) for i in range(batch_size)], dtype=object) + non_tensor_batch["raw_prompt_ids"] = np.array( + [_pre_process_inputs(self.pad_token_id, idx[i]) for i in range(batch_size)], dtype=object + ) if batch_size != len(non_tensor_batch["raw_prompt_ids"]): raise RuntimeError("vllm sharding manager is not work properly.") if "multi_modal_data" in non_tensor_batch: vllm_inputs = [] - for raw_prompt_ids, multi_modal_data in zip(non_tensor_batch.pop("raw_prompt_ids"), non_tensor_batch.pop("multi_modal_data")): + for raw_prompt_ids, multi_modal_data in zip( + non_tensor_batch.pop("raw_prompt_ids"), non_tensor_batch.pop("multi_modal_data"), strict=True + ): vllm_inputs.append({"prompt_token_ids": raw_prompt_ids, "multi_modal_data": multi_modal_data}) else: - vllm_inputs = [{"prompt_token_ids": raw_prompt_ids} for raw_prompt_ids in non_tensor_batch.pop("raw_prompt_ids")] + vllm_inputs = [ + {"prompt_token_ids": raw_prompt_ids} for raw_prompt_ids in non_tensor_batch.pop("raw_prompt_ids") + ] # ensure the type of `prompt_token_ids` passed to vllm is list[int] # https://github.com/volcengine/verl/pull/772 @@ -265,7 +305,9 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: if isinstance(input_data["prompt_token_ids"], np.ndarray): input_data["prompt_token_ids"] = input_data["prompt_token_ids"].tolist() elif not isinstance(input_data["prompt_token_ids"], list): - raise TypeError(f"prompt_token_ids must be a list or numpy array, got {type(input_data['prompt_token_ids'])}") + raise TypeError( + f"prompt_token_ids must be a list or numpy array, got {type(input_data['prompt_token_ids'])}" + ) do_sample = prompts.meta_info.get("do_sample", True) is_validate = prompts.meta_info.get("validate", False) @@ -297,44 +339,160 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: # NOTE: added by Reasoning360 if "num_samples" in prompts.meta_info: kwargs["n"] = prompts.meta_info["num_samples"] - + + # Preserve the n value for deterministic expansion (needed for both paths) + n_samples = kwargs.get('n', 1) + + # Check for individual response lengths in non_tensor_batch + individual_sampling_params = None + per_prompt_length_budget = None + if "per_prompt_length_budget" in prompts.non_tensor_batch: + response_length = prompts.non_tensor_batch["per_prompt_length_budget"] + + # Convert to list if it's a numpy array + if isinstance(response_length, np.ndarray): + response_length = response_length.tolist() + + # Ensure we have the right number for this worker's batch + if len(response_length) != batch_size: + # The framework distributes data across workers, so we may get a subset + # Take the first batch_size elements (assumes ordered distribution) + response_length = response_length[:batch_size] if len(response_length) > batch_size else response_length + print(f"vLLM rollout (SPMD): Using {len(response_length)} response lengths for batch_size {batch_size}") + + # Set flag to use individual context manager approach + individual_sampling_params = True + per_prompt_length_budget = response_length + print(f"vLLM rollout (SPMD): Using individual response lengths (range: {min(response_length)}-{max(response_length)})") + + # Check for single response_length override in meta_info (backward compatibility) + elif "response_length" in prompts.meta_info: + response_length = prompts.meta_info["response_length"] + # meta_info should only contain single values, not lists + kwargs["max_tokens"] = response_length + print(f"vLLM rollout (SPMD): Overriding max_tokens to {response_length}") + # users can customize different sampling_params at different run - with self.update_sampling_params(**kwargs), self.timer() as t: - outputs = self.inference_engine.generate( - prompts=vllm_inputs, # because we have already convert it to prompt token id - sampling_params=self.sampling_params, - lora_request=lora_requests, - use_tqdm=False, - ) - - # TODO(sgm): disable logprob when recompute_log_prob is enable - # if n = 1: (bs, response_length) ; if n > 1: (bs * n, response_length) - - response = [] - rollout_log_probs = [] - for output in outputs: - for sample_id in range(len(output.outputs)): - response_ids = output.outputs[sample_id].token_ids - response.append(response_ids) - curr_log_prob = [] - for i, logprob in enumerate(output.outputs[sample_id].logprobs): - curr_log_prob.append(logprob[response_ids[i]].logprob) - rollout_log_probs.append(curr_log_prob) - + if is_validate: + kwargs["max_tokens"] = 32768 + with self.update_sampling_params(**kwargs), self.timer() as t: + outputs = self.inference_engine.generate( + prompts=vllm_inputs, # because we have already convert it to prompt token id + sampling_params=self.sampling_params, + lora_request=lora_requests, + use_tqdm=False, + ) + kwargs["max_tokens"] = self.config.response_length + self.update_sampling_params(**kwargs) + + elif individual_sampling_params is not None: + # PALU training + # Use context manager approach to create individual sampling params while maintaining batch efficiency + individual_sampling_params_list = [] + + # Create individual sampling params using context manager approach (no rollback needed since we're not calling generate) + for length in response_length: + # Create kwargs for this specific prompt + individual_kwargs = kwargs.copy() + individual_kwargs["max_tokens"] = int(length) + + # Temporarily create a modified sampling params using the same logic as update_sampling_params + # Save current state + old_sampling_params_args = {} + for key, value in individual_kwargs.items(): + if hasattr(self.sampling_params, key): + old_value = getattr(self.sampling_params, key) + old_sampling_params_args[key] = old_value + setattr(self.sampling_params, key, value) + + # Create a copy with current state (same as what update_sampling_params would use) + import copy + individual_sampling_param = copy.deepcopy(self.sampling_params) + individual_sampling_params_list.append(individual_sampling_param) + + # Restore original state + for key, value in old_sampling_params_args.items(): + setattr(self.sampling_params, key, value) + + # Now do batch inference with all prompts and individual sampling params + with self.timer() as t: + outputs = self.inference_engine.generate( + prompts=vllm_inputs, + sampling_params=individual_sampling_params_list, + lora_request=lora_requests, + use_tqdm=False, + ) + else: + # GRPO baseline training + # Use single sampling params for all prompts (original behavior) + kwargs["max_tokens"] = self.config.response_length + with self.update_sampling_params(**kwargs), self.timer() as t: + outputs = self.inference_engine.generate( + prompts=vllm_inputs, # because we have already convert it to prompt token id + sampling_params=self.sampling_params, + lora_request=lora_requests, + use_tqdm=False, + ) + + # TODO(sgm): disable logprob when recompute_log_prob is enable + # if n = 1: (bs, response_length) ; if n > 1: (bs * n, response_length) + + response = [] + rollout_log_probs = [] + for output in outputs: + for sample_id in range(len(output.outputs)): + response_ids = output.outputs[sample_id].token_ids + response.append(response_ids) + curr_log_prob = [] + for i, logprob in enumerate(output.outputs[sample_id].logprobs): + curr_log_prob.append(logprob[response_ids[i]].logprob) + rollout_log_probs.append(curr_log_prob) + + if is_validate: + response = pad_2d_list_to_length(response, self.pad_token_id, max_length=32768).to(idx.device) + rollout_log_probs = pad_2d_list_to_length(rollout_log_probs, -1, max_length=32768).to(idx.device) + else: response = pad_2d_list_to_length(response, self.pad_token_id, max_length=self.config.response_length).to(idx.device) rollout_log_probs = pad_2d_list_to_length(rollout_log_probs, -1, max_length=self.config.response_length).to(idx.device) - rollout_log_probs = rollout_log_probs.to(torch.float32) - - if self.sampling_params.n > 1 and do_sample: - idx = _repeat_interleave(idx, self.sampling_params.n) - attention_mask = _repeat_interleave(attention_mask, self.sampling_params.n) - position_ids = _repeat_interleave(position_ids, self.sampling_params.n) - batch_size = batch_size * self.sampling_params.n - # NOTE(linjunrong): for multi-turn https://github.com/volcengine/verl/pull/1037 - if "tools_kwargs" in non_tensor_batch.keys(): - non_tensor_batch["tools_kwargs"] = _repeat_interleave(non_tensor_batch["tools_kwargs"], self.sampling_params.n) - - seq = torch.cat([idx, response], dim=-1) + rollout_log_probs = rollout_log_probs.to(torch.float32) + + # Check if vLLM has already expanded the batch due to n > 1 + response_batch_size = response.size(0) + original_batch_size = batch_size + + # If response tensor is already expanded (response_batch_size > original_batch_size), + # then vLLM has handled the expansion internally + vllm_already_expanded = response_batch_size > original_batch_size + + # Original deterministic approach: utilize sampling params for tensor expansion + # n_samples is the same whether we use individual sampling params or not + if n_samples > 1 and do_sample and not vllm_already_expanded: + # Only expand if vLLM hasn't already done it + idx = _repeat_interleave(idx, n_samples) + attention_mask = _repeat_interleave(attention_mask, n_samples) + position_ids = _repeat_interleave(position_ids, n_samples) + batch_size = batch_size * n_samples + # NOTE(linjunrong): for multi-turn https://github.com/volcengine/verl/pull/1037 + # NOTE: Fix batch size mismatch by expanding ALL non_tensor_batch items + expanded_non_tensor_batch = {} + for key, val in non_tensor_batch.items(): + expanded_non_tensor_batch[key] = _repeat_interleave(val, n_samples) + non_tensor_batch = expanded_non_tensor_batch + elif vllm_already_expanded: + # vLLM has already expanded, so we need to expand the prompt tensors to match + expansion_factor = response_batch_size // original_batch_size + idx = _repeat_interleave(idx, expansion_factor) + attention_mask = _repeat_interleave(attention_mask, expansion_factor) + position_ids = _repeat_interleave(position_ids, expansion_factor) + batch_size = response_batch_size + # NOTE(linjunrong): for multi-turn https://github.com/volcengine/verl/pull/1037 + # NOTE: Fix batch size mismatch by expanding ALL non_tensor_batch items + expanded_non_tensor_batch = {} + for key, val in non_tensor_batch.items(): + expanded_non_tensor_batch[key] = _repeat_interleave(val, expansion_factor) + non_tensor_batch = expanded_non_tensor_batch + + seq = torch.cat([idx, response], dim=-1) response_length = response.size(1) delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device) @@ -348,41 +506,75 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: # position_ids: [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11] response_position_ids = position_ids[..., -1:] + delta_position_id position_ids = torch.cat([position_ids, response_position_ids], dim=-1) - response_attention_mask = get_response_mask(response_id=response, eos_token=eos_token_id, dtype=attention_mask.dtype) + response_attention_mask = get_response_mask( + response_id=response, eos_token=eos_token_id, dtype=attention_mask.dtype + ) attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1) - # NOTE: added by Reasoning360. temporarily disabled to avoid messy logging - # tokens_per_second = torch.sum(response_attention_mask).item() / t() - # print( - # f'Tokens per second: {tokens_per_second} t/s on device {os.environ["CUDA_VISIBLE_DEVICES"]} on host {os.uname().nodename}', - # flush=True, - # ) - # all the tp ranks should contain the same data here. data in all ranks are valid batch = TensorDict( { "prompts": idx, "responses": response, "input_ids": seq, # here input_ids become the whole sentences - 'rollout_log_probs': rollout_log_probs, # we will recompute old log prob with actor "attention_mask": attention_mask, "position_ids": position_ids, }, batch_size=batch_size, ) - + if self.config.calculate_log_probs: + # we will recompute old log prob with actor + batch["rollout_log_probs"] = rollout_log_probs + + # Prepare meta_info for the returned DataProto + meta_info = prompts.meta_info.copy() + + # Set target_max_response_length for backward compatibility + if per_prompt_length_budget is None: + meta_info["target_max_response_length"] = self.config.response_length + + if is_validate: + generated_response_lengths = response_attention_mask.sum(dim=1) # shape: [batch_size] + # logging + print(f"Generated response lengths: {generated_response_lengths}") + # save them to a local dataframe: + import os + import pandas as pd + job_id = os.environ.get("SLURM_JOB_ID", "nojob") # fallback for local runs + + df = pd.DataFrame({ + "generated_response_lengths": generated_response_lengths.cpu().numpy(), + }) + df.to_csv(f"./response_lengths_job{job_id}_step{prompts.meta_info['global_steps']}.csv", index=False, mode='a', header=False) + # LLM360 (removed in latest vllm) # free vllm cache engine - if ( - vllm_version - in ( - "0.5.4", - "0.6.3", - ) - and self.config.free_cache_engine - ): - self.inference_engine.free_cache_engine() + # if ( + # vllm_version + # in ( + # "0.5.4", + # "0.6.3", + # ) + # and self.config.free_cache_engine + # ): + # self.inference_engine.free_cache_engine() + + return DataProto(batch=batch, non_tensor_batch=non_tensor_batch, meta_info=meta_info) + - return DataProto(batch=batch, non_tensor_batch=non_tensor_batch) +# https://github.com/vllm-project/vllm/issues/13175 +def _monkey_patch_compute_logits(model, vocab_size: int): + original_compute_logits = model.compute_logits + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> torch.Tensor: + logits = original_compute_logits(hidden_states, sampling_metadata) + logits[..., vocab_size:] = float("-inf") + return logits + + model.compute_logits = MethodType(compute_logits, model) class vLLMAsyncRollout: @@ -390,13 +582,58 @@ class vLLMAsyncRollout: which is engine in single worker process. """ - def __init__(self, *args, **kwargs): + def __init__(self, model_path: str, config: DictConfig, tokenizer, model_hf_config, **kwargs): + self.tokenizer = tokenizer + # Engine is deferred to be initialized in init_worker + self.config = config self.inference_engine: WorkerWrapperBase = None self.sharding_manager = None self.is_sleep = False + self.address = self._init_zeromq() + + def _init_zeromq(self) -> str: + tensor_parallel_size = self.config.tensor_model_parallel_size - def init_worker(self, all_kwargs: List[Dict[str, Any]]): + # single node: ipc, multi nodes: tcp + local_world_size = int(os.environ["RAY_LOCAL_WORLD_SIZE"]) + socket_type = "ipc" if tensor_parallel_size <= local_world_size else "tcp" + + # File lock to prevent multiple workers listen to same port + with FileLock("/tmp/verl_vllm_zmq.lock"): + if socket_type == "ipc": + pid = os.getpid() + address = f"ipc:///tmp/verl_vllm_zmq_{pid}.ipc" + else: + ip, port = self._get_free_port() + address = f"tcp://{ip}:{port}" + context = zmq.Context() + self.socket = context.socket(zmq.REP) + self.socket.bind(address) + + self.loop_thread = threading.Thread(target=self._loop_forever) + self.loop_thread.start() + + return address + + def _get_free_port(self): + ip = ray.util.get_node_ip_address() + with socket.socket() as sock: + sock.bind(("", 0)) + port = sock.getsockname()[1] + return ip, port + + def _loop_forever(self): + while True: + message = self.socket.recv() + method, args, kwargs = pickle.loads(message) + result = self.execute_method(method, *args, **kwargs) + self.socket.send(pickle.dumps(result)) + + def get_zeromq_address(self): + return self.address + + def init_worker(self, all_kwargs: list[dict[str, Any]]): """Initialize worker engine.""" all_kwargs[0]["rank"] = int(os.environ["RANK"]) all_kwargs[0]["local_rank"] = 0 @@ -408,10 +645,12 @@ def init_worker(self, all_kwargs: List[Dict[str, Any]]): def load_model(self, *args, **kwargs): self.inference_engine.load_model(*args, **kwargs) - # inference engine is intialized now, update sharding manager + # inference engine is initialized now, update sharding manager self.sharding_manager.inference_engine = self.inference_engine self.sharding_manager.model_runner = self.inference_engine.worker.model_runner + _monkey_patch_compute_logits(self.inference_engine.worker.model_runner.model, len(self.tokenizer)) + def sleep(self, *args, **kwargs): """Offload model weights and discard kv cache.""" if self.is_sleep: @@ -426,7 +665,7 @@ def wake_up(self, *args, **kwargs): self.sharding_manager.__enter__() # pylint: disable=C2801 self.is_sleep = False - def execute_method(self, method: Union[str, bytes], *args, **kwargs): + def execute_method(self, method: str | bytes, *args, **kwargs): if method == "init_worker": return self.init_worker(*args, **kwargs) elif method == "load_model": diff --git a/verl/workers/sharding_manager/base.py b/verl/workers/sharding_manager/base.py index d7415892a..59537be64 100644 --- a/verl/workers/sharding_manager/base.py +++ b/verl/workers/sharding_manager/base.py @@ -19,6 +19,9 @@ class BaseShardingManager: + def __init__(self): + self.timing = {} + def __enter__(self): pass diff --git a/verl/workers/sharding_manager/fsdp_sglang.py b/verl/workers/sharding_manager/fsdp_sglang.py index 3608d932b..80201dc56 100644 --- a/verl/workers/sharding_manager/fsdp_sglang.py +++ b/verl/workers/sharding_manager/fsdp_sglang.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import logging import os @@ -29,9 +30,12 @@ from verl import DataProto from verl.protocol import all_gather_data_proto -from verl.utils.debug import GPUMemoryLogger, log_gpu_memory_usage +from verl.utils.device import get_device_id, get_torch_device from verl.utils.fsdp_utils import fsdp_version, load_fsdp_model_to_gpu, offload_fsdp_model_to_cpu -from verl.utils.torch_functional import check_cuda_is_available +from verl.utils.model import convert_weight_keys +from verl.utils.profiler import GPUMemoryLogger, log_gpu_memory_usage, simple_timer +from verl.utils.torch_functional import check_device_is_available +from verl.workers.rollout.sglang_rollout.utils import get_named_tensor_buckets from .base import BaseShardingManager @@ -47,26 +51,32 @@ def _preprocess_tensor_for_update_weights(tensor: torch.Tensor): class FSDPSGLangShardingManager(BaseShardingManager): - @check_cuda_is_available() + @check_device_is_available() def __init__( self, module: FSDP, inference_engine: Engine, model_config, + rollout_config, full_params: bool = False, device_mesh: DeviceMesh = None, offload_param: bool = False, + multi_stage_wake_up: bool = False, ): self.module = module self.inference_engine = inference_engine self.model_config = model_config + self.rollout_config = rollout_config self.device_mesh = device_mesh self.offload_param = offload_param + self.multi_stage_wake_up = multi_stage_wake_up # Full params self.full_params = full_params if full_params and fsdp_version(self.module) == 1: - FSDP.set_state_dict_type(self.module, state_dict_type=StateDictType.FULL_STATE_DICT, state_dict_config=FullStateDictConfig()) + FSDP.set_state_dict_type( + self.module, state_dict_type=StateDictType.FULL_STATE_DICT, state_dict_config=FullStateDictConfig() + ) elif fsdp_version(self.module) == 1: FSDP.set_state_dict_type( self.module, @@ -78,93 +88,157 @@ def __init__( self.tp_rank = self.device_mesh["infer_tp"].get_local_rank() # Note that torch_random_states may be different on each dp rank - self.torch_random_states = torch.cuda.get_rng_state() + self.torch_random_states = get_torch_device().get_rng_state() # get a random rng states if self.device_mesh is not None: gen_dp_rank = self.device_mesh["dp"].get_local_rank() - torch.cuda.manual_seed(gen_dp_rank + 1000) # make sure all tp ranks have the same random states - self.gen_random_states = torch.cuda.get_rng_state() - torch.cuda.set_rng_state(self.torch_random_states) + get_torch_device().manual_seed(gen_dp_rank + 1000) # make sure all tp ranks have the same random states + self.gen_random_states = get_torch_device().get_rng_state() + get_torch_device().set_rng_state(self.torch_random_states) else: self.gen_random_states = None @GPUMemoryLogger(role="FSDPSGLangShardingManager enter", logger=logger) def __enter__(self): - torch.cuda.empty_cache() - log_gpu_memory_usage("Before state_dict() in sharding manager memory", logger=logger) - if self.offload_param: - load_fsdp_model_to_gpu(self.module) - params = self.module.state_dict() - log_gpu_memory_usage("After state_dict() in sharding manager memory", logger=logger) - device = torch.cuda.current_device() # used when fsdp2 set cpu_offload_policy - params = {k: v.to(device, non_blocking=True) if fsdp_version(self.module) == 2 else v for k, v in params.items()} - # Copy, not share memory - self.update_weights(params) - log_gpu_memory_usage("After sync model weights in sharding manager", logger=logger) - - del params - if self.offload_param: - offload_fsdp_model_to_cpu(self.module) - torch.cuda.empty_cache() - log_gpu_memory_usage("After del state_dict and empty_cache in sharding manager", logger=logger) - - # important: need to manually set the random states of each tp to be identical. - if self.device_mesh is not None: - self.torch_random_states = torch.cuda.get_rng_state() - torch.cuda.set_rng_state(self.gen_random_states) + self.timing = {} + with simple_timer("reshard", self.timing): + loop = asyncio.get_event_loop() + loop.run_until_complete(self.wake_up()) @GPUMemoryLogger(role="FSDPSGLangShardingManager exit", logger=logger) def __exit__(self, exc_type, exc_value, traceback): - log_gpu_memory_usage("Before SGLang offload in sharding manager", logger=logger) - self.release_memory() - log_gpu_memory_usage("After SGLang offload in sharding manager", logger=logger) - - self.module.train() - - # add empty cache after each compute - torch.cuda.empty_cache() - - # restore random states - if self.device_mesh is not None: - self.gen_random_states = torch.cuda.get_rng_state() - torch.cuda.set_rng_state(self.torch_random_states) - - def update_weights(self, params): - if self.device_mesh["infer_tp"].get_local_rank() == 0: - self.inference_engine.resume_memory_occupation() + loop = asyncio.get_event_loop() + loop.run_until_complete(self.sleep()) + async def update_weights(self, params): # Most naive implementation, can optimize a lot if it is bottleneck from sglang Engine weight update named_tensors = [(k, v) for k, v in params.items()] load_format = None - for tensor_index, (name, tensor) in enumerate(named_tensors): - serialized_tensor = MultiprocessingSerializer.serialize(_preprocess_tensor_for_update_weights(tensor)) + # convert megabytes to bytes + update_weights_bucket_bytes = int(self.rollout_config.update_weights_bucket_megabytes) << 20 + for batch in get_named_tensor_buckets(named_tensors, update_weights_bucket_bytes): + # On each rank, serialize a batch of (name, tensor) tuples. + # named_tensors_batch will be a list like: + # [(name0, serialized_tensor0_tp0), (name1, serialized_tensor1_tp0), ...] + named_tensors_batch = [ + (name, MultiprocessingSerializer.serialize(_preprocess_tensor_for_update_weights(tensor))) + for name, tensor in batch + ] if self.device_mesh["infer_tp"].get_local_rank() == 0: - gathered_serialized_tensors = [None for _ in range(self.device_mesh["infer_tp"].mesh.size()[0])] + # On rank 0, prepare a list to hold the gathered batches from all ranks. + gathered_serialized_batches = [None for _ in range(self.device_mesh["infer_tp"].mesh.size()[0])] else: - gathered_serialized_tensors = None + gathered_serialized_batches = None + + # Gather the named_tensors_batch from all ranks to rank 0. + # After this, on rank 0, gathered_serialized_batches will be a list of lists: + # [ [ (name0, s_t0_tp0), (name1, s_t1_tp0), ... ], # batch from TP rank 0 + # [ (name0, s_t0_tp1), (name1, s_t1_tp1), ... ], # batch from TP rank 1 + # ... ] + # On other ranks, gathered_serialized_batches will be None. dist.gather_object( - obj=serialized_tensor, - object_gather_list=gathered_serialized_tensors, + obj=named_tensors_batch, + object_gather_list=gathered_serialized_batches, dst=self.device_mesh["infer_tp"].mesh.tolist()[0], group=self.device_mesh["infer_tp"].get_group(), ) if self.device_mesh["infer_tp"].get_local_rank() == 0: - self.inference_engine.update_weights_from_tensor( + # Use zip(*) to "transpose" the data structure. + # This groups the serialized parts for each individual tensor across all TP ranks. + # Example: from [[(n0, t0_tp0), (n1, t1_tp0)], [(n0, t0_tp1), (n1, t1_tp1)]] + # to [ ( (n0, t0_tp0), (n0, t0_tp1) ), ( (n1, t1_tp0), (n1, t1_tp1) ) ] + logical_tensors = zip(*gathered_serialized_batches, strict=True) + + await self.inference_engine.update_weights_from_tensor( named_tensors=[ + # 'tensor_group' represents a single logical tensor's data from all ranks. ( - name, - LocalSerializedTensor(values=gathered_serialized_tensors), + tensor_group[0][0], # Get the name from the first rank's data. + LocalSerializedTensor( + # 'rank_part' is the (name, serialized_tensor) tuple from one specific rank. + values=[rank_part[1] for rank_part in tensor_group] + ), ) + for tensor_group in logical_tensors + # each tensor_group is like ( (n0, t0_tp0), (n0, t0_tp1) ) ], load_format=load_format, - flush_cache=tensor_index == len(named_tensors) - 1, + flush_cache=False, ) - def release_memory(self): if self.device_mesh["infer_tp"].get_local_rank() == 0: - self.inference_engine.release_memory_occupation() + await self.inference_engine.flush_cache() + + async def release_memory(self): + if self.device_mesh["infer_tp"].get_local_rank() == 0 and self.rollout_config.free_cache_engine: + await self.inference_engine.release_memory_occupation() + + @GPUMemoryLogger(role="FSDPSGLangShardingManager enter", logger=logger) + async def wake_up(self): + get_torch_device().empty_cache() + + if self.device_mesh["infer_tp"].get_local_rank() == 0 and self.rollout_config.free_cache_engine: + if self.multi_stage_wake_up: + await self.inference_engine.resume_memory_occupation(tags=["weights"]) + log_gpu_memory_usage("Before resume SGLang weights in sharding manager", logger=logger) + else: + await self.inference_engine.resume_memory_occupation() + log_gpu_memory_usage("Before resume SGLang weights + kv_cache in sharding manager", logger=logger) + + log_gpu_memory_usage("Before state_dict() in sharding manager memory", logger=logger) + if self.offload_param: + load_fsdp_model_to_gpu(self.module) + params = self.module.state_dict() + log_gpu_memory_usage("After state_dict() in sharding manager memory", logger=logger) + device = get_device_id() # used when fsdp2 set cpu_offload_policy + params = { + k: v.to(device, non_blocking=True) if fsdp_version(self.module) == 2 else v for k, v in params.items() + } + + # convert weight keys to match the model config + params = convert_weight_keys(params, getattr(self.module, "_fsdp_wrapped_module", self.module)) + + # Copy, not share memory + await self.update_weights(params) + log_gpu_memory_usage("After sync model weights in sharding manager", logger=logger) + + del params + if self.offload_param: + offload_fsdp_model_to_cpu(self.module) + get_torch_device().empty_cache() + log_gpu_memory_usage("After del state_dict and empty_cache in sharding manager", logger=logger) + + if ( + self.multi_stage_wake_up + and self.rollout_config.free_cache_engine + and self.device_mesh["infer_tp"].get_local_rank() == 0 + ): + await self.inference_engine.resume_memory_occupation(tags=["kv_cache"]) + log_gpu_memory_usage("After resume SGLang kv_cache in sharding manager", logger=logger) + + # important: need to manually set the random states of each tp to be identical. + if self.device_mesh is not None: + self.torch_random_states = get_torch_device().get_rng_state() + get_torch_device().set_rng_state(self.gen_random_states) + + @GPUMemoryLogger(role="FSDPSGLangShardingManager exit", logger=logger) + async def sleep(self): + if self.rollout_config.free_cache_engine: + log_gpu_memory_usage("Before SGLang offload in sharding manager", logger=logger) + await self.release_memory() + log_gpu_memory_usage("After SGLang offload in sharding manager", logger=logger) + + self.module.train() + + # add empty cache after each compute + get_torch_device().empty_cache() + + # restore random states + if self.device_mesh is not None: + self.gen_random_states = get_torch_device().get_rng_state() + get_torch_device().set_rng_state(self.torch_random_states) def preprocess_data(self, data: DataProto) -> DataProto: """All gather across tp group to make each rank has identical input.""" diff --git a/verl/workers/sharding_manager/fsdp_vllm.py b/verl/workers/sharding_manager/fsdp_vllm.py index ec589933f..1a9677df5 100644 --- a/verl/workers/sharding_manager/fsdp_vllm.py +++ b/verl/workers/sharding_manager/fsdp_vllm.py @@ -18,7 +18,6 @@ import time from collections import OrderedDict -import torch from torch.distributed.device_mesh import DeviceMesh from torch.distributed.fsdp.api import FullStateDictConfig, ShardedStateDictConfig, StateDictType from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP @@ -33,12 +32,18 @@ from verl import DataProto from verl.protocol import all_gather_data_proto -from verl.third_party.vllm import LLM, vllm_version +from verl.third_party.vllm import LLM from verl.third_party.vllm import parallel_state as vllm_ps -from verl.utils.debug import GPUMemoryLogger, log_gpu_memory_usage -from verl.utils.device import get_torch_device -from verl.utils.fsdp_utils import fsdp_version, layered_summon_lora_params, load_fsdp_model_to_gpu, offload_fsdp_model_to_cpu -from verl.utils.torch_functional import check_cuda_is_available +from verl.utils.device import get_device_id, get_device_name, get_torch_device +from verl.utils.fsdp_utils import ( + fsdp_version, + layered_summon_lora_params, + load_fsdp_model_to_gpu, + offload_fsdp_model_to_cpu, +) +from verl.utils.model import check_exclude_modules, check_target_modules, convert_weight_keys +from verl.utils.profiler import GPUMemoryLogger, log_gpu_memory_usage, simple_timer +from verl.utils.torch_functional import check_device_is_available from verl.utils.vllm_utils import TensorLoRARequest, VLLMHijack, is_version_ge, patch_vllm_moe_model_weight_loader from .base import BaseShardingManager @@ -48,21 +53,40 @@ class FSDPVLLMShardingManager(BaseShardingManager): - @check_cuda_is_available() - def __init__(self, module: FSDP, inference_engine: LLM, model_config, full_params: bool = False, device_mesh: DeviceMesh = None, offload_param: bool = False, load_format: str = "dummy_hf", layered_summon: bool = True): + """Sharding manager for FSDP models with vLLM inference engine integration. + + Manages parameter synchronization between FSDP training models and vLLM + inference engines, handling both full parameters and LoRA adapters with + efficient memory management and device placement. + """ + + @check_device_is_available() + def __init__( + self, + module: FSDP, + inference_engine: LLM, + model_config, + rollout_config, + full_params: bool = False, + device_mesh: DeviceMesh = None, + offload_param: bool = False, + load_format: str = "dummy_hf", + layered_summon: bool = True, + ): self.module = module - # For AsyncLLM, inference_engine and model_runner are defer intialized in vLLMAsyncRollout.load_model + # For AsyncLLM, inference_engine and model_runner are defer initialized in vLLMAsyncRollout.load_model self.inference_engine = inference_engine - # self.model_runner = inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner if inference_engine else None + # self.model_runner = inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner if + # inference_engine else None - if "vllm_v_0_6_3" in str(type(self.inference_engine)) or "vllm_v_0_5_4" in str(type(self.inference_engine)): - # vLLM <= v0.6.3 - self.model_runner = self.inference_engine.llm_engine.model_executor.worker.model_runner if self.inference_engine else None - else: - # vLLM > v0.6.3 - self.model_runner = self.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner if self.inference_engine else None + self.model_runner = ( + self.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner + if self.inference_engine + else None + ) self.model_config = model_config + self.rollout_config = rollout_config self.device_mesh = device_mesh self.offload_param = offload_param self.load_format = load_format @@ -71,7 +95,9 @@ def __init__(self, module: FSDP, inference_engine: LLM, model_config, full_param # Full params self.full_params = full_params if full_params and fsdp_version(self.module) == 1: - FSDP.set_state_dict_type(self.module, state_dict_type=StateDictType.FULL_STATE_DICT, state_dict_config=FullStateDictConfig()) + FSDP.set_state_dict_type( + self.module, state_dict_type=StateDictType.FULL_STATE_DICT, state_dict_config=FullStateDictConfig() + ) elif fsdp_version(self.module) == 1: FSDP.set_state_dict_type( self.module, @@ -111,30 +137,42 @@ def __collect_lora_params() -> OrderedDict: if fsdp_version(self.module) > 0: if self.layered_summon: if not self.base_sync_done: - raise ValueError("To use layered_summon, you must make sure base-model is preloaded in vllm, e.g. let rollout.load_format=safetensors") + raise ValueError( + "To use layered_summon, you must make sure base-model is preloaded in vllm, e.g. let " + "rollout.load_format=safetensors" + ) lora_params = layered_summon_lora_params(self.module) else: with FSDP.summon_full_params(self.module, writeback=False): if self.base_sync_done: lora_params = get_peft_model_state_dict(peft_model) - lora_params = {name: param.full_tensor().detach().cpu() if hasattr(param, "full_tensor") else param.detach().cpu() for name, param in lora_params.items()} + lora_params = { + name: param.full_tensor().detach().cpu() + if hasattr(param, "full_tensor") + else param.detach().cpu() + for name, param in lora_params.items() + } else: model = peft_model.base_model.model - orig_dev = "cpu" if "cpu" in next(model.parameters()).device else "cuda" + orig_dev = "cpu" if "cpu" in str(next(model.parameters()).device) else get_device_name() model = model.to("cpu") for name, param in model.state_dict().items(): if any(x in name for x in ["_flat_param", "lora_"]): continue name = name.replace("_fsdp_wrapped_module.", "").replace(".base_layer", "") - lora_params[name] = param.full_tensor().detach().cpu() if hasattr(param, "full_tensor") else param.detach().cpu() + lora_params[name] = ( + param.full_tensor().detach().cpu() + if hasattr(param, "full_tensor") + else param.detach().cpu() + ) model = model.to(orig_dev) - torch.cuda.empty_cache() + get_torch_device().empty_cache() else: if self.base_sync_done: lora_params = get_peft_model_state_dict(peft_model) else: model = peft_model.base_model.model - orig_dev = "cpu" if "cpu" in next(model.parameters()).device else "cuda" + orig_dev = "cpu" if "cpu" in str(next(model.parameters()).device) else get_device_name() model = model.to("cpu") for name, param in model.state_dict().items(): if any(x in name for x in ["_flat_param", "lora_"]): @@ -151,36 +189,29 @@ def __collect_lora_params() -> OrderedDict: # # pytorch: https://pytorch.org/docs/stable/notes/cuda.html#memory-management # vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/device_allocator/cumem.py#L103 - get_torch_device().empty_cache() - - log_gpu_memory_usage("Before state_dict() in sharding manager memory", logger=logger) - if self.offload_param: - load_fsdp_model_to_gpu(self.module) - - peft_config = None - peft_model = getattr(self.module, "_fsdp_wrapped_module", self.module) - if hasattr(peft_model, "peft_config"): - peft_config = peft_model.peft_config.get("default", None) - params = __collect_lora_params() - else: - params = self.module.state_dict() - log_gpu_memory_usage("After state_dict() in sharding manager memory", logger=logger) + self.timing = {} + with simple_timer("reshard", self.timing): + get_torch_device().empty_cache() - # Copy, not share memory - load_format = "hf" if self.full_params else "dtensor" + log_gpu_memory_usage("Before state_dict() in sharding manager memory", logger=logger) + if self.offload_param: + load_fsdp_model_to_gpu(self.module) - if vllm_version in ( - "0.5.4", - "0.6.3", - ): - self.inference_engine.sync_model_weights(params, load_format=load_format) - log_gpu_memory_usage("After sync model weights in sharding manager", logger=logger) - del params - else: - if "tags" in inspect.signature(self.inference_engine.wake_up).parameters: - self.inference_engine.wake_up(tags=["weights"]) + peft_config = None + peft_model = getattr(self.module, "_fsdp_wrapped_module", self.module) + if hasattr(peft_model, "peft_config"): + peft_config = peft_model.peft_config.get("default", None) + params = __collect_lora_params() else: - self.inference_engine.wake_up() + params = self.module.state_dict() + params = convert_weight_keys(params, getattr(self.module, "_fsdp_wrapped_module", self.module)) + log_gpu_memory_usage("After state_dict() in sharding manager memory", logger=logger) + + if self.rollout_config.free_cache_engine: + if "tags" in inspect.signature(self.inference_engine.wake_up).parameters: + self.inference_engine.wake_up(tags=["weights"]) + else: + self.inference_engine.wake_up() # update model params self.update_params(params, peft_config=peft_config) @@ -190,25 +221,22 @@ def __collect_lora_params() -> OrderedDict: offload_fsdp_model_to_cpu(self.module) get_torch_device().empty_cache() - if "tags" in inspect.signature(self.inference_engine.wake_up).parameters: + if ( + self.rollout_config.free_cache_engine + and "tags" in inspect.signature(self.inference_engine.wake_up).parameters + ): self.inference_engine.wake_up(tags=["kv_cache"]) - log_gpu_memory_usage("After del state_dict and empty_cache in sharding manager", logger=logger) + log_gpu_memory_usage("After del state_dict and empty_cache in sharding manager", logger=logger) - # important: need to manually set the random states of each tp to be identical. - if self.device_mesh is not None: - self.torch_random_states = get_torch_device().get_rng_state() - get_torch_device().set_rng_state(self.gen_random_states) + # important: need to manually set the random states of each tp to be identical. + if self.device_mesh is not None: + self.torch_random_states = get_torch_device().get_rng_state() + get_torch_device().set_rng_state(self.gen_random_states) @GPUMemoryLogger(role="fsdp vllm sharding_manager", logger=logger) def __exit__(self, exc_type, exc_value, traceback): - # TODO(ZSL): check this - if vllm_version in ( - "0.5.4", - "0.6.3", - ): - self.inference_engine.offload_model_weights() - else: + if self.rollout_config.free_cache_engine: self.inference_engine.sleep(level=1) self.module.train() @@ -228,13 +256,7 @@ def preprocess_data(self, data: DataProto) -> DataProto: return data # TODO: Current impl doesn't consider FSDP with torch micro-dp - if vllm_version in ( - "0.5.4", - "0.6.3", - ): - group = vllm_ps.get_tensor_model_parallel_group() - else: - group = vllm_ps.get_tensor_model_parallel_group().device_group + group = vllm_ps.get_tensor_model_parallel_group().device_group all_gather_data_proto(data=data, process_group=group) return data @@ -248,6 +270,16 @@ def postprocess_data(self, data: DataProto) -> DataProto: return data.chunk(chunks=self.tp_size)[self.tp_rank] def update_params(self, updated_params, peft_config=None): + """Update model parameters in the vLLM inference engine. + + Synchronizes parameters from the FSDP training model to the vLLM inference + engine, handling both full model parameters and LoRA adapters with proper + device placement and memory management. + + Args: + updated_params (dict): Dictionary of parameter names to tensor values. + peft_config (optional): PEFT configuration for LoRA adapters. + """ model = self.model_runner.model if peft_config: if self.base_sync_done: @@ -265,18 +297,46 @@ def update_params(self, updated_params, peft_config=None): else: def replace_lora_wrapper(k): + """Replace LoRA parameter keys with base layer equivalents. + + Transforms LoRA parameter names to their corresponding base layer + names for proper weight loading in vLLM when base model sync is not done. + + Args: + k (str): Original parameter key name. + + Returns: + str: Transformed parameter key for base layer. + """ stacked_params = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] - if any([k.endswith(f"{s}.weight") for s in stacked_params]): - return k.replace(".weight", ".base_layer.weight") - if any([k.endswith(f"{s}.bias") for s in stacked_params]): - return k.replace(".bias", ".base_layer.bias") + if k.endswith(".weight"): + module_k = k[: -len(".weight")] + if check_exclude_modules(peft_config, module_k): + return k + elif any([module_k.endswith(s) for s in stacked_params]) or check_target_modules( + peft_config, module_k + ): + return f"{module_k}.base_layer.weight" + if k.endswith(".bias"): + module_k = k[: -len(".bias")] + if check_exclude_modules(peft_config, module_k): + return k + elif any([module_k.endswith(s) for s in stacked_params]) or check_target_modules( + peft_config, module_k + ): + return f"{module_k}.base_layer.bias" return k updated_params = {replace_lora_wrapper(k): v for k, v in updated_params.items()} patch_vllm_moe_model_weight_loader(model) - device = get_torch_device().current_device() # used when fsdp2 set cpu_offload_policy - loaded_params = model.load_weights(((name, param.to(device, non_blocking=True).full_tensor() if isinstance(param, DTensor) else param) for name, param in updated_params.items())) + device = get_device_id() # used when fsdp2 set cpu_offload_policy + loaded_params = model.load_weights( + ( + (name, param.to(device, non_blocking=True).full_tensor() if isinstance(param, DTensor) else param) + for name, param in updated_params.items() + ) + ) self.base_sync_done = True logger.info(f"vLLM load weights, loaded_params: {len(loaded_params) if loaded_params else -1}") diff --git a/verl/workers/sharding_manager/megatron_sglang.py b/verl/workers/sharding_manager/megatron_sglang.py index 4e047e212..d353c70e8 100644 --- a/verl/workers/sharding_manager/megatron_sglang.py +++ b/verl/workers/sharding_manager/megatron_sglang.py @@ -17,17 +17,27 @@ This file contains a Megatron style Hybrid Engine that shares the weights of the actor with the inference engine. """ +import asyncio import logging import os -import torch +import torch.distributed as dist +from omegaconf import DictConfig from sglang.srt.entrypoints.engine import Engine +from sglang.srt.model_executor.model_runner import LocalSerializedTensor +from sglang.srt.utils import MultiprocessingSerializer from torch import nn from torch.distributed.device_mesh import DeviceMesh from verl.protocol import DataProto, all_gather_data_proto -from verl.utils.debug import GPUMemoryLogger, log_gpu_memory_usage -from verl.utils.megatron_utils import per_tensor_generator +from verl.utils.device import get_torch_device +from verl.utils.megatron_utils import ( + load_megatron_model_to_gpu, + offload_megatron_model_to_cpu, + per_tensor_generator, +) +from verl.utils.profiler import GPUMemoryLogger, log_gpu_memory_usage, simple_timer +from verl.workers.rollout.sglang_rollout.utils import get_named_tensor_buckets from .base import BaseShardingManager @@ -38,7 +48,8 @@ """ Megatron Hybrid Engine: - During training, only the current pp stage holds the parameters -- Before inference, broadcast the parameters of the current pp rank to all other pp ranks (all pp ranks holds all the parameters) +- Before inference, broadcast the parameters of the current pp rank to all other pp ranks (all pp ranks holds all + the parameters) - Bind the parameters to the inference engine - Do inference in tp. pp is treated as additional dp - After inference, all the parameters that doesn't belong to this pp rank is freed. @@ -46,23 +57,50 @@ class MegatronSGLangShardingManager(BaseShardingManager): + """A sharding manager for Megatron-style training & inference with SGLang. + + This class manages the sharding of model parameters between training and inference + phases in a Megatron-style parallel setup. It handles: + - Loading/offloading parameters between CPU/GPU + - Updating inference engine weights + - Managing random states for reproducibility + - Data preprocessing for distributed inference + + Args: + actor_module (nn.ModuleList): The actor model modules + inference_engine (Engine): The SGLang inference engine + model_config: Configuration for the actor's model + rollout_config: Configuration for rollout generation + transformer_config: Transformer-specific configuration + layer_name_mapping: Mapping between layer names and parameters + weight_converter: Utility for converting weights between formats + device_mesh (DeviceMesh | None): PyTorch device mesh for distributed training + offload_param (bool): Whether to offload parameters to CPU when not in use + """ + def __init__( self, actor_module: nn.ModuleList, inference_engine: Engine, - model_config, + model_config: DictConfig, + rollout_config: DictConfig, transformer_config, layer_name_mapping, weight_converter, device_mesh: DeviceMesh | None = None, + offload_param: bool = False, + bridge=None, ): self.actor_module = actor_module self.inference_engine = inference_engine self.model_config = model_config + self.rollout_config = rollout_config self.transformer_config = transformer_config self.layer_name_mapping = layer_name_mapping self.weight_converter = weight_converter self.device_mesh = device_mesh + self.bridge = bridge + self.offload_param = offload_param if self.device_mesh is not None: self.infer_tp_size = self.device_mesh["tp"].mesh.size()[0] @@ -70,75 +108,143 @@ def __init__( self.infer_tp_size = self.inference_engine._tp_size # Note that torch_random_states may be different on each dp rank - self.torch_random_states = torch.cuda.get_rng_state() + self.torch_random_states = get_torch_device().get_rng_state() # get a random rng states if self.device_mesh is not None: gen_dp_rank = self.device_mesh["dp"].get_local_rank() - torch.cuda.manual_seed(gen_dp_rank + 1000) # make sure all tp ranks have the same random states - self.gen_random_states = torch.cuda.get_rng_state() - torch.cuda.set_rng_state(self.torch_random_states) + get_torch_device().manual_seed(gen_dp_rank + 1000) # make sure all tp ranks have the same random states + self.gen_random_states = get_torch_device().get_rng_state() + get_torch_device().set_rng_state(self.torch_random_states) else: self.gen_random_states = None @GPUMemoryLogger(role="MegatronSGLangShardingManager enter", logger=logger) def __enter__(self): - per_tensor_param = per_tensor_generator( - self.actor_module, - self.model_config, - self.weight_converter, - self.transformer_config, - self.layer_name_mapping, - ) - self.update_weights(per_tensor_param) - - # important: need to manually set the random states of each tp to be identical. - if self.device_mesh is not None: - self.torch_random_states = torch.cuda.get_rng_state() - torch.cuda.set_rng_state(self.gen_random_states) + self.timing = {} + with simple_timer("reshard", self.timing): + loop = asyncio.get_event_loop() + loop.run_until_complete(self.wake_up()) @GPUMemoryLogger(role="MegatronSGLangShardingManager exit", logger=logger) def __exit__(self, exc_type, exc_value, traceback): - log_gpu_memory_usage("Before SGLang offload in sharding manager", logger=logger) - self.release_memory() - log_gpu_memory_usage("After SGLang offload in sharding manager", logger=logger) - - for model in self.actor_module: - model.train() - # add empty cache after each compute - torch.cuda.empty_cache() + loop = asyncio.get_event_loop() + loop.run_until_complete(self.sleep()) - # restore random states - if self.device_mesh is not None: - self.gen_random_states = torch.cuda.get_rng_state() - torch.cuda.set_rng_state(self.torch_random_states) - - def update_weights(self, params): - if self.device_mesh["tp"].get_local_rank() == 0: - self.inference_engine.resume_memory_occupation() + async def update_weights(self, params): + """ + Update model weights using tensor buckets, similar to THUDM/slime's implementation. - # Most naive implementation, can optimize a lot if it is bottleneck from sglang Engine weight update - # named_tensors = [(k, v) for k, v in params.items()] + Notes: + - For the best performance of `rebuild_cuda_tensor`, it is recommended to: + 1. Enable `RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES`. + 2. Manually set `CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7` + when using Tensor Parallelism (TP >= 8). + - See reference implementations in SLIME: + - Main logic: https://github.com/THUDM/slime/blob/fb7605cc5fb09af0f9369d37f7192f12bddee577/slime/ray/ppo_actor.py#L452 + - runtime envs: https://github.com/THUDM/slime/blob/fb7605cc5fb09af0f9369d37f7192f12bddee577/slime/ray/ppo_actor.py#L39 + """ + if self.device_mesh["tp"].get_local_rank() == 0 and self.rollout_config.free_cache_engine: + await self.inference_engine.resume_memory_occupation() named_tensors = params load_format = None - for tensor_index, (name, tensor) in enumerate(named_tensors): + + update_weights_bucket_bytes = int(self.rollout_config.update_weights_bucket_megabytes) << 20 + for batch in get_named_tensor_buckets(named_tensors, update_weights_bucket_bytes): + # On each rank, serialize a batch of (name, tensor) tuples. + # named_tensors_batch will be a list like: + # [(name0, serialized_tensor0_tp0), (name1, serialized_tensor1_tp0), ...] + named_tensors_batch = [ + (name, MultiprocessingSerializer.serialize(tensor.detach())) for name, tensor in batch + ] + + if self.device_mesh["tp"].get_local_rank() == 0: + # On rank 0, prepare a list to hold the gathered batches from all ranks. + gathered_serialized_batches = [None for _ in range(self.device_mesh["tp"].mesh.size()[0])] + else: + gathered_serialized_batches = None + + # Gather the named_tensors_batch from all ranks to rank 0. + # After this, on rank 0, gathered_serialized_batches will be a list of lists: + # [ [ (name0, s_t0_tp0), (name1, s_t1_tp0), ... ], # batch from TP rank 0 + # [ (name0, s_t0_tp1), (name1, s_t1_tp1), ... ], # batch from TP rank 1 + # ... ] + # On other ranks, gathered_serialized_batches will be None. + dist.gather_object( + obj=named_tensors_batch, + object_gather_list=gathered_serialized_batches, + dst=self.device_mesh["tp"].mesh.tolist()[0], + group=self.device_mesh["tp"].get_group(), + ) + if self.device_mesh["tp"].get_local_rank() == 0: - self.inference_engine.update_weights_from_tensor( + # Use zip(*) to "transpose" the data structure. + # This groups the serialized parts for each individual tensor across all TP ranks. + # Example: from [[(n0, t0_tp0), (n1, t1_tp0)], [(n0, t0_tp1), (n1, t1_tp1)]] + # to [ ( (n0, t0_tp0), (n0, t0_tp1) ), ( (n1, t1_tp0), (n1, t1_tp1) ) ] + logical_tensors = zip(*gathered_serialized_batches, strict=False) + await self.inference_engine.update_weights_from_tensor( named_tensors=[ + # 'tensor_group' represents a single logical tensor's data from all ranks. ( - name, - tensor.detach(), + tensor_group[0][0], # Get the name from the first rank's data. + LocalSerializedTensor( + # 'rank_part' is the (name, serialized_tensor) tuple from one specific rank. + values=[rank_part[1] for rank_part in tensor_group] + ), ) + for tensor_group in logical_tensors + # each tensor_group is like ( (n0, t0_tp0), (n0, t0_tp1) ) ], load_format=load_format, flush_cache=False, ) - if self.device_mesh["tp"].get_local_rank() == 0: - self.inference_engine.flush_cache() - - def release_memory(self): if self.device_mesh["tp"].get_local_rank() == 0: - self.inference_engine.release_memory_occupation() + await self.inference_engine.flush_cache() + + async def release_memory(self): + if self.device_mesh["tp"].get_local_rank() == 0 and self.rollout_config.free_cache_engine: + await self.inference_engine.release_memory_occupation() + + @GPUMemoryLogger(role="MegatronSGLangShardingManager enter", logger=logger) + async def wake_up(self): + if self.offload_param: + load_megatron_model_to_gpu(self.actor_module) + if self.bridge is not None: + per_tensor_param = self.bridge.export_weights(self.actor_module) + else: + per_tensor_param = per_tensor_generator( + self.actor_module, + self.model_config, + self.weight_converter, + self.transformer_config, + self.layer_name_mapping, + ) + await self.update_weights(per_tensor_param) + if self.offload_param: + offload_megatron_model_to_cpu(self.actor_module) + get_torch_device().empty_cache() + # important: need to manually set the random states of each tp to be identical. + if self.device_mesh is not None: + self.torch_random_states = get_torch_device().get_rng_state() + get_torch_device().set_rng_state(self.gen_random_states) + + @GPUMemoryLogger(role="MegatronSGLangShardingManager exit", logger=logger) + async def sleep(self): + if self.rollout_config.free_cache_engine: + log_gpu_memory_usage("Before SGLang offload in sharding manager", logger=logger) + await self.release_memory() + log_gpu_memory_usage("After SGLang offload in sharding manager", logger=logger) + + for model in self.actor_module: + model.train() + # add empty cache after each compute + get_torch_device().empty_cache() + + # restore random states + if self.device_mesh is not None: + self.gen_random_states = get_torch_device().get_rng_state() + get_torch_device().set_rng_state(self.torch_random_states) @GPUMemoryLogger(role="megatron sglang sharding_manager", logger=logger) def preprocess_data(self, data: DataProto) -> DataProto: diff --git a/verl/workers/sharding_manager/megatron_vllm.py b/verl/workers/sharding_manager/megatron_vllm.py index 5179a6137..b04352c24 100644 --- a/verl/workers/sharding_manager/megatron_vllm.py +++ b/verl/workers/sharding_manager/megatron_vllm.py @@ -21,30 +21,20 @@ import torch import torch.distributed -import torch.distributed as dist -from megatron.core import DistributedDataParallel as LocalDDP from megatron.core import parallel_state as mpu -from megatron.core.transformer.module import Float16Module +from omegaconf import DictConfig from torch import nn -from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP from verl import DataProto from verl.models.mcore.weight_converter import McoreToHFWeightConverterBase from verl.protocol import all_gather_data_proto -from verl.third_party.vllm import LLM, vllm_version +from verl.third_party.vllm import LLM from verl.third_party.vllm import parallel_state as vllm_ps -from verl.utils.debug import GPUMemoryLogger -from verl.utils.megatron_utils import ( - get_model, - per_tensor_generator, - unwrap_model, -) -from verl.utils.memory_buffer import ( - build_memory_buffer, - build_memory_reference_from_module, - get_weight_buffer_meta_from_module, -) -from verl.utils.torch_functional import check_cuda_is_available +from verl.utils.device import get_torch_device +from verl.utils.megatron_utils import load_megatron_model_to_gpu, offload_megatron_model_to_cpu, per_tensor_generator +from verl.utils.profiler import GPUMemoryLogger, log_gpu_memory_usage +from verl.utils.profiler.performance import simple_timer +from verl.utils.torch_functional import check_device_is_available from verl.utils.vllm_utils import patch_vllm_moe_model_weight_loader from .base import BaseShardingManager @@ -53,198 +43,6 @@ logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) -class AllGatherPPModel: - def __init__(self, model_provider, use_distributed_optimizer=True) -> None: - print( - "[WARNING] This class is deprecated and will no longer be supported. \ -Consider using the `MegatronPPOActor` class directly as a replacement." - ) - self._pp_group = mpu.get_pipeline_model_parallel_group() - self._pp_rank = mpu.get_pipeline_model_parallel_rank() - self._pp_size = mpu.get_pipeline_model_parallel_world_size() - self._vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size() - self._model_chunk_size = self._vpp_size or 1 - - # each one holds a list of model_chunks in this pp stage - self._pp_models = [None] * self.pp_size - - rank_list = list(range(self.pp_size)) - # make current rank the last one to initialize - rank_list[self.pp_rank], rank_list[-1] = rank_list[-1], rank_list[self.pp_rank] - self._this_rank_models = None - - # store the parameter of each pp stage - self.memory_buffers = [None] * self.pp_size - for cur_pp_rank in rank_list: - print( - "create pp model", - f"torch allocated {torch.cuda.memory_allocated() / 1e9:.4f} GB, reserved {torch.cuda.memory_reserved() / 1e9:.4f} GB", - ) - # since the last initialized rank is the current pp rank, after init, the pp rank is still correct - mpu.set_pipeline_model_parallel_rank(cur_pp_rank) - if cur_pp_rank != self.pp_rank: - models = get_model(model_provider, wrap_with_ddp=False, use_distributed_optimizer=False) - models = nn.ModuleList(models) - assert len(models) == self._model_chunk_size, f"{len(models)} != {self._model_chunk_size}" - self.pp_models[cur_pp_rank] = models - else: - # for regular model, we wrapped it with DDP - models = get_model(model_provider, wrap_with_ddp=True, use_distributed_optimizer=use_distributed_optimizer) - assert len(models) == self._model_chunk_size, f"{len(models)} != {self._model_chunk_size}" - self._this_rank_models = nn.ModuleList(models) - self.pp_models[cur_pp_rank] = nn.ModuleList(unwrap_model(models, (torchDDP, LocalDDP))) - - self._build_param_buffer(cur_pp_rank) - self._build_param_references(cur_pp_rank, maintain_weight=cur_pp_rank == self.pp_rank) - - # TODO: after binding to the memory buffer, we can load the checkpoint here - if cur_pp_rank != self.pp_rank: - for model in self.pp_models[cur_pp_rank]: - model.eval() - self._offload_params_to_cpu(cur_pp_rank) - - def _build_param_buffer(self, pp_rank): - """Build the parameter buffer in each pp rank""" - if pp_rank == self._pp_rank: - from verl.utils.memory_buffer import MemoryBuffer - - # The code here is very hard-coded, based on the following assumptions: - # 1. `len(_this_rank_models) == 1` - # 2. `_this_rank_models[0]` is a instance of `DistributedDataParallel` and `use_distributed_optimizer=True` - # 3. Only bfloat16 data type is used in parameters - source = self._this_rank_models[0].buffers[0].param_data - self.memory_buffers[pp_rank] = {torch.bfloat16: MemoryBuffer(source.numel(), source.numel(), torch.bfloat16, source)} - else: - model = self.pp_models[pp_rank] - weight_buffer_meta = get_weight_buffer_meta_from_module(model) - self.memory_buffers[pp_rank] = build_memory_buffer(weight_buffer_meta) - - def _build_param_references(self, pp_rank, maintain_weight=False): - if pp_rank == self._pp_rank: - return - model = self.pp_models[pp_rank] - build_memory_reference_from_module(model, self.memory_buffers[pp_rank], maintain_weight=maintain_weight) - - def _load_params_to_cuda(self, pp_rank, to_empty=False): - assert pp_rank != self.pp_rank, f"unexpected to load current pp rank [{pp_rank}] back to cuda" - for buffer in self.memory_buffers[pp_rank].values(): - if not to_empty: - buffer.data = buffer.data.to(torch.cuda.current_device(), non_blocking=True) - else: - buffer.data = torch.empty_like(buffer.data, device="cuda") - # rebuild reference after loading to CUDA - self._build_param_references(pp_rank) - - def _offload_params_to_cpu(self, pp_rank, to_empty=False): - assert pp_rank != self.pp_rank, f"unexpected to offload current pp rank [{pp_rank}] to cpu" - for buffer in self.memory_buffers[pp_rank].values(): - if not to_empty: - # offload the whole memory buffer to CPU - buffer.data = buffer.data.to("cpu", non_blocking=True) - else: - buffer.data = torch.empty_like(buffer.data, device="cpu") - self._build_param_references(pp_rank) - - def load_params_to_cuda(self, to_empty=False): - """load all model params to cuda""" - for cur_pp_rank in range(self.pp_size): - if cur_pp_rank != self.pp_rank: - self._load_params_to_cuda(cur_pp_rank, to_empty=to_empty) - - def allgather_params(self): - """allgather params of all pp ranks. Return a list of handles""" - for cur_pp_rank in range(self.pp_size): - global_src = dist.get_global_rank(group=self.pp_group, group_rank=cur_pp_rank) - - # NOTE(sgm): the async op may cause memory leakage of the memory_buffer/pp_models - - for _, param in sorted(self.pp_models[cur_pp_rank].named_parameters()): - dist.broadcast(tensor=param.data, src=global_src, group=self.pp_group, async_op=False) - - def forward(self, *inputs, **kwargs): - try: - prev_output = None - for cur_chunk_rank in range(self._model_chunk_size): - if self._vpp_size: - mpu.set_virtual_pipeline_model_parallel_rank(cur_chunk_rank) - - for cur_pp_rank in range(self.pp_size): - mpu.set_pipeline_model_parallel_rank(cur_pp_rank) - self.pp_models[cur_pp_rank][cur_chunk_rank].set_input_tensor(prev_output) - ret = self.pp_models[cur_pp_rank][cur_chunk_rank](*inputs, **kwargs) - self.pp_models[cur_pp_rank][cur_chunk_rank].set_input_tensor(None) - prev_output = ret - finally: - if self._vpp_size: - mpu.set_virtual_pipeline_model_parallel_rank(0) - mpu.set_pipeline_model_parallel_rank(self.pp_rank) - return ret - - def __call__(self, *inputs, **kwargs): - return self.forward(*inputs, **kwargs) - - def eval(self): - for model in self.pp_models[self.pp_rank]: - model.eval() - - def train(self): - for model in self.pp_models[self.pp_rank]: - model.train() - - def offload_params_to_cpu(self, to_empty=False): - """offload params of models that are not of current pp rank to cpu""" - for cur_pp_rank in range(self.pp_size): - if cur_pp_rank != self.pp_rank: - self._offload_params_to_cpu(cur_pp_rank, to_empty=to_empty) - - def get_all_params(self): - """Get all the parameters of the models in all pp ranks - - Returns: - params: List[List[Dict[str, Tensor]]]: a list of parameters in all pp, where each is a list of dict - tensors of each model chunk - - """ - params = [] - for pp_rank in range(self.pp_size): - params.append([]) - for model_chunk_idx in range(len(self.pp_models[pp_rank])): - params[pp_rank].append({}) - pp_model = self.pp_models[pp_rank][model_chunk_idx] - pp_model = unwrap_model(pp_model, ((torchDDP, LocalDDP, Float16Module))) # not use Float16Module - for name, param in pp_model.named_parameters(): - # NOTE(gh) workaround: should not get lora params for inference - if "lora" in name: - continue - params[pp_rank][model_chunk_idx][name] = param - - return params - - def update_this_rank_models(self, new_models): - self._this_rank_models = new_models - self._pp_models[self.pp_rank] = unwrap_model(new_models, (torchDDP, LocalDDP)) - - @property - def this_rank_models(self): - return self._this_rank_models - - @property - def pp_size(self): - return self._pp_size - - @property - def pp_rank(self): - return self._pp_rank - - @property - def pp_group(self): - return self._pp_group - - @property - def pp_models(self): - return self._pp_models - - """ Megatron Hybrid Engine: - During training, only the current pp stage holds the parameters @@ -256,42 +54,70 @@ def pp_models(self): """ -# Micro Data parallel group. Micro data parallel group is additional dp group that origins from splitting training tp -# into infer_tp and micro_tp. By default, we use order micro_dp - tp -# NOTICE: in new version of vLLM, We need to all-gather all tp rank's model weights -# For code reuse, we directly assign Megatron's TENSOR_MODEL_PARALLEL_GROUP to this -_MICRO_DATA_PARALLEL_GROUP = None - - class MegatronVLLMShardingManager(BaseShardingManager): - @check_cuda_is_available() + """A sharding manager that bridges Megatron-LM training with vLLM inference. + + This class handles the parameter sharding and communication between: + - Megatron-LM's tensor/expert parallel training setup + - vLLM's tensor parallel inference setup + + Key responsibilities: + - Manages parameter broadcasting between training and inference configurations + - Handles weight conversion between Megatron and HuggingFace formats + - Coordinates memory management between training and inference phases + - Maintains random state consistency across different parallel groups + + Args: + actor_module (nn.ModuleList): The Megatron-LM model being trained + inference_engine (LLM): The vLLM inference engine + model_config: Configuration for the actor's model + transformer_config: Transformer-specific configuration for the model + rollout_config: Configuration for rollout + layer_name_mapping: Mapping between Megatron and HF layer names + weight_converter (McoreToHFWeightConverterBase): Converts weights between formats + device_mesh: Device mesh for parallel operations + offload_param (bool): Whether to offload parameters when not in use + """ + + @check_device_is_available() def __init__( self, actor_module: nn.ModuleList, inference_engine: LLM, - model_config, + model_config: DictConfig, transformer_config, + rollout_config: DictConfig, layer_name_mapping, weight_converter: McoreToHFWeightConverterBase, - module: AllGatherPPModel = None, + device_mesh, + offload_param: bool = True, + bridge=None, ): - from megatron.core import parallel_state as mpu - self.actor_module = actor_module self.inference_engine = inference_engine + self.offload_param = offload_param + + # For AsyncLLM, inference_engine and model_runner are defer initialized in vLLMAsyncRollout.load_model + self.model_runner = ( + self.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner + if self.inference_engine + else None + ) + self.model_config = model_config self.transformer_config = transformer_config + self.rollout_config = rollout_config self.layer_name_mapping = layer_name_mapping self.weight_converter = weight_converter - self.module = module + self.bridge = bridge # initialize groups for vllm inference self.rank = torch.distributed.get_rank() self.world_size = torch.distributed.get_world_size() - self.infer_tp_size = vllm_ps.get_tensor_model_parallel_world_size() - self.infer_tp_rank = vllm_ps.get_tensor_model_parallel_rank() - self.infer_tp_group = vllm_ps.get_tensor_model_parallel_group() - if vllm_version not in ("0.5.4", "0.6.3"): - self.infer_tp_group = self.infer_tp_group.device_group + + self.device_mesh = device_mesh + self.infer_tp_size = self.device_mesh["infer_tp"].size() + self.infer_tp_rank = self.device_mesh["infer_tp"].get_local_rank() + self.train_tp_size = mpu.get_tensor_model_parallel_world_size() self.train_tp_rank = mpu.get_tensor_model_parallel_rank() self.train_tp_group = mpu.get_tensor_model_parallel_group() @@ -304,57 +130,85 @@ def __init__( self.need_tp_reshard = self.train_tp_size != self.infer_tp_size self.train_tp_larger = self.train_tp_size > self.infer_tp_size + self.torch_random_states = get_torch_device().get_rng_state() + if self.device_mesh is not None: + gen_dp_rank = self.device_mesh["dp"].get_local_rank() + get_torch_device().manual_seed(gen_dp_rank + 1000) # make sure all tp ranks have the same random states + self.gen_random_states = get_torch_device().get_rng_state() + get_torch_device().set_rng_state(self.torch_random_states) + else: + self.gen_random_states = None + @GPUMemoryLogger(role="megatron vllm sharding_manager", logger=logger) def __enter__(self): - if vllm_version in ( - "0.5.4", - "0.6.3", - ): - per_tensor_param = per_tensor_generator(self.actor_module, self.model_config, self.weight_converter, self.transformer_config, self.layer_name_mapping, convert_qkv_gate_up_by_simple_split=False) - self.inference_engine.sync_model_weights(per_tensor_param, load_format="megatron") - else: - # > 0.7.2 - if "tags" in inspect.signature(self.inference_engine.wake_up).parameters: - self.inference_engine.wake_up(tags=["weights"]) + self.timing = {} + with simple_timer("reshard", self.timing): + get_torch_device().empty_cache() + + log_gpu_memory_usage("Before state_dict() in sharding manager memory", logger=logger) + if self.offload_param: + load_megatron_model_to_gpu(self.actor_module) + + if self.rollout_config.free_cache_engine: + if "tags" in inspect.signature(self.inference_engine.wake_up).parameters: + self.inference_engine.wake_up(tags=["weights"]) + else: + self.inference_engine.wake_up() + if self.bridge is not None: + per_tensor_param = self.bridge.export_weights(self.actor_module) else: - self.inference_engine.wake_up() - per_tensor_param = per_tensor_generator( - self.actor_module, - self.model_config, - self.weight_converter, - self.transformer_config, - self.layer_name_mapping, - ) - model = self.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model + per_tensor_param = per_tensor_generator( + self.actor_module, + self.model_config, + self.weight_converter, + self.transformer_config, + self.layer_name_mapping, + ) + model = self.model_runner.model patch_vllm_moe_model_weight_loader(model) loaded_params = model.load_weights(per_tensor_param) info = f"vLLM load weights, loaded_params: {len(loaded_params)}" logger.info(info) - # (vermouth1992) We move wake up kv cache after we release model weights. Need refactor to make API cleaner - # if "tags" in inspect.signature(self.inference_engine.wake_up).parameters: - # self.inference_engine.wake_up(tags=["kv_cache"]) + if self.offload_param: + offload_megatron_model_to_cpu(self.actor_module) + get_torch_device().empty_cache() + + if ( + self.rollout_config.free_cache_engine + and "tags" in inspect.signature(self.inference_engine.wake_up).parameters + ): + self.inference_engine.wake_up(tags=["kv_cache"]) + + # important: need to manually set the random states of each tp to be identical. + if self.device_mesh is not None: + self.torch_random_states = get_torch_device().get_rng_state() + get_torch_device().set_rng_state(self.gen_random_states) @GPUMemoryLogger(role="megatron vllm sharding_manager", logger=logger) def __exit__(self, exc_type, exc_value, traceback): - if vllm_version in ( - "0.5.4", - "0.6.3", - ): - self.inference_engine.offload_model_weights() - else: + if self.rollout_config.free_cache_engine: self.inference_engine.sleep(level=1) for model in self.actor_module: model.train() - torch.cuda.empty_cache() + get_torch_device().empty_cache() + + # restore random states + if self.device_mesh is not None: + self.gen_random_states = get_torch_device().get_rng_state() + get_torch_device().set_rng_state(self.torch_random_states) @GPUMemoryLogger(role="megatron vllm sharding_manager", logger=logger) def preprocess_data(self, data: DataProto) -> DataProto: # DP_COMPUTE_PROTO: all training ranks are dp, the same as fsdp if self.infer_tp_size == 1: return data - all_gather_data_proto(data, self.infer_tp_group) + + # TODO: Current impl doesn't consider FSDP with torch micro-dp + group = vllm_ps.get_tensor_model_parallel_group().device_group + + all_gather_data_proto(data=data, process_group=group) return data @GPUMemoryLogger(role="megatron vllm sharding_manager", logger=logger)